# RFA-U-Net Visualization

This notebook demonstrates how to load the pre-trained RFA-U-Net model, run inference on a sample OCT image, and visualize the segmentation results using the `plot_boundaries` function.

In [None]:
import torch
from src.rfa_u_net import AttentionUNetViT, plot_boundaries

# Configuration
config = {
    "image_size": 224,
    "hidden_dim": 1024,
    "patch_size": 16,
    "num_channels": 3,
    "num_classes": 2,
    "retfound_weights_path": "weights/rfa_unet_best.pth"
}

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AttentionUNetViT(config).to(device)
# Load pre-trained RFA-U-Net weights
checkpoint = torch.load("weights/rfa_unet_best.pth", map_location=device)
model.load_state_dict(checkpoint, strict=False)
model.eval()

# Dummy input (replace with actual OCT image and mask)
image = torch.randn(1, 3, 224, 224).to(device)  # Batch of 1, RGB, 224x224
mask = torch.randn(1, 2, 224, 224).to(device)   # Batch of 1, 2 classes, 224x224

# Run inference
with torch.no_grad():
    output = model(image)
    predicted_mask = torch.sigmoid(output)

# Visualize results
plot_boundaries(image, mask, predicted_mask)