In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2

from sod_model import SODModel
from data_loader import get_dataloaders


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = SODModel().to(device)
model.load_state_dict(torch.load("best_model.pth", map_location=device))
model.eval()

print("Model loaded on:", device)


In [None]:
def preprocess_image(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (224, 224))
    img_norm = img.astype(np.float32) / 255.0
    tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).to(device)
    return img, tensor


In [None]:
def predict_mask(model, tensor):
    with torch.no_grad():
        pred = model(tensor)
    pred = pred.squeeze().cpu().numpy()
    return pred


In [None]:
def visualize(img, mask):
    overlay = img.copy()
    overlay[:, :, 1] = np.maximum(overlay[:, :, 1], mask)
    overlay = np.clip(overlay, 0, 1)

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title("Input")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(mask, cmap="gray")
    plt.title("Prediction")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(overlay)
    plt.title("Overlay")
    plt.axis("off")

    plt.show()


In [None]:
from google.colab import files
uploaded = files.upload()

for name in uploaded.keys():
    img, tensor = preprocess_image(name)
    pred = predict_mask(model, tensor)
    visualize(img / 255.0, pred)
