# Uncorking the I/O Bottleneck of Bio-Imaging

## Part 2 - Feeding your DL Pipelines with DASK

In the first section, you saw how we can speed up the loading and decoding of large images by using a high-performance image loader such as cuCIM. You also saw how we can use multi-threading and multiprocessing to reduce the latency of loading a large image. So long as the image format and loader supports loading regions of interest then we can use different processes or threads to simultaneously get different parts of the image into memory. 

Often, the image loading can be the bottleneck that slows the whole processing pipeline down. Of course it is often not just the loading that we need to do. There may be a need to do some preprocessing on the image and, for digital pathology, we may need to threshold each region to ensure that we are not wasting time processing empty or background regions of the Whole Slide. There may also be some image transformation or augmentation to do. All of these operations can become the part that slows everything else down and results in under-utilised GPUs, if not dealt with efficiently.

We are going to work through an example in which we use a well-known network architecture - the [variational autoencoder](https://en.wikipedia.org/wiki/Variational_autoencoder#:~:text=In%20machine%20learning%2C%20a%20variational,models%20and%20variational%20Bayesian%20methods.https://en.wikipedia.org/wiki/Variational_autoencoder#:~:text=In%20machine%20learning%2C%20a%20variational,models%20and%20variational%20Bayesian%20methods.) - to learn how to distill each patch into a compact representation and then reconstruct the input from this compact representation to test just how well the network is performing. This technique can be a really useful way of converting an unwieldy high-dimensional image into a more compact set of features, which are sufficient for certain analyses. Alternatively, we could simply downsample the image, but this process gets rid of a lot of important, fine-detailed information. On the other hand, because a VAE uses Deep Learning techniques, it is better at retaining the important information and discarding the less useful information (depending on the loss function used). 

We start off by importing the libraries we need Python to be able to use

In [None]:
import numpy as np
import random
import os
import matplotlib.pyplot as plt
from cucim import CuImage

Next we will import a couple of Dask-related dependencies and then create a local Dask 'cluster'. As you will see in the cell below we can specify a port with which to connect to a dashboard. We can also specify whether our cluster should use processes or threads (in the same way as they were used in the previous notebook) i.e. 'processes=True' means that we will be using multiprocessing.

In [None]:
from dask.distributed import as_completed
from dask.distributed import Client, LocalCluster

# Setup a local cluster.
cluster = LocalCluster(dashboard_address= 8789, processes=True)
client = Client(cluster)
client


You should see that once a client hase been created for the Dask cluster we can output some information about it, such as the total number of threads and workers that are available to the cluster. We can also use the address of the dashboard to access the metrics that Dask's dashboard provides.

When you are connected to a local machine then you can just copy and paste the Dashboard URL shown above into your browser or into the dashboard's URL search box at the top.
However, because we are connected to a cloud instance, you will need to paste a slightly different URL. This is formed by copying the first part of the address that is showing up for this page in your browser and then appending a semicolon, the port (8789) and the suffix '/status'. 
N.B. The URL may look a little different but you should take the part up to and including the .com

e.g. http://ec2-54-146-232-42.compute-1.amazonaws.com/lab/lab/tree/Notebook_2.ipynb

would be truncated to 'http://ec2-54-146-232-42.compute-1.amazonaws.com' and appended to  '[:8789/status](http://ec2-54-146-232-42.compute-1.amazonaws.com)'

to give: http://ec2-3-216-79-138.compute-1.amazonaws.com:8789/status

If you paste your version of this into the Dask Dashbaord URL you should see that the various dashboard options will appear as orange rectangles. Each time you click on one a new tab will be created with than specific metric displayed. You can drag and drop them into a panel on the right of this notebook. It is recommended to include the 'Progress' and 'Task Stream' items, which will show the progress of some later computations.

![Dask](images/dask.png)



As before, we can now load a thumbnail of the image to show what we will be working with. In this case we are working with an image from the [CAMELYON 17 challenge dataset](https://camelyon17.grand-challenge.org/data/https://camelyon17.grand-challenge.org/data/). The dataset  is open access and can be downloaded from a number of sources and are made available under [CCO](https://creativecommons.org/publicdomain/zero/1.0https://creativecommons.org/publicdomain/zero/1.0) 

The dataset contains whole-slide images (WSI) of hematoxylin and eosin (H&E) stained lymph node sections.

In [None]:
# Load the image header
input_file = "data/patient_100_node_0.tif"
wsi = CuImage(input_file)

# Extract the image metadata
sizes=wsi.metadata["cucim"]["resolutions"]
levels = sizes["level_count"]

# Dimensions of the whole Slide at full-resolution
w = sizes["level_dimensions"][0][0]
h = sizes["level_dimensions"][0][1]

# Dimensions of the Whole Slide at lowest resolution
wt = sizes["level_dimensions"][levels-1][0]
ht = sizes["level_dimensions"][levels-1][1]

# Load and decode the pixels at the lowest resolution
wsi_thumb = wsi.read_region(location=(0,0), size=(wt,ht), level=levels-1)

# Show the image
plt.figure(figsize=(10,10))
plt.imshow(wsi_thumb)
plt.title('patient_100_node_0.tif')
print("Width = {}, Height = {}".format(w, h))
plt.show()

So, you should notice that the majority of the image is actually empty background. This means that any processing we do on those regions is not going to provide any useful information and, because of the size of the image (with billions of pixels) we will need to split the image into patches to ingest it into our VAE.

The traditional way of accomplishing this might be to pre-tile the images, saving their coordinates in some sort of metadata file and discarding the empty regions. There are many different algorithms to ascertain whether each patch might be useful foreground information or uninformative background, but for this exercise we will use the variance of the pixel intensities. If all the pixels are the same, then the variance will be 0 - regardless of the specific pixel intensity (for each color channel).


### Task 1

Create a function that accepts a 3D array (Width, Height, Color Channels) and returns True if the input is above a fixed threshold and False if below it. The skeleton is provided below. Remember that we have 2 dimensions and 3 channels and we need to compute the variance across all of these values ([solution](solutions/solution2_1.py))

In [None]:
# evaluates whether the block contains tissue to analyse
def threshold(arr, threshold_value):
    
    # TODO - check whether there is sufficient variance in the input


To test this function we will load a few patches from the test image above and see whether they produce the correct result. You may need to change the default threshold value to get the required result.

In [None]:
patch1 = np.array(wsi.read_region(location=(41500,30000), size=(64,64), level=levels-1))
patch2 = np.array(wsi.read_region(location=(41500,2000), size=(64,64), level=levels-1))
patch3 = np.array(wsi.read_region(location=(80000,78000), size=(64,64), level=levels-1))

patch1_result = threshold(patch1)
patch2_result = threshold(patch2)
patch3_result = threshold(patch3)

print("patch1 > threshold is {}".format(patch1_result))
print("patch2 > threshold is {}".format(patch2_result))
print("patch3 > threshold is {}".format(patch3_result))
      
fig, ax = plt.subplots(1,3,figsize = (10,10))
ax[0].imshow(patch1)
ax[1].imshow(patch2)
ax[2].imshow(patch3)

if patch1_result == patch1_result == True and patch3_result == False:
    print("Correct!")
else:
    print("Not quite...")

So, hopefully, we now have a basic working threshold function. The next task is to create a pipeline to look at the whole slide and emit a list of all the tiles that are above the threshold so that we can do something with them.

In this next step we are going to introduce [DASK](https://docs.dask.org/en/stable/https://docs.dask.org/en/stable/), which is a very useful tool for breaking large tasks into lots of smaller chunks to reduce overall latency. This may sound quite familiar - because that is precisely what we were dong with Python's multiprocessing and multithreading tools in the previous exercises. DASK provides features that resemble some of these functions, but it also provides a swathe of other benefits including:

* A rich set of visualization tools to monitor the status of your running code
* Integrations and compatibility with many other tools from the Data Science ecosystem
* Abstractions that provide powerful but simple to use concurrency
 

When it comes to concurrency, DASK provides two main tools - [Futures](https://docs.dask.org/en/stable/futures.htmlhttps://docs.dask.org/en/stable/futures.html) and [Delayed](https://docs.dask.org/en/stable/delayed.htmlhttps://docs.dask.org/en/stable/delayed.html) functions.

Futures are used to asynchronously process results, with the results becoming available when the computation has completed. Delayed functions are used to 'lazily' compute values, as the results of prior computations or inputs become available.

Let's look at a toy example. Imagine that we want to sum a series of integers. Naively, you'd have to iterate over each element one at a time adding each element to the running total. The run time would be a factor of the number elements. A better way would be to add every other element to its neighbour iteratively until there is only one element left. This would bring the runtime down to log(N) time. By providing a few basic commands you can let Dask figure out the execution graph for you. Let's look at a concrete example

We can write the code to do the adding for us using a Dask Delayed function. This means that before the result is calculated a graph is constructed and Dask will map this graph onto the available compute (e.g. Processes, Threads or GPUs)


In [None]:
import dask
from dask import delayed

@dask.delayed
def add(x, y):
    return x + y

a = [1,2,3,4,5,6,7,8,9,10,11,12,13,15,15,16]
b = []
c = []
d = []

for i in range(0,16,2):
    b.append(add(a[i],a[i+1]))
    
for i in range(0,8,2):
    c.append(add(b[i],b[i+1]))
    
for i in range(0,4,2):
    d.append(add(c[i],c[i+1]))
    
e=add(d[0],d[1])
    
e.visualize()

At this point, no computation has been done - just the graph construction. By doing this up-front, a more efficient graph can be created. You can see that the graph shows how the additions at each phase can be done in parallel , but also how each subsequent addition depends only on its ancestors. To actually do the computation, we need to execute a compute() command

In [None]:
e.compute()

We can now apply the same technique to the task of loading the Whole Slide Image. In this example we will use Dask Futures. To do this we create batches of inputs to process and let Dask map them to available compute resources using the map function. Because this executes asynchronously, we need to wait until each individual task has completed before handing the result on for further processing. In this case, we will use the as_completed function, which yields results as they come in. 
This sort of dynamic execution allows for more efficient concurrent execution, which can significantly reduce overall latency.

Normally, when processing a Whole Slide Image, you'd probably threshold at a reduced resolution but in this case we are going to do it at full resolution, but break it into 2 steps. First of all we will threshold the whole image in 256 x 256 patches, then we will add the coordinates of those patches above the threshold to a list. When this stage has completed we will then do a second thresholding, but this time from 64x64 tiles within the list of above-threshold patches.

At the end of the process we will have a single list containing all the 64x64 tiles that are of interest for further analysis.

In [None]:
patch_size = 256
tile_size = 64

# iterate over a set of regions from which to threshold
def process_patch(start_loc_list):
    ps = patch_size
    slide = CuImage(input_file)
    res = []
    for start_loc in start_loc_list:
        # you can do usually do thresholding at higher reduction factor
        level = 0
        region = np.array(slide.read_region(start_loc, [ps , ps], level))
        if threshold(region):
            res.append((start_loc[0], start_loc[1]))
        
    return res

# As the results are processed, put them into a list
def compile_results(futures):
    patches = []

    for future in as_completed(futures):
        res1 = future.result()
        if res1:
            for patch in res1:
                patches.append(patch)
                
    return patches

When we execute the cell below it will run through each patch, and evaluate whether it is above the threshold we defined earlier. To reduce latency we let Dask map the various tiles to evaluate to the available workers (which might be threads, processes or GPU nodes).

Notice that the dashboard will show the status of each task and the state of each worker as the computation unfolds.

In [None]:
%%time
patch_size = 256
num_chunks = 64

start_loc_data = [(sx, sy)
                  for sy in range(0, h, patch_size)
                      for sx in range(0, w, patch_size)]

chunk_size = len(start_loc_data) // num_chunks

start_loc_list = [start_loc_data[i:i+chunk_size]  for i in range(0, len(start_loc_data), chunk_size)]
future_result1 = list(client.map(process_patch, start_loc_list))
patches = compile_results(future_result1)
                 
print("Number of 256 x 256 patches found = {}".format(len(patches)))

So, we just thresholded the whole WSI at full resolution!

### Task 2

For a finer-grained threshold on each 64x64 tile within these blocks, create a second phase of computation that will process the patches and output a similar list of (64x64) tiles. You should use the previous tile stage as a reference ([solution](solutions/solution2_2.py))

In [None]:
%%time

# TODO iterate over a set of patches from which to threshold
def process_tile(...):
    pass

# TODO create a list of tiles to threshold from each patch returned

# TODO map the tiles to threshold to the process tile function

# TODO process the results
tiles = compile_results(...)

print("Number of 64 x 64 tiles found = {}".format(len(tiles)))

You should see about 408,915 tiles (depending on how you set the threshold value)

Now we can load a few random tiles to check that they look okay. Hopefully they all contain some tissue!

In [None]:
fig, ax = plt.subplots(1,3,figsize = (15,15))
i = random.randint(0,len(tiles)-1)
ax[0].imshow(tiles[i][2])
i = random.randint(0,len(tiles)-1)
ax[1].imshow(tiles[i][2])
i = random.randint(0,len(tiles)-1)
ax[2].imshow(tiles[i][2])

One common technique for combatting the huge amount of data we need to ingest for a single whole slide image is to reduce each tile down to a much smaller set of features that capture the essence of the tile. We could use some statistical measures such as mean pixel intensity, variance etc. The problem is that these features may not capture the unique nature of the images we are using here. 

Instead, what we will do is to use a Variational AutoEncoder that has been trained in an unsupervised way to encode and then decode each tile, using a loss function that compares the original image to the decoded version. Once trained, we then just retain the encoder part of the network, which has reduced our inputs down to just 32 features (aka latent variables). These features effectively represent the principle components of the tiles.

This can make analysis of a Whole Slide a more tractable problem because we can hugely reduce the amount of information to process. By removing the background data and then converting each 64x64 tile into a 32x1 tensor we have reduced the information on this slide by a factor of about 12,000. That's quite a significant reduction but we are still left with 11,848,000 data points.

In this instance we are not going to train the VAE because it would take too long but you can examine the PyTorch model code and we will run it to do inference on a few tiles, so that you can see what its output looks like.

![VAE](images/VAE.png)

In [None]:
# Execute this cell to build the VAE model in PyTorch
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class UnFlatten(nn.Module):
    def forward(self, input, size=1024):
        return input.view(input.size(0), size, 1, 1)
    
class VAE2(nn.Module):
    def __init__(self, image_channels=3, h_dim=1024, z_dim=32):
        super(VAE2, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2),
            nn.ReLU(),
            Flatten()
        )
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
            nn.Sigmoid(),
        )
        
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.cuda.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def encode(self, x):
        h = self.encoder(x)
        z, mu, logvar = self.bottleneck(h)
        return z, mu, logvar

    def decode(self, z):
        z = self.fc3(z)
        #print(z.shape)
        z = self.decoder(z)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        z = self.decode(z)
        return z, mu, logvar
    
model2 = VAE2(image_channels=3)
model2 = torch.nn.DataParallel(model2, device_ids=[0])
model2.cuda()

reconstruction_function = nn.BCELoss()
reconstruction_function.size_average = False
def loss_function2(recon_x, x, mu, logvar):
    BCE = reconstruction_function(recon_x, x)

    # https://arxiv.org/abs/1312.6114 (Appendix B)
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    KLD /= 1600 * 32 * 32
    return BCE + KLD

optimizer2 = optim.Adam(model2.parameters(), lr=1e-3)


Next we can test the trained VAE with a few samples. To do this we can load the weights from a previously trained model which were saved into the 'M2_WEIGHTS.PT' file.

In [None]:
model2 = VAE2(image_channels=3)
model2 = torch.nn.DataParallel(model2, device_ids=[0])
dct = torch.load('data/M2_WEIGHTS.PT')
model2.load_state_dict(dct['state_dict'])
model2.cuda()

model2.eval()

bs = len(tiles) // 128

batch_list = [i for i in range(len(tiles))]
random.shuffle(batch_list)
batch=np.zeros((512,3,tile_size,tile_size),np.uint8)

if len(batch_list)>=512:

    for p in range(512):
        j = batch_list.pop()
        tile = np.moveaxis(tiles[j][2],2,0)
        batch[p]=tile
        

    tiles_cuda=torch.FloatTensor(batch).cuda()
    tiles_cuda=tiles_cuda/255

    recon_batch, _, _ = model2(tiles_cuda)

    tensor1 = recon_batch.cpu().detach().numpy() * 255
    tensor1 = tensor1.astype(np.uint8)
    tensor2 = tiles_cuda.cpu().numpy() * 255
    tensor2 = tensor2.astype(np.uint8)
    
    print("Real Inputs")
    fig, ax = plt.subplots(1,4,figsize = (10,10))
    ax[0].imshow(np.moveaxis(tensor2[20,:,:,:],0,2))
    ax[1].imshow(np.moveaxis(tensor2[25,:,:,:],0,2))
    ax[2].imshow(np.moveaxis(tensor2[120,:,:,:],0,2))
    ax[3].imshow(np.moveaxis(tensor2[127,:,:,:],0,2))
    plt.show()
    
    print("Regenerated")
    fig, ax = plt.subplots(1,4,figsize = (10,10))
    ax[0].imshow(np.moveaxis(tensor1[20,:,:,:],0,2))
    ax[1].imshow(np.moveaxis(tensor1[25,:,:,:],0,2))
    ax[2].imshow(np.moveaxis(tensor1[120,:,:,:],0,2))
    ax[3].imshow(np.moveaxis(tensor1[127,:,:,:],0,2))

    plt.show()

So, with only 32 values the VAE can do a reasonable job of representing each tile. This VAE has only been trained on a couple of WSIs but with more training (and more latent variables) you could get a much clearer rendering of the original. For this exercise we are not going to worry too much about the quality of the VAE.

For now, what has been done is that a Pandas Dataframe has been saved with the location of all of the tiles and the 32 latent variables computed by the VAE for each of those tiles. In the next notebook we will load up this Dataframe and see how it can be used for a variety of analyses.