In [8]:
# ==========================================
# 1. Install + imports
# ==========================================
!pip install -q gradio

from google.colab import drive
drive.mount("/content/drive", force_remount=True)

import os, shutil, sys, time
import torch
import numpy as np
import cv2
from PIL import Image
import gradio as gr
from torchvision import transforms


print("Torch version:", torch.__version__)

Mounted at /content/drive
Torch version: 2.9.0+cu126


In [9]:
# ==========================================
# 2. Load project code from Drive
# ==========================================
project_path = "/content/project"

if not os.path.exists(project_path):
    shutil.copytree("/content/drive/MyDrive/SOD_PROJECT_FINAL", project_path)

sys.path.append(project_path)
print("Project files:", os.listdir(project_path))

from sod_model_simple_new import SimpleUNet
from sod_model_dropout import UNetDropout
from sod_model_deep import UNetDeep

Project files: ['sod_model_dropout.py', 'sod_model_simple_new.py', 'test_checkpoints.py', 'sod_model_deep.py', '.ipynb_checkpoints', 'evaluate.py', '.git', '__pycache__', 'data_loader.py', '.gitignore', 'checkpoints', 'sod_model.py', 'train.py']


In [11]:
# ==========================================
# 3. Device, transforms, model paths
# ==========================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

def get_transform(img_size=224):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
    ])

CKPT_DIR = "/content/drive/MyDrive/SOD_PROJECT_FINAL/checkpoints"

MODEL_CLASSES = {
    "Simple U-Net": SimpleUNet,
    "U-Net Dropout": UNetDropout,
    "Deep U-Net": UNetDeep,
}


MODEL_PATHS = {
    "Simple U-Net": os.path.join(CKPT_DIR, "best_model_224_new.pth"),
    "U-Net Dropout": os.path.join(CKPT_DIR, "model_dropout.pth"),
    "Deep U-Net": os.path.join(CKPT_DIR, "model_deep.pth"),
}

print("Checkpoint files:", os.listdir(CKPT_DIR))

Using device: cuda
Checkpoint files: ['best_model_224.pth', 'model_dropout.pth', 'model_deep.pth', 'last_checkpoint.pth', 'best_model_224_new.pth', 'last_checkpoint_new.pth']


In [12]:
# ==========================================
# 4. Load all 3 models ONCE (no reloading in predict)
# ==========================================
LOADED_MODELS = {}

for name, cls in MODEL_CLASSES.items():
    path = MODEL_PATHS[name]
    print(f"Loading {name} from {path}")
    model = cls().to(DEVICE)
    state = torch.load(path, map_location=DEVICE)
    model.load_state_dict(state, strict=True)
    model.eval()
    LOADED_MODELS[name] = model

print("✅ All three models loaded.")

Loading Simple U-Net from /content/drive/MyDrive/SOD_PROJECT_FINAL/checkpoints/best_model_224_new.pth
Loading U-Net Dropout from /content/drive/MyDrive/SOD_PROJECT_FINAL/checkpoints/model_dropout.pth
Loading Deep U-Net from /content/drive/MyDrive/SOD_PROJECT_FINAL/checkpoints/model_deep.pth
✅ All three models loaded.


In [13]:
# ==========================================
# 5. Prediction function
# ==========================================
tf = get_transform(224)

def predict(image, model_name):
    """
    Perform inference on a single image using the selected model.

    Returns:
      - resized input image (224x224 RGB)
      - predicted saliency map (grayscale)
      - overlay (input + heatmap)
      - inference time string
    """
    # 1) Prepare input
    orig = image.convert("RGB")
    img_tensor = tf(orig).unsqueeze(0).to(DEVICE)

    # 2) Select already-loaded model
    model = LOADED_MODELS[model_name]

    # 3) Forward pass
    start = time.time()
    with torch.no_grad():
        pred = model(img_tensor)
    elapsed = time.time() - start

    # 4) Post-process prediction to [0,1]
    pred = pred.squeeze().cpu().numpy()
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-6)

    # 5) Build overlay
    img_np = np.array(orig.resize((224, 224)))
    saliency_gray = (pred * 255).astype(np.uint8)
    saliency_color = cv2.applyColorMap(saliency_gray, cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(img_np, 0.6, saliency_color, 0.4, 0)

    return (
        img_np,
        saliency_gray,
        cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB),
        f"{elapsed:.3f} seconds"
    )

In [None]:
demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Upload an Image"),
        gr.Dropdown(list(MODEL_CLASSES.keys()), label="Choose Model")
    ],
    outputs=[
        gr.Image(label="Input"),
        gr.Image(label="Predicted Mask"),
        gr.Image(label="Overlayed"),
        gr.Text(label="Inference Time")
    ],
    title="SOD Neural Network Demo",
    description="Upload an image to see saliency detection results."
)

demo.launch(debug=True)

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://8d90f75109db57c546.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/uvicorn/protocols/http/h11_impl.py", line 403, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fastapi/applications.py", line 1134, in __call__
    await super().__call__(scope, receive, send)
  File "/usr/local/lib/python3.12/dist-packages/starlette/applications.py", line 107, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/errors.py",