# **PyTorch tutorial: Brain tumour classification**

Learning outcomes:
*   Train a basic machine learning model using PyTorch.
*   Have an appreciation for the steps taken in a machine learning workflow of an AI research project.



## **Import necessary libraries**

In [None]:
import h5py
import os
import torch
import shutil
import random
import time
import numpy as np
import matplotlib.pyplot as plt

from skimage import io, transform
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch import nn
from torch.optim import AdamW
from torch.optim import lr_scheduler

# **Data**

## Get Dataset

The dataset is downloaded from [figshare](https://figshare.com/articles/dataset/brain_tumor_dataset/1512427).

This brain tumor dataset containing 3064 T1 MRIs from 233 patients with three kinds of brain tumor: Meningioma (708 slices), Glioma (1426 slices), Pituitary tumor (930 slices).

This data is organized in matlab data format (.mat file). Each file stores a struct containing the following fields for an image:
*   cjdata.label:
    *   1 for meningioma
    *   2 for glioma
    *   3 for pituitary tumor.
*   cjdata.PID: patient ID.
*   cjdata.image: image data.
*   cjdata.tumorBorder: a vector storing the coordinates of discrete points on tumor border.
*   cjdata.tumorMask: a binary image with 1s indicating tumor region.

Only image and tumorMask are used in this project.

In [None]:
# Create dataset folder in /content/ folder
!mkdir /content/brain_tumour_dataset/
!mkdir /content/brain_tumour_dataset/BrainTumorData/

# Download dataset from figshare
!wget https://figshare.com/ndownloader/files/3381290/brainTumorDataPublic_1-766.zip https://figshare.com/ndownloader/files/3381293/brainTumorDataPublic_1533-2298.zip https://figshare.com/ndownloader/files/3381296/brainTumorDataPublic_767-1532.zip https://figshare.com/ndownloader/files/3381302/brainTumorDataPublic_2299-3064.zip
!unzip /content/'*.zip*' -d /content/brain_tumour_dataset/BrainTumorData/

## Data visualization

Visualize 4 images and their correspondnig mask.

In [None]:
ncol = 4
rand_ndx = random.sample(range(0, 3065), ncol)
fig, ax = plt.subplots(nrows=3,  ncols=ncol, figsize=(20, 10))
i = 0
for n in rand_ndx:
  file = h5py.File(r"/content/brain_tumour_dataset/BrainTumorData/"+
                   str(n)+'.mat','r').get('cjdata')
  ax[0][i].imshow(file.get('image')[()],cmap='gray')
  ax[0][i].imshow(file.get('tumorMask')[()],cmap='gray', alpha=0.3)
  ax[0][i].set_title('Overlay')
  ax[1][i].imshow(file.get('image')[()],cmap='gray')
  ax[1][i].set_title('Image')
  ax[2][i].imshow(file.get('tumorMask')[()],cmap='gray')
  ax[2][i].set_title('Mask')
  i+=1

## Dataset class

Create a [custom Dataset class](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html).

The 2 most important methods to override from the Pytorch Dataset class are:
*   \__len__ so that len(dataset) returns the size of the dataset.
*   \__getitem__ to support the indexing such that dataset[i] can be used to get *ith* sample.

To avoid loading all the images at once, we initialize the dataset with the paths to the images and load them only when \__getitem__ is called (through the dataloader).





In [None]:
class BrainMRIDataset(Dataset):
    """Brain MRI dataset."""

    def __init__(self, root_dir, transform=None):
        """
        Arguments:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.images = os.listdir(self.root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        filename = os.path.join(self.root_dir,self.images[idx])
        file = h5py.File(filename,'r').get('cjdata')
        image = file.get('image')[()]
        mask = file.get('tumorMask')[()]

        sample = {'image': image, 'mask': mask}

        if self.transform:
            sample = self.transform(sample)

        return sample

## Transform class

The transforms help to preprocess the data so that it fits the model inputs format.

They can also be used for data augmentation (crop, rotation, etc).

In this project, we create a resize transform and a 'toTensor' transform to go from numpy arrays to tensors as expected by the model.


In [None]:
class Resize(object):
    """Resize the image and mask.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image = sample['image']
        mask = sample['mask']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        image = transform.resize(image, (new_h, new_w)) * 255.0
        image = np.stack((image,) * 3, axis=-1)
        mask = transform.resize(mask, (new_h, new_w)) * 255.0
        mask = np.expand_dims(mask,axis=-1)

        sample['image'] = image
        sample['mask'] = mask

        return sample

class ToTensor(object):
    """Convert ndarrays to Tensors."""

    def __call__(self, sample):
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        image = sample['image'].transpose((2, 0, 1))
        mask = sample['mask'].transpose((2, 0, 1))
        sample['image'] = torch.from_numpy(image)
        sample['mask'] = torch.from_numpy(mask)

        return sample

## Split Data and create Data sets

In deep learning, 3 sets are needed: train, test, and validation. We use the train_test_split function from sklearn library two times to split the data. The split is done on the path to images.

[Reminder on data split](https://medium.com/syntaxerrorpub/understanding-the-difference-between-training-test-and-validation-sets-in-machine-learning-c59feec6483b#:~:text=In%20summary%2C%20training%2C%20testing%2C,model%20selection%20and%20hyperparameter%20tuning.): 'The training set is used to train the model; the test set evaluates its performance on unseen data; and the validation set aids in model selection and hyperparameter tuning.'

In [None]:
all_files = os.listdir("/content/brain_tumour_dataset/BrainTumorData/")

# Split train and test data
train_data, test_data = train_test_split(all_files, test_size = 0.1, random_state=123)

# Split train and validation data
train_data, val_data = train_test_split(train_data, test_size = 0.2, random_state=123)

Then, we create the 3 split folders and move images accordingly.

In [None]:
root_dir = '/content/brain_tumour_dataset/BrainTumorData/'

# Create train folder
os.makedirs(os.path.join(root_dir, 'train'), exist_ok=True)
for train_file in train_data:
  shutil.move(os.path.join(root_dir, train_file), os.path.join(root_dir, 'train', train_file))

# Create validation folder
os.makedirs(os.path.join(root_dir, 'val'), exist_ok=True)
for val_file in val_data:
  shutil.move(os.path.join(root_dir, val_file), os.path.join(root_dir, 'val', val_file))

# Create test folder
os.makedirs(os.path.join(root_dir, 'test'), exist_ok=True)
for test_file in test_data:
  shutil.move(os.path.join(root_dir, test_file), os.path.join(root_dir, 'test', test_file))

Finally, we create 3 dataset instances passing the path to the corresponding folder and the previous mentionned transforms.

*transforms.Compose* enables to run the transforms sequentially on the data.

In [None]:
# Create train dataset
train_dataset = BrainMRIDataset(
    root_dir='/content/brain_tumour_dataset/BrainTumorData/train',
    transform=transforms.Compose([Resize(256), ToTensor()]))

# Create validation dataset
val_dataset = BrainMRIDataset(
    root_dir='/content/brain_tumour_dataset/BrainTumorData/val',
    transform=transforms.Compose([Resize(256), ToTensor()]))

# Create test dataset
test_dataset = BrainMRIDataset(
    root_dir='/content/brain_tumour_dataset/BrainTumorData/test',
    transform=transforms.Compose([Resize(256), ToTensor()]))

## Data Loaders

The dataset objects are used to create the data loaders. The default DataLoader class from Pytorch is used with batch size 4 (i.e passing 4 images at a time to the network).

*Shuffle* is set to *True* for the train loader to mix the dataset.

In [None]:
# Training data
train_dataloader = DataLoader(
    train_dataset,batch_size= 4, num_workers=2, shuffle=True)

In [None]:
# Test data
test_dataloader = DataLoader(
    test_dataset,batch_size= 4, num_workers=2, shuffle=False)

In [None]:
# Validation data
val_dataloader = DataLoader(
    val_dataset,batch_size= 4, num_workers=2, shuffle=False)

# **Model**

Torch hub provides a [U-Net model](https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/) with pretrained weights for brain MRI.

In [None]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)

The model is sent to GPU if available to speed the calculations.

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# **Loss**

We combine the DICE and BCE losses into the criterion we want to minimize during training. More information on losses choice can be found [here](https://www.linkedin.com/pulse/in-depth-exploration-loss-functions-deep-learning-kiran-dev-yadav/).

In [None]:
def dice_loss(inputs, target):
    num = target.size(0)
    inputs = inputs.reshape(num, -1)
    target = target.reshape(num, -1)
    smooth = 1.0
    intersection = (inputs * target)
    dice = (2. * intersection.sum(1) + smooth) / (inputs.sum(1) + target.sum(1) + smooth)
    dice = 1 - dice.sum() / num
    return dice

def bce_dice_loss(inputs, target):
    dicescore = dice_loss(inputs, target)
    bcescore = nn.BCELoss()
    bceloss = bcescore(inputs, target)

    return bceloss + dicescore

In [None]:
criterion = bce_dice_loss

# **Optimizer**

The AdamW optimizer is used for backpropagation as well as a learning rate scheduler that will reduce the learning rate every 5 steps with a coefficient of 0.5.

In [None]:
optimizer = AdamW(model.parameters(), 0.1)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# **Training**

First, the loss histories are created and the best loss is set to infinite.

In [None]:
loss_history = []
loss_history_val = []
best_loss_val = float('inf')

We train the model for 30 epochs. At eahc step, the model is saved at /content/best_model.pth if the validation is smaller than the one from previous step.

In [None]:
print("Start train…")

for epoch in range(30):
   #Train mode
    start_time = time.time()
    model.train()
    loss_running = []
    for j, sample in enumerate(train_dataloader):
        x, y = sample['image'].float().to(device), sample['mask'].float().to(device)

        pred = model(x)
        loss = criterion(pred, y)
        loss_running.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    loss_history.append(np.mean(loss_running))

    # Evaluate mode
    model.eval()
    with torch.no_grad():
        loss_val_running = []
        for _, sample in enumerate(val_dataloader):
            x_val, y_val = sample['image'].float().to(device), sample['mask'].float().to(device)
            pred_val = model.forward(x_val) #pred_val = model(x_val)
            loss_val= criterion(pred_val, y_val)
            loss_val_running.append(loss_val.item())

    curr_loss_val = np.mean(loss_val_running)
    loss_history_val.append(curr_loss_val)

    # Save the best weights
    if curr_loss_val < best_loss_val:
        best_loss_val = curr_loss_val
        torch.save(model.state_dict(), r'/content/best_model.pth')

    # Change the learning rate
    scheduler.step()

    # Print the results
    print("epoch", epoch,
          "train loss", loss_history[-1],
          "val loss", loss_history_val[-1],
          "epoch duration", time.time()-start_time)

Visualize the losses graph.

In [None]:
plt.figure(figsize=(15, 7))
plt.plot(loss_history, label='train loss')
plt.plot(loss_history_val, label='val loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# **Inference**

The best model weights can now be load and use to predict the mask of test images (unseen during training).



## Load best model weights

In [None]:
checkpoint = torch.load('/content/best_model.pth',
                        map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)

## Visualization

Visualize 4 random images and their predicted and groundtruth masks from the test set.

In [None]:
def plot_mask(mask_3d_array, scan, axx, title):
    mask_cpu = mask_3d_array.cpu().detach().numpy()[0]
    scan_cpu = scan.cpu().detach().numpy()[0]
    axx.imshow(scan_cpu, cmap='gray')
    axx.imshow(np.round(mask_cpu), cmap='gray', alpha=0.5)
    axx.set_title(title)

In [None]:
dataloader = test_dataloader
ncol = 4
rand_ndx = random.sample(range(0, len(dataloader)), ncol)
fig, ax = plt.subplots(nrows=2,  ncols=ncol, figsize=(20, 10))
i = 0
for n, sample in enumerate(dataloader):
    x, y = sample['image'].float().to(device), sample['mask'].float().to(device)
    if n in rand_ndx:
        pred = model.forward(x)
        plot_mask(pred[0], x[0], ax[0][i], 'Prediction')
        plot_mask(y[0], x[0], ax[1][i], 'Groundtruth')
        i+=1

## Accuracy

We define the dice_metric function to compute accuracy of the model after training.

In [None]:
def dice_metric(inputs, target):
    intersection = 2.0 * (target * inputs).sum()
    union = target.sum() + inputs.sum()
    if target.sum() == 0 and inputs.sum() == 0:
        return 1.0

    return intersection / union

In [None]:
def compute_acc(dataloader, model):
    acc = []
    loss = []
    #model.eval()
    #with torch.no_grad():
    for _, sample in enumerate(dataloader):
        x, y = sample['image'].float().to(device), sample['mask'].float().to(device)
        pred = model(x)
        acc.append(dice_metric(pred.data.cpu().numpy(), y.data.cpu().numpy()))

    return np.mean(acc)

We compute the accuracy on each dataset. The accuracy on the test set should be used as a reference to assess the model performance.

In [None]:
# Train set
acc_train = compute_acc(train_dataloader, model)
print(f'Acccuracy on the train set is {acc_train}')

In [None]:
# Validation set
acc_val = compute_acc(val_dataloader, model)
print(f'Acccuracy on the train set is {acc_val}')

In [None]:
# Test set
acc_test = compute_acc(test_dataloader, model)
print(f'Acccuracy on the train set is {acc_test}')

# Acknowledgement:

*   https://github.com/seyma-tas/Brain-Tumor-Segmentation-Project/blob/master/2AdamW_DICE_BrainTumorGenesis.ipynb

