# RFA-U-Net Visualization

This notebook demonstrates how to load the pre-trained RFA-U-Net model, run inference on a sample OCT image, and visualize the segmentation results using the `plot_boundaries` function.

In [None]:
import torch
import os
from src.rfa_u_net import AttentionUNetViT, plot_boundaries

# Configuration
config = {
    "image_size": 224,
    "hidden_dim": 1024,
    "patch_size": 16,
    "num_channels": 3,
    "num_classes": 2,
    "retfound_weights_path": "weights/rfa_unet_best.pth"
}

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AttentionUNetViT(config).to(device)
# Optionally load pre-trained weights
# Set weights_type to 'none' to use random initialization
weights_type = 'rfa-unet'  # Options: 'none', 'retfound', 'rfa-unet'
weights_path = 'weights/rfa_unet_best.pth'  # Default path; can be overridden
if weights_type in ['retfound', 'rfa-unet']:
    if os.path.exists(weights_path):
        checkpoint = torch.load(weights_path, map_location=device)
        model.load_state_dict(checkpoint, strict=False)
        print(f"Loaded weights from {weights_path}")
    else:
        raise FileNotFoundError(f"Weights file not found: {weights_path}. Please download the appropriate weights file and place it in the weights/ directory, or specify a custom path using --weights_path when running the script.")
else:
    print("Using random initialization for the model weights.")
model.eval()

# Note: Pre-trained weights are optional. To train from scratch, set weights_type to 'none'.
# To use RETFound weights, download RETFound_oct_weights.pth from https://github.com/rmaphoh/RETFound_MAE,
# place it in the weights/ directory, and set weights_type to 'retfound'.
# To use RFA-U-Net weights, the script will automatically download rfa_unet_best.pth to the weights/ directory
# when running with --weights_type rfa-unet, unless a custom path is specified with --weights_path.
# Example command with custom weights path:
# python src/rfa-u-net.py --image_dir path/to/data/images --mask_dir path/to/data/masks --weights_type rfa-unet --weights_path /custom/path/rfa_unet_best.pth

# Dummy input (replace with actual OCT image and mask)
image = torch.randn(1, 3, 224, 224).to(device)  # Batch of 1, RGB, 224x224
mask = torch.randn(1, 2, 224, 224).to(device)   # Batch of 1, 2 classes, 224x224

# Run inference
with torch.no_grad():
    output = model(image)
    predicted_mask = torch.sigmoid(output)

# Visualize results
plot_boundaries(image, mask, predicted_mask)