# RFA-U-Net Visualization

This notebook demonstrates how to load the pre-trained RFA-U-Net model, run inference on a sample OCT image, and visualize the segmentation results using the `plot_boundaries` function.

In [None]:
import torch
import os
from src.rfa_u_net import AttentionUNetViT, plot_boundaries, OCTDataset, val_test_transform

# Configuration
config = {
    "image_size": 224,
    "hidden_dim": 1024,
    "patch_size": 16,
    "num_channels": 3,
    "num_classes": 2,
    "retfound_weights_path": "weights/rfa_unet_best.pth"
}

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AttentionUNetViT(config).to(device)

# Load pre-trained weights
weights_path = config['retfound_weights_path']
if os.path.exists(weights_path):
    checkpoint = torch.load(weights_path, map_location=device)
    model.load_state_dict(checkpoint, strict=False)
    print(f"Loaded weights from {weights_path}")
else:
    raise FileNotFoundError(f"Weights file not found: {weights_path}. Please provide the correct path.")
model.eval()

# Load a real test image and mask
test_image_dir = "/path/to/your/test/images"
test_mask_dir  = "/path/to/your/test/masks"
dataset = OCTDataset(test_image_dir, test_mask_dir, config['image_size'], transform=val_test_transform)
img, gt = dataset[0]  # first sample
image = img.unsqueeze(0).to(device)
mask  = gt.unsqueeze(0).to(device)

# Run inference
with torch.no_grad():
    output = model(image)
    predicted_mask = torch.sigmoid(output)

# Visualization threshold
threshold = 0.5

# Visualize results
plot_boundaries(image, mask, predicted_mask, threshold)