In [None]:
#@markdown <center><h1>Draw Inpaint Mask</h1></center>

input_image = "content/test.png" # @param {"type":"string"}
invert = False # @param {"type":"boolean"}

paint_interface = f"""
<style>
  @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css');

  body {{
    background-color: #2D2D2D;
    font-family: Arial, sans-serif;
    color: #FFFFFF;
    margin: 0;
    display: flex;
    justify-content: center;
    align-items: center;

  }}
  #toolbar {{
    display: flex;
    flex-wrap: wrap;
    justify-content: center;
    align-items: center;
    padding: 10px 0;
    margin: 10px auto;
    background-color: #3C3F41;
    border-radius: 8px;
    width: 100%;
    max-width: 512px;
    box-shadow: 0 4px 10px rgba(0, 0, 0, 0.3);
  }}
  .button {{
    padding: 10px 15px;
    margin: 5px;
    font-size: 14px;
    cursor: pointer;
    border-radius: 5px;
    color: #FFFFFF;
    background-color: #5E5E5E;
    border: none;
    font-weight: bold;
    transition: 0.3s;
  }}
  .button:hover {{
    background-color: #757575;
  }}
  .button.active {{
    background-color: #FFF7E0;
    color: #202124;
  }}
  .button:disabled{{
    cursor: default;
    background-color: #A9A9A9;
  }}
  .slider-container {{
    display: flex;
    align-items: center;
    margin: 10px 0;
  }}
  #brushSizeSlider {{
    margin-left: 10px;
    cursor: pointer;
    width: 200px;
  }}
  #canvasContainer {{
    margin: 20px auto;
    border: 2px solid #5E5E5E;
    border-radius: 8px;
    position: relative;
    display: inline-block;
    width: 512px;
    height: 512px;
  }}
  canvas {{
    display: block;
    cursor: crosshair;
  }}
  #imageOverlay {{
    position: absolute;
    top: 0;
    left: 0;
    pointer-events: none;
    opacity: 0.5;
    z-index: 0;
  }}
</style>
<div id="toolbar">
  <button class="button active" id="drawButton" onclick="setTool('draw')"><i class="fa fa-paint-brush"></i> Draw</button>
  <button class="button" id="eraseButton" onclick="setTool('erase')"><i class="fa fa-eraser"></i> Erase</button>
  <button class="button" id="clearButton" onclick="clearCanvas()"><i class="fa fa-trash"></i> Clear</button>
  <button class="button" id="saveButton" onclick="saveCanvas()"><i class="fa fa-check"></i> Save</button>
  <button class="button" id="undoButton" onclick="undo()" disabled><i class="fa fa-undo"></i></button>
  <button class="button" id="redoButton" onclick="redo()" disabled><i class="fa fa-rotate-right"></i></button>
  <div class="slider-container">
    <label for="brushSizeSlider">Brush Size: <span id="brushSizeDisplay">20</span></label>
    <input id="brushSizeSlider" type="range" min="1" max="100" value="20">
  </div>
</div>
<div id="canvasContainer">
  <img id="imageOverlay" />
  <canvas id="paintCanvas"></canvas>
</div>
<script>
  const canvas = document.getElementById('paintCanvas');
  const ctx = canvas.getContext('2d');
  const imageOverlay = document.getElementById('imageOverlay');
  const canvasContainer = document.getElementById('canvasContainer');
  const brushSizeSlider = document.getElementById('brushSizeSlider');
  const brushSizeDisplay = document.getElementById('brushSizeDisplay');
  let tool = 'draw';
  let drawing = false;
  let brushSize = 20;

  function highlightActiveButton(tool) {{
    document.querySelectorAll('.button').forEach(button => button.classList.remove('active'));
    if (tool === 'draw') {{
      document.getElementById('drawButton').classList.add('active');
    }} else if (tool === 'erase') {{
      document.getElementById('eraseButton').classList.add('active');
    }}
  }}

  const backgroundImage = new Image();
  backgroundImage.src = "/files/{input_image}";
  backgroundImage.onload = function() {{
    const aspectRatio = backgroundImage.width / backgroundImage.height;
    let width, height;

    if (backgroundImage.width > backgroundImage.height) {{
      height = 512;
      width = height * aspectRatio;
    }} else {{
      width = 512;
      height = width / aspectRatio;
    }}

    canvas.width = width;
    canvas.height = height;
    imageOverlay.width = width;
    imageOverlay.height = height;

    canvasContainer.style.width = `${{width}}px`;
    canvasContainer.style.height = `${{height}}px`;

    ctx.fillStyle = '{'white' if invert else 'black'}';
    ctx.fillRect(0, 0, canvas.width, canvas.height);
    imageOverlay.src = backgroundImage.src;
  }};

  brushSizeSlider.addEventListener('input', function(event) {{
    brushSize = event.target.value;
    brushSizeDisplay.textContent = brushSize;
  }});

  function setTool(selectedTool) {{
    tool = selectedTool;
    highlightActiveButton(tool);
  }}

  canvas.addEventListener('mousedown', () => {{ drawing = true; }});
  canvas.addEventListener('mouseup', () => {{
    drawing = false;
    ctx.beginPath();
  }});
  canvas.addEventListener('mousemove', draw);

  function draw(event) {{
    if (!drawing) return;
    const rect = canvas.getBoundingClientRect();
    const x = event.clientX - rect.left;
    const y = event.clientY - rect.top;
    ctx.lineWidth = brushSize;
    ctx.lineCap = 'round';
    ctx.globalCompositeOperation = 'source-over';
    if (tool === 'draw') {{
      ctx.strokeStyle = '{'black' if invert else 'white'}';
    }} else if (tool === 'erase') {{
      ctx.strokeStyle = '{'white' if invert else 'black'}';
    }}
    ctx.lineTo(x, y);
    ctx.stroke();
    ctx.beginPath();
    ctx.moveTo(x, y);
  }}

  function clearCanvas() {{
    ctx.globalCompositeOperation = 'source-over';
    ctx.fillStyle = '{'white' if invert else 'black'}';
    ctx.fillRect(0, 0, canvas.width, canvas.height);
    ctx.beginPath();
  }}

  function saveCanvas() {{
    const dataURL = canvas.toDataURL('image/png');
    google.colab.kernel.invokeFunction('notebook.save_image', [dataURL], {{}});
  }}

  const undoStack = [];
  const redoStack = [];

  function updateUndoRedoButtons() {{
    document.getElementById('undoButton').disabled = undoStack.length === 0;
    document.getElementById('redoButton').disabled = redoStack.length === 0;
  }}

  function saveState() {{
    undoStack.push(canvas.toDataURL());
    redoStack.length = 0;
    updateUndoRedoButtons();
  }}

  function undo() {{
    if (undoStack.length === 0) return;
    redoStack.push(canvas.toDataURL());
    const previousState = undoStack.pop();
    const img = new Image();
    img.src = previousState;
    img.onload = function () {{
      ctx.clearRect(0, 0, canvas.width, canvas.height);
      ctx.drawImage(img, 0, 0);
    }};
    updateUndoRedoButtons();
  }}

  function redo() {{
    if (redoStack.length === 0) return;
    undoStack.push(canvas.toDataURL());
    const nextState = redoStack.pop();
    const img = new Image();
    img.src = nextState;
    img.onload = function () {{
      ctx.clearRect(0, 0, canvas.width, canvas.height);
      ctx.drawImage(img, 0, 0);
    }};
    updateUndoRedoButtons();
  }}

  canvas.addEventListener('mousedown', () => {{
    saveState();
    drawing = true;
  }});

  canvas.addEventListener('mouseup', () => {{
    drawing = false;
    ctx.beginPath();
  }});

  updateUndoRedoButtons();
</script>
"""

from IPython.display import HTML, display
from google.colab import output
from base64 import b64decode
from PIL import Image, ImageOps
import io

def save_image(data_url):
    header, encoded = data_url.split(",", 1)
    binary_data = b64decode(encoded)
    img = Image.open(io.BytesIO(binary_data))
    img.save("inpaint_mask.png")
    return "Mask saved as inpaint_mask.png"

output.register_callback('notebook.save_image', save_image)
display(HTML(paint_interface))
