In [5]:
from PIL import Image
import numpy as np
import ipywidgets as widgets
from io import BytesIO
from ipycanvas import Canvas, hold_canvas, MultiCanvas
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import torch.fft
from torchvision import transforms

In [6]:
class DownscaleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownscaleBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        return x + self.conv(x)

class UpscaleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpscaleBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)
    
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.downscale1 = nn.Sequential(
            DownscaleBlock(6, 64),
            DownscaleBlock(64, 128),
            DownscaleBlock(128, 256),
        )
        self.residual1 = nn.Sequential(
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
        )
        self.upscale1 = nn.Sequential(
            UpscaleBlock(256, 128),
            UpscaleBlock(128, 64),
            UpscaleBlock(64, 3),
        )
        self.downscale2 = nn.Sequential(
            DownscaleBlock(6, 64),
            DownscaleBlock(64, 128),
            DownscaleBlock(128, 256),
        )
        self.residual2 = nn.Sequential(
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
        )
        self.upscale2 = nn.Sequential(
            UpscaleBlock(256, 128),
            UpscaleBlock(128, 64),
            UpscaleBlock(64, 3),
        )

    def forward(self, x, mask):
        x_masked = x * mask
        x1 = torch.cat((x_masked, mask), dim=1)
        x1 = self.downscale1(x1)
        x1 = self.residual1(x1)
        x1 = self.upscale1(x1)
        inpainted = x_masked + x1 * (1 - mask)
        x2 = torch.cat((inpainted, mask), dim=1)
        x2 = self.downscale2(x2)
        x2 = self.residual2(x2)
        x2 = self.upscale2(x2)
        refined = inpainted + x2 * (1 - mask)
        return refined
    
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [7]:
image_size = 256
stroke_size = 5

image = np.array([])
left_canvas = MultiCanvas(2, width=image_size, height=image_size, sync_image_data=True)
right_canvas = Canvas(width=image_size, height=image_size, sync_image_data=True)
left_canvas[1].global_alpha=0.5
left_canvas[1].sync_image_data=True
is_drawing = False
is_uploaded = False
output = widgets.Output()
position = [0, 0]
mask = np.array([])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to("cuda")
generator.load_state_dict(torch.load("generator_new_50.pth"))

def upload_image(change):
    with output:
        global image, is_uploaded
        left_canvas.clear()
        image = np.array(Image.open(BytesIO(change.new[0]["content"])).crop((0, 0, image_size, image_size)))
        left_canvas[0].put_image_data(image)
        right_canvas.clear()
        is_uploaded = True

def paint_on_left(x, y):
    if not is_uploaded:
        return
    with output:
        with hold_canvas():
            left_canvas[1].stroke_style = "red"
            left_canvas[1].line_width = stroke_size
            if is_drawing:
                left_canvas[1].begin_path()
                left_canvas[1].move_to(position[0], position[1])
                left_canvas[1].line_to(x, y)
                left_canvas[1].stroke()
                left_canvas[1].close_path()
            position[0] = x
            position[1] = y

def confirm_preview(change):
    if not is_uploaded:
        return
    global image
    with hold_canvas():
        image = right_canvas.get_image_data().copy()
        left_canvas[0].put_image_data(image)
        left_canvas[1].clear()
        right_canvas.clear()

def revert_changes(change):
    if not is_uploaded:
        return
    global mask
    with hold_canvas():
        left_canvas[1].clear()
        mask = np.zeros_like(mask)
        right_canvas.clear()

def start_drawing(x, y):
    if not is_uploaded:
        return
    global is_drawing
    is_drawing = True

def stop_drawing(x, y):
    if not is_uploaded:
        return
    with output:
        global is_drawing, mask, image
        is_drawing = False
        mask = left_canvas[1].get_image_data()
        mask = mask[:,:,3]!=0
        with hold_canvas():
            image_temp = image.copy()
            image_temp = transform(image_temp).unsqueeze(0).to(device)
            mask_temp = np.ones_like(image)
            mask_temp[mask] = 0
            mask_temp = torch.from_numpy(mask_temp).permute(2, 0, 1).unsqueeze(0).float().to(device)
            image_temp = generator(image_temp, mask_temp)
            image_temp = image_temp.squeeze(0)
            image_temp = image_temp.detach().cpu()
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            image_temp = image_temp * std + mean
            image_temp = image_temp.numpy()
            image_temp = np.transpose(image, (1, 2, 0))
            image_temp = (image_temp * 255).astype(np.uint8)
            image_temp = Image.fromarray(image_temp)
            right_canvas.put_image(image_temp)
    
def stroke_changed(change):
    global stroke_size
    stroke_size = change.new

def save_result(change):
    if not is_uploaded:
        return
    result = right_canvas.to_file()
    with open("result.png", "wb") as f:
        f.write(result.getbuffer())
    print("Result saved as result.png")


In [8]:
upload_button = widgets.FileUpload(accept="image/*", multiple=False)
upload_button.observe(upload_image, names="value")

left_canvas.on_mouse_move(paint_on_left)

confirm_button = widgets.Button(description="Confirm")
confirm_button.on_click(confirm_preview)

revert_button = widgets.Button(description="Revert")
revert_button.on_click(revert_changes)

save_button = widgets.Button(description="Save")
save_button.on_click(save_result)

stroke_size_change = widgets.IntText(value=stroke_size, description='Stroke Size:', disabled=False)
stroke_size_change.observe(stroke_changed, 'value')

left_canvas.on_mouse_down(start_drawing)
left_canvas.on_mouse_up(stop_drawing)

display(widgets.HBox([left_canvas, right_canvas]))
display(stroke_size_change)
display(widgets.HBox([upload_button, confirm_button, revert_button, save_button]))
display(output)

HBox(children=(MultiCanvas(height=256, image_data=b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x0…

IntText(value=5, description='Stroke Size:')

HBox(children=(FileUpload(value=(), accept='image/*', description='Upload'), Button(description='Confirm', sty…

Output()