In [None]:
# Import libraries
from PIL import Image
import torch
from torch.utils.data import DataLoader
from data_loader import get_dataloaders
from sod_model import SOD_CNN
import matplotlib.pyplot as plt
import torchvision.transforms as T

In [None]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

In [None]:
# Load trained model
model = SOD_CNN().to(device)
model.load_state_dict(torch.load("best_sod_model.pth", map_location=device))
model.eval()

In [None]:
# Transform for input image
transform = T.Compose([
T.Resize((224,224)),
T.ToTensor(),
])

In [None]:
# Function to run demo on a single image
def run_demo(image_path):
    # Load image
    img = Image.open(image_path).convert("RGB")
    input_tensor = transform(img).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        output = model(input_tensor)

    output_mask = output.squeeze().cpu().numpy()
    img_np = transform(img).permute(1,2,0).numpy()

    # Visualization
    plt.figure(figsize=(10,3))
    plt.subplot(1,3,1)
    plt.imshow(img_np)
    plt.title("Input Image")
    plt.axis('off')

    plt.subplot(1,3,2)
    plt.imshow(output_mask, cmap='gray')
    plt.title("Predicted Mask")
    plt.axis('off')

    plt.subplot(1,3,3)
    plt.imshow(img_np)
    plt.imshow(output_mask, cmap='jet', alpha=0.5)
    plt.title("Overlay")
    plt.axis('off')

    plt.show()

In [None]:
run_demo("image_path")
