<a href="https://colab.research.google.com/github/Aldrin-Fanir/Hippocampal-Region-Segmentation-UNet/blob/main/HippocampalRegionUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Required Library

In [None]:
import copy
import os
import random
import zipfile
from math import atan2, cos, sin, sqrt, pi, log

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from numpy import linalg as LA
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from tqdm import tqdm

# Drive Mount: Import Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Directory of Images and Masks

In [None]:
root_path = '/content/drive/MyDrive/HippocampalRegionSegmentationUNet/Cohort1-Multiplexed Dataset'

In [None]:
print("Is path correct?", os.path.exists(root_path))

Is path correct? True


# Hippocampal Region Dataset

In [None]:
class HippocampalDataset(Dataset):
  def __init__(self, root_path, limit = None):
    self.root_path = root_path
    self.limit = limit

    self.images = sorted([root_path + "/cFos_NeuN_dFos_dataset_images/" + i for i in os.listdir(root_path + "/cFos_NeuN_dFos_dataset_images")])[: self.limit]
    self.masks = sorted([root_path + "/cFos_NeuN_dFos_dataset_masks/" + i for i in os.listdir(root_path + "/cFos_NeuN_dFos_dataset_masks")])[: self.limit]

    self.transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    if self.limit is None:
      self.limit = len(self.images)


  def __getitem__(self, index):
    img = Image.open(self.images[index]).convert("RGB")
    mask = Image.open(self.masks[index]).convert("L")

    return self.transform(img), self.transform(mask)

  def __len__(self):
    return min(len(self.images), self.limit)


# Test Dataset

In [None]:
dataset = HippocampalDataset(root_path, limit = None)

loader = DataLoader(dataset, batch_size = 1, shuffle = True)

for i, (images, masks) in enumerate(loader):
  print(f'Batch {i+1}')
  print(f'Image Shape: {images.shape}')
  print(f'Mask Shape: {masks.shape}')
  plt.figure(figsize=(12, 6))

  #Train Image
  plt.subplot(1, 2, 1)
  plt.imshow(images[0].permute(1, 2, 0).cpu().numpy())
  plt.title(f'Original Image {i+1}')
  plt.axis('off')

  #Train Mask
  plt.subplot(1, 2, 2)
  plt.imshow(masks[0].permute(1, 2, 0).cpu().numpy())
  plt.title(f'Mask Image {i+1}')
  plt.axis('off')

  plt.show()

# Create Model

# Double Convolution

In [None]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv_op = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
        nn.ReLU(inplace = True),
        nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1),
        nn.ReLU(inplace = True)
    )

  def forward(self, x):
    return self.conv_op(x)

# Downsampling

In [None]:
class DownSample(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv = DoubleConv(in_channels, out_channels)
    self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

  def forward(self, x):
    down = self.conv(x)
    p = self.pool(down)

    return down, p

# Upsampling

In [None]:
class UpSample(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size = 2, stride = 2)
    self.conv = DoubleConv(in_channels, out_channels)

  def forward(self, x1, x2):
    x1 = self.up(x1)
    x = torch.cat([x1, x2], 1)

    return self.conv(x)

# UNet Architecture

In [None]:
class UNet(nn.Module):
  def __init__(self, in_channels, num_classes):
    super().__init__()

    self.down_convolution_1 = DownSample(in_channels, 64)
    self.down_convolution_2 = DownSample(64, 128)
    self.down_convolution_3 = DownSample(128, 256)
    self.down_convolution_4 = DownSample(256, 512)

    self.bottle_neck = DoubleConv(512, 1024)

    self.up_convolution_1 = UpSample(1024, 512)
    self.up_convolution_2 = UpSample(512, 256)
    self.up_convolution_3 = UpSample(256, 128)
    self.up_convolution_4 = UpSample(128, 64)

    self.output = nn.Conv2d(64, out_channels = num_classes, kernel_size = 1)



  def forward(self, x):
    down_1, p1 = self.down_convolution_1(x)
    down_2, p2 = self.down_convolution_2(p1)
    down_3, p3 = self.down_convolution_3(p2)
    down_4, p4 = self.down_convolution_4(p3)

    b = self.bottle_neck(p4)

    up_1 = self.up_convolution_1(b, down_4)
    up_2 = self.up_convolution_2(up_1, down_3)
    up_3 = self.up_convolution_3(up_2, down_2)
    up_4 = self.up_convolution_4(up_3, down_1)

    return self.output(up_4)

# Trainning The Model

In [None]:
train_dataset = HippocampalDataset(root_path)
generator = torch.Generator().manual_seed(42)

In [None]:
train_dataset, test_dataset = random_split(train_dataset, [0.8, 0.2], generator = generator)

In [None]:
test_dataset, val_dataset = random_split(test_dataset, [0.5, 0.5], generator = generator)

**Running this experiment with CUDA**

In [None]:
num_workers = 1
device = "cuda" if torch.cuda.is_available() else "cpu"

if device == "cuda":
  num_workers = torch.cuda.device_count()*4

**Now We setup out model using the AdamW optimizer and the BCEWithLogitsLoss**

In [None]:
Learning_Rate = 3e-4
batch_size = 16

train_dataloader = DataLoader(dataset = train_dataset, num_workers = num_workers, pin_memory = True, batch_size = batch_size, shuffle = True)
test_dataloader = DataLoader(dataset = test_dataset, num_workers = num_workers, pin_memory = True, batch_size = batch_size, shuffle = False)
val_dataloader = DataLoader(dataset = val_dataset, num_workers = num_workers, pin_memory = True, batch_size = batch_size, shuffle = False)

model = UNet(in_channels = 3, num_classes = 1).to(device)
optimizer = optim.AdamW(model.parameters(), lr = Learning_Rate)
criterion = nn.BCEWithLogitsLoss()

In [None]:
print(len(train_dataset))
print(len(test_dataset))
print(len(val_dataset))

239
59
29


**Evaluating Segmentation Performance with DICEMetric**

In [None]:
def dice_coefficient(prediction, target, epsilon = 1e-07):
  prediction_copy = prediction.clone()

  prediction_copy[prediction_copy<0] = 0
  prediction_copy[prediction_copy>0] = 1

  intersection = abs(torch.sum(prediction_copy * target))
  union = torch.sum(prediction_copy) + torch.sum(target)
  dice = (2.0 * intersection + epsilon) / (union + epsilon)

  return dice

In [None]:
torch.cuda.empty_cache()

# Training

In [None]:
epochs = 10
train_losses = []
train_dcs = []

val_losses = []
val_dcs = []

for epoch in tqdm(range(epochs)):
  model.train()
  train_running_loss = 0
  traing_running_dc = 0

  for idx, img_mask in enumerate(tqdm(train_dataloader, position = 0, leave=True)):
    img = img_mask[0].float().to(device)
    mask = img_mask[1].float().to(device)

    y_pred = model(img)
    optimizer.zero_grad()

    dc = dice_coefficient(y_pred, mask)
    loss = criterion(y_pred, mask)

    train_running_loss += loss.item()
    traing_running_dc += dc.item()

    loss.backward()
    optimizer.step()

  train_loss = train_running_loss/(idx + 1)
  train_dc = traing_running_dc/(idx + 1)

  train_losses.append(train_loss)
  train_dcs.append(train_dc)




  model.eval()
  val_running_loss = 0
  val_running_dc = 0

  with torch.no_grad():
    for idx, img_mask in enumerate(tqdm(val_dataloader, position=0, leave = True)):
      img = img_mask[0].float().to(device)
      mask = img_mask[1].float().to(device)

      y_pred = model(img)
      loss = criterion(y_pred , mask)
      dc = dice_coefficient(y_pred, mask)

      val_running_loss += loss.item()
      val_running_dc += dc.item()

    val_loss = val_running_loss / (idx + 1)
    val_dc = val_running_dc / (idx + 1)

  val_losses.append(val_loss)
  val_dcs.append(val_dc)



  print("-" * 30)
  print(f"Training Loss EPOCH {epoch + 1}: {train_loss:.4f}")
  print(f"Training DICE EPOCH {epoch + 1}: {train_dc:.4f}")
  print("\n")
  print(f"Validation Loss EPOCH {epoch + 1}: {val_loss:.4f}")
  print(f"Validation DICE EPOCH {epoch + 1}: {val_dc:.4f}")
  print("-" * 30)

100%|██████████| 15/15 [00:04<00:00,  3.37it/s]
100%|██████████| 2/2 [00:01<00:00,  1.91it/s]
 10%|█         | 1/10 [00:05<00:49,  5.51s/it]

------------------------------
Training Loss EPOCH 1: 0.5652
Training DICE EPOCH 1: 0.0103


Validation Loss EPOCH 1: 0.3763
Validation DICE EPOCH 1: 0.0000
------------------------------


100%|██████████| 15/15 [00:03<00:00,  3.86it/s]
100%|██████████| 2/2 [00:01<00:00,  1.91it/s]
 20%|██        | 2/10 [00:10<00:41,  5.18s/it]

------------------------------
Training Loss EPOCH 2: 0.3274
Training DICE EPOCH 2: 0.0000


Validation Loss EPOCH 2: 0.2535
Validation DICE EPOCH 2: 0.0000
------------------------------


100%|██████████| 15/15 [00:04<00:00,  3.72it/s]
100%|██████████| 2/2 [00:01<00:00,  1.82it/s]
 30%|███       | 3/10 [00:15<00:36,  5.16s/it]

------------------------------
Training Loss EPOCH 3: 0.2325
Training DICE EPOCH 3: 0.0966


Validation Loss EPOCH 3: 0.2099
Validation DICE EPOCH 3: 0.5405
------------------------------


100%|██████████| 15/15 [00:03<00:00,  3.82it/s]
100%|██████████| 2/2 [00:01<00:00,  1.95it/s]
 40%|████      | 4/10 [00:20<00:30,  5.08s/it]

------------------------------
Training Loss EPOCH 4: 0.2202
Training DICE EPOCH 4: 0.4504


Validation Loss EPOCH 4: 0.1976
Validation DICE EPOCH 4: 0.4808
------------------------------


100%|██████████| 15/15 [00:03<00:00,  3.87it/s]
100%|██████████| 2/2 [00:01<00:00,  1.85it/s]
 50%|█████     | 5/10 [00:25<00:25,  5.04s/it]

------------------------------
Training Loss EPOCH 5: 0.2102
Training DICE EPOCH 5: 0.4326


Validation Loss EPOCH 5: 0.1924
Validation DICE EPOCH 5: 0.5543
------------------------------


100%|██████████| 15/15 [00:04<00:00,  3.73it/s]
100%|██████████| 2/2 [00:01<00:00,  1.96it/s]
 60%|██████    | 6/10 [00:30<00:20,  5.04s/it]

------------------------------
Training Loss EPOCH 6: 0.2045
Training DICE EPOCH 6: 0.5198


Validation Loss EPOCH 6: 0.1911
Validation DICE EPOCH 6: 0.5879
------------------------------


100%|██████████| 15/15 [00:03<00:00,  3.87it/s]
100%|██████████| 2/2 [00:01<00:00,  1.94it/s]
 70%|███████   | 7/10 [00:35<00:14,  5.00s/it]

------------------------------
Training Loss EPOCH 7: 0.2010
Training DICE EPOCH 7: 0.5363


Validation Loss EPOCH 7: 0.1782
Validation DICE EPOCH 7: 0.6066
------------------------------


100%|██████████| 15/15 [00:03<00:00,  3.78it/s]
100%|██████████| 2/2 [00:01<00:00,  1.93it/s]
 80%|████████  | 8/10 [00:40<00:10,  5.00s/it]

------------------------------
Training Loss EPOCH 8: 0.2051
Training DICE EPOCH 8: 0.5289


Validation Loss EPOCH 8: 0.1845
Validation DICE EPOCH 8: 0.5808
------------------------------


100%|██████████| 15/15 [00:03<00:00,  3.83it/s]
100%|██████████| 2/2 [00:01<00:00,  1.90it/s]
 90%|█████████ | 9/10 [00:45<00:04,  5.00s/it]

------------------------------
Training Loss EPOCH 9: 0.2063
Training DICE EPOCH 9: 0.5070


Validation Loss EPOCH 9: 0.1900
Validation DICE EPOCH 9: 0.5962
------------------------------


100%|██████████| 15/15 [00:03<00:00,  3.79it/s]
100%|██████████| 2/2 [00:01<00:00,  1.91it/s]
100%|██████████| 10/10 [00:50<00:00,  5.05s/it]

------------------------------
Training Loss EPOCH 10: 0.2017
Training DICE EPOCH 10: 0.5222


Validation Loss EPOCH 10: 0.1892
Validation DICE EPOCH 10: 0.5931
------------------------------



