# Mask RCNN Model for Segmenting Breast Ultrasound Images in Tumour Detection

The following Jupyter Notebook uses the PyTorch machine learning framework to train and test on a pre-trained Mask Region-based Convolutional Neural Network (Mask RCNN). Seen here: https://pytorch.org/vision/main/models/generated/torchvision.models.detection.maskrcnn_resnet50_fpn.html

The dataset used consists of images from breast ultrasound scans that are categorized into two classes depending on the tumour (or lack thereof): normal, benign, and malignant.

However, as you will notice during the data preparation segment, we will not be making use of the normal images. The model will be trained on benign and malignant classes only.

Each ultrasound scan contains a mask image that segment where the tumour is.

The goal of the model is, given an ultrasound image, be able to segment the location of the tumour and return a mask.

The Notebook is divided as follows:

* Data preparation

* Creating the custom Dataset

* Training and testing the model on a single sample

* Instantiate our Dataset, DataLoader, hyperparameters, and model

* Train the model

* Evaluate and output performance

* Save the model

**Note**: This follows the same pattern and logic as the Penn-Fudan tutorial by PyTorch, seen here: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html


## How to run

1. Click on `Runtime` from the bar above.
2. Click on `change runtime type`.
3. Select `GPU` as your hardware accelerator and save.
4. Connect to a runtime by clicking on `Connect` in the top right-hand side.
5. Click on `Runtime` again and `Run all`.

Importing all necessary libraries.

In [None]:
import os
import shutil
import numpy as np
import sys
import matplotlib.pyplot as plt 
import torch
import torchvision
import torch.utils.data
import random
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.optim.lr_scheduler import StepLR
from PIL import Image, ImageChops

In [None]:
%%shell

# download TorchVision repo to use some files from references/detection
git clone https://github.com/pytorch/vision.git
cd vision
git checkout v0.8.2

cp references/detection/utils.py ../
cp references/detection/transforms.py ../
cp references/detection/coco_eval.py ../
cp references/detection/engine.py ../
cp references/detection/coco_utils.py ../

In [None]:
# !pip uninstall torch -y
# !pip uninstall torchtext -y
# !pip uninstall torchdata -y
# !pip uninstall torchaudio -y
!pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
!pip install cython
!pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'

In [None]:
from engine import train_one_epoch, evaluate
import utils
import transforms as T

## Data preperation

Download data zip file and create directories.

In [None]:
!gdown 1LljpoDlVfLoowaG6qAq_rCzX7W6wVjql

In [None]:
%%shell
unzip data.zip

In [None]:
%%shell
mkdir Dataset/
cd Dataset/
mkdir UltrasoundImages
mkdir MaskImages

In [None]:
# converts mask images to greyscale and sets the objects to either 0, 1, 2, or 3 depending on the number of objects present
def img_to_index_array(path, index):
    mask = Image.open(path).convert('L')
    mask_asarray = np.asarray(mask)
    mask_asarray[mask_asarray>0] = index

    return mask_asarray

Creates a tuple combining each ultrasound image with it's corresponding mask image.

Copies and renames the images into either the ultrasound folder or the mask folder.

After doing so, the following structure will be created:

```
Dataset/
  MaskImages/
    Mask00000_mask.png
    Mask00001_mask.png
    Mask00002_mask.png
    Mask00003_mask.png
    ...
  UltrasoundImages/
    Image00000.png
    Image00001.png
    Image00002.png
    Image00003.png
```

In [None]:
path_benign = 'ultrasound/benign/'
path_malignant = 'ultrasound/malignant/'
path = [path_benign, path_malignant]


data = []
for item in path:
  img_dir = os.listdir(item)
  for image in img_dir:
    if 'mask' not in image:
      ultrasound_image = image 
      mask_image = image.replace(')', ')_mask')
      if mask_image in img_dir:
        img_msk_tuple = [item + ultrasound_image, item + mask_image]
      for mask_id in ['1', '2']:
        mask_image_extra = image.replace(')', ')_mask_'+ mask_id)
        if mask_image_extra in img_dir:
          img_msk_tuple.append(item + mask_image_extra)
      data.append(tuple(img_msk_tuple))

# ultrasound images
i = 0 
for item in data:
  if i < 10:
    shutil.copy(item[0], os.path.join('Dataset/UltrasoundImages/', 'Image000'+str(i)+'.png'))
  elif i <100 and i >= 10:
    shutil.copy(item[0], os.path.join('Dataset/UltrasoundImages/', 'Image00'+str(i)+'.png'))
  else:
    shutil.copy(item[0], os.path.join('Dataset/UltrasoundImages/', 'Image0'+str(i)+'.png'))
  i = i+1


# mask images
index = [0,1,2,3] # 0 is a placeholder
i = 0
for item in data:
  if len(item) - 1 == 1:  # first index is an ultrasound image, therefore -1
    mask_asarray = img_to_index_array(item[1], index[1])
  elif len(item) - 1 == 2:
    mask_asarray_1 = img_to_index_array(item[1], index[1])
    mask_asarray_2 = img_to_index_array(item[2], index[2])
    mask_asarray = mask_asarray_1 +  mask_asarray_2
  else:
    mask_asarray_1 = img_to_index_array(item[1], index[1])
    mask_asarray_2 = img_to_index_array(item[2], index[2])
    mask_asarray_3 = img_to_index_array(item[3], index[3]) 
    mask_asarray = mask_asarray_1 +  mask_asarray_2 + mask_asarray_3

  mask_composite = Image.fromarray(mask_asarray) 

  if i < 10:
    mask_composite.save(os.path.join('Dataset/MaskImages/Mask000'+str(i)+'.png'))
  elif i <100 and i >= 10:   
    mask_composite.save(os.path.join('Dataset/MaskImages/Mask00'+str(i)+'.png'))
  else:
    mask_composite.save(os.path.join('Dataset/MaskImages/Mask0'+str(i)+'.png'))

  i = i+1


Here is one example of an image in the dataset, with its corresponding  segmentation mask.

Then showcasing the same mask with a palette.

In [None]:
ultrasound_image = Image.open('/content/Dataset/UltrasoundImages/Image0001.png')
mask_image = Image.open('/content/Dataset/MaskImages/Mask0001.png')

mask_array = np.asarray(mask_image)
mask_array[mask_array >= 1] = 255
mask_image = Image.fromarray(mask_array)

ultrasound_image = ultrasound_image.resize((500, 400), Image.ANTIALIAS)
mask_image = mask_image.resize((500, 400), Image.ANTIALIAS)

ultrasound_image.show(), mask_image.show()

In [None]:
mask = Image.open('/content/Dataset/MaskImages/Mask0001.png')
mask_withPalette = mask.convert('P')
mask_withPalette.putpalette([
    0, 0, 0, # black background
    255, 0, 0, # index 1 is red
    255, 255, 0, # index 2 is yellow
    255, 153, 0, # index 3 is orange
])
mask_withPalette

## Creating our Dataset class

In [None]:
class UltrasoundDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to ensure all ultrasound images align with its corresponding mask image
        self.imgs = list(sorted(os.listdir(os.path.join(root, "UltrasoundImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "MaskImages"))))

    def __getitem__(self, idx):
        # load images and masks
        img_path = os.path.join(self.root, "UltrasoundImages", self.imgs[idx])
        mask_path = os.path.join(self.root, "MaskImages", self.masks[idx])
        
        img = Image.open(img_path).convert("RGB")
        # mask is not converted to RGB because each color corresponds to a different instance with 0 being the background
        mask = Image.open(mask_path)

        mask = np.array(mask)
        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]

        # split the color-encoded mask into a set of binary masks
        masks = mask == obj_ids[:, None, None]

        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

Let's take a look at the output of our Dataset class.

A `PIL` image and a target dictionary containing the required data for the model to process.

In [None]:
dataset = UltrasoundDataset('Dataset/')
dataset[0]

As we are using a pre-trained model, we will need to finetune it to our specific dataset. 

The following function will take the number of classes we have and return a instance segmentation Mask RCNN model to fit with our data.

In [None]:
def get_instance_segmentation_model(num_classes):
  
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

## Transforms and training/testing the model on a single sample

In [None]:
def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images and ground truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

We will take a look at what a single forward pass on an image looks like before iterating over our Dataset.

In [None]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
dataset = UltrasoundDataset('Dataset', get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=0,
    collate_fn=utils.collate_fn
)

# training
images, targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images, targets)   # returns losses and detections

# inference/testing
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x) # returns predictions

## Defining our hyperparameters and training our model

After setting up our data, Dataset, and transforms, it's time to apply everything together.

In [None]:
# use our Dataset and defined transformations
dataset = UltrasoundDataset('Dataset', get_transform(train=True))
dataset_test = UltrasoundDataset('Dataset', get_transform(train=False))

# split the Dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# define training and testing DataLoader
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=0,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=0,
    collate_fn=utils.collate_fn)

In [None]:
# check on the number of examples in each Dataset
len(dataset), len(dataset_test)

Insantiate the model along with optimizer.

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has two classes only - background and tumour
num_classes = 2

# create the model using our helper function
model = get_instance_segmentation_model(num_classes).to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler which decreases the learning rate by 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

We can now train the model.

We will be using 10 epochs and the `train_one_epoch` helper function from PyTorch's torchvision package.

In [None]:
# let's train it for 10 epochs
num_epochs = 10

for epoch in range(num_epochs):
  
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)

    # update the learning rate
    lr_scheduler.step()

    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

## Evaluation

Afer completing training, we can take a look at what the model outputs when passing it a test image.

In [None]:
import random
# pick one image from the test set
img, mask = dataset_test[30]

# put the model in evaluation mode
model.eval()
with torch.no_grad():
    prediction = model([img.to(device)])

Taking a look at the `prediction`, we can see it stores a list of dictionaries containing the specific information we defined back in the Dataset class.

In [None]:
plt.imshow(mask['masks'].squeeze(dim=0).cpu().numpy())

In [None]:
prediction

We can now examine the test image we gave our model and its predicted mask.

In [None]:
Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())

In [None]:
Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())

The following function will take in our model and produce `n` predictions and display them.

In [None]:
def display_predictions(model, n=10):

  ultrasound_files = []
  ground_truth_files = []
  pred_files = []

  for i in range(n):

    r = random.randint(0, len(dataset_test)-1)
    img, mask = dataset_test[r]

    model.eval()
    with torch.no_grad():
      prediction = model([img.to(device)])

    ultrasound_files.append(img)
    ground_truth_files.append(mask)
    pred_files.append(prediction)

  fig, ax = plt.subplots(n, 3, figsize=(15, 30))

  for idx, ground_truth_image in enumerate(ultrasound_files):

    if idx == 0:
      ax[idx, 0].set_title("Ultrasound image")
      ax[idx, 1].set_title("Ground truth mask")
      ax[idx, 2].set_title("Predicted mask")

    # fetch the ultrasound image for the corresponding index above
    ultrasound_image = ultrasound_files[idx]
    ultrasound_image = ultrasound_image.permute(1, 2, 0)

    # fetch the ground truth image for the corresponding index above
    ground_truth_image = ground_truth_files[idx]
    ground_truth_image = Image.fromarray(ground_truth_image['masks'].squeeze(dim=0).cpu().numpy())

    # fetch the pred image for the corresponding index above
    pred_image = Image.fromarray(pred_files[idx][0]['masks'][0, 0].mul(255).byte().cpu().numpy())

    ax[idx, 0].imshow(ultrasound_image)
    ax[idx, 0].axis(False)

    ax[idx, 1].imshow(ground_truth_image)
    ax[idx, 1].axis(False)

    ax[idx, 2].imshow(pred_image)
    ax[idx, 2].axis(False)


In [None]:
display_predictions(model)

## Saving our model for later use if needed

In [None]:
def save_checkpoint(state, filename="MyMaskRCNN_Model.pth"):
    torch.save(state, filename)

In [None]:
state = {
    "state_dict": model.state_dict(),
    "optimizer":optimizer.state_dict(),
}
save_checkpoint(state)