# Computer Vision

In the following you will import a public dataset and write python classes to use this dataset for training a neural network. Afterwards, you will deploy a neural network, train it, and also evaluate its performance. Training of neural networks works much faster on GPUs. These are available on Colab. To make sure you can access it:

1. In the above menu, select `Runtime` > `Change runtime type`
2. In `Hardware accelerator` box, select GPU
3. Again, in above menu, select `Runtime` > `Restart runtime`...
4. Rerun `!nvidia-smi` line below and confirm GPU is available

In [None]:
!nvidia-smi

## Exercise 1 - The Dataset

### 1. Let's copy the dataset.

We will use a publicly available dataset called [PanNuke](https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke). Uploading the dataset to colab takes time and each time your runtime in colab is disrupted these files will be deleted.

#### Upload the file to your google drive and mount it here

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Afterwards, please create a folder `folds` in your root folder on google drive and upload `fold1.zip` inside the folder.
Let's see what is in that folder and unzip the large zip file into the colab storage.

In [None]:
!ls -l /content/gdrive/MyDrive/folds

!unzip /content/gdrive/MyDrive/folds/fold1.zip -d /content/


### 2. Let's look at the data
The zip contains 3 binary files: images, masks, and types. Each of the files contains one big array encoded using `numpy` python library. Let's see what each file contains.

In [1]:
# First a few imports
import numpy as np
import torch
from torch.utils.data import Dataset
import os, random
#
import cv2
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
imgs = np.load("/content/fold1/images.npy")
print(imgs.shape)
print(imgs.dtype)

In [2]:
tps = np.load(r"C:\Users\Danish Hussain\Downloads\fold_1\Fold 1\images\fold1/types.npy")
print(tps.shape)
tps[:10]

(2656,)


array(['Breast', 'Breast', 'Breast', 'Breast', 'Breast', 'Breast',
       'Breast', 'Breast', 'Breast', 'Breast'], dtype='<U13')

In [None]:
masks = np.load("/content/fold1/masks.npy")
print(masks.shape)
print(masks.dtype)
print(np.max(masks))

We need to reshape the data arrays into `[N,C,H,W]` format that is needed for the ML models we will develop during this course:

In [None]:
imgs = np.rollaxis(imgs, 3, 1)
masks = np.rollaxis(masks, 3, 1)
print(imgs.shape)
print(masks.shape)

The following arrays represent images, types of tissue, and semgnetation masks. Segmentation masks mark areas of each cell with a different (integer) number. There are 6 cell types in this dataset which we need to define in a separate array:

In [None]:
CELL_TYPES = ["neoplastic", "inflammatory", "softtissue", "dead", "epithelial", "any"]

We will first view some example segmentation masks:

In [None]:
def cmap_discrete():
    cmap = plt.cm.jet  # define the colormap
    cmaplist = [cmap(i) for i in range(cmap.N)]
    random.shuffle(cmaplist)
    cmaplist = [(0, 0, 0, 1.0)] + cmaplist
    # create the new map
    cmap = mpl.colors.LinearSegmentedColormap.from_list('cmap', cmaplist, cmap.N+1)
    return cmap

def show_example_images(imgs, masks, tps, n=10):
    plt.figure(figsize=(12, n*2))
    indices = [x for x in range(len(masks))]
    indices = random.sample(indices, k=n)
    cmap = cmap_discrete()
    for i, ind in enumerate(indices):
        mask = masks[ind]
        tissue = tps[ind]
        img = imgs[ind]
        _ = plt.subplot(n, 7, 7*i+1)
        plt.imshow(np.rollaxis(img,0,3))
        plt.title(tissue, fontsize=8)
        plt.axis("off")
        for j in range(len(CELL_TYPES)):
            _ = plt.subplot(n, 7, 7*i + 2 + j)
            plt.imshow(mask[j,:,:],  cmap=cmap, interpolation='nearest')
            plt.title(tissue + " " + CELL_TYPES[j], fontsize=8)
            plt.axis("off")

    plt.show()


In [None]:
show_example_images(imgs, masks, tps)

Now how many images, tissue types, and cells of each type does the dataset contain? We will list and plot them in a barplot. First tissue types:

In [None]:
counts = np.unique(tps, return_counts=True)
print(list(zip(counts[0],counts[1])))

In [None]:
plt.figure(figsize=(6, 6))
plt.barh(counts[0], counts[1])
plt.show()

Then the same for each cell type. Last entry in the 'CELL_TYPES` is a binary segmentation map of all cell types. We will therefore omit that entry in the table in the cell counting below. We need to extract the cell count numbers from segmentation maps:

In [None]:
cell_counts = []
for i, ct in enumerate(CELL_TYPES[:-1]):
  counts = np.array(list(map(lambda m: np.max(m), masks[:,i,:,:])))
  cell_counts.append(np.sum(counts))

plt.figure(figsize=(6,3))
plt.barh(CELL_TYPES[:-1], cell_counts)
plt.show()

### 3. Dataset class for an ML model

We will use [PyTorch](https://pytorch.org/) throughout our excercises. We will start with construction of a [dataset class](https://pytorch.org/docs/stable/data.html) that we will use for feeding the data to the deep learning model. Here is an abstract `Segmentation Dataset` class:

In [None]:
from torch.utils.data import Dataset


class SegmDataset(Dataset):
    def __init__(self, root_dir, imgs, masks, tps, augment, is_binary):
        self.root_dir = root_dir
        self.augment = augment
        self.is_binary = is_binary

        if root_dir != "":
            self.imgs = np.load(os.path.join(root_dir, "images.npy"))
            self.masks = np.load(os.path.join(root_dir, "masks.npy"))
            self.tissues = np.load(os.path.join(root_dir, "types.npy"))
        else:
            self.imgs = imgs
            self.masks = masks
            self.tissues = tps
        self.imgs = self.imgs.astype(np.float32) / 255


    def __len__(self):
        return len(self.imgs)


    def transform(self, img, mask):
        i = random.randint(0,3)
        if i == 1:
            img = np.flip(img, axis=1)
            mask = np.flip(mask, axis=1)
        elif i == 2:
            img = np.flip(img, axis=2)
            mask = np.flip(mask, axis=2)
        elif i == 3:
            img = np.flip(np.flip(img, axis=1), axis=2)
            mask = np.flip(np.flip(mask, axis=1), axis=2)

        i = random.randint(0,3)
        img = np.rot90(img, k=i, axes=(1,2))
        mask = np.rot90(mask, k=i, axes=(1,2))

        return img.copy(), mask.copy()

    def __getitem__(self, idx):
        pass

`Dataset` class implements `__getitem__` and `__len__` functions which allow to interate through it, just like we do through a list in python.

### 4. Your task
Implement the `__getitem__()` method. You are allowed to implement further helper functions:

In [None]:
class SegmDatasetBinary(SegmDataset):
# Class that returns a dictionary of three values:
# {'image': image,
#  'mask': binary segmentation mask,
#  'tissue': name_of_tissue }
# in the __getitem__ function
    def __init__(self, root_dir, imgs, masks, tps, augment=True):
        super().__init__(root_dir, imgs, masks, tps, augment, True)

    ### TODO ###
    def __getitem__(self, idx):
      pass

In [None]:
class SegmDatasetMultiClass(SegmDataset):
# Class that returns a dictionary of three values:
# {'image': image,
#  'mask': segmentation mask in which different cell types are marked with following integer values (neoplastic:1, inflammatory:2, softtissue:3, dead:4, epithelial:5),
#  'tissue': name of tissue }
# in the __getitem__ function
    def __init__(self, root_dir, imgs, masks, tps, augment=True):
        super().__init__(root_dir, imgs, masks, tps, augment, False)

    ### TODO ###
    def __getitem__(self, idx):
        pass

### 5. Task 2

As you saw in the barplots above, our dataset is strongly biased. Among 19 different tissues over 30% of images come from the breast. When we train a model based on such data we risk to overfit it to the breast tissue. We will now create a new abstract dataset class that will be designed such that images of all non-breast tissues are fed into the model more often than their actual image count.

In [None]:
class SegmDatasetBalanced(SegmDataset):
    def __init__(self, root_dir, imgs, masks, tps, augment, is_binary, balance_to):
        super().__init__(root_dir, imgs, masks, tps, augment, is_binary)
        tissue_names, counts = np.unique(self.tissues, return_counts=True)

        counts_max = np.max(counts)
        ts_max = tissue_names[counts == counts_max]
        indexes = {}
        for ts in tissue_names:
            ts_ind = np.where(self.tissues == ts)[0]
            missing_n = int(counts_max * balance_to - ts_ind.size)
            if missing_n > 0:
              additional_ts_ind = np.random.choice(ts_ind, size=missing_n, replace=True)
              ts_ind = np.concatenate((ts_ind, additional_ts_ind))
            indexes[ts] = ts_ind
        self.all_indexes = np.concatenate(list(indexes.values()))
        np.random.shuffle(self.all_indexes)

    def __len__(self):
        return len(self.all_indexes)

    def __getitem__(self, idx):
        pass

The meaning of parameter `balance_to` is *proportion of the breast images to which we should upsample  given tissue images*. So we can choose to upsample kidney, bladder, colon, etc such that they represent at least e.g. 50% of the number of breast tissue images.

Implement the `__getitem__()` method. You are allowed to implement further helper functions:

In [None]:
class SegmDatasetBinaryBalanced(SegmDatasetBalanced):
# Class that returns a dictionary of three values:
# {'image': image,
#  'mask': binary segmentation mask,
#  'tissue': name_of_tissue }
# in the __getitem__ function
    def __init__(self, root_dir, imgs, masks, tps, augment, balance_to):
        super().__init__(root_dir, imgs, masks, tps, augment, True, balance_to)

    ### TODO ###
    def __getitem__(self, idx):
        pass

class SegmDatasetMultiClassBalanced(SegmDatasetBalanced):
# Class that returns a dictionary of three values:
# {'image': image,
#  'mask': segmentation mask in which different cell types are marked with following integer values (neoplastic:1, inflammatory:2, softtissue:3, dead:4, epithelial:5),
#  'tissue': name of tissue }
# in the __getitem__ function
    def __init__(self, root_dir, imgs, masks, tps, augment, balance_to):
        super().__init__(root_dir, imgs, masks, tps, augment, False, balance_to)

    ### TODO ###
    def __getitem__(self, idx):
        pass

### 6. Here is how we will test your code

We will use some plotting functions to inspect the data returned by the different dataset classes.

In [None]:
CELL_TYPE_COLORS = [(255,64,52), (50, 168, 82), (132, 46, 158), (46, 214, 217), (247, 235, 5)]


def composite(background, foreground, alpha):
    for color in range(3):
        background[:, :, color] = alpha * foreground[:, :, color] + background[:, :, color] * (1 - alpha)

    return background.astype(np.uint8)

def overlay_binary_mask(img, mask):
    mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.float32)
    mask_rgb[:,:,:0] = mask*255
    alpha = mask[:,:,0] * 0.25
    return composite(img*255, mask_rgb, alpha)


def overlay_binary_contours(img, mask):
    contours, _  = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    contour_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    cv2.drawContours(contour_rgb, contours, -1, CELL_TYPE_COLORS[0], 3)
    alpha = np.zeros((mask.shape[0], mask.shape[1]))
    alpha[np.max(contour_rgb, axis=2) > 0] = 0.75
    return composite(img*255, contour_rgb, alpha)

def overlay_multiclass_contours(img, mask):
    contour_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    for i in range(len(CELL_TYPE_COLORS)):
        m = mask[:,:,0] == i + 1
        contours, _  = cv2.findContours(m.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(contour_rgb, contours, -1, CELL_TYPE_COLORS[i], 3)
    alpha = np.zeros((mask.shape[0], mask.shape[1]))
    alpha[np.max(contour_rgb, axis=2) > 0] = 0.9
    return composite(img*255, contour_rgb, alpha)

In [None]:
def test_dset(dataset):
    plt.figure(figsize=(6, 6))
    n = 12
    indices = random.sample(range(len(dataset)), k=n)

    for i, ind in enumerate(indices):
      datapoint = dataset[ind]
      img, mask, ts = datapoint['image'], datapoint['mask'], datapoint['tissue']
      _ = plt.subplot(n // 4, 4, i+1)
      img = np.rollaxis(img, 0, 3)
      mask = np.rollaxis(mask, 0, 3)
      im = overlay_binary_contours(img,mask) if dataset.is_binary else overlay_multiclass_contours(img,mask)
      plt.imshow(im)
      plt.title(ts, fontsize=8)
      plt.axis("off")

    plt.show()


In [None]:
def plot_dset_counts(dataset):
    tissues, counts = np.unique([dataset.tissues[i] for i in dataset.all_indexes], return_counts=True)
    plt.figure(figsize=(6, 6))
    plt.barh(tissues, counts)
    plt.show()

In [None]:
dataset = SegmDatasetMultiClassBalanced("", imgs, masks, tps, True, 0.25)

print(len(dataset))

In [None]:
test_dset(dataset)

In [None]:
plot_dset_counts(dataset)

### 7. Final task

Now the dataset class will become more complicated: instead of images and their segmentation masks the dataset should return crops of individual cells with their cell type label:

In [None]:
class SingleCellDataset(Dataset):
# This class should return via __getitem__ function following dictionary:
# {'image': a w x h-sized cropped image of individual cell
# 'cell_type': type of the respective cell
# 'tissue': tissue of origin }
    def __init__(self, root_dir, imgs, masks, tps, augment=True, random_pad=5, margin=10, w=64, h=64):
        """
        margin: margin around the cell contours
        random_pad: maximum number by which we vary the margin size
        w,h: width and height of the output image
        """
        self.root_dir = root_dir
        self.augment = augment
        self.h = h
        self.w = w
        if root_dir != "":
            self.imgs = np.load(os.path.join(root_dir, "images.npy"))
            self.masks = np.load(os.path.join(root_dir, "masks.npy"))
            self.tissues = np.load(os.path.join(root_dir, "types.npy"))
        else:
            self.imgs = imgs
            self.masks = masks
            self.tissues = tps
        self.imgs = self.imgs.astype(np.float32) / 255
        self.tissue_types = np.unique(self.tissues)
        self.random_pad = random_pad

        ## index: (cell type, tissue) : [ (img_i, x_min, x_max, y_min, y_max), ...]
        indices = {}
        for ct in CELL_TYPES[:5]:
            for tt in self.tissue_types:
                indices[(ct, tt)] = []

        for i in range(self.masks.shape[0]):
            tt = self.tissues[i]
            for i_ct, ct in enumerate(CELL_TYPES[:5]):
                mask = self.masks[i,i_ct,:,:]
                cells = np.unique(mask)
                cells = np.sort(cells)
                for ic in cells[1:]:
                    ys, xs = np.where(mask == ic)
                    x_min, x_max, y_min, y_max = np.min(xs), np.max(xs), np.min(ys), np.max(ys)
                    if (x_min > 0) and (x_max < mask.shape[1]-1) and (y_min > 0) and (y_max < mask.shape[0] -1):
                        x_min = max(0, x_min - margin)
                        x_max = min(mask.shape[1], x_max + margin)
                        y_min = max(0, y_min - margin)
                        y_max = min(mask.shape[0], y_max + margin)
                        indices[(ct, tt)].append((i, x_min, x_max, y_min, y_max))
        self.indices_d = indices
        self.indices = list(indices.values())
        self.labels = list(indices.keys())
        a = list(map(len, self.indices))
        self.cs = np.cumsum(a)


    def __len__(self):
        return self.cs[-1]

    ### TODO ###
    def transform(self, img):
        pass

    def pad_img(self, img):
        pass

    def __getitem__(self, idx):
        pass

And this is how we will test this class:

In [None]:
def test_singlecell_dset(dataset):
    plt.figure(figsize=(6, 6))
    n = 8
    indices = random.sample(range(len(dataset)), k=n)

    for i, ind in enumerate(indices):
      datapoint = dataset[ind]
      img, ct, ts = datapoint['image'], datapoint['cell_type'], datapoint['tissue']
      img = (img*255)
      img = img.astype(np.uint8)
      img = np.rollaxis(img, 0, 3)
      _ = plt.subplot(n*2 // 4, 4, 2*i+1)

      plt.imshow(img)
      plt.title(ct + " " + ts, fontsize=8)
      plt.axis("off")

      datapoint = dataset[ind]
      img, ct, ts = datapoint['image'], datapoint['cell_type'], datapoint['tissue']
      img = (img*255)
      img = img.astype(np.uint8)
      img = np.rollaxis(img, 0, 3)
      _ = plt.subplot(n*2 // 4, 4, 2*i+2)
      plt.imshow(img)
      plt.title(ct + " " + ts, fontsize=8)
      plt.axis("off")

    plt.show()

In [None]:
sc_dataset = SingleCellDataset("", imgs, masks, tps, margin=10)

In [None]:
test_singlecell_dset(sc_dataset)

## Exercise 2 - Model Definition and Training

In order to free some RAM, please execute the following cell before continuing

In [None]:
del(dataset)
del(test_dset)
del(sc_dataset)
del(test_singlecell_dset)

Some further imports

In [None]:
import time
from random import sample
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix

### 1. Load and prepare the data

We will load the data and make a train-validation-test dataset split.

In [None]:
imgs = np.load("/content/fold1/images.npy")
tps = np.load("/content/fold1/types.npy")
masks = np.load("/content/fold1/masks.npy")
imgs = np.rollaxis(imgs, 3, 1)
masks = np.rollaxis(masks, 3, 1)
print(imgs.shape)
print(masks.shape)

In [None]:
rd = np.random.sample(imgs.shape[0])
in_test = rd >=0.8
in_val = np.logical_and(rd >=0.6, rd < 0.8)
in_train = rd < 0.6
imgs_train, masks_train, tissues_train =  imgs[in_train].copy(), masks[in_train].copy(), tps[in_train].copy()
imgs_val, masks_val, tissues_val = imgs[in_val].copy(), masks[in_val].copy(), tps[in_val].copy()
imgs_test, masks_test, tissues_test = imgs[in_test].copy(), masks[in_test].copy(), tps[in_test].copy()

print("Train set size:      (%i, %i, %i, %i)" % imgs_train.shape)
print("Validation set size: (%i, %i, %i, %i)" % imgs_val.shape)
print("Test set size:       (%i, %i, %i, %i)" % imgs_test.shape)


### 2. Define and create the network

We will implement [UNet](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/) in it's basic form and train it to perform semantic segmentation.

In [None]:
import torch.nn.functional as F
import torch
from torch import nn

class UNet(nn.Module):
    def __init__(
        self,
        in_channels=1,
        n_classes=2,
        depth=5,
        wf=6,
        padding=True,
        batch_norm=True,
        up_mode='upconv',
    ):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597

        Using the default arguments will yield the exact version used
        in the original paper

        Args:
            in_channels (int) : number of input channels
            n_classes (int)   : number of output channels
            depth (int)       : depth of the network
            wf (int)          : number of filters in the first layer is 2**wf
            padding (bool)    : if True, apply padding such that the input shape
                                is the same as the output.
                                This may introduce artifacts
            batch_norm (bool) : Use BatchNorm after layers with an
                                activation function
            up_mode (str)     : one of 'upconv' or 'upsample'.
                                'upconv' will use transposed convolutions for
                                learned upsampling.
                                'upsample' will use bilinear upsampling.
        """
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()

        # this is where the network itself is created! according to your value of
        # 'depth', you add down_layers and then up_layers on the way back
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        return self.last(x)


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

Here we will parametrize the network. Depending whether you perform binary of multiclass segmentation you need to set the `n_classes` parameter to 2 or 6. We also parametrize the size of UNet. For initial test runs make the network smaller to speed up the training.

In [None]:
network_depth = 3 # how deep should the UNet be
wf = 4            # number of filters in the first layer
batch_size = 4
num_epochs = 50   # number of training epochs
gpuid = 0
n_classes   = 6   # number of classes in the data mask that we'll aim to predict
in_channels = 3   # input channel of the data, RGB = 3

torch.manual_seed(10100)

param_stamp = f'd{network_depth}_wf{wf}_b{batch_size}_e{num_epochs}'

print("torch cuda is available:", torch.cuda.is_available())

# specify if we should use a GPU (cuda) or only the CPU
if torch.cuda.is_available():
    print(torch.cuda.get_device_properties(gpuid))
    torch.cuda.set_device(gpuid)
    device = torch.device(f'cuda:{gpuid}')
else:
    device = torch.device(f'cpu')

# build the model according to the paramters specified above and copy it to the GPU.
# then print out the number of trainable parameters
model = UNet(
    n_classes   = n_classes,
    in_channels = in_channels,
    depth       = network_depth,
    wf          = wf
).to(device)
print(f"total params: \t{sum([np.prod(p.size()) for p in model.parameters()])}")
print(f'\nNetwork depth: {network_depth}\nwf: {wf}\n'
      f'batch size: {batch_size}\n{num_epochs} epochs\n')


We have the network, we will now create instances of dataset class for each of the datasets: train, validation and test.

In [None]:
datasets_train_val = {"train": SegmDatasetMultiClass("", imgs_train, masks_train, tissues_train),
                      "val": SegmDatasetMultiClass("", imgs_val, masks_val, tissues_val) }
dataset_test = SegmDatasetMultiClass("", imgs_test, masks_test, tissues_test, augment=False)

dataloaders = {}
for phase, ds in datasets_train_val.items():
    dataloaders[phase] = DataLoader(
        datasets_train_val[phase],
        batch_size  = batch_size,
        shuffle     = True,
        num_workers = 0,
        pin_memory  = True
    )
test_loader = DataLoader(dataset_test, batch_size=batch_size)


### 3. Network training

We will now train the network, two remaining things to define is the optimizer and loss function. The basic optimizer is [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) but you can also experiment with [other optimizers](https://pytorch.org/docs/stable/optim.html#algorithms).

In [None]:
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Adam is an example alternative to SGD:
#optim = torch.optim.Adam(model.parameters())

We will use [cross entropy](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) as loss function.

In [None]:
criterion = nn.CrossEntropyLoss(reduction='none').to(device)
# reduction='none' makes sure we get a 2D output instead of a 1D "summary" value

And here is the training loop. In each epoch we collect and report loss values.

In [None]:
torch.manual_seed(10100)

loss_log = {'train': list(), 'val': list()}
acc_log  = {'train': list(), 'val': list()}

start_time  = time.time()
model = model.float()
best_loss = np.Infinity
phases = list(dataloaders.keys())

for epoch in range(num_epochs):
    # zero out epoch based performance variables
    all_acc  = {key: 0 for key in phases}
    all_loss = {key: torch.zeros(0).to(device) for key in phases}
    cmatrix  = {key: np.zeros((n_classes, n_classes)) for key in phases}
    marker = ''

    for phase in phases:  # iterate through both training and validation states

        if phase == 'train':
            model.train()  # Set model to training mode
        else:  # when in eval mode, we don't want parameters to be updated
            model.eval()  # Set model to evaluate mode

        for batch in dataloaders[phase]:
            input_imgs = batch['image']
            input_masks = batch['mask']

            input_imgs = input_imgs.to(device=device, dtype=torch.float32)
            input_masks = input_masks.to(device=device, dtype=torch.long)

            with torch.set_grad_enabled(phase == 'train'):  # dynamically set gradient computation,
                # in case of validation, this isn't needed
                # disabling is good practice and improves inference time

                prediction = model(input_imgs)  # [N, Nclass, H, W]
                loss_matrix = criterion(prediction, torch.squeeze(input_masks, 1))
                loss = loss_matrix.mean()  # can skip if edge weight==1

                if phase == "train":  # in case we're in train mode, need to do back propogation
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                    train_loss = loss

                all_loss[phase] = torch.cat((all_loss[phase], loss.detach().view(1, -1)))

                if phase == "val":# and not torch.isnan(loss):  # if this phase is part of validation, compute confusion matrix
                    p = prediction[:, :, :, :].detach().cpu().numpy()
                    cpredflat = np.argmax(p, axis=1).flatten()
                    input_masks = input_masks.cpu().numpy().flatten()

                    cmatrix[phase] = cmatrix[phase] + confusion_matrix(input_masks, cpredflat, labels=range(n_classes))

        all_acc[phase] = (cmatrix[phase] / (cmatrix[phase].sum() + 0.0000001)).trace()
        all_loss[phase] = all_loss[phase].cpu().numpy().mean()

        loss_log[phase].append(all_loss[phase])
        acc_log[phase].append(all_acc[phase])

    if all_loss['val'] < best_loss:
      marker = '***'
      best_loss = all_loss['val']

    print( '[%3d/%3d] (%3d%%)  |  LOSS - train: %.4f    val: %.4f  %s' % (
        epoch + 1,
        num_epochs,
        (epoch + 1) / num_epochs * 100,
        all_loss["train"],
        all_loss["val"],
        marker
        ),
        end="\n"
    )

No we will plot the loss values and find which epoch showed best results in the validation set.

In [None]:
best_epoch = np.argmin(loss_log["val"])

print(f'Best epoch: {best_epoch} with a loss of {np.min(loss_log["val"])}')

plt.plot(loss_log['train'])
plt.plot(loss_log['val'])
plt.vlines(np.argmin(loss_log['val']), 0, 1, linestyles='dotted', alpha=.5)
plt.ylim(0,1)

### 4. Model evaluation

Finally we will run segmentation on the test set and inspect the results. First we run the inference:

In [None]:
res_masks = np.zeros((imgs_test.shape[0], imgs_test.shape[2], imgs_test.shape[3]), dtype=np.uint8)

with torch.no_grad():
    model.eval()
    for batch_i, batch in enumerate(test_loader):
        batch_imgs = batch['image']
        batch_masks = batch['mask']

        batch_imgs = batch_imgs.to(device=device, dtype=torch.float32)
        batch_masks = batch_masks.to(device=device, dtype=torch.long)

        prediction = model(batch_imgs)  # [N, Nclass, H, W]
        sem_preds = np.argmax(F.softmax(prediction, dim=1).cpu().numpy(), axis=1)
        res_masks[batch_i*batch_size:(batch_i+1)*batch_size] = sem_preds


Let's plot some example segmenation masks.

In [None]:
n = 12
plt.figure(figsize=(6, 6))
if len(res_masks.shape) < 4:
    res_masks = np.expand_dims(res_masks, 3)
for i, ind in enumerate(random.sample(range(imgs_test.shape[0]), n)):
    _ = plt.subplot(4, n // 4, i+1)
    img = np.rollaxis(imgs_test[ind], 0, 3).astype(np.float32) / 255
    mask = res_masks[ind]
    im = overlay_binary_contours(img,mask) if dataset_test.is_binary else overlay_multiclass_contours(img,mask)
    plt.imshow(im)
    plt.title("", fontsize=8)
    plt.axis("off")

plt.show()

Finally we will calculate some accuracy metrics on the test set:
- [Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) for all classes
- number of not detected cells - cells that have intersection over unition < 0.5 with any of the cells in ground truth
- number of misclassified cells - pixels of a cell that have correct class assignment represent < 0.5 of all cell pixels

In [None]:
smooth = 1.
def dice_coef(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    return (2. * intersection + smooth) / (np.sum(y_true) + np.sum(y_pred) + smooth)

In [None]:
def cell_detection(masks, preds, cutoff=0.5):
    cell_acc = np.zeros((3, n_classes)) # detected, correctly classified
    for i, ct in enumerate(CELL_TYPES[:5]):
        for img_i in range(masks.shape[0]):
            mask = masks[img_i,i,:,:]
            cns = np.unique(mask)
            cns = cns[cns != 0]
            pred_mask = preds[img_i]
            for cn in cns:
                pred_pixels = pred_mask[mask == cn]
                cell_size = np.sum(mask == cn)
                prop_detected = np.sum(pred_pixels > 0) / cell_size
                prop_correct_class = np.sum(pred_pixels == i+1) / cell_size
                cell_acc[0,i] += prop_detected >= cutoff
                cell_acc[1,i] += prop_correct_class >= cutoff
                cell_acc[2,i] += 1
    return cell_acc

def accuracy_metrics(masks, preds):
    cell_acc = cell_detection(masks, preds)
    dice_classes = np.zeros((n_classes))
    for cl in range(n_classes):
        y_true = (masks[:,cl,:,:] > 0).astype(np.uint8)
        y_pred = (preds == cl+1).astype(np.uint8)
        dice_classes[cl] = dice_coef(y_true, y_pred)
    for i, ct in enumerate(CELL_TYPES[:5]):
        print("%s\tdice: %.3f\tdetected: %.3f\tcorrect class: %.3f" % (ct.ljust(12, ' '), dice_classes[i], cell_acc[0,i]/cell_acc[2,i], cell_acc[1,i]/cell_acc[2,i]))

In [None]:
accuracy_metrics(masks_test, np.squeeze(res_masks))

### 5. Your Task

Experiment with the training: hyperparameters, optimizers, various versions of the dataset class, length of training, such that the segmentation on the test set is best possible. All possible postprocessing of the results is allowed. Save your notebook and network with the command below.


In [None]:
torch.save(model.state_dict(), 'cp.pth')