In [None]:
import io
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms.functional as TF
from IPython.display import clear_output
import ipywidgets as widgets

from sod_model import UNet

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 128


def load_model(weights="best_unet_improved.pth"):
    model = UNet(in_channels=3, base_channels=32).to(DEVICE)
    state = torch.load(weights, map_location=DEVICE)
    model.load_state_dict(state)
    model.eval()
    return model


model = load_model()

uploader = widgets.FileUpload(accept='image/*', multiple=False)
display(uploader)


def run_demo(change):
    if len(uploader.value) == 0:
        return

    clear_output(wait=True)
    display(uploader)

    file_info = list(uploader.value.values())[0]
    img = Image.open(io.BytesIO(file_info["content"])).convert("RGB")

    img_resized = img.resize((IMG_SIZE, IMG_SIZE))

    img_tensor = TF.to_tensor(img_resized).unsqueeze(0).to(DEVICE)

    start = time.time()
    with torch.no_grad():
        pred = model(img_tensor)
    duration = time.time() - start

    threshold = 0.5
    pred_bin = (pred > threshold).float()

    img_np = np.array(img_resized)
    pred_np = pred_bin[0, 0].cpu().numpy()

    overlay = img_np.astype(np.float32) / 255.0
    overlay[..., 0] = np.clip(overlay[..., 0] + pred_np * 0.6, 0, 1)

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))

    axs[0].imshow(img_np)
    axs[0].set_title("Input Image")
    axs[0].axis("off")

    axs[1].imshow(pred_np, cmap="gray")
    axs[1].set_title("Saliency Mask")
    axs[1].axis("off")

    axs[2].imshow(overlay)
    axs[2].set_title("Overlay")
    axs[2].axis("off")

    plt.tight_layout()
    plt.show()

    print(f"Inference time per image: {duration*1000:.2f} ms (on {DEVICE})")


uploader.observe(run_demo, names="value")