# Exercise 2: Detection of liver tumors

In [1]:
%pip install --quiet celluloid

In [2]:
%pip install --quiet torchio

In [3]:
%pip install --quiet monai

In [4]:
%matplotlib inline
from pathlib import Path
import nibabel as nib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.patches as mpatches
import numpy as np
import cv2
import tarfile
import re

from celluloid import Camera
from IPython.display import HTML

import os
import random
import shutil

import torchio as tio
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.transforms import functional as F
from torchvision.transforms import Grayscale
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import warnings
from typing import Optional

# Data Loading

Download dataset

In [5]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)
%cd /content/drive/MyDrive/AI_for_Healthcare/Exercise2/

Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/158PtvUr8lfOwjSGt_oaovb-hZuMvQxOl/AI_for_Healthcare/Exercise2


In [6]:
# tar_path = 'Task03_Liver.tar'

# with tarfile.open(tar_path, 'r') as tar:
#     tar.extractall()


In [7]:
root = Path("Task03_Liver/imagesTr/")
label = Path("Task03_Liver/labelsTr/")

In [8]:
def img_path_to_label_path(path):
    """
    replace 'imagesTr' with 'labelsTr' in the path
    and get the subject id from 'imagesTr' for later use to find corresponding label in 'labelsTr'
    """
    parts = list(path.parts)
    parts[parts.index("imagesTr")] = "labelsTr"
    label_path = Path(*parts)

    # get subject id from filename
    number_match = re.search(r'liver_(\d+)\.nii\.gz', str(label_path))
    if number_match:
        number = int(number_match.group(1))
        return number, label_path
    else:
        return None, label_path

In [9]:
class MySubject(tio.Subject):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    # Add or override any methods as needed
    def check_consistent_attribute(
        self,
        attribute: str,
        relative_tolerance: float = 1e-6,
        absolute_tolerance: float = 1e-6,
        message: Optional[str] = None,
    ) -> None:
        pass

    def check_consistent_spatial_shape(self) -> None:
        pass
    def check_consistent_orientation(self) -> None:
        pass
    def check_consistent_affine(self) -> None:
        pass

In [10]:
ct_folder = Path("Task03_Liver/imagesTr/")

subjects_paths = list(ct_folder.glob("liver_*"))
subjects = []

for subject_path in subjects_paths:
    label_path = img_path_to_label_path(subject_path)[1]
    subject = tio.Subject({"CT":tio.ScalarImage(subject_path), "Label":tio.LabelMap(label_path)})
    subjects.append(subject)

In [11]:
tio.Subject.relative_attribute_tolerance = 1000 # applies to all instances since it is a static attribute
tio.Subject.absolute_attribute_tolerance = 1000

In [12]:
subjects[1]['CT'].orientation

('R', 'A', 'S')

In [13]:
for subject in subjects:
    assert subject["CT"].orientation == ("R", "A", "S")

All of our subjects have the same CT orientation.

In [14]:
len(subjects)

131

# Data Preprocessing

In [15]:
preprocess = tio.Compose([
            tio.ToCanonical(),
            tio.CropOrPad((256, 256, 200)), # crop/pad to this shape so all samples same size
            tio.RescaleIntensity((-1, 1)), # normalize intensity range to -1 to 1
            tio.Clamp(out_min= -150, out_max= 250), # HU-value clipping
            tio.Resample('CT')
            ])

augmentation = tio.RandomAffine(scales=(0.9, 1.1), degrees=(-10, 10))

train_transform = tio.Compose([preprocess, augmentation])
val_transform = preprocess
test_transform = preprocess

## Split the data set into training and testing sets

We have a total of 131 subjects. We use 105 subjects for training, and 13 subjects for each of the val and test datasets.

In [23]:
# train_dataset = tio.SubjectsDataset(subjects[:105], transform=train_transform)
# val_dataset = tio.SubjectsDataset(subjects[105:118], transform=val_transform)
# test_dataset = tio.SubjectsDataset(subjects[118:], transform=test_transform)

train_dataset = tio.SubjectsDataset(subjects[:3], transform=train_transform)
val_dataset = tio.SubjectsDataset(subjects[3:5], transform=val_transform)
test_dataset = tio.SubjectsDataset(subjects[5:7], transform=test_transform)

Process the data in batches

In [24]:
batch_size = 1

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


# Training the RetinaNet Model

Note for evaluation:
Individual lesions are defined as 3D connected components within an image. A lesion is considered
detected if the predicted lesion has sufficient overlap with its corresponding reference lesion, measured
as the intersection over the union of their respective segmentation masks. It allows for a count of true
positive, false positive, and false-negative detection, from which we compute the precision and recall of
lesion detection.

In [25]:
def mask_to_bbox(masks):
    batch_boxes = []
    for mask in masks:
        image_boxes = []
        image_labels = []
        for label_value in torch.unique(mask):
            if label_value == 0:  # skip the background label
                continue
            indices = torch.nonzero(mask == label_value, as_tuple=True)
            if len(indices[0]) > 0:  # check if the indices tensor is not empty
                x_min, x_max = torch.min(indices[1]), torch.max(indices[1])
                y_min, y_max = torch.min(indices[0]), torch.max(indices[0])

                # make sure bbox has positive height and width
                if x_max > x_min and y_max > y_min:
                    # 2D bounding box has format [y_min, x_min, y_max, x_max]
                    bounding_box = torch.tensor([x_min.item(), y_min.item(), x_max.item(), y_max.item()], dtype=torch.float32)

                    label_value = label_value.item()

                    image_boxes.append(bounding_box)
                    image_labels.append(label_value)

        if len(image_boxes) == 0:
            # add empty box if there are no objects in slice
            batch_boxes.append({'boxes': torch.zeros((0,4),dtype=torch.float32).to(device), 'labels': torch.tensor([0], dtype=torch.int64).to(device)})
        else:
            image_boxes = torch.stack(image_boxes).to(device)
            image_labels = torch.tensor(image_labels, dtype=torch.int64).to(device)
            batch_boxes.append({'boxes': image_boxes, 'labels': image_labels})

    return batch_boxes

Training loop:

In [26]:
from torchvision.models.detection import retinanet_resnet50_fpn_v2

num_classes = 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 5

model = retinanet_resnet50_fpn_v2(pretrained=True)
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)



In [27]:
checkpoint_path = 'Checkpoints/'

In [34]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0


    # iterate over the training data
    for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):

        # get the input and target tensors for the batch
        inputs = batch['CT']['data']
        targets = batch['Label']['data']

        batch_size, channels, height, width, depth = inputs.size()

        # convert to 3 channel input to work with 2d retinanet model
        inputs = inputs.expand(batch_size, 3, height, width, depth)
        targets = targets.expand(batch_size, 3, height, width, depth)

        # reshape inputs into required size
        inputs = inputs.permute(0, 4, 1, 2, 3)
        inputs = inputs.flatten(0,1)
        inputs = [tensor.squeeze(1).to(device) for tensor in inputs]

        targets = targets.permute(0, 4, 1, 2, 3)
        targets = targets.flatten(0,1)
        targets = [tensor.squeeze(1) for tensor in targets]

        # convert the GT segmentation masks to bounding boxes
        ground_truth_boxes = mask_to_bbox(targets)


        # Forward pass
        optimizer.zero_grad()

        split_size = 10
        inputs = [inputs[i:i+split_size] for i in range(0, len(inputs), split_size)]
        ground_truth_boxes = [ground_truth_boxes[i:i+split_size] for i in range(0, len(ground_truth_boxes), split_size)]

        for mini_inputs, mini_gt_boxes in zip(inputs, ground_truth_boxes):

            output_dict = model(mini_inputs, mini_gt_boxes)
            loss_cls = output_dict['classification']
            loss_reg = output_dict['bbox_regression']
            loss = loss_cls + loss_reg

            loss.backward()
            optimizer.step()

            # calculate loss per slice and add to total
            train_loss += loss.item() * len(mini_inputs)

            del mini_inputs, mini_gt_boxes, output_dict, loss_cls, loss_reg
            torch.cuda.empty_cache()

    # calculate the average training loss for the epoch
    train_loss = train_loss / len(train_dataset)

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}')

    # Validation loop
    # model.eval()  # Set the model to evaluation mode
    val_loss = 0.0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):

            # get the input and target tensors for the batch
            inputs = batch['CT']['data']
            targets = batch['Label']['data']

            batch_size, channels, height, width, depth = inputs.size()

            # convert to 3 channel input to work with 2d retinanet model
            inputs = inputs.expand(batch_size, 3, height, width, depth)
            targets = targets.expand(batch_size, 3, height, width, depth)

            # reshape inputs into required size
            inputs = inputs.permute(0, 4, 1, 2, 3)
            inputs = inputs.flatten(0,1)
            inputs = [tensor.squeeze(1).to(device) for tensor in inputs]

            targets = targets.permute(0, 4, 1, 2, 3)
            targets = targets.flatten(0,1)
            targets = [tensor.squeeze(1) for tensor in targets]

            # convert the GT segmentation masks to bounding boxes
            ground_truth_boxes = mask_to_bbox(targets)

            split_size = 10
            inputs = [inputs[i:i+split_size] for i in range(0, len(inputs), split_size)]
            ground_truth_boxes = [ground_truth_boxes[i:i+split_size] for i in range(0, len(ground_truth_boxes), split_size)]

            for mini_inputs, mini_gt_boxes in zip(inputs, ground_truth_boxes):

                output_dict = model(mini_inputs, mini_gt_boxes)
                loss_cls = output_dict['classification']
                loss_reg = output_dict['bbox_regression']
                loss = loss_cls + loss_reg

                # calculate loss per slice and add to total
                val_loss += loss.item() * len(mini_inputs)

                del mini_inputs, mini_gt_boxes, output_dict, loss_cls, loss_reg
                torch.cuda.empty_cache()

        # calculate the average validation loss for the epoch
        val_loss = val_loss / len(val_dataset)

        print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}')

        torch.save(model.state_dict(), f'{checkpoint_path}/model_epoch{epoch+1}.pth')

Epoch 1/5: 100%|██████████| 3/3 [02:38<00:00, 52.70s/it]


Epoch [1/5], Train Loss: 28.6641


Epoch 1/5: 100%|██████████| 2/2 [01:01<00:00, 30.54s/it]


Epoch [1/5], Validation Loss: 530.4664


Epoch 2/5: 100%|██████████| 3/3 [02:37<00:00, 52.39s/it]


Epoch [2/5], Train Loss: 1269.6610


Epoch 2/5: 100%|██████████| 2/2 [01:01<00:00, 30.69s/it]


Epoch [2/5], Validation Loss: 25172.8809


Epoch 3/5: 100%|██████████| 3/3 [02:36<00:00, 52.21s/it]


Epoch [3/5], Train Loss: 1057.1961


Epoch 3/5: 100%|██████████| 2/2 [01:01<00:00, 30.63s/it]


Epoch [3/5], Validation Loss: 12011.6842


Epoch 4/5: 100%|██████████| 3/3 [02:35<00:00, 51.75s/it]


Epoch [4/5], Train Loss: 595.0103


Epoch 4/5: 100%|██████████| 2/2 [01:00<00:00, 30.28s/it]


Epoch [4/5], Validation Loss: 8955.6092


Epoch 5/5: 100%|██████████| 3/3 [02:34<00:00, 51.58s/it]


Epoch [5/5], Train Loss: 777.2659


Epoch 5/5: 100%|██████████| 2/2 [00:59<00:00, 29.85s/it]


Epoch [5/5], Validation Loss: 66561.5092


# Evaluate Model

In [None]:
# model = retinanet_resnet50_fpn_v2(pretrained=True)
# model.load_state_dict(torch.load('Checkpoing/model_epoch5.pth'))
# model.eval()

In [None]:
def calculate_iou(boxA, boxB):
    if len(boxA) == 0 or len(boxB) == 0:
        return 0.0

    # calculate IoU
    iou = ops.box_iou(boxA, boxB)

    return iou

In [None]:
model.eval()
test_loss = 0.0
total_iou = 0
total_tumors = 0
correct_tumors = 0
tp = 0
tn = 0
fp = 0
fn = 0

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Testing'):

        # get the input and target tensors for the batch
        inputs = batch['CT']['data']
        targets = batch['Label']['data']

        batch_size, channels, height, width, depth = inputs.size()

        # convert to 3 channel input to work with 2d retinanet model
        inputs = inputs.expand(batch_size, 3, height, width, depth)
        targets = targets.expand(batch_size, 3, height, width, depth)

        # reshape inputs into required size
        inputs = inputs.permute(0, 4, 1, 2, 3)
        inputs = inputs.flatten(0,1)
        inputs = [tensor.squeeze(1).to(device) for tensor in inputs]

        targets = targets.permute(0, 4, 1, 2, 3)
        targets = targets.flatten(0,1)
        targets = [tensor.squeeze(1) for tensor in targets]

        # convert the GT segmentation masks to bounding boxes
        ground_truth_boxes = mask_to_bbox(targets)

        split_size = 10
        inputs = [inputs[i:i+split_size] for i in range(0, len(inputs), split_size)]
        ground_truth_boxes = [ground_truth_boxes[i:i+split_size] for i in range(0, len(ground_truth_boxes), split_size)]

        for mini_inputs, mini_gt_boxes in zip(inputs, ground_truth_boxes):

            output_dicts = model(mini_inputs, mini_gt_boxes)
            # loss_cls = output_dict['classification']
            # loss_reg = output_dict['bbox_regression']
            # loss = loss_cls + loss_reg

            # # calculate loss per slice and add to total
            # test_loss += loss.item() * len(mini_inputs)

            for output_dict, gt_box in zip(output_dicts, mini_gt_boxes):
                # Get the predicted boxes, labels, and scores
                predicted_boxes = output_dict['boxes']
                predicted_labels = output_dict['labels']
                predicted_scores = output_dict['scores']

                for box, label in zip(predicted_boxes, predicted_labels):
                    # classify tumors
                    iou = calculate_iou(gt_box['boxes'], box)
                    if torch.is_tensor(iou):
                        is_tumor = (iou > 0.5).any().item()
                    else:
                        is_tumor = iou > 0.5


                    if is_tumor:
                        total_tumors += 1
                        if predicted_labels.item() == 1:  # Assuming tumor class is 1
                            correct_tumors += 1

                    if is_tumor and gt_box['labels'].item() == 2:
                        tp += 1
                    elif is_tumor and (gt_box['labels'].item() == 0 or gt_box['labels'].item() == 1):
                        fn += 1
                    elif not is_tumor and (gt_box['labels'].item() == 0 or gt_box['labels'].item() == 1):
                        fp += 1
                    elif not is_tumor and (gt_box['labels'].item() == 0):
                        tn += 1

            del mini_inputs, mini_gt_boxes, output_dict
            torch.cuda.empty_cache()

    # calculate the average test loss for the epoch
    test_loss = test_loss / len(val_dataset)

    print(f'Epoch [{epoch+1}/{num_epochs}], Test Loss: {test_loss:.4f}')

# calculate metrics
accuracy = correct_tumors / total_tumors
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1_score = 2 * (precision * recall) / (precision + recall)

test_loss = test_loss / len(test_dataset)

print(f'Test Loss: {test_loss:.4f}')
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1_score:.4f}')
