# RetinaNet_ResNet50_fpn Network for Cell detection

In [None]:
import os
import torch
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import cv2
import torchvision
from torchvision.models.detection import retinanet_resnet50_fpn
from torchvision.models.detection.anchor_utils import AnchorGenerator
from sklearn.model_selection import train_test_split
import albumentations as A

# Custom Dataset class

In the classification network, we returned a cell at time with get_item method. In the detection network, we will return a larger crop of an image (300 x 300). Also, we have to include the bounding boxes and corresponding labels.

In [None]:
class MyDataset(Dataset):
    def __init__(self, df, transform, num_samples=10, crop_width = 300, crop_height = 300):
        self.df = df
        self.num_samples = num_samples # number of crops
        self.crop_width = crop_width
        self.crop_height = crop_height
        self.transform = transform
        self.image_names = df['filename'].unique()


    def is_within_crop(self, x, y, crop_x, crop_y, crop_w, crop_h):
      return crop_x <= x < crop_x + crop_w and crop_y <= y < crop_y + crop_h

    def adjust_coordinates(self, x, y, crop_x, crop_y):
      return x - crop_x, y - crop_y

    def clamp_bbox(self, bbox, width, height):
        x_min, y_min, x_max, y_max, label = bbox
        x_min = max(0, min(x_min, width - 1))
        y_min = max(0, min(y_min, height - 1))
        x_max = max(0, min(x_max, width - 1))
        y_max = max(0, min(y_max, height - 1))
        return [x_min, y_min, x_max, y_max, label]

    def __len__(self):
        return len(self.image_names) * self.num_samples

    def __getitem__(self, index):

      # selecting an image and a crop indexes
      img_index = index // self.num_samples
      crop_index = index % self.num_samples

      img_name = self.image_names[img_index]
      img_path = os.path.join('replace w/ path', img_name)
      image = cv2.imread(img_path)

      height, width, _ = image.shape

      # cropping randomly
      crop_x = np.random.randint(0, width - self.crop_width + 1)
      crop_y = np.random.randint(0, height - self.crop_height + 1)
      crop_image = image[crop_y:crop_y + self.crop_height, crop_x:crop_x + self.crop_width]
      crop_image = cv2.cvtColor(crop_image, cv2.COLOR_BGR2RGB)

      # cells for the current image
      img_df = self.df[self.df['filename'] == img_name]
      crop_annotations = []

      for _, row in img_df.iterrows():
        max_x, max_y = row['max_x'], row['max_y']
        min_x, min_y = row['min_x'], row['min_y']
        label = row['label']

        if (self.is_within_crop(max_x, max_y, crop_x, crop_y, self.crop_width, self.crop_height) or
            self.is_within_crop(min_x, min_y, crop_x, crop_y, self.crop_width, self.crop_height) or
            self.is_within_crop(min_x, max_y, crop_x, crop_y, self.crop_width, self.crop_height) or
            self.is_within_crop(max_x, min_y, crop_x, crop_y, self.crop_width, self.crop_height)):

          adj_max_x, adj_max_y = self.adjust_coordinates(max_x, max_y, crop_x, crop_y)
          adj_min_x, adj_min_y = self.adjust_coordinates(min_x, min_y, crop_x, crop_y)

          clamped_bbox = self.clamp_bbox([adj_min_x, adj_min_y, adj_max_x, adj_max_y, label], self.crop_width, self.crop_height)
          #crop_annotations.append(clamped_bbox)

          if clamped_bbox[0] < clamped_bbox[2] and clamped_bbox[1] < clamped_bbox[3]:
                    crop_annotations.append(clamped_bbox)

      # normalising bboxes[0,1] to be compatible with albumentaions
      bboxes = []
      labels = []
      for an in crop_annotations:
        x_min, y_min, x_max, y_max, label = an

        if x_min < x_max and y_min < y_max:
          if label != 0:
            bboxes.append([
                min(x_min / self.crop_width, x_max / self.crop_width),
                min(y_min / self.crop_height, y_max / self.crop_height),
                max(x_min / self.crop_width, x_max / self.crop_width),
                max(y_min / self.crop_height, y_max / self.crop_height),
            ])
            labels.append(1)
          else:
            bboxes.append([
                min(x_min / self.crop_width, x_max / self.crop_width),
                min(y_min / self.crop_height, y_max / self.crop_height),
                max(x_min / self.crop_width, x_max / self.crop_width),
                max(y_min / self.crop_height, y_max / self.crop_height),
            ])
            labels.append(0)
      # Ensure at least one bounding box is present
      if len(bboxes) == 0:
        bboxes.append([0, 0, 1, 1])  # Default bounding box
        labels.append(0)

      # Applying transormations
      transformed = self.transform(image= crop_image,
                                   bboxes= bboxes,
                                   labels= labels)
      crop_image = transformed['image']
      transformed_bboxes = transformed['bboxes']
      labels = transformed['labels']

      target = {
          'boxes': torch.tensor(bboxes, dtype=torch.float32),
          'labels': torch.tensor(labels, dtype=torch.int64)
      }

      return torch.tensor(crop_image, dtype=torch.float32).permute(2,0,1), target

# Augmentation Methods

Geometric augmentations are complicated for detection tasks, as bboxes always have to be transformed too. Therefore, we use albumentations library to solve this issue.  

In [None]:
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=90, p=0.5),
    A.RandomBrightnessContrast(p=0.2),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))

# Initialising the model

For this task, we utilise retinanet_resnet50_fpn model, which is a retinanet with resnet50_fpn backbone. Retinanet uses anchor boxes to predict the bounding box. The model also use focal loss function, hence we don't need to define a custom function.

More information about the model can be found on torchvision documentation.

In [None]:
model = retinanet_resnet50_fpn(pretrained=True)
num_classes = 2  # background=0 , cell=1
in_features = model.backbone.out_channels

# Anchor boxes
anchor_generator = AnchorGenerator(
    sizes=((32,), (64,), (128,), (256,), (512,),),
    aspect_ratios=((0.5, 1.0, 2.0),) * 5,
)

num_anchors = anchor_generator.num_anchors_per_location()[0]
model.head = torchvision.models.detection.retinanet.RetinaNetHead(
    in_channels=in_features, num_classes=num_classes, num_anchors=num_anchors
)


model.anchor_generator = anchor_generator

# Freezing the backbone
for param in model.backbone.parameters():
    param.requires_grad = False

# Optimizer & Splitting the dataset

Adam optimizer

Train 70 % , Validation 15 %, Test 15 %

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
my_dataset = pd.read_csv('annotation_frame.csv')

traindata, testdata = train_test_split(my_dataset, test_size=0.3, random_state=56)
validata, testdata = train_test_split(testdata, test_size=0.5, random_state=56)

train_dataset = MyDataset(traindata, transform=transform)
val_dataset = MyDataset(validata, transform=transform)
test_dataset = MyDataset(testdata, transform=transform)

def custom_length(batch):
  images, targets = zip(*batch)
  return images, targets

trainloader = DataLoader(train_dataset, batch_size=4, collate_fn=custom_length, shuffle=True)
valiloader = DataLoader(val_dataset, batch_size=4, collate_fn=custom_length, shuffle=False)
testloader = DataLoader(test_dataset, batch_size=4, collate_fn=custom_length, shuffle=False)

# Training & Validating Loops

We'll use inbuilt focal loss as the loss function. But we need a detection metric to evaluate the accuracy of the bounding boxes. Therefore, we will use mean average precision (mAP).

In [None]:
def train_one_epoch(model, dataloader, optimizer):
    model.train()
    running_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc='Training'):
      images = torch.stack(inputs)
      targets = [{'boxes': t['boxes'], 'labels': t['labels']} for t in targets]

      # making gradients zero
      optimizer.zero_grad()
      # forward pass
      loss_dict = model(images, targets)
      losses = sum(loss for loss in loss_dict.values())
      #back propagation
      losses.backward()
      # updating parameters
      optimizer.step()

      running_loss += losses.item()

    epoch_loss = running_loss / len(dataloader)
    return epoch_loss

def valid_one_epoch(model, dataloader):
    model.eval()
    val_loss = 0.0
    mAP = MeanAveragePrecision(iou_type='bbox')
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc='Validation'):
          images = torch.stack(inputs)
          targets = [{'boxes': t['boxes'], 'labels': t['labels']} for t in targets]

          # Calculate loss
          loss_dict = model(images, targets)
          losses = sum(loss for loss in loss_dict.values())
          val_loss += losses.item()
          # Collect predictions and targets for mAP calculation
          outputs = model(images)
          for output, target in zip(outputs, targets):
                all_predictions.append({k: v for k, v in output.items()})
                all_targets.append({k: v for k, v in target.items()})

    # Calculate mAP
    mAP_val = mAP(all_predictions, all_targets)

    epoch_loss = val_loss / len(dataloader)

    return epoch_loss, mAP_val

In [None]:
def n_epoch(model, train_loader, valid_loader, optimizer, n_epochs=5):
  train_losses = []
  valid_losses = []
  valid_mAPs = []
  best_mAP = -1

  for epoch in range(n_epochs):
    print(f'Epoch {epoch + 1}/{n_epochs}')

    # Training
    train_loss = train_one_epoch(model, train_loader, optimizer)
    train_losses.append(train_loss)
    print(f"Training loss: {train_loss}")

    # Validation
    valid_loss, mAP_val = valid_one_epoch(model, valid_loader)
    valid_losses.append(valid_loss)
    valid_mAPs.append(mAP_val)
    print(f"Validation Loss: {valid_loss}")
    print(f"Mean Average Precision: {mAP_val}")

    if mAP_val > best_mAP:
            best_mAP = mAP_val
            torch.save(model.state_dict(), 'best_model_detection.pth')


  return train_losses, valid_losses, valid_mAPs

# Training

In [None]:
train_losses, valid_losses, valid_mAPs = n_epoch(model, trainloader, valiloader, optimizer=optimizer)