# Brain tumour image segmentation

## The imaging dataset

The dataset is curated from the brain imaging dataset in [Medical Decathlon Challenge](http://medicaldecathlon.com/). To save storage and reduce computational cost, 2D image slices from T1-Gd contrast enhanced 3D brain volumes are extracted and downsampled.

The dataset consists of a training set and a test set. Each image is of dimension 120 x 120, with a corresponding label map of the same dimension. There are four number of classes in the label map:

- 0: background
- 1: edema
- 2: non-enhancing tumour
- 3: enhancing tumour

### Notebook initialisation:

In [None]:
# Import libraries
import tarfile
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
import time
import os
import random
import matplotlib.pyplot as plt
from matplotlib import colors

datafile = tarfile.open('BrainTumour_2D.tar.gz')
datafile.extractall()
datafile.close()

### Visualising a random set of 4 training images along with their label maps

In [None]:
seg_cmap = colors.ListedColormap(['black', 'green', 'blue', 'red'])
train_im_dir = 'BrainTumour_2D/training_images/'
train_seg_dir = 'BrainTumour_2D/training_labels/'
images = os.listdir(train_im_dir)
for _ in range(4):
  c = random.choice(images)
  image = train_im_dir + c
  smap = train_seg_dir + c
  fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))
  axes[0].imshow(imageio.imread(image), cmap='gray')
  axes[0].set_title('Brain MR Image')
  axes[1].imshow(imageio.imread(smap), cmap=seg_cmap)
  axes[1].set_title('Brain MR Segmentation Map')
  plt.show()

### Implementing a dataset class to read the imaging dataset and get items, pairs of images and label maps, as training batches

In [None]:
def normalise_intensity(image, thres_roi=1.0):
    """ Normalise the image intensity by the mean and standard deviation """
    # ROI defines the image foreground
    val_l = np.percentile(image, thres_roi)
    roi = (image >= val_l)
    mu, sigma = np.mean(image[roi]), np.std(image[roi])
    eps = 1e-6
    image2 = (image - mu) / (sigma + eps)
    return image2


class BrainImageSet(Dataset):
    """ Brain image set """
    def __init__(self, image_path, label_path='', deploy=False):
        self.image_path = image_path
        self.label_path = label_path
        self.deploy = deploy
        self.images = []
        self.labels = []

        image_names = sorted(os.listdir(image_path))
        for image_name in image_names:
            # Read the image
            image = imageio.imread(os.path.join(image_path, image_name))
            self.images += [image]

            # Read the label map
            if not self.deploy:
                label_name = os.path.join(label_path, image_name)
                label = imageio.imread(label_name)
                self.labels += [label]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Get an image and perform intensity normalisation
        # Dimension: XY
        image = normalise_intensity(self.images[idx])

        # Get its label map
        # Dimension: XY
        label = self.labels[idx]
        return image, label

    def get_random_batch(self, batch_size):
        # Get a batch of paired images and label maps
        # Dimension of images: NCXY
        # Dimension of labels: NXY
        images, labels = [], []
        for _ in range(batch_size):
          n = random.randint(0,len(self.images)-1)
          images.append(np.stack((self.images[n],), axis=0))
          labels.append(self.labels[n])

        images = np.array(images)
        labels = np.array(labels)
        return images, labels

### Applying the [U-Net architecture](https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28) for building the segmentation class

In [None]:
""" U-net """
class UNet(nn.Module):
    def __init__(self, input_channel=1, output_channel=1, num_filter=16):
        super(UNet, self).__init__()

        # BatchNorm: by default during training this layer keeps running estimates
        # of its computed mean and variance, which are then used for normalization
        # during evaluation.

        # Encoder path
        n = num_filter  # 16
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channel, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU(),
            nn.Conv2d(n, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU()
        )

        n *= 2  # 32
        self.conv2 = nn.Sequential(
            nn.Conv2d(int(n / 2), n, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU(),
            nn.Conv2d(n, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU()
        )

        n *= 2  # 64
        self.conv3 = nn.Sequential(
            nn.Conv2d(int(n / 2), n, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU(),
            nn.Conv2d(n, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU()
        )

        n *= 2  # 128
        self.conv4 = nn.Sequential(
            nn.Conv2d(int(n / 2), n, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU(),
            nn.Conv2d(n, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU()
        )

        # Decoder path
        self.up3 = nn.ConvTranspose2d(n, int(n / 2), kernel_size=2, stride=2)
        self.upconv3 = nn.Sequential(
            nn.Conv2d(n, int(n / 2), kernel_size=3, padding=1),
            nn.BatchNorm2d(int(n / 2)),
            nn.ReLU(),
            nn.Conv2d(int(n / 2), int(n / 2), kernel_size=3, padding=1),
            nn.BatchNorm2d(int(n / 2)),
            nn.ReLU()
        )

        n //= 2
        self.up2 = nn.ConvTranspose2d(n, int(n / 2), kernel_size=2, stride=2)
        self.upconv2 = nn.Sequential(
            nn.Conv2d(n, int(n / 2), kernel_size=3, padding=1),
            nn.BatchNorm2d(int(n / 2)),
            nn.ReLU(),
            nn.Conv2d(int(n / 2), int(n / 2), kernel_size=3, padding=1),
            nn.BatchNorm2d(int(n / 2)),
            nn.ReLU()
        )

        n //= 2
        self.up1 = nn.ConvTranspose2d(n, int(n / 2), kernel_size=2, stride=2)
        self.upconv1 = nn.Sequential(
            nn.Conv2d(n, int(n / 2), kernel_size=3, padding=1),
            nn.BatchNorm2d(int(n / 2)),
            nn.ReLU(),
            nn.Conv2d(int(n / 2), int(n / 2), kernel_size=3, padding=1),
            nn.BatchNorm2d(int(n / 2)),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv1(x)
        conv1_skip = x

        x = self.conv2(x)
        conv2_skip = x

        x = self.conv3(x)
        conv3_skip = x

        x = self.conv4(x)

        x = self.up3(x)
        x = torch.cat((x, conv3_skip), dim=1)
        x = self.upconv3(x)

        x = self.up2(x)
        x = torch.cat((x, conv2_skip), dim=1)
        x = self.upconv2(x)

        x = self.up1(x)
        x = torch.cat((x, conv1_skip), dim=1)
        x = self.upconv1(x)
        return x

### Training the segmentation model:

In [None]:
# CUDA device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: {0}'.format(device))

# Build the model
num_class = 4
model = UNet(input_channel=1, output_channel=num_class, num_filter=16)
model = model.to(device)
params = list(model.parameters())

model_dir = 'saved_models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# Optimizer
optimizer = optim.Adam(params, lr=1e-3)

# Segmentation loss
criterion = nn.CrossEntropyLoss()

# Datasets
train_set = BrainImageSet('BrainTumour_2D/training_images', 'BrainTumour_2D/training_labels')
test_set = BrainImageSet('BrainTumour_2D/test_images', 'BrainTumour_2D/test_labels')

# Train the model
num_iter = 10000
train_batch_size = 16
eval_batch_size = 16
start = time.time()
for it in range(1, 1 + num_iter):
    start_iter = time.time()
    model.train()

    # Get a batch of images and labels
    images, labels = train_set.get_random_batch(train_batch_size)
    images, labels = torch.from_numpy(images), torch.from_numpy(labels)
    images, labels = images.to(device, dtype=torch.float32), labels.to(device, dtype=torch.long)
    logits = model(images)

    # Perform optimisation and print out the training loss
    optimizer.zero_grad()
    loss = criterion(logits, labels)
    loss.backward()
    optimizer.step()
    print(f"Iteration: [{it}/{num_iter}] - Loss: {loss.item():.4f}")

    # Evaluate
    if it % 100 == 0:
        model.eval()
        # Disabling gradient calculation during reference to reduce memory consumption
        with torch.no_grad():
            test_images, test_labels = test_set.get_random_batch(eval_batch_size)
            test_images, test_labels = torch.from_numpy(test_images), torch.from_numpy(test_labels)
            test_images, test_labels = test_images.to(device, dtype=torch.float32), test_labels.to(device, dtype=torch.long)
            test_logits = model(test_images)
            test_loss = criterion(test_logits, test_labels).item()
            print(f"Iteration [{it}/{num_iter}] - Test Loss: {test_loss:.4f} - Evaluated on {eval_batch_size} random test images")

    # Save the model
    if it % 5000 == 0:
        torch.save(model.state_dict(), os.path.join(model_dir, 'model_{0}.pt'.format(it)))
print('Training took {:.3f}s in total.'.format(time.time() - start))

### Applying the trained model to a random set of 4 test images and visualising the automated segmentation:

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(4, 3, figsize=(12, 12))
images_test, labels_test = test_set.get_random_batch(4)
images_test, labels_test = torch.from_numpy(images_test), torch.from_numpy(labels_test)
images_test, labels_test = images_test.to(device, dtype=torch.float32), labels_test.to(device, dtype=torch.long)
with torch.no_grad():
  model.cuda()
  y_hat = model(images_test)
for i in range(len(images_test)):
    test_image = images_test[i].cpu().squeeze()
    ground_truth = labels_test[i].cpu().squeeze()
    _, prediction = torch.max(y_hat[i].data, 0)
    axes[i, 0].imshow(test_image, cmap='gray')
    axes[i, 0].set_title('Test Image')
    axes[i, 1].imshow(prediction.cpu(), cmap=seg_cmap)
    axes[i, 1].set_title('Automated Segmentation')
    axes[i, 2].imshow(ground_truth, cmap=seg_cmap)
    axes[i, 2].set_title('Ground Truth Segmentation')

plt.tight_layout()
plt.show()

## Evaluation:

The model works fairly well with a final training loss of ~0.0788 and a final test loss of ~0.1236 after 10000 iterations.

The model upon general inspection seems to do pretty well overall in detecting conditions.
It is fairly good in determining if an edema is present in the scan, however there are times where the model seems to struggle a little with identifying enhancing tumours and may confuse them with non-enhancing tumours.