# Demo Inference Notebook
This notebook demonstrates loading a trained U-Net model and running inference on a sample OCT image.

In [None]:
import torch
from unet import UNet  # adjust import based on your code structure
from PIL import Image
import matplotlib.pyplot as plt

# Load sample image
img = Image.open('../data/sample_images/sample_oct.png').convert('L')

# Load trained model (replace with your checkpoint path)
model = UNet(n_channels=1, n_classes=2)
model.load_state_dict(torch.load('../models/unet_checkpoint.pth', map_location='cpu'))
model.eval()

# Run inference
import torchvision.transforms as T
x = T.ToTensor()(img).unsqueeze(0)
with torch.no_grad():
    pred = model(x)
mask = torch.argmax(pred, dim=1).squeeze().numpy()

plt.subplot(1,2,1)
plt.imshow(img, cmap='gray')
plt.title('Input OCT')
plt.subplot(1,2,2)
plt.imshow(mask, cmap='gray')
plt.title('Predicted Mask')
plt.show()