# Using MONAI to unlock clinically valuable insights from Digital Pathology

## Part 2 - Nuclei segmentation and classification with MONAI

<img src="images/monai.png" alt="MONAI" style="width: 500px;"/>

In the first section, you saw how we can speed up the loading and decoding of large images by using a high-performance image loader such as cuCIM. You also saw how we can use multi-threading to reduce the latency of loading a large image. So long as the image format and loader supports loading regions of interest then we can use different processes or threads to simultaneously get different parts of the image into memory. 

Often, the image loading can be the bottleneck that slows the whole processing pipeline down. Of course it is often not just the loading that we need to do. There may be a need to do some preprocessing on the image and, for digital pathology, we may need to threshold each region to ensure that we are not wasting time processing empty or background regions of the Whole Slide. There may also be some image transformation or augmentation to do. All of these operations can become the part that slows everything else down and results in under-utilised GPUs, if not dealt with efficiently.

We are going to work through an example in which we use the HoVerNet [1] network and post-processing pipeline to detect, localise and classify nuclei and then, in the following notebook, analyse the output.

![image](images/hovernet2.png)

[1] Simon Graham, Quoc Dang Vu, Shan E Ahmed Raza, Ayesha Azam, Yee Wah Tsang, Jin Tae Kwak, Nasir Rajpoot, Hover-Net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images, Medical Image Analysis, 2019 https://doi.org/10.1016/j.media.2019.101563

We start off by importing the libraries we need Python to be able to use

In [None]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

# General Python libraries
from functools import partial
from matplotlib import pyplot as plt
import logging
import os
import time
from tqdm.auto import tqdm

# torch-related classes
import torch
import numpy as np

# Monai classes
from monai.config import print_config
from monai.data import (
    DataLoader, 
    MaskedPatchWSIDataset, 
    decollate_batch, 
    PatchWSIDataset, 
    IterableDataset,
)

from monai.networks.nets import HoVerNet
from monai.engines import IterationEvents, SupervisedEvaluator
from monai.inferers import SimpleInferer

# Pathology-specific transforms
from monai.apps.pathology.transforms import (
    GenerateWatershedMaskd,
    GenerateInstanceBorderd,
    GenerateDistanceMapd,
    GenerateWatershedMarkersd,
    GenerateInstanceContour,
    GenerateInstanceCentroid,
    GenerateInstanceType,
    HoVerNetInstanceMapPostProcessingd, 
    HoVerNetNuclearTypePostProcessingd,
)

from monai.apps.pathology.transforms.post.dictionary import (
    HoVerNetNuclearTypePostProcessingd, 
    Watershedd,
)

# Generic Transforms
from monai.transforms import (
    Compose,
    ScaleIntensityRanged,
    CastToTyped,
    Lambdad,
    LoadImage,
    LoadImaged,
    EnsureChannelFirst,
    EnsureChannelFirstd,
    ComputeHoVerMapsd,
    BoundingRect,
    ThresholdIntensity,
    NormalizeIntensityd,
    apply_transform,
)

# Event handlers
from monai.handlers import EarlyStopHandler
from monai.utils import convert_to_tensor, first, HoVerNetBranch

print_config()

The output of print_config provides us with a useful summary of the installed monai components and dependencies. If you encounter any issues when developing with Monai, it can be useful to include this output in any issues you raise in the GitHub repo

As before, we we will be working with the same image from the TCGA archive. However, rather than loading the image directly, we are going to look at how Monai abstracts the loading of images away. There are many different possibilities for loading data in Monai because it is a very flexible API. 
Sometimes, it makes sense to just point a Monai engine at a data source, such as a folder of images and let it run automatically, using a pre-configured pipeline. at other times you may have different requirements that necessitate some customisation at some level. We will take a look at a few of these options.

First off, let's look at some of the Monai components that comprise a typical pipeline. 
We need to define a source of some images. We formulate this as a list of dictionaries to fit with the way that Monai operates.

In [None]:
data_list = [
    {"image": "/datasets/dli_gtc_23/data/images/im_test1.nii.gz"},
    {"image": "/datasets/dli_gtc_23/data/images/im_test2.nii.gz"},
    {"image": "/datasets/dli_gtc_23/data/images/im_test3.nii.gz"},
    {"image": "/datasets/dli_gtc_23/data/images/im_test4.nii.gz"},
    {"image": "/datasets/dli_gtc_23/data/images/im_test5.nii.gz"},
    {"image": "/datasets/dli_gtc_23/data/images/im_test6.nii.gz"},
]

So, we now have a list containing 6 dictionaries, each containing an "image" key and a filename as the value. We can use a Monai Transform to turn this list of images into a list of actual images. The Transform that we need to do this is LoadImaged. We could also have simply used LoadImage, which expects a List of Images rather than a List of Dictionaries, but the list of dictionaries offers more flexibility, such as allowing us to filter the items in the dictionary using their keys.
It is common to combine Transforms into a pipeline and the Compose function allows us to do that. In this initial case, there is only one item in the Compose pipeline but we could easily chain more together.

In [None]:
trans = Compose([LoadImage(image_only=True)])
img = trans(data_list[0]["image"])
print(type(img), img.shape, img.get_device())

So, we have a 2k by 2k image with 3 channels. Let's plot that out next

In [None]:
import matplotlib.pyplot as plt

# Use Matplotlib to display the thumbnail view of the image
plt.figure(figsize=(5,5))
plt.imshow(np.array(img).astype(int))
plt.title('tcga1.svs')
plt.show()

So, one thing that needs to be accounted for is the channel ordering. Monai, like PyTorch, on which it is based, generally uses the convention of channels being the first dimension of an image. To make sure that this is the case we can use another transform - EnsureChannelFirst. For the case in which an image has only a single channel, there may not be a separate dimension (e.g. shape = [100, 100] rather than [1, 100, 100]) but this method will create one. Otherwise it will reorder the dimensions. 

In [None]:
trans = Compose([LoadImage(image_only=True), EnsureChannelFirst(channel_dim=-1)])
img = trans(data_list[0]["image"])
print(type(img), img.shape, img.get_device())

So, this has now provided the tensor in the correct format. Note that matplotlib uses the channel-last convention, so if you want to plot the image, you will need to reorder the channels again before displaying. The other thing to note is that the type of the tensor is reported as a Monai MetaTensor. You can find out more about this type in the Monai Documentation (https://docs.monai.io/en/stable/data.html#metatensor) but it is a useful feature of Monai that allows you to examine and add to the data associated with an image tensor. For example the name of source image file or its dimensions or, in the case of Whole Slide Image processing, perhaps the coordinates of the current tile within the WSI.

To make this useful, we would most likely want to integrate a Transform into a data loading workflow. So far we used Array Tranforms, which operate on simple arrays of inputs. Now we will switch to Dictionary Tranforms, which are able to use the keys to filter the inputs. This can be useful, especially when training, because the filters can be used to deal with labels and input data separately. We are only doing inference in this case so we don't have labels but it can still be useful.

Firstly, we will redefine the Transforms to use the Dictionary equivalent functions. They always include a 'd' suffix e.g. LoadImage becomes LoadImaged. This is also a feature of PyTorch, which also has the concept of Tranforms, which can be used with Monai Transforms.

In [None]:
trans = Compose([LoadImaged(keys="image"), EnsureChannelFirstd(keys="image", channel_dim=-1)])

This has defined the Transform. Next we need to use this Transform in some sort of data flow. We will use a Monai Dataset to do this.

In [None]:
# Create an iterable from the image list
data_iterator = iter(data_list)

# Create a dataset with the data_iterator and the transforms that we defined
dataset = IterableDataset(data=data_iterator, transform=trans)

Now we can create a DataLoader to create some batches. This should yield 2 batch of 3 images, each of dimension (3, 2000, 2000)

In [None]:
dataloader = DataLoader(dataset=dataset, batch_size=3, num_workers=2)
for d in dataloader:
    print(d["image"].shape)

If we wanted to create a Dataset rather than an IterableDataset then, rather than using the iter() function, we'd need to wrap the data_list in a class that provides the necessary methods e.g.:

    class MyIterator:
        def __init__(self, data):
            self.data = data

        def __iter__(self):
            return iter(self.data)

        def __getitem__(self, index):
            return self.data[index]
This would allow us to do something like:

    data_iterator = MyIterator(data_list)
    
    dataset = Dataset(data=data_iterator, transform=trans)
    for _ in dataset:
        pass


We can also add other Tranforms, if needed. See if you can add a Transform to Normalize the Intensity of the image (use the [documentation](https://docs.monai.io/en/stable/transforms.html#intensity) to find the right Tranform) ([solution](solutions/nomalize.py))

In [None]:
# TODO Add a new transform to the Compose input list
trans = Compose([LoadImaged(keys="image"), EnsureChannelFirstd(keys="image", channel_dim=-1), NormalizeIntensityd(keys="image"),])

#Create the data iterator and dataset
data_iterator = iter(data_list)
dataset = IterableDataset(data=data_iterator, transform=trans)

# Load the data
dataloader = DataLoader(dataset=dataset, batch_size=3, num_workers=2)

# retrieve the first batch
d = first(dataloader)

img = np.array(d["image"][0])
# TODO - Remember that MatplotLib expects channels last - use numpy to do this
img = np.moveaxis(img, 0, -1)
plt.figure(figsize=(5,5))
plt.imshow(img)
plt.title('Nomalized Image')
plt.show()

For WholeSlide Image Processing, we don't get the patches or tiles as the initial input. Instead it is necessary to slice the large images up into smaller patches or tiles using a Sliding Window approach. Monai provides a number of tools to help with this process. To show this, we will start by defining a new datalist. This time we will supply the single WSI file, but provide some metadata that the Monai Dataset can use to only load specified tiles (at a specified reolution, which defaults to full resolution).

In [None]:
# Generate a data list that starts in the middle of the image and
# specifies patches of 256 pixel at intervals of 164 (overlapping) within a 2k by 2k region
width = 87647
height = 52434

x = 25000
y = 47000

data_list = []

for i in range(x, x+2000, 164):
    for j in range(y, y+2000, 164):
        data_list.append({"image": "data/tcga1.svs","location": [i, j], "size": [256,256]})

print(data_list[0:10])

Next, we can use Monai's PatchWSIDataset to handle the patch creation. Notice that this uses "cuCIM" by default as the image loader and this means that we don't actually need to create a Transform to do the Loading of Images. It also handles the ordering of the channels for us.

In [None]:
dataset = PatchWSIDataset(
    data_list,
    patch_size=256,
    patch_level=0,
    include_label=False,
    reader="cuCIM",
    additional_meta_keys=["location", "size"],
)
print("dataset created")

Now we can create the image batches with a DataLoader. The Monai DataLoaders are based on the equivalent PyTorch and are compatible with them. You can use the same features such as the number of concurrent workers to optimise the data loading.

In [None]:
# Load the image batches
dataloader = DataLoader(dataset=dataset, batch_size=3, num_workers=2)

# retrieve the first batch
d = first(dataloader)

Finally we can show the first batch of 3 images. You should see that they are adjacent tiles.

In [None]:
imgs = np.array(d["image"])
imgs = np.moveaxis(imgs, 1, -1)

print(imgs[0][46:-46,46:-46].shape)

fig, ax = plt.subplots(1, 3, figsize=(10, 10))
ax[0].imshow(imgs[0][46:-46,46:-46].astype(np.uint8))
ax[1].imshow(imgs[1][46:-46,46:-46].astype(np.uint8))
ax[2].imshow(imgs[2][46:-46,46:-46].astype(np.uint8))

So, we have shown a few ways that you can load and process batches of images, but there is plenty more to play around with if you have specific needs that are not met by what we have seen already. 

We will now move on to doing some inference using HoVerNet

HoVerNet can produces 3 branches of output which contain a nucleus probability map, a class probability map and a 2 channel Horizontal and Vertical distance from the centroid map. These outputs require some post-processing to convert the raw predictions into cleanly segmented and typed nuclei. 

![image](images/Post_Processing_Workflow.png)

Monai contains all of the necessary capabilities to be able to do this post-processing and it is accmplished with - guess what - *Transforms*!

In [None]:
log_dir = "outputs"

torch.cuda.set_device(0)
device = torch.device("cuda")

# Preprocessing transforms
pre_transforms = Compose(
    [
        CastToTyped(keys=["image"], dtype=torch.float32),
        ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
    ]
)

# Dataset of patches
dataset = PatchWSIDataset(
    data_list,
    patch_size=256,
    patch_level=0,
    include_label=False,
    transform=pre_transforms,
    reader="cuCIM",
    additional_meta_keys=["location"],
)
print("dataset created")

# Dataloader
data_loader = DataLoader(dataset, 
                        num_workers=2, 
                        batch_size=8, 
                        pin_memory=True,)
print("dataloader created")

To sanity-check the DataLoader, we can call the first() method, passing the DataLoader as the parameter. This will return the first batch.

In [None]:
first_sample = first(data_loader)

print("image: ")
print("    shape", first_sample["image"].shape)
print("    type: ", type(first_sample["image"]))
print("    dtype: ", first_sample["image"].dtype)
print("    location: ", first_sample["image"][0].meta["location"])
print(f"number of batches: {len(data_loader)}")

This should have produced a batch of 8 images, each with 3 channels and height and width of 256.

Now we can create the HoVerNet model, specifying the parameters (those provided are generally appropriate).

In [None]:
# Create model 
model = HoVerNet(
    mode="fast", 
    in_channels=3, 
    out_classes=5, 
    act=("relu", {"inplace": True}), 
    norm="batch", 
    dropout_prob=0.0,
)

# Specifies to use the first available GPU
device = torch.device("cuda:0")

#Load the pre-trained weights
model.load_state_dict(torch.load("data/fast.pt")['model'])
model = model.to(device)
model.eval()

inferer = SimpleInferer()

dataloader = DataLoader(dataset=dataset, batch_size=4, num_workers=8)
model.eval()
with torch.no_grad():
    for d in dataloader:
        pred = inferer(inputs=d["image"].to(device), network=model)
        imgs = np.array(pred["nucleus_prediction"][:4].cpu())
        break
        
print(imgs[0,:,:,0].shape)
fig, ax = plt.subplots(1, 4, figsize=(10, 10))
ax[0].imshow(imgs[0,1,:,:])
ax[1].imshow(imgs[1,1,:,:])
ax[2].imshow(imgs[2,1,:,:])
ax[3].imshow(imgs[3,1,:,:])


So, we now have some inference results from HoVerNet, but this is without any postprocessing. Preprocessing is already invoked when the data is loaded, to normalize the intensity. 

One important thing to notice is that the outpput size from HoVerNet is smaller (164 x 164) than the input size (256 x 256). This is because the series of convolutions reduce the input size. To compensate for this, we need to ingest tiles with overlapping borders (46 pixels on each side). The benefit of this approach is that we get less pronounced tile border artefacts.

Let's now define some post-processing to clean up the inference predictions. 


## Setup a post-processing pipeline to convert the predictions into the desired outputs
Here we use several transforms to turn pixel-level predictions into maps, contours and images that can be saved to disk:
1. `GenerateWatershedMaskd` Creates a binary mask within which to compute watershed
2. `GenerateInstanceBorderd` Generate an instance border using a horizontal and vertical (hover) distance map
3. `GenerateDistanceMapd` Within a segmentation region, computes the distance from the centre for horizontal and vertical axes
4. `GenerateWatershedMarkersd` Generate markers to be used in `Watershed` algorithm
6. `Watershedd` Uses the watershed algorithm to link pixels to specific object instances

We could implement these directly or, for an easier life, we can just use the predefined HoVerNet PostProcessing pipeline

In [None]:
# Postprocessing transforms
post_transforms = Compose(
    [
        HoVerNetInstanceMapPostProcessingd(sobel_kernel_size=21, marker_threshold=0.4, marker_radius=2),
        HoVerNetNuclearTypePostProcessingd(),
    ]
)

In [None]:
dataloader = DataLoader(dataset=dataset, batch_size=4, num_workers=8)
model.eval()
out=[]

with torch.no_grad():
    for d in dataloader:
        pred = inferer(inputs=d["image"].to(device), network=model)

        nu = np.array(pred["nucleus_prediction"].cpu())
        hv = np.array(pred["horizontal_vertical"].cpu())
        tp = np.array(pred["type_prediction"].cpu())
        
        for i in range(len(nu)):
            inputs =  {"nucleus_prediction": nu[i], "horizontal_vertical": hv[i], "type_prediction": tp[i]}
            out.append(post_transforms(inputs))  
                       
        break
        
# Plot out the cleaned up instance map
fig, ax = plt.subplots(1, 4, figsize=(10, 10))
ax[0].imshow(out[0]["instance_map"].squeeze())
ax[1].imshow(out[1]["instance_map"].squeeze())
ax[2].imshow(out[2]["instance_map"].squeeze())
ax[3].imshow(out[3]["instance_map"].squeeze())

In [None]:
# Plot out the type map (nucleus sub-types)
fig, ax = plt.subplots(1, 4, figsize=(10, 10))
ax[0].imshow(out[0]["type_map"].squeeze())
ax[1].imshow(out[1]["type_map"].squeeze())
ax[2].imshow(out[2]["type_map"].squeeze())
ax[3].imshow(out[3]["type_map"].squeeze())

If we look at the dictionary in each output row, we can see that the post-processing has created some additional information (the last three keys).

In [None]:
out[0].keys()

If we look in the instance_info item, we can find the contours and centroids of each nuclei instance found.

In [None]:
out[0]["instance_info"].keys()

This reveals that for the first patch, there are 4 nuclei instance with 1-based numeric keys. Looking at the first of these reveals the metadata that the post-processing has computed for each nuclei instance. 

In [None]:
out[0]["instance_info"][1].keys()

For the next task, you are going to create a function that can write the centroid coordinates and type of each nucleus into an array which we can then use to map the various nuclei types found in the 2k x 2k region. To do this you will need to examine the data that we looked at above and find the relevant items to write to disk. ([solution](solutions/get_centroids.py))

In [None]:
# The offset of the current patch - use it to position the
# centroids relative to the Region of Interest offset
current_tile_offset = ()

# Region of Interest Offset
roi_y = 47000
roi_x = 25000

# create a function that will be added
# to the post-processing Transforms
def get_centroids(inst_info):
    
    centroids=[]
    #TODO
    
        
    return centroids
        
# Postprocessing transforms
post_transform_with_centroids = Compose(
    [
        HoVerNetInstanceMapPostProcessingd(sobel_kernel_size=21, marker_threshold=0.4, marker_radius=2),
        HoVerNetNuclearTypePostProcessingd(),
        #TODO use Lambdad to call the get_centroids function,
    ]
)

dataloader = DataLoader(dataset=dataset, batch_size=3, num_workers=8)
model.eval()
out=[]

with torch.no_grad():
    
    centroids = []
    for d in dataloader:
        # TODO get the offset for each image in the batch
        offsets = ...
        pred = inferer(inputs=d["image"].to(device), network=model)
        nu = np.array(pred["nucleus_prediction"].cpu())
        hv = np.array(pred["horizontal_vertical"].cpu())
        tp = np.array(pred["type_prediction"].cpu())
        
        for i in range(len(nu)):
            # TODO set the current offset
            current_tile_offset = ...
            raw_results =  {"nucleus_prediction": nu[i], "horizontal_vertical": hv[i], "type_prediction": tp[i]}
            # TODO apply the postprocessing transform to the raw_results
            # TODO Add the output to a list of centroids in dictionarry format 

print("Completed")

In [None]:
print(centroids[0:10])

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from cucim import CuImage

# load the image header
wsi = CuImage("data/tcga1.svs")

# Get the resolution meta data
sizes=wsi.metadata["cucim"]["resolutions"]

# Load the image data at this resolution
wsi_thumb = wsi.read_region(location=(47000, 25000), size=(2000,2000), level=0)

centres = np.zeros((len(centroids),3),dtype=int)
label = [""] * len(centroids)

for i, centre in enumerate(centroids):
    #invert row coordinate and swap x/y to match coordinates used in images
    centres[i] = [centre["y"], 2000-centre["x"], centre["type"]-1]  

# plot
cmap = ListedColormap(["blue", "gold", "lawngreen", "red"])
fig, ax = plt.subplots(1,2,figsize = (10,5))

ax[0].scatter(centres[:,0], centres[:,1], c=centres[:,2], cmap=cmap)
ax[0].set(xlim=(0, 2000), xticks=np.arange(0, 2000, 500),
       ylim=(0, 2000), yticks=np.arange(0, 2000, 500))

ax[1].imshow(wsi_thumb)

This sort of output, showing cell types and locations can provide useful clinical information (e.g. counts and densities of mitotic cells) and, as we will see in the next notebook, can also be used to provide raw data for other types of analysis.

There are other abstractions that make using HoverNet even easier, such as using Ignite-based Evaluators, which include some nice features, 

In [None]:
def myfunc(engine, engine_state_batch):
    print("Iteration Update Event Fired!")
    
with tqdm(total=len(data_loader)) as pbar:

    # Class used for event handling
    class TestEvalIterEvents:
        def attach(self, engine):
            engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed)

        def _forward_completed(self, engine):
            pbar.update(engine.state.iteration)

    # Define some Handlers
    inference_handlers = [
        TestEvalIterEvents(),
    ]

    model = model.to(device)

    # Use an Ignite-based Evaluator
    inference = SupervisedEvaluator(
        device=device,
        val_data_loader=data_loader,
        network=model,
        iteration_update=myfunc,
        val_handlers=inference_handlers,
        amp=True,
    )
    
    # This event can be used to stop iteration during training or evaluation
    EarlyStopHandler(
        patience=20, score_function=lambda x: 1.0, epoch_level=False, trainer=inference
    ).attach(inference)

    inference.run()


You should be able to see how the declared events have printed out messages.

Monai has a lot more to explore and the pathology capabilities are steadily growing, so you are encouraged to look at the documentation and tutorials to learn more.

One of the items that we did not cover in this notebook was the use of thresholding to eliminate regions of the image that contain no tissue. This is something that can be done using features and functions of MONAI, but has not been used in this tutorial. If you want to experiment further, you could adapt the thresholding done in the previous notebook so that it can be used with the inference examples from this notebook.