**The goal of this project is to build a Computer Vision Segmentation Model to segment Retinal Blood Vessels.**

**Installing required ```libraries``` for the project-**

In [None]:
%pip install torch torchvision torchsummary matplotlib pillow

**Importing those ```libraries```-**

In [13]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import os
from PIL import Image

**Creating a Custom ``RetinalDataset`` Class in ``PyTorch`` for the Training Retinal Images-**
1. **Input Shape-** The ```input_shape``` of the retinal images is initially **3⨯512⨯512.**
2. **Dataset Filepaths-**
    - ``image_ds_path`` stores the path to the retinal images
    - ``mask_ds_path`` stores the path to the corresponding retinal masks.
3. **Defining the ``RetinalDataset`` Class-**
    - Creating a custom class by extending the ``Dataset`` class from ``torch.utils.data``.
    - ``__init__`` initializes the dataset's **attributes** such as filepaths and ``transforms`` using specified **parameters**.
    - ``__len__`` returns the total number of samples in the ``RetinalDataset`` **(80).**
    - ``__getitem__`` retrieves the **grayscale** image and  mask at a specified ``idx``, applies the ``transforms``, and then returns **transformed_image** and **transformed_mask.**
4. **Composting the Transformations-**
    - ``transforms.Resize()`` ensures the images and masks are their original **W⨯H.**
    - ``transforms.ToTensor()`` **normalizes** the pixel values to be **[0,1].**

In [14]:
input_shape = torch.tensor([3, 512, 512])

image_ds_path = "C:/Users/User/Retinal_Vessel_Segmentation/train/image/"
mask_ds_path = "C:/Users/User/Retinal_Vessel_Segmentation/train/mask/"

class RetinalDataset(Dataset):
    def __init__(self, image_folder, mask_folder, transforms):
        self.image_folder = image_folder
        self.mask_folder = mask_folder
        self.transforms = transforms
        self.retinal_ids = []

    def __len__(self):
        return len(self.retinal_ids)

    def retinal_id_list(self):
        for image in os.listdir(self.image_folder):
            self.retinal_ids.append(image)
        return self.retinal_ids

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_folder, self.retinal_ids[idx])
        mask_path = os.path.join(self.mask_folder, self.retinal_ids[idx])

        image = Image.open(image_path).convert('L')
        mask = Image.open(mask_path).convert('L')

        transformed_image = self.transforms(image)
        transformed_mask = self.transforms(mask)

        return transformed_image, transformed_mask
    

retinal_transforms = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

RetinalData = RetinalDataset(image_folder=image_ds_path, mask_folder=mask_ds_path, transforms=retinal_transforms)

**Displaying the first Retinal Image and the respective Blood Vessel Mask using ``matplotlib.pyplot``.**

In [None]:
image_tensor, mask_tensor = RetinalData.__getitem__(0)
image = image_tensor[0].detach().cpu()
mask = mask_tensor[0].detach().cpu()

print(image_tensor.size())

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1) # (1 row, 2 columns, first subplot/column)
plt.imshow(image, cmap="gray")
plt.title("Retina")

plt.subplot(1, 2, 2)  # (1 row, 2 columns, second subplot/column)
plt.imshow(mask, cmap="gray")
plt.title("Retinal Mask")

plt.tight_layout()
plt.show()

In [7]:
class DoubleConvolutional(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_c = in_channels
        self.out_c = out_channels
        self.conv1 = nn.Conv2d(in_channels=self.in_c, out_channels=self.out_c, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=self.out_c, out_channels=self.out_c, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(num_features=self.out_c)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, input):
        x = self.conv1(input)
        x = self.bn(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn(x)
        x = self.relu(x)

        return x

class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_c = in_channels
        self.out_c = out_channels
        self.conv = DoubleConvolutional(in_channels=self.in_c, out_channels=self.out_c)
        self.pool = nn.MaxPool2d((2,2))
    
    def forward(self, input):
        x = self.conv(input)
        p = self.pool(x)

        return x, p
    
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_c = in_channels
        self.out_c = out_channels
        self.up = nn.ConvTranspose2d(in_channels=self.in_c, out_channels=self.out_c, kernel_size=2, stride=2, padding=0)
        self.conv = DoubleConvolutional(in_channels=(self.out_c*2), out_channels=self.out_c)

    def forward(self, input, skip):
        x = self.up(input)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x

In [8]:
class UNET(nn.Module):
    def __init__(self):
        super().__init__()
        self.e1 = Encoder(in_channels=1, out_channels=32)
        self.e2 = Encoder(in_channels=32, out_channels=64)
        self.e3 = Encoder(in_channels=64, out_channels=128)
        self.e4 = Encoder(in_channels=128, out_channels=256)

        self.b = DoubleConvolutional(in_channels=256, out_channels=512)

        self.d1 = Decoder(in_channels=512, out_channels=256)
        self.d2 = Decoder(in_channels=256, out_channels=128)
        self.d3 = Decoder(in_channels=128, out_channels=64)
        self.d4 = Decoder(in_channels=64, out_channels=32)

        self.outputs = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=1, padding=0)

    def forward(self, input):
        s1, p1 = self.e1(input)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b = self.b(p4)

        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        output = self.outputs(d4)

        return output

In [None]:
RetinalSegmentor = UNET()
summary(model=RetinalSegmentor, input_size=(1, 512, 512))

In [10]:
def show_image_and_mask(image, true_mask, pred_mask):
    """Display an image, its ground truth mask, and the predicted mask."""
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))

    # Show the input image
    axs[0].imshow(image.squeeze().cpu().numpy(), cmap='gray')
    axs[0].set_title("Input Image")
    axs[0].axis("off")

    # Show the ground truth mask
    axs[1].imshow(true_mask.squeeze().cpu().numpy(), cmap='gray')
    axs[1].set_title("Ground Truth Mask")
    axs[1].axis("off")

    # Show the predicted mask
    axs[2].imshow(pred_mask.squeeze().detach().cpu().numpy(), cmap='gray')
    axs[2].set_title("Generated Mask")
    axs[2].axis("off")

    plt.show()

In [11]:
batch_size = 8
epochs = 10
lr = 3e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = DataLoader(dataset=RetinalData, batch_size=batch_size, shuffle=True)
r_optimizer = torch.optim.Adam(RetinalSegmentor.parameters(), lr=lr)
r_loss_fn = nn.BCEWithLogitsLoss()
losses = []

def train_model(model, dataloader, optimizer, loss_fn, device, num_epochs, loss_tracker):
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        print(f"Epoch [{epoch + 1}/{num_epochs}]")
        
        for batch_idx, data in enumerate(dataloader):
            inputs, masks = data
            inputs, masks = inputs.to(device), masks.to(device)

            outputs = model(inputs)
            loss = loss_fn(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            print(f"  Batch [{batch_idx + 1}/{len(dataloader)}]: Loss = {loss.item():.4f}")
        
        loss_tracker.append(epoch_loss)
        model.eval()
        with torch.no_grad():
            generated_mask = torch.sigmoid(outputs[0])  # Apply sigmoid for binary segmentation
            show_image_and_mask(inputs[0], masks[0], generated_mask)
        print(f"Epoch [{epoch + 1}/{num_epochs}] Average Loss: {epoch_loss / len(dataloader):.4f}\n")

In [None]:
train_model(model=RetinalSegmentor, dataloader=train_loader, optimizer=r_optimizer, loss_fn=r_loss_fn, device=device, num_epochs=epochs, loss_tracker=losses)

In [None]:
epochs = range(1, 11)  # 5 epochs

# Creating a plot
plt.plot(epochs, losses, label='Discriminator Loss', color='#228B22', linewidth=2)

# Labeling the axes and the title
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Retinal Blood Vessel Semantic Segmentation Loss')

# Show the legend
plt.legend()

# Display the plot
plt.show()