In [1]:
import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image, ImageOps
from scipy.ndimage import center_of_mass

In [2]:
model = tf.keras.models.load_model('model.h5')



In [3]:
def center_image(img_arr):
    mask = img_arr < 128
    if not np.any(mask):  
        return img_arr
    cy, cx = center_of_mass(mask)
    if np.isnan(cx) or np.isnan(cy):
        return img_arr
    shiftx = np.round(img_arr.shape[1] / 2.0 - cx).astype(int)
    shifty = np.round(img_arr.shape[0] / 2.0 - cy).astype(int)
    return np.roll(np.roll(img_arr, shifty, axis=0), shiftx, axis=1)

In [5]:
def thicken(arr, k=2):
    kernel = np.ones((k, k), np.uint8)
    return cv2.dilate(arr, kernel, iterations=1)

def preprocess(img):
    im = Image.fromarray(img)

    if im.mode == "RGBA":  
        _, _, _, a = im.split()
        im = a

    im = im.convert("L")               
    im = ImageOps.invert(im)           
    bbox = im.getbbox()
    if bbox:
        im = im.crop(bbox)             
    im = im.resize((20, 20), Image.LANCZOS)
    im = ImageOps.pad(im, (28, 28), color=255, centering=(0.5, 0.5))

    arr = np.array(im)
    arr = (arr < 128).astype(np.uint8) * 255  
    arr = center_image(arr)                   
    arr = thicken(arr, k=2)                   

    return arr

In [6]:
def sketch_recognition(img):
    if img is None:
        return {}, None

    composite = img["composite"]
    arr = preprocess(composite)

    x = arr.astype("float32") / 255.0
    x = x.reshape(-1, 28, 28, 1)

    preds = model.predict(x).tolist()[0]

    return {str(i): preds[i] for i in range(10)}, arr


In [9]:
interface = gr.Interface(
    fn=sketch_recognition,
    inputs=gr.Sketchpad(type="numpy"),  
    outputs=[
        gr.Label(num_top_classes=3),              
        gr.Image(type="numpy", image_mode="L")    
    ]
)

In [11]:
interface.launch()

Rerunning server... use `close()` to stop if you need to change `launch()` parameters.
----

To create a public link, set `share=True` in `launch()`.




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 61ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 65ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 57ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 69ms/step
