In [115]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import ipywidgets as widgets
from io import BytesIO
from ipycanvas import Canvas, hold_canvas, MultiCanvas

In [181]:
image_size = 224
stroke_size = 5

image = np.array([])
left_canvas = MultiCanvas(2, width=image_size, height=image_size, sync_image_data=True)
right_canvas = Canvas(width=image_size, height=image_size, sync_image_data=True)
left_canvas[1].global_alpha=0.5
left_canvas[1].sync_image_data=True
is_drawing = False
is_uploaded = False
output = widgets.Output()
position = [0, 0]
mask = np.array([])

def upload_image(change):
    with output:
        global image, is_uploaded
        left_canvas.clear()
        image = np.array(Image.open(BytesIO(change.new[0]["content"])).crop((0, 0, image_size, image_size)))
        left_canvas[0].put_image_data(image)
        right_canvas.clear()
        is_uploaded = True

def paint_on_left(x, y):
    if not is_uploaded:
        return
    with output:
        with hold_canvas():
            left_canvas[1].stroke_style = "red"
            left_canvas[1].line_width = stroke_size
            if is_drawing:
                left_canvas[1].begin_path()
                left_canvas[1].move_to(position[0], position[1])
                left_canvas[1].line_to(x, y)
                left_canvas[1].stroke()
                left_canvas[1].close_path()
            position[0] = x
            position[1] = y

def confirm_preview(change):
    if not is_uploaded:
        return
    global image
    with hold_canvas():
        image = right_canvas.get_image_data().copy()
        left_canvas[0].put_image_data(image)
        left_canvas[1].clear()
        right_canvas.clear()

def revert_changes(change):
    if not is_uploaded:
        return
    global mask
    with hold_canvas():
        left_canvas[1].clear()
        mask = np.zeros_like(mask)
        right_canvas.clear()

def start_drawing(x, y):
    if not is_uploaded:
        return
    global is_drawing
    is_drawing = True

def stop_drawing(x, y):
    if not is_uploaded:
        return
    with output:
        global is_drawing, mask, image
        is_drawing = False
        mask = left_canvas[1].get_image_data()
        mask = mask[:,:,3]!=0
        with hold_canvas():
            print(image.shape)
            image_temp = image.copy()
            image_temp[mask,:3] = 0
            print(image_temp[mask][:,:3].shape)
            right_canvas.put_image_data(image_temp)
    
def stroke_changed(change):
    global stroke_size
    stroke_size = change.new

def save_result(change):
    if not is_uploaded:
        return
    result = right_canvas.to_file()
    with open("result.png", "wb") as f:
        f.write(result.getbuffer())
    print("Result saved as result.png")


In [182]:
upload_button = widgets.FileUpload(accept="image/*", multiple=False)
upload_button.observe(upload_image, names="value")

left_canvas.on_mouse_move(paint_on_left)

confirm_button = widgets.Button(description="Confirm")
confirm_button.on_click(confirm_preview)

revert_button = widgets.Button(description="Revert")
revert_button.on_click(revert_changes)

save_button = widgets.Button(description="Save")
save_button.on_click(save_result)

stroke_size_change = widgets.IntText(value=stroke_size, description='Stroke Size:', disabled=False)
stroke_size_change.observe(stroke_changed, 'value')

left_canvas.on_mouse_down(start_drawing)
left_canvas.on_mouse_up(stop_drawing)

display(widgets.HBox([left_canvas, right_canvas]))
display(stroke_size_change)
display(widgets.HBox([upload_button, confirm_button, revert_button, save_button]))
display(output)

HBox(children=(MultiCanvas(height=224, image_data=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xe0\x00\x0…

IntText(value=5, description='Stroke Size:')

HBox(children=(FileUpload(value=(), accept='image/*', description='Upload'), Button(description='Confirm', sty…

Output()