# Automated Gleason Grading

## Imports

In [16]:
import os

import openslide
import random
import seaborn as sns
import cv2

# Torch packages
import torch
from torch.utils.data import Dataset
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torchvision.transforms import CenterCrop
from torch.nn import functional as F

# General packages
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import PIL
from IPython.display import Image, display

# Plotly for the interactive viewer (see last section)
import plotly.graph_objs as go

## Load Dataset

In [17]:
# Location of the training images

DATA_PATH = '../../ganz/data/panda_dataset'

# image and mask directories
data_dir = f'{DATA_PATH}/train_images'
mask_dir = f'{DATA_PATH}/train_label_masks'


# Location of training labels
train = pd.read_csv(f'{DATA_PATH}/train.csv').set_index('image_id')
test = pd.read_csv(f'{DATA_PATH}/test.csv')
submission = pd.read_csv(f'{DATA_PATH}/sample_submission.csv')

In [25]:
# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)
# # determine if we will be pinning memory during data loading
# PIN_MEMORY = True if DEVICE == "cuda" else False

cuda


## Visualize Data

### Sample Images

In [19]:
images = [
    '07a7ef0ba3bb0d6564a73f4f3e1c2293',
    '037504061b9fba71ef6e24c48c6df44d',
    '035b1edd3d1aeeffc77ce5d248a01a53',
    '059cbf902c5e42972587c8d17d49efed',
    '06a0cbd8fd6320ef1aa6f19342af2e68',
    '06eda4a6faca84e84a781fee2d5f47e1',
    '0a4b7a7499ed55c71033cefb0765e93d',
    '0838c82917cd9af681df249264d2769c',
    '046b35ae95374bfb48cdca8d7c83233f',
    '074c3e01525681a275a42282cd21cbde',
    '05abe25c883d508ecc15b6e857e59f32',
    '05f4e9415af9fdabc19109c980daf5ad',
    '060121a06476ef401d8a21d6567dee6d',
    '068b0e3be4c35ea983f77accf8351cc8',
    '08f055372c7b8a7e1df97c6586542ac8'
]

### Display Images

In [20]:
def display_images(slides): 
    f, ax = plt.subplots(5,3, figsize=(18,22))
    for i, slide in enumerate(slides):
        image = openslide.OpenSlide(os.path.join(data_dir, f'{slide}.tiff'))
        # spacing = 1 / (float(image.properties['tiff.XResolution']) / 10000)
        patch = image.read_region((1780,1950), 0, (256, 256))
        ax[i//3, i%3].imshow(patch) 
        image.close()       
        ax[i//3, i%3].axis('off')
        
        image_id = slide
        data_provider = train.loc[slide, 'data_provider']
        isup_grade = train.loc[slide, 'isup_grade']
        gleason_score = train.loc[slide, 'gleason_score']
        ax[i//3, i%3].set_title(f"ID: {image_id}\nSource: {data_provider} ISUP: {isup_grade} Gleason: {gleason_score}")

    plt.show()

In [None]:
display_images(images)

### Visualize Masks

In [21]:
def display_masks(slides): 
    f, ax = plt.subplots(5,3, figsize=(18,22))
    for i, slide in enumerate(slides):
        
        mask = openslide.OpenSlide(os.path.join(mask_dir, f'{slide}_mask.tiff'))
        mask_data = mask.read_region((0,0), mask.level_count - 1, mask.level_dimensions[-1])
        cmap = matplotlib.colors.ListedColormap(['black', 'gray', 'green', 'yellow', 'orange', 'red'])

        ax[i//3, i%3].imshow(np.asarray(mask_data)[:,:,0], cmap=cmap, interpolation='nearest', vmin=0, vmax=5) 
        mask.close()       
        ax[i//3, i%3].axis('off')
        
        image_id = slide
        data_provider = train.loc[slide, 'data_provider']
        isup_grade = train.loc[slide, 'isup_grade']
        gleason_score = train.loc[slide, 'gleason_score']
        ax[i//3, i%3].set_title(f"ID: {image_id}\nSource: {data_provider} ISUP: {isup_grade} Gleason: {gleason_score}")
        f.tight_layout()
        
    plt.show()

In [None]:
display_masks(images)

## Patch Segmentation

### Initialize Globals

In [22]:
# define the number of channels in the input, number of classes,
# and number of levels in the U-Net model
NUM_CHANNELS = 1
NUM_CLASSES = 1
NUM_LEVELS = 3
# initialize learning rate, number of epochs to train for, and the batch size
INIT_LR = 0.001
NUM_EPOCHS = 40
BATCH_SIZE = 64
# define the input image dimensions
BATCH_WIDTH = 256
BATCH_HEIGHT = 256
# define threshold to filter weak predictions
THRESHOLD = 0.5

# define the path to the base output directory
BASE_OUTPUT = "../output"
# define the path to the output serialized model, model training
# plot, and testing image paths
MODEL_PATH = f"{BASE_OUTPUT}/unet_tgs_salt.pth"
PLOT_PATH = f"{BASE_OUTPUT}/plot.png"
TEST_PATHS = f"{BASE_OUTPUT}/test_paths.txt"

### Create Training Set

In [23]:
def create_training_set(slides):
  train_imgs = []
  train_masks = []
  for i, slide in enumerate(slides):
    image = openslide.OpenSlide(os.path.join(data_dir, f'{slide}.tiff'))
    mask = openslide.OpenSlide(os.path.join(mask_dir, f'{slide}_mask.tiff'))
    # print("img level_count: " + str(image.level_count))
    # print("img level 0 dimension x: " + str(image.dimensions[0]))
    # print("img level 0 dimension y: " + str(image.dimensions[1]))
    # print("img level 1 dimension x: " + str(image.level_dimensions[1][0]))
    # print("img level 1 dimension y: " + str(image.level_dimensions[1][1]))
    # print("img level 2 dimension x: " + str(image.level_dimensions[2][0]))
    # print("img level 2 dimension y: " + str(image.level_dimensions[2][1]))
    # print("mask level 0 dimension x: " + str(mask.dimensions[0]))
    # print("mask level 0 dimension y: " + str(mask.dimensions[1]))
    max_x = image.dimensions[0] - (image.dimensions[0] % BATCH_WIDTH)
    max_y = image.dimensions[1] - (image.dimensions[1] % BATCH_HEIGHT)
    x = 0
    while x < max_x:
      y = 0
      while y < max_y:
        train_imgs.append(image.read_region((x,y), 0, (BATCH_WIDTH, BATCH_HEIGHT)))
        train_masks.append(mask.read_region((x,y), 0, (BATCH_WIDTH, BATCH_HEIGHT)))
        y+=BATCH_HEIGHT
      x+=BATCH_WIDTH
  return train_imgs, train_masks

In [24]:
create_training_set(images[1:2])

img level 0 dimension x: 10496
img level 0 dimension y: 24832
mask level 0 dimension x: 10496
mask level 0 dimension y: 24832


([], [])

### Dataset Class

In [None]:
class SegmentationDatasetv0(Dataset):
	def __init__(self, img_list, mask_list):
		self.img_list = img_list
		self.mask_list = mask_list

	def __len__(self):
		# return the number of total samples contained in the dataset
		return len(self.img_list)
    
	def __getitem__(self, idx):
		# grab the image and mask from the current index
		img = self.img_list[idx]
		mask = self.mask_list[idx]
		# return a tuple of the image and its mask
		return (image, mask)

### UNET Modules

In [None]:
class Block(Module):
	def __init__(self, inChannels, outChannels):
		super().__init__()
		# store the convolution and RELU layers
		self.conv1 = Conv2d(inChannels, outChannels, 3)
		self.relu = ReLU()
		self.conv2 = Conv2d(outChannels, outChannels, 3)
    
	def forward(self, x):
		# apply CONV => RELU => CONV block to the inputs and return it
		return self.conv2(self.relu(self.conv1(x)))

In [None]:
class Encoder(Module):
	def __init__(self, channels=(3, 16, 32, 64)):
		super().__init__()
		# store the encoder blocks and maxpooling layer
		self.encBlocks = ModuleList(
			[Block(channels[i], channels[i + 1])
			 	for i in range(len(channels) - 1)])
		self.pool = MaxPool2d(2)
    
	def forward(self, x):
		# initialize an empty list to store the intermediate outputs
		blockOutputs = []
		# loop through the encoder blocks
		for block in self.encBlocks:
			# pass the inputs through the current encoder block, store
			# the outputs, and then apply maxpooling on the output
			x = block(x)
			blockOutputs.append(x)
			x = self.pool(x)
		# return the list containing the intermediate outputs
		return blockOutputs

In [None]:
class Decoder(Module):
	def __init__(self, channels=(64, 32, 16)):
		super().__init__()
		# initialize the number of channels, upsampler blocks, and
		# decoder blocks
		self.channels = channels
		self.upconvs = ModuleList(
			[ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
			 	for i in range(len(channels) - 1)])
		self.dec_blocks = ModuleList(
			[Block(channels[i], channels[i + 1])
			 	for i in range(len(channels) - 1)])

	def forward(self, x, encFeatures):
		# loop through the number of channels
		for i in range(len(self.channels) - 1):
			# pass the inputs through the upsampler blocks
			x = self.upconvs[i](x)
			# crop the current features from the encoder blocks,
			# concatenate them with the current upsampled features,
			# and pass the concatenated output through the current
			# decoder block
			encFeat = self.crop(encFeatures[i], x)
			x = torch.cat([x, encFeat], dim=1)
			x = self.dec_blocks[i](x)
		# return the final decoder output
		return x
    
	def crop(self, encFeatures, x):
		# grab the dimensions of the inputs, and crop the encoder
		# features to match the dimensions
		(_, _, H, W) = x.shape
		encFeatures = CenterCrop([H, W])(encFeatures)
		# return the cropped features
		return encFeatures

In [None]:
class UNet(Module):
	def __init__(self, encChannels=(3, 16, 32, 64),
		 decChannels=(64, 32, 16),
		 nbClasses=1, retainDim=True,
		 outSize=(BATCH_HEIGHT,  BATCH_WIDTH)):
		super().__init__()
		# initialize the encoder and decoder
		self.encoder = Encoder(encChannels)
		self.decoder = Decoder(decChannels)
		# initialize the regression head and store the class variables
		self.head = Conv2d(decChannels[-1], nbClasses, 1)
		self.retainDim = retainDim
		self.outSize = outSize

  def forward(self, x):
		# grab the features from the encoder
		encFeatures = self.encoder(x)
		# pass the encoder features through decoder making sure that
		# their dimensions are suited for concatenation
		decFeatures = self.decoder(encFeatures[::-1][0],
			encFeatures[::-1][1:])
		# pass the decoder features through the regression head to
		# obtain the segmentation mask
		map = self.head(decFeatures)
		# check to see if we are retaining the original output
		# dimensions and if so, then resize the output to match them
		if self.retainDim:
			map = F.interpolate(map, self.outSize)
		# return the segmentation map
		return map