# Model Evaluation
This notebook will walk through applying an ML model to data it was not trained on, to show you the sorts of things that can go wrong when using models on out of distribution data.

The model we are using here was trained on Lung Cancer patients who were imaged in their radiotherapy treatment position for treatment planning. This is a very controlled environment, with well calibrated CT machines, consistent voxel dimensions, and fairly consistent anatomy.

In the images we will test with here, we have a lot of differences:
- Some patients with active COVID infections
- Some patients imaged with/without contrast
- Variable imaging dose (diagnostic CT vs PET-CT)


I didn't test the model on any of these, so I have no idea if it will perform well or not. I would expect it to fall over a fair bit, but equally we might be surprised how resilient it is.

To run this notebook, start at the top and click the play button beside every chunk of code. This will install the necessary libraries and import them for use, then define a bunch of helper functions. 

Once the cell at the bottom is reached, you should be able to change the file path and see how the model behaves on the different examples. 

In [None]:
%pip install albumentations pydicom pytorch-lightning
%pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as ipyw
import torch
import os
import albumentations as A
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import pydicom

structure_names = ['SpinalCord', 'Lung_R', 'Lung_L', 'Heart', 'Esophagus']
try:
  import google.colab
  IN_COLAB = True
  datapath = "/content/"
except:
  IN_COLAB = False
  datapath = "./"

In [None]:
def load_image(slices_fpath):
    """
    This functions load the DICOM corresponding to an image, but doesn't actually load the pixels. 
    Instead, we just keep copies of the whole DICOM object for each slice, which then includes other stuff.
    Crucially, we sort the returned list on the patient's position so that the slices and other returned 
    things are in the correct order
    """
    slices = []
    for slice_fname in  os.listdir(slices_fpath):
        try:
            slice_f = pydicom.dcmread(os.path.join(slices_fpath, slice_fname))
            slice_f.pixel_array ## in case there's an RTSTRUCT
            assert slice_f.Modality != "RTDOSE"
            slices.append(slice_f)
        except:
            continue
    slices = sorted(slices, key=lambda s: s.ImagePositionPatient[-1])
    uids = [s.SOPInstanceUID for s in slices]
    pixels = np.array([(float(s.PixelSpacing[0]), float(s.PixelSpacing[1]))  for s in slices])
    origins = np.array([s.ImagePositionPatient for s in slices])


    img_array = np.zeros((len(slices), *slices[0].pixel_array.shape) , dtype=np.int16)
    
    ## Very important to apply the rescale & intercept - pydicom doesn't do it by default
    for idx, img_slice in enumerate(slices):
        img_array[idx, ...] = pydicom.pixel_data_handlers.apply_rescale(img_slice.pixel_array, img_slice)

    return img_array

In [None]:
def window_level(data, window=350, level=50):
    """
    Apply a window and level transformation to CT slices. 

    The default values are taken taken from https://radiopaedia.org/articles/windowing-ct?lang=gb and are recommended for visualising the mediastinum
    
    The returned array will be NxHxWx3, as we expand the array into 3 channels. Values will be in the range 0-255 and type will be uint8 to mimic a 'normal' image
    """
    ## calculate high & low edges of level & window
    low_edge  = level - (window//2)
    high_edge = level + (window//2)
    ## use np.clip to clip into that level/window, then adjust to range 0 - 255 and convert to uint8
    windowed_data = (((np.clip(data, low_edge, high_edge) - low_edge)/window) * 255).astype(np.uint8)
    
    return windowed_data

val_transforms = A.Compose([
    A.Normalize(mean=(np.mean([0.485, 0.456, 0.406])), std=(np.mean([0.229, 0.224, 0.225]))) ## Note mean of means, mean of stds
])

In [None]:
class ImageSliceViewer3D:
    """ 
    ImageSliceViewer3D is for viewing volumetric image slices in jupyter or
    ipython notebooks. 
    
    User can interactively change the slice plane selection for the image and 
    the slice plane being viewed. 

    Argumentss:
    Volume = 3D input image
    figsize = default(8,8), to set the size of the figure
    cmap = default('gray'), string for the matplotlib colormap. You can find 
    more matplotlib colormaps on the following link:
    https://matplotlib.org/users/colormaps.html
    
    """
    
    def __init__(self, volume, mask, figsize=(10,10), cmap='gray'):
        self.volume = volume
        self.mask = mask
        self.figsize = figsize
        self.cmap = cmap
        self.v = [np.min(volume), np.max(volume)]
        
        # Call to select slice plane
        ipyw.interact(self.views)
    
    def views(self):
        self.vol1 = np.transpose(self.volume, [1,2,0])
        self.mask1 = np.transpose(self.mask, [1,2,0])
        maxZ1 = self.vol1.shape[2] - 1
        ipyw.interact(self.plot_slice, 
            z1=ipyw.IntSlider(min=0, max=maxZ1, step=1, continuous_update=False, 
            description='Axial:'),)
    
    def plot_slice(self, z1):
        # Plot slice for the given plane and slice
        f,ax = plt.subplots(1,1, figsize=self.figsize)
        #print(self.figsize)
        #self.fig = plt.figure(figsize=self.figsize)
        #f(figsize = self.figsize)
        ax[0].imshow(self.vol1[:,:,z1], cmap=plt.get_cmap(self.cmap), 
            vmin=self.v[0], vmax=self.v[1])
        ax[0].imshow(self.mask1[:,:,z1], alpha=0.75, cmap='Pastel2', vmin=1, vmax=5)
        plt.show()

In [None]:
## Define the class that will wrap te pytorch model up for ptl
class LightningFPN(pl.LightningModule):
  def __init__(self):
    super().__init__()
    ## Create the pytorch model 
    self.model = smp.FPN("resnet18", in_channels=1, classes=len(structure_names)+1, encoder_weights='imagenet')
    
    ## Construct a loss function, this is DSC, configured for multiple classes, and ignoring the background
    self.loss_fcn = smp.losses.DiceLoss("multiclass", from_logits=True)

    ## Specify which optimiser to use here
    self.optimizer = torch.optim.Adam

  def forward(self, x):
    return self.model(x)

  def configure_optimizers(self):
    optimizer = self.optimizer(self.parameters(), lr=1e-4)## May need to handle other kwargs here!
    return {"optimizer": optimizer, "reduce_on_plateau":True}
    ## Note - we are reducing the learning rate when the validation loss plateaus for a while - this should improve the model

  def predict(self, x):
    return self.model.predict(x)

  def training_step(self, batch, batch_idx):
    img, msk = batch
    msk_hat = self(img)
    loss = self.loss_fcn(msk_hat, msk.long())
    self.log("loss", loss)
    return loss

  def validation_step(self, batch, batch_idx):
    img, msk = batch
    msk_hat = self(img)
    val_loss = self.loss_fcn(msk_hat, msk.long())
    self.log("val_loss", val_loss)
    return val_loss


## Now we can wrap the prebuilt model up inside a pytorch lightning module:

pl_model = LightningFPN()

## Done!

In [None]:
def segment_3d(image, model, transforms):
    batch_size = 8
    whole_batches = image.shape[0] // batch_size
    batch_splitpoints = [(a*batch_size, a*batch_size + batch_size) for a in  range(whole_batches)]

    ## do things on GPU
    model.cuda()

    if image.shape[0] % batch_size != 0:
        last_batch_start_idx = whole_batches * batch_size
        last_batch_size = image.shape[0] - last_batch_start_idx
        batch_splitpoints.append((last_batch_start_idx, image.shape[0]))
    
    segmentation = np.zeros_like(image)
    for b_start, b_stop in batch_splitpoints:
        transformed_image = torch.tensor(transforms(image=image[b_start:b_stop])['image']).cuda()
        logits = model.predict(transformed_image[:, np.newaxis,...]).cpu()
        probs = torch.nn.functional.softmax(logits)
        segmentation[b_start:b_stop,...] = np.argmax(probs, axis=1)
    return segmentation

In [None]:
def segment_dicom(dicom_path, model):
    """
    Does all the loading and preprocessing for you!
    """
    image = load_image(dicom_path)
    wl_image = window_level(image)
    segmentation = segment_3d(wl_image, model, val_transforms)

    return image, segmentation

In [None]:
## Here we download the pretrained model weights for this model
!wget https://www.dropbox.com/s/sbgmtd7t344iklx/pretrained_checkpoint.ckpt?dl=0 -O pretrained_checkpoint.ckpt

In [None]:
## First load a pretrained model from the checkpoint, this should allow us to segment things

model = pl_model.load_from_checkpoint(os.path.join(datapath, "pretrained_checkpoint.ckpt"))

In [None]:
## Here we download and extract the test images I prepared
!wget https://www.dropbox.com/s/b42o3o8bpzt55pv/TestImages.tar.gz?dl=0 -O TestImages.tar.gz
!tar -xf TestImages.tar.gz
!rm TestImages.tar.gz

In [None]:
## Check the file paths in the browser on the left, this should work for the first example
image, segmentation = segment_dicom("/content/TestImages/Image001", model)


In [None]:
## Now visualise it
vw = ImageSliceViewer3D(image, segmentation)

Now try doing this with the other 9 images. Some of them should give some pretty funky results. I can show you how I downloaded these images if you would like to find more to try this with!