In [None]:
#Mount Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#Restart runtime after running once

%pip install segmentation-models-pytorch==0.1.0
%pip install -U catalyst

In [None]:
#Dependencies

#Handles data 
import json
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
from operator import itemgetter
import pickle

#Torch utilities 
from typing import List
from pathlib import Path
from torch.utils.data import Dataset
import torch

#Data Loader utilities 
import collections
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

#Model building and training 
import segmentation_models_pytorch as smp
from torch import nn

from catalyst.contrib.nn import DiceLoss, IoULoss
from torch import optim
from torch.optim import AdamW
from catalyst import utils

from catalyst.contrib.nn import RAdam, Lookahead
from catalyst.dl import SupervisedRunner

from catalyst.dl.callbacks import DiceCallback, IouCallback, \
  CriterionCallback, AccuracyCallback, MulticlassDiceMetricCallback

In [None]:
#Set seed for better reproducibility 
SEED = 42
utils.set_global_seed(SEED)
utils.prepare_cudnn(deterministic=True)

In [None]:
#Define and establish a dataset class
class SegmentationDataset(Dataset): 
    def __init__(
        self,
        image_arr_path,
        mask_arr_path,
    ) -> None:
        self.images = np.load(image_arr_path)
        self.masks = np.load(mask_arr_path)

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> dict:
        image = self.images[idx]
        image = np.swapaxes(image, 2, 0)
        image = np.swapaxes(image, 2, 1)
        image = torch.from_numpy(image).float()
        result = {"image": image}
        
        if self.masks is not None:
            mask = self.masks[idx]
            mask = np.swapaxes(mask, 2, 0)
            mask = np.swapaxes(mask, 2, 1)
            mask = torch.from_numpy(mask).float()
            result["mask"] = mask
        return result

In [None]:
#Load dataset once to enable visualizion prior to model training 
dset = SegmentationDataset(image_arr_path="/content/drive/Shared drives/Allan add details", 
                           mask_arr_path="/content/drive/Shared drives/Allan add details")

In [None]:
#Show sizes of the image and mask 
out = dset[0]
out["image"].shape, out["mask"].shape, len(dset)

In [None]:
#Show an image 
image = np.asarray(dset[40]['image'])
image = np.swapaxes(image, 2, 0)
image = np.swapaxes(image, 1, 0)
image = image.astype(np.uint8)
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

In [None]:
#Show associated mask 
mask = np.squeeze(dset[40]['mask'])
plt.imshow(mask)

In [None]:
#Set up U-Net with EfficientNet backbone pretrained on ImageNet
	
ENCODER = 'efficientnet-b3'
ENCODER_WEIGHTS = 'imagenet'
DEVICE = 'cuda'
ACTIVATION = None

model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=3, 
    activation=ACTIVATION,
)

In [None]:
#Define loaders 

def get_loaders(
    images: List[Path],
    masks: List[Path],
    image_arr_path: str,
    mask_arr_path: str,
    random_state: int,
    valid_size: float = 0.1,
    batch_size: int = 12,
    num_workers: int = 4,
    ) -> dict:

    indices = np.arange(len(images))

    train_indices, valid_indices = train_test_split(
      indices, test_size=valid_size, random_state=random_state, shuffle=True
    )

    np_images = np.array(images)
    np_masks = np.array(masks)


    train_dataset = SegmentationDataset(image_arr_path, mask_arr_path)
    train_dataset.images = np_images[train_indices]
    train_dataset.masks = np_masks[train_indices]

    valid_dataset = SegmentationDataset(image_arr_path, mask_arr_path)
    valid_dataset.images = np_images[valid_indices]
    valid_dataset.masks = np_masks[valid_indices]

    train_loader = DataLoader(
      train_dataset,
      batch_size=batch_size,
      shuffle=False,
      num_workers=num_workers,
      drop_last=False,
    )

    valid_loader = DataLoader(
      valid_dataset,
      batch_size=batch_size,
      shuffle=False,
      num_workers=num_workers,
      drop_last=False,
    )

    loaders = collections.OrderedDict()
    loaders["train"] = train_loader
    loaders["valid"] = valid_loader

    return loaders

In [None]:
#Get loaders  

loaders = get_loaders(
    images=np.load("/content/drive/Shared drives/Allan add details"),
    masks=np.load("/content/drive/Shared drives/Allan add details"),
    image_arr_path="/content/drive/Shared drives/Allan add details",
    mask_arr_path="/content/drive/Shared drives/Allan add details",
    random_state=420,
    valid_size=0.1,
    batch_size=3,
    num_workers=2,
)

In [None]:
#    Helpful code taken from Joseph Chen 
#
#    Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import torch
from torch import nn
import numpy as np

def sum_tensor(inp, axes, keepdim=False):
    axes = np.unique(axes).astype(int)
    if keepdim:
        for ax in axes:
            inp = inp.sum(int(ax), keepdim=True)
    else:
        for ax in sorted(axes, reverse=True):
            inp = inp.sum(int(ax))
    return inp

def softmax_helper(x):
    rpt = [1 for _ in range(len(x.size()))]
    rpt[1] = x.size(1)
    x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
    e_x = torch.exp(x - x_max)
    return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)

class CrossentropyND(nn.CrossEntropyLoss):
    """
    Network has to have NO NONLINEARITY!
    """
    def forward(self, inp, target):
        target = target.long()
        num_classes = inp.size()[1]

        i0 = 1
        i1 = 2

        while i1 < len(inp.shape): 
            inp = inp.transpose(i0, i1)
            i0 += 1
            i1 += 1

        inp = inp.contiguous()
        inp = inp.view(-1, num_classes)

        target = target.view(-1,)

        return super(CrossentropyND, self).forward(inp, target)

def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes:
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)

    if square:
        tp = tp ** 2
        fp = fp ** 2
        fn = fn ** 2

    tp = sum_tensor(tp, axes, keepdim=False)
    fp = sum_tensor(fp, axes, keepdim=False)
    fn = sum_tensor(fn, axes, keepdim=False)

    return tp, fp, fn


class SoftDiceLoss(nn.Module):
    def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True,
                 smooth=1., square=False):
        super(SoftDiceLoss, self).__init__()

        self.square = square
        self.do_bg = do_bg
        self.batch_dice = batch_dice
        self.apply_nonlin = apply_nonlin
        self.smooth = smooth

    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)

        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)

        if not self.do_bg:
            if self.batch_dice:
                dc = dc[1:]
            else:
                dc = dc[:, 1:]
        dc = dc.mean()

        return -dc


class DC_and_CE_loss(nn.Module):
    def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
        super(DC_and_CE_loss, self).__init__()
        self.aggregate = aggregate
        self.ce = CrossentropyND(**ce_kwargs)
        self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)

    def forward(self, net_output, target):
        dc_loss = self.dc(net_output, target)
        ce_loss = self.ce(net_output, target)
        if self.aggregate == "sum":
            result = ce_loss + dc_loss
        else:
            raise NotImplementedError("did not work") 
        return result

In [None]:
#Define loss criterion
criterion = {
    "CE": CrossentropyND(),
}

#Configure model optimization 
learning_rate = 0.001 
encoder_learning_rate = 0.0005
encoder_weight_decay = 0.00003 
optimizer_weight_decay = 0.0003 
optim_factor = 0.25
optim_patience = 2 

optimizer = AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)

#Use scheduler for improved results
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=optim_factor, patience=optim_patience)

num_epochs = 10
device = utils.get_device()

runner = SupervisedRunner(device=device, input_key="image", input_target_key="mask")

In [None]:
#Establish calculations during training through Catalyst callbacks
callbacks = [
        CriterionCallback(
            input_key="mask",
            prefix="loss",
            criterion_key="CE"
        ),
        MulticlassDiceMetricCallback(input_key="mask")
        ]

In [None]:
#Run training loop
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=callbacks,
    logdir='content/full_model2', #this logdir must be changed with every new run
    num_epochs=num_epochs,
    main_metric="loss",
    minimize_metric=True,
    fp16=None,    
    verbose=True,
)

In [None]:
#Test model on test dataset
test_data = SegmentationDataset("/content/drive/Shared drives/Allan add details", 
                                "/content/drive/Shared drives/Allan add details")

In [None]:
#Create loader for predictions
infer_loader = DataLoader(
    test_data,
    batch_size=12,
    shuffle=False,
    num_workers=4
)

In [20]:
#Get model predictions on test dataset
predictions = np.vstack(list(map(
    lambda x: x["logits"].cpu().numpy(), 
    runner.predict_loader(loader=infer_loader, resume=f"content/full_model2/checkpoints/best.pth")
)))

print(type(predictions))
print(predictions.shape)

<class 'numpy.ndarray'>
(149, 3, 480, 640)


In [None]:
#Display sample prediction results 

pred_results = {}
rand_nums = np.random.randint(low=1, high=148, size=18)
rand_nums = np.insert(rand_nums, 0, 30)
rand_nums = np.insert(rand_nums, 0, 141)

for num in rand_nums:

    #Show image
    image = np.asarray(test_data[num]['image'])
    image = np.swapaxes(image, 2, 0)
    image = np.swapaxes(image, 1, 0)
    image = image.astype(np.uint8)


    #Show mask
    mask = np.squeeze(test_data[num]['mask'])

    #Show predicted mask
    pred_mask = np.asarray(predictions[num])
    pred_mask = np.swapaxes(pred_mask, 2, 0)
    pred_mask = np.swapaxes(pred_mask, 1, 0)
    pred_mask = pred_mask.astype(np.float64)
    pred_mask = np.argmax(pred_mask, axis=2)
    
    #Add three images to list
    images = []
    images.append(image)
    images.append(mask)
    images.append(pred_mask)

    #Show and plot all three images
    plt.figure(figsize=(30,30))
    columns = 5
    for i, image in enumerate(images):
        image_plot = plt.subplot(len(images) / columns + 1, columns, i + 1)
        if i == 0:
            label = 'Raw Image {}'.format(num)
            image_plot.set_title(label)
            result = plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        elif i == 1: 
            label = 'Ground Truth {}'.format(num)
            image_plot.set_title(label)
            result = plt.imshow(image)
        elif i == 2: 
            label = 'Predicted Mask {}'.format(num)
            image_plot.set_title(label)
            result = plt.imshow(image)
        pred_results[label] = result

In [None]:
#Display dictionary of sample test results
pred_results

In [None]:
#Pickle sample test results
f = open("pred_results.pkl","wb")
pickle.dump(pred_results,f)
f.close()