In [1]:
from PIL import Image
import scipy.ndimage
import os
import numpy as np
# import sys

In [2]:
import gradio as gr

In [3]:
import torch
import torchvision.transforms as transforms

In [4]:
from model_edge2hats import Generator, ResNetBlock

In [5]:
js_func = """
    function refresh() {
        const url = new URL(window.location);
        if (url.searchParams.get('__theme') !== 'light') {
            url.searchParams.set('__theme', 'light');
            window.location.href = url.href;
        }
    }
    """

In [6]:
# Define global variables
resolution = 500  # Define your resolution
nz = 100
ngf = 64
edge_brush = gr.Brush(default_size = 4, colors = ["#010101", "#F0F0F0"], color_mode = "fixed", default_color = "#010101")
canvas = gr.ImageEditor(label = "Canvas", show_label = False, interactive = True, type = "numpy",
                        image_mode = "RGB", sources = ("upload"), brush = edge_brush, eraser = False, transforms = [])
output_canvas = gr.Image(label = "Output", show_label = False, interactive = False, height = resolution, width = resolution)

In [7]:
transform_edge = transforms.Compose([
                               transforms.Resize(128),
                               transforms.CenterCrop(128),
                               transforms.ToTensor(),
                               transforms.Normalize(0.5, 0.5),
                           ])
model_sub_path = os.path.join("models", "50.pth")
model = Generator(1)
model_tmp = torch.load(model_sub_path, map_location = torch.device('cpu'))
model.load_state_dict(model_tmp.state_dict())

<All keys matched successfully>

In [8]:
def resize_template(orignal):
    original_shape = max(orignal.shape)
    scale_factor = resolution / original_shape
    square_array = scipy.ndimage.zoom(orignal, scale_factor)
    if orignal.shape[0] == orignal.shape[1]:
        return square_array
    right_pad = resolution - square_array.shape[1]
    down_pad = resolution - square_array.shape[0]
    square_array = np.pad(square_array, 
                          ((0, down_pad), (0, right_pad)),
                          mode = "constant",
                          constant_values = 255)
    return square_array

In [9]:
def color_fill(image_array):
    background = image_array["background"]
    image_array = image_array["composite"]
    image_array = np.sum(image_array, axis=2)
    if background.sum() == 0:
        image_array = np.where(image_array == 3, 0, 255).astype(np.uint8)
    else:
        image_array = np.where(image_array == 3 * 255, 255, 0).astype(np.uint8)

    if image_array.shape[0] != image_array.shape[1]:
        padding = max(image_array.shape)
        top_pad = (padding - image_array.shape[0]) // 2
        bottom_pad = padding - image_array.shape[0] - top_pad
        left_pad = (padding - image_array.shape[1]) // 2
        right_pad = padding - image_array.shape[1] - left_pad
        # Pad the image at the center
        image_array = np.pad(image_array, 
                              ((top_pad, bottom_pad), (left_pad, right_pad)),
                              mode="constant",
                              constant_values=255)
        
    image_tensor = transform_edge(Image.fromarray(image_array))
    image_tensor = image_tensor.unsqueeze(0)
    fixed_noise = torch.randn(1, 100, 1, 1)
    model.eval()
    with torch.no_grad():
        generated = model(fixed_noise, image_tensor)[0]
    low = float(generated.min())
    high = float(generated.max())
    generated = generated.clamp_(min = low, max = high)
    generated = generated.sub_(low).div_(max(high - low, 1e-5))
    generated = generated.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    zoom_factor = [resolution / 128, resolution / 128, 1]
    generated = scipy.ndimage.zoom(generated, zoom_factor)
    return generated

In [10]:
with gr.Blocks(js = js_func) as demo:
    # Create Gradio interface
    draw_interface = gr.Interface(
        color_fill,
        theme = "default",
        title = "Edge2Hats",
        allow_flagging=False,
        inputs = canvas,
        outputs = output_canvas
    )
    demo.launch(draw_interface, server_name = "0.0.0.0", server_port = 3308, share = True)



Running on local URL:  http://0.0.0.0:3308
IMPORTANT: You are using gradio version 4.24.0, however version 4.29.0 is available, please upgrade.
--------
IMPORTANT: You are using gradio version 4.24.0, however version 4.29.0 is available, please upgrade.
--------
Running on public URL: https://4aa1cabf216e23a075.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
