# RetinaNet with PyTorch

The dataset used can be found at https://www.kaggle.com/datasets/ipythonx/retinal-vessel-segmentation/data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

In [None]:
# Device configuration
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f"Using {device} device")

# Load Data

In [None]:
class RetinaDataset(Dataset):
    def __init__(self, imgs, masks, manual, transform=None):
        self.imgs = imgs
        self.masks = masks
        self.manual = manual
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        mask = self.masks[idx]
        manual = self.manual[idx]

        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)
            manual = self.transform(manual)

In [None]:
def read_data(path, dataset="DRIVE"):
    imgs = []
    masks = []
    manuals = []
    subFolder = os.listdir(path)
    for folder in subFolder:
        if folder == "images":
            for img in os.listdir(os.path.join(path, folder)):
                imgs.append(plt.imread(os.path.join(path, folder, img)))
        elif folder == "mask":
            for mask in os.listdir(os.path.join(path, folder)):
                masks.append(plt.imread(os.path.join(path, folder, mask)))
        elif folder == "manual1" or folder == "1st_manual":
            for manual in os.listdir(os.path.join(path, folder)):
                manuals.append(plt.imread(os.path.join(path, folder, manual)))
    return imgs, masks, manuals

In [None]:
train_data = read_data("archive/DRIVE/training")
test_data = read_data("archive/DRIVE/test")

In [None]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
train_dataset = RetinaDataset(train_data[0], train_data[1], train_data[2], transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

test_dataset = RetinaDataset(test_data[0], test_data[1], test_data[2], transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)

# Image Visualization

In [None]:
def show_image(img, mask, manual):
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(img)
    ax[0].set_title("Image")
    ax[0].axis("off")
    ax[1].imshow(mask)
    ax[1].set_title("Mask")
    ax[1].axis("off")
    ax[2].imshow(manual)
    ax[2].set_title("Manual")
    ax[2].axis("off")
    plt.show()

In [None]:
show_image(train_data[0][0], train_data[1][0], train_data[2][0])

# Model

In [None]:
import torch.nn as nn
from torchvision.models import ResNet50_Weights
from torchsummary import summary

In [None]:
class RetinaNet(nn.Module):
    """
    RetinaNet model with ResNet backbone
    """
    def __init__(self):
        super(RetinaNet, self).__init__()
        self.resnet = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)
        # Remove the fully connected layer
        self.backbone = nn.Sequential(*list(self.resnet.children())[:-2])
        
        # Feature Pyramid Network (FPN)
        self.fpn = nn.Sequential(
            nn.Conv2d(2048, 256, kernel_size=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Segmentation head
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(256, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=1)  # For binary segmentation
        )
        
    def forward(self, x):
        # Extract features from the backbone
        features = self.backbone(x)
        
        # Process features through the FPN
        fpn_features = self.fpn(features)
        
        # Generate the segmentation map
        segmentation_map = self.segmentation_head(fpn_features)
        
        # Apply sigmoid activation for binary segmentation
        segmentation_map = torch.sigmoid(segmentation_map)
        
        return segmentation_map

In [None]:
model = RetinaNet()
summary(model, (3, 224, 224))
model = model.to(device)

In [None]:
from loss import DiceLoss

In [None]:
# Loss and optimizer
criterion = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Train the model
total_step = len(train_loader)
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    for i, (images, masks, manuals) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device)
        manuals = manuals.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')