# Retinal Vessel Segmentation with U‑Net

Full ETL, U‑Net model, training, evaluation, and Streamlit app generation.

In [None]:
# 1. Imports
import os
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from albumentations import Compose, Resize, Normalize, HorizontalFlip, RandomRotate90, ElasticTransform
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import jaccard_score, f1_score
import pickle

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

In [None]:
# 2. ETL & Dataset Definition
class RetinalDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(img_dir))
        self.transform = transform if transform is not None else Compose([
            Resize(256, 256),
            Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            ToTensorV2()
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype('float32')
        augmented = self.transform(image=image, mask=mask)
        return augmented['image'], augmented['mask']

In [None]:
# 3. U‑Net Model Definition
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super(UNet, self).__init__()
        self.down1 = DoubleConv(n_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.down4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(512, 1024)
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv1 = DoubleConv(128, 64)
        self.final = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        d1 = self.down1(x)
        p1 = self.pool1(d1)
        d2 = self.down2(p1)
        p2 = self.pool2(d2)
        d3 = self.down3(p2)
        p3 = self.pool3(d3)
        d4 = self.down4(p3)
        p4 = self.pool4(d4)
        bn = self.bottleneck(p4)
        up4 = self.up4(bn)
        merge4 = torch.cat([up4, d4], dim=1)
        c4 = self.conv4(merge4)
        up3 = self.up3(c4)
        merge3 = torch.cat([up3, d3], dim=1)
        c3 = self.conv3(merge3)
        up2 = self.up2(c3)
        merge2 = torch.cat([up2, d2], dim=1)
        c2 = self.conv2(merge2)
        up1 = self.up1(c2)
        merge1 = torch.cat([up1, d1], dim=1)
        c1 = self.conv1(merge1)
        return torch.sigmoid(self.final(c1))

# Instantiate model
model = UNet().to(device)
print(model)

In [None]:
# 4. Data Loaders
train_dataset = RetinalDataset('data/train/images', 'data/train/masks')
val_dataset = RetinalDataset('data/val/images', 'data/val/masks')
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)


In [None]:
# 5. Training Loop
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 10
train_losses, val_losses, val_iou, val_dice = [], [], [], []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.unsqueeze(1).to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    train_losses.append(epoch_loss / len(train_loader))

    model.eval()
    val_loss, ious, dices = 0, [], []
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.unsqueeze(1).to(device)
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            val_loss += loss.item()
            preds = (torch.sigmoid(outputs) > 0.5).cpu().numpy().flatten()
            truth = masks.cpu().numpy().flatten()
            ious.append(jaccard_score(truth, preds))
            dices.append(f1_score(truth, preds))
    val_losses.append(val_loss / len(val_loader))
    val_iou.append(np.mean(ious))
    val_dice.append(np.mean(dices))

    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f} - IoU: {val_iou[-1]:.4f} - Dice: {val_dice[-1]:.4f}")


In [None]:
# 6. Metrics Visualization
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.legend(); plt.title('Loss Curves')
plt.subplot(1,2,2)
plt.plot(val_iou, label='IoU')
plt.plot(val_dice, label='Dice')
plt.legend(); plt.title('Validation Metrics')
plt.show()


In [None]:
# 7. Save Model & Transform
os.makedirs('models', exist_ok=True)
torch.save(model.state_dict(), 'models/unet_retinal.pth')
with open('models/transform.pkl', 'wb') as f:
    pickle.dump(train_dataset.transform, f)
print("Model and transform saved.")

In [None]:
# 8. Generate Streamlit App File
streamlit_code = """import streamlit as st
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as T
import pickle
from __main__ import UNet

# Load model
model = UNet().to('cpu')
model.load_state_dict(torch.load('models/unet_retinal.pth', map_location='cpu'))
model.eval()

# Load transform
with open('models/transform.pkl', 'rb') as f:
    transform = pickle.load(f)

st.title("Retinal Vessel Segmentation")

uploaded = st.file_uploader("Upload a fundus image", type=["png","jpg","jpeg"])
if uploaded is not None:
    image = Image.open(uploaded).convert("RGB")
    img_np = np.array(image)
    augmented = transform(image=img_np, mask=img_np)  # mask ignored
    input_tensor = torch.unsqueeze(augmented['image'], 0)
    with st.spinner("Segmenting..."):
        with torch.no_grad():
            output = model(input_tensor)
            mask = (torch.sigmoid(output)[0,0].numpy() > 0.5).astype('uint8')
    st.image(image, caption='Original', use_column_width=True)
    st.image(mask, caption='Segmented Vessels', use_column_width=True)
"""
with open('streamlit_retinal_app.py', 'w') as f:
    f.write(streamlit_code)
print("Streamlit app file generated: streamlit_retinal_app.py")