In [9]:
!pip install gradio



In [2]:
from google.colab import drive
drive.mount('/content/drive')

import shutil, os

PROJECT_DIR = "/content/drive/MyDrive/SOD_PROJECT_FINAL"


!rm -rf /content/project
shutil.copytree(PROJECT_DIR, "/content/project")

os.chdir("/content/project")
print("Loaded project files:", os.listdir())

Mounted at /content/drive
Loaded project files: ['sod_model_deep.py', 'data_loader.py', 'sod_model_dropout.py', 'sod_model.py', 'evaluate.py', 'train.py', 'checkpoints']


In [10]:
import torch
import numpy as np
import cv2
from PIL import Image
import gradio as gr
import time
import torchvision.transforms as transforms


from sod_model import SimpleUNet
from sod_model_dropout import UNetDropout
from sod_model_deep import UNetDeep

In [11]:
def get_transform(img_size):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor()
    ])

In [12]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_CLASSES = {
    "Simple U-Net": SimpleUNet,
    "U-Net Dropout": UNetDropout,
    "Deep U-Net": UNetDeep,
}

MODEL_PATHS = {
    "Simple U-Net": "checkpoints/best_model_224.pth",
    "U-Net Dropout": "checkpoints/model_dropout.pth",
    "Deep U-Net": "checkpoints/model_deep.pth",
}

In [13]:
def predict(image, model_name):
    """
    Returns:
        input image,
        predicted mask,
        overlay,
        inference time
    """


    orig = image.convert("RGB")


    tf = get_transform(224)
    img_tensor = tf(orig).unsqueeze(0).to(DEVICE)


    model = MODEL_CLASSES[model_name]().to(DEVICE)
    model.load_state_dict(torch.load(MODEL_PATHS[model_name], map_location=DEVICE))
    model.eval()


    start = time.time()
    with torch.no_grad():
        pred = model(img_tensor)
    elapsed = time.time() - start

    pred = pred.squeeze().cpu().numpy()
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-6)


    img_np = np.array(orig.resize((224, 224)))
    heatmap = (pred * 255).astype(np.uint8)
    heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(img_np, 0.6, heatmap_color, 0.4, 0)

    return (
        img_np,
        pred,
        cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
        f"{elapsed:.3f} seconds"
    )

In [None]:
demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Upload an Image"),
        gr.Dropdown(list(MODEL_CLASSES.keys()), label="Choose Model")
    ],
    outputs=[
        gr.Image(label="Input"),
        gr.Image(label="Predicted Mask"),
        gr.Image(label="Overlayed"),
        gr.Text(label="Inference Time")
    ],
    title="SOD Neural Network Demo",
    description="Upload an image to see saliency detection results."
)

demo.launch(debug=True)

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://4e4bdd004d9cc3cfd3.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/uvicorn/protocols/http/h11_impl.py", line 403, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fastapi/applications.py", line 1134, in __call__
    await super().__call__(scope, receive, send)
  File "/usr/local/lib/python3.12/dist-packages/starlette/applications.py", line 107, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/errors.py",