# Automated Gleason Grading

## Imports

In [1]:
import os

import openslide
import seaborn as sns
import cv2

# Torch packages
import torch
from torch.utils.data import Dataset
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torchvision.transforms import CenterCrop
from torch.nn import functional as F
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms

# General packages
import pandas as pd
import numpy as np
import numpy.random as random
import matplotlib
import matplotlib.pyplot as plt
import PIL
import math
import time
from IPython.display import Image, display
from tqdm import tqdm

# Plotly for the interactive viewer
import plotly.graph_objs as go

  from .autonotebook import tqdm as notebook_tqdm


## Load Dataset

In [3]:
# Location of the training images

DATA_PATH = '../../ganz/data/panda_dataset'

# image and mask directories
data_dir = f'{DATA_PATH}/train_images'
mask_dir = f'{DATA_PATH}/train_label_masks'


# Location of training labels
train = pd.read_csv(f'{DATA_PATH}/train.csv').set_index('image_id')
test = pd.read_csv(f'{DATA_PATH}/test.csv')
submission = pd.read_csv(f'{DATA_PATH}/sample_submission.csv')

# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)
# # determine if we will be pinning memory during data loading
# PIN_MEMORY = True if DEVICE == "cuda" else False

cuda


## Visualize Data

### Sample Images

In [4]:
# all_train_img_names = [
#     '07a7ef0ba3bb0d6564a73f4f3e1c2293',
#     '037504061b9fba71ef6e24c48c6df44d',
#     '035b1edd3d1aeeffc77ce5d248a01a53',
#     '059cbf902c5e42972587c8d17d49efed',
#     '06a0cbd8fd6320ef1aa6f19342af2e68',
#     '06eda4a6faca84e84a781fee2d5f47e1',
#     '0a4b7a7499ed55c71033cefb0765e93d',
#     '0838c82917cd9af681df249264d2769c',
#     '046b35ae95374bfb48cdca8d7c83233f',
#     '074c3e01525681a275a42282cd21cbde',
#     '05abe25c883d508ecc15b6e857e59f32',
#     '05f4e9415af9fdabc19109c980daf5ad',
#     '060121a06476ef401d8a21d6567dee6d',
#     '068b0e3be4c35ea983f77accf8351cc8',
#     '08f055372c7b8a7e1df97c6586542ac8'
# ]

all_train_img_names = list(train.index)
print(all_train_img_names[:10])

['0005f7aaab2800f6170c399693a96917', '000920ad0b612851f8e01bcc880d9b3d', '0018ae58b01bdadc8e347995b69f99aa', '001c62abd11fa4b57bf7a6c603a11bb9', '001d865e65ef5d2579c190a0e0350d8f', '002a4db09dad406c85505a00fb6f6144', '003046e27c8ead3e3db155780dc5498e', '0032bfa835ce0f43a92ae0bbab6871cb', '003a91841da04a5a31f808fb5c21538a', '003d4dd6bd61221ebc0bfb9350db333f']


### Display Images

In [20]:
def display_images(slides): 
    f, ax = plt.subplots(5,3, figsize=(18,22))
    for i, slide in enumerate(slides):
        image = openslide.OpenSlide(os.path.join(data_dir, f'{slide}.tiff'))
        # spacing = 1 / (float(image.properties['tiff.XResolution']) / 10000)
        patch = image.read_region((1780,1950), 0, (256, 256))
        ax[i//3, i%3].imshow(patch) 
        image.close()       
        ax[i//3, i%3].axis('off')
        
        image_id = slide
        data_provider = train.loc[slide, 'data_provider']
        isup_grade = train.loc[slide, 'isup_grade']
        gleason_score = train.loc[slide, 'gleason_score']
        ax[i//3, i%3].set_title(f"ID: {image_id}\nSource: {data_provider} ISUP: {isup_grade} Gleason: {gleason_score}")

    plt.show()

In [None]:
display_images(all_train_img_names)

### Visualize Masks

In [21]:
def display_masks(slides): 
    f, ax = plt.subplots(5,3, figsize=(18,22))
    for i, slide in enumerate(slides):
        
        mask = openslide.OpenSlide(os.path.join(mask_dir, f'{slide}_mask.tiff'))
        mask_data = mask.read_region((0,0), mask.level_count - 1, mask.level_dimensions[-1])
        cmap = matplotlib.colors.ListedColormap(['black', 'gray', 'green', 'yellow', 'orange', 'red'])

        ax[i//3, i%3].imshow(np.asarray(mask_data)[:,:,0], cmap=cmap, interpolation='nearest', vmin=0, vmax=5) 
        mask.close()       
        ax[i//3, i%3].axis('off')
        
        image_id = slide
        data_provider = train.loc[slide, 'data_provider']
        isup_grade = train.loc[slide, 'isup_grade']
        gleason_score = train.loc[slide, 'gleason_score']
        ax[i//3, i%3].set_title(f"ID: {image_id}\nSource: {data_provider} ISUP: {isup_grade} Gleason: {gleason_score}")
        f.tight_layout()
        
    plt.show()

In [None]:
display_masks(all_train_img_names)

## Patch Segmentation

### Initialize Globals

In [5]:
# define the number of channels in the input, number of classes,
# and number of levels in the U-Net model
NUM_CHANNELS = 1
NUM_CLASSES = 1
NUM_LEVELS = 3
# initialize learning rate, number of epochs to train for, and the batch size
INIT_LR = 0.001
NUM_EPOCHS = 40
BATCH_SIZE = 64
# define the input image dimensions
PATCH_WIDTH = 256
PATCH_HEIGHT = 256
# define threshold to filter weak predictions
THRESHOLD = 0.5

# define the validation split
VAL_SPLIT = 0.85

# define the path to the base output directory
BASE_OUTPUT = "../output"
# define the path to the output serialized model, model training
# plot, and testing image paths
MODEL_PATH = f"{BASE_OUTPUT}/unet_tgs_salt.pth"
PLOT_PATH = f"{BASE_OUTPUT}/plot.png"
TEST_PATHS = f"{BASE_OUTPUT}/test_paths.txt"

### Create Training Set

In [6]:
def create_image_set(slides):
  imgs = []
  masks = []
  for i, slide in enumerate(slides):
    image = openslide.OpenSlide(os.path.join(data_dir, f'{slide}.tiff'))
    mask = openslide.OpenSlide(os.path.join(mask_dir, f'{slide}_mask.tiff'))
    # print("img level_count: " + str(image.level_count))
    # print("img level 0 dimension x: " + str(image.dimensions[0]))
    # print("img level 0 dimension y: " + str(image.dimensions[1]))
    # print("img level 1 dimension x: " + str(image.level_dimensions[1][0]))
    # print("img level 1 dimension y: " + str(image.level_dimensions[1][1]))
    # print("img level 2 dimension x: " + str(image.level_dimensions[2][0]))
    # print("img level 2 dimension y: " + str(image.level_dimensions[2][1]))
    # print("mask level 0 dimension x: " + str(mask.dimensions[0]))
    # print("mask level 0 dimension y: " + str(mask.dimensions[1]))
    max_x = image.dimensions[0] - (image.dimensions[0] % PATCH_WIDTH)
    max_y = image.dimensions[1] - (image.dimensions[1] % PATCH_HEIGHT)
    x = 0
    while x < max_x:
      y = 0
      while y < max_y:
        imgs.append(image.read_region((x,y), 0, (PATCH_WIDTH, PATCH_HEIGHT)))
        masks.append(mask.read_region((x,y), 0, (PATCH_WIDTH, PATCH_HEIGHT)))
        y+=PATCH_HEIGHT
      x+=PATCH_WIDTH
  return imgs, masks

In [24]:
create_image_set(all_train_img_names[1:2])

img level 0 dimension x: 10496
img level 0 dimension y: 24832
mask level 0 dimension x: 10496
mask level 0 dimension y: 24832


([], [])

### Dataset Class

In [7]:
class SegmentationDataset(Dataset):
  def __init__(self, wsi_names, pseudo_epoch_length: int = 1024):
    self.wsi_names = wsi_names
    self.pseudo_epoch_length = pseudo_epoch_length

    # opens all slides and stores them in slide_dict
    self.slide_dict = self.make_slide_dict(wsi_names=wsi_names)

    # samples a list of patch coordinates and annotations 
    self.sample_dict = self.sample_coord_list(wsi_names=self.wsi_names,
                                            pseudo_epoch_length=self.pseudo_epoch_length)

  def make_slide_dict(self, wsi_names):
    slide_dict = {}
    for wsi_name in tqdm(wsi_names, total=len(wsi_names), desc='Make Slide Dict'):
      if wsi_name not in slide_dict:
        slide_dict[wsi_name] = {}
        slide_dict[wsi_name]['slide'] = openslide.OpenSlide(os.path.join(data_dir, f'{wsi_name}.tiff'))
        slide_dict[wsi_name]['mask'] = openslide.OpenSlide(os.path.join(mask_dir, f'{wsi_name}_mask.tiff'))
        slide_dict[wsi_name]['size'] = slide_dict[wsi_name]['slide'].dimensions
    return slide_dict

  def sample_coord_list(self, wsi_names, pseudo_epoch_length):
    # sample random coordinates
    filenames, coords = self._sample_random_coords(pseudo_epoch_length)
    
    # bring everything in one dict
    sample_dict = {}
    for index, (filename, coord) in enumerate(zip(filenames, coords)):
      sample_dict[index] = {'filename': filename, 'coordinates': coord}

    return sample_dict

  def _sample_random_coords(self, pseudo_epoch_length):
    filenames = list(random.choice(self.wsi_names, size=pseudo_epoch_length, replace=True))
    coords = []
    for filenames in filenames:
      width, height = self.slide_dict[filename]['size']
      xy = list(random.randint(low=(0, 0), 
                              high=(width-PATCH_WIDTH, height-PATCH_HEIGHT),
                              size=2))
      coords.append(xy)
    return filenames, coords

  def __len__(self):
    # return the number of total samples
    return len(self.pseudo_epoch_length)

  def __getitem__(self, index):
    # grab the image from the current index
    coords = self.sample_dict[index]['coordinates'].copy()
    filename = self.sample_dict[index]['filename']

    # load patch and mask
    img = self.load_image(filename, coords, 'slide')
    mask = self.load_image(filename, coords, 'mask')

    return img, mask

  def load_image(self, filename, coords, type):
    """Loads an image patch from a slide and returns it as a numpy array."""
    slide = self.slide_dict[filename][type]
    img = slide.read_region(coords, size=(PATCH_WIDTH, PATCH_HEIGHT), level=0)
    return np.asarray(img, dtype=np.uint8)

### UNET Modules

In [8]:
class Block(Module):
	def __init__(self, inChannels, outChannels):
		super().__init__()
		# store the convolution and RELU layers
		self.conv1 = Conv2d(inChannels, outChannels, 3)
		self.relu = ReLU()
		self.conv2 = Conv2d(outChannels, outChannels, 3)
    
	def forward(self, x):
		# apply CONV => RELU => CONV block to the inputs and return it
		return self.conv2(self.relu(self.conv1(x)))

class Encoder(Module):
	def __init__(self, channels=(3, 16, 32, 64)):
		super().__init__()
		# store the encoder blocks and maxpooling layer
		self.encBlocks = ModuleList(
			[Block(channels[i], channels[i + 1])
			 	for i in range(len(channels) - 1)])
		self.pool = MaxPool2d(2)
    
	def forward(self, x):
		# initialize an empty list to store the intermediate outputs
		blockOutputs = []
		# loop through the encoder blocks
		for block in self.encBlocks:
			# pass the inputs through the current encoder block, store
			# the outputs, and then apply maxpooling on the output
			x = block(x)
			blockOutputs.append(x)
			x = self.pool(x)
		# return the list containing the intermediate outputs
		return blockOutputs

class Decoder(Module):
	def __init__(self, channels=(64, 32, 16)):
		super().__init__()
		# initialize the number of channels, upsampler blocks, and
		# decoder blocks
		self.channels = channels
		self.upconvs = ModuleList(
			[ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
			 	for i in range(len(channels) - 1)])
		self.dec_blocks = ModuleList(
			[Block(channels[i], channels[i + 1])
			 	for i in range(len(channels) - 1)])

	def forward(self, x, encFeatures):
		# loop through the number of channels
		for i in range(len(self.channels) - 1):
			# pass the inputs through the upsampler blocks
			x = self.upconvs[i](x)
			# crop the current features from the encoder blocks,
			# concatenate them with the current upsampled features,
			# and pass the concatenated output through the current
			# decoder block
			encFeat = self.crop(encFeatures[i], x)
			x = torch.cat([x, encFeat], dim=1)
			x = self.dec_blocks[i](x)
		# return the final decoder output
		return x
    
	def crop(self, encFeatures, x):
		# grab the dimensions of the inputs, and crop the encoder
		# features to match the dimensions
		(_, _, H, W) = x.shape
		encFeatures = CenterCrop([H, W])(encFeatures)
		# return the cropped features
		return encFeatures
		
class UNet(Module):
	def __init__(self, encChannels=(3, 16, 32, 64),
		 decChannels=(64, 32, 16),
		 nbClasses=1, retainDim=True,
		 outSize=(PATCH_HEIGHT,  PATCH_WIDTH)):
		super().__init__()
		# initialize the encoder and decoder
		self.encoder = Encoder(encChannels)
		self.decoder = Decoder(decChannels)
		# initialize the regression head and store the class variables
		self.head = Conv2d(decChannels[-1], nbClasses, 1)
		self.retainDim = retainDim
		self.outSize = outSize

	def forward(self, x):
		# grab the features from the encoder
		encFeatures = self.encoder(x)
		# pass the encoder features through decoder making sure that
		# their dimensions are suited for concatenation
		decFeatures = self.decoder(encFeatures[::-1][0],
			encFeatures[::-1][1:])
		# pass the decoder features through the regression head to
		# obtain the segmentation mask
		map = self.head(decFeatures)
		# check to see if we are retaining the original output
		# dimensions and if so, then resize the output to match them
		if self.retainDim:
			map = F.interpolate(map, self.outSize)
		# return the segmentation map
		return map

### Training

In [9]:
# all_train_imgs, all_train_masks = create_image_set(all_train_img_names)

# partition the data into training and validation splits using 85% of
# the data for training and the remaining 15% for validation
split_size = math.floor(VAL_SPLIT*len(all_train_img_names))
split = torch.utils.data.random_split(all_train_img_names,
	[split_size, len(all_train_img_names)-split_size], 
	generator=torch.Generator().manual_seed(42))

# unpack the data split
(train_img_names, val_img_names) = split
train_img_names = list(train_img_names)
val_img_names = list(val_img_names)
# print(train_img_names[:10])
# print(val_img_names[:10])

['8b36508994d132ac03c5effd255b6eaf', 'aca61df8852daa183fa8cbf30a2e8660', 'a4fe6a765a5d525015f97f2f967caa67', '18e2d22113d3bf1ef0575c90cb22fde6', '037f29b7d62479c7fb257903f401269a', '0fe439d806b88e850a6a69442d295530', '2e512e923950da10c0b4432d08ca0663', 'ef528624c588c2f9d04c79d400f1aacc', 'd4f83e0d9b46e89cd5e9da401f625bce', 'cc3d4d77585447466f27904725bd1d7b']
['8a62c7e1f8070d642c7de7156dcf55bb', 'ce04c1a8d8577bb9620620c0a62edb15', '375771fed439d26b5374aae3f6b1d653', '0ec5bf6cbb0a9195b4db46173203a848', 'ac07fd6531d7536124ceab1990d5227d', 'f35124ad7e0c9d8537a8e57966c06445', 'e692db297233f27312315e8a5c75c32a', 'e3177c4430b780c973ef857a320a211d', '92d641eafe819ad90148027f5da9afd3', '5c3aa6266bfcbf2f9bf057a0a5e39520']


In [10]:
# create the train and validation datasets
trainDS = SegmentationDataset(wsi_names=train_img_names, pseudo_epoch_length=1024)
valDS = SegmentationDataset(wsi_names=val_img_names, pseudo_epoch_length=1024)
print(f"[INFO] found {len(trainDS)} samples in the training set...")
print(f"[INFO] found {len(valDS)} samples in the validation set...")

# create the training and validation data loaders
trainLoader = DataLoader(trainDS, shuffle=True,
	batch_size=BATCH_SIZE, num_workers=4)
valLoader = DataLoader(valDS, shuffle=False,
	batch_size=BATCH_SIZE, num_workers=4)

[INFO] found 9023 wsi in the training set...
[INFO] found 1593 wsi in the test set...


In [11]:
# initialize our UNet model
unet = UNet().to(DEVICE)
# initialize loss function and optimizer
lossFunc = BCEWithLogitsLoss()
opt = Adam(unet.parameters(), lr=INIT_LR)
# calculate steps per epoch for training and validation set
trainSteps = len(trainDS) // BATCH_SIZE
valSteps = len(valDS) // BATCH_SIZE
# initialize a dictionary to store training history
H = {"train_loss": [], "val_loss": []}

# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(NUM_EPOCHS)):
	# set the model in training mode
	unet.train()
	# initialize the total training and validation loss
	totalTrainLoss = 0
	totalValLoss = 0

	# loop over the training set
	for (i, (x, y)) in enumerate(trainLoader):
    # send the input to the device
    (x, y) = (x.to(DEVICE), y.to(DEVICE))
    # perform a forward pass and calculate the training loss
    pred = unet(x)
    loss = lossFunc(pred, y)
    # first, zero out any previously accumulated gradients, then
    # perform backpropagation, and then update model parameters
    opt.zero_grad()
    loss.backward()
    opt.step()
    # add the loss to the total training loss so far
    totalTrainLoss += loss
    
	# switch off autograd
	with torch.no_grad():
		# set the model in evaluation mode
		unet.eval()
		# loop over the validation set
		for (x, y) in valLoader:
      # send the input to the device
      (x, y) = (x.to(DEVICE), y.to(DEVICE))
      # make the predictions and calculate the validation loss
      pred = unet(x)
      totalValLoss += lossFunc(pred, y)

	# calculate the average training and validation loss
	avgTrainLoss = totalTrainLoss / trainSteps
	avgValLoss = totalValLoss / valSteps
	# update our training history
	H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
	H["val_loss"].append(avgValLoss.cpu().detach().numpy())
	# print the model training and validation information
	print("[INFO] EPOCH: {}/{}".format(e + 1, NUM_EPOCHS))
	print("Train loss: {:.6f}, Val loss: {:.4f}".format(
		avgTrainLoss, avgValLoss))

# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
	endTime - startTime))

[INFO] training the network...


  0%|          | 0/40 [04:34<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# plot the training loss
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(PLOT_PATH)
# serialize the model to disk
torch.save(unet, MODEL_PATH)

## Testing

In [None]:
def prepare_plot(origImage, origMask, predMask):
	# initialize our figure
	figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
	# plot the original image, its mask, and the predicted mask
	ax[0].imshow(origImage)
	ax[1].imshow(origMask)
	ax[2].imshow(predMask)
	# set the titles of the subplots
	ax[0].set_title("Image")
	ax[1].set_title("Original Mask")
	ax[2].set_title("Predicted Mask")
	# set the layout of the figure and display it
	figure.tight_layout()
	figure.show()

def make_predictions(model, test_image, test_mask):
	# set model to evaluation mode
	model.eval()
	# turn off gradient tracking
	with torch.no_grad():
		# resize the image and make a copy of it for visualization
		# test_image = cv2.resize(test_image, (PATCH_WIDTH, PATCH_HEIGHT))
		orig = test_image.copy()

  	# make the channel axis to be the leading one, add a batch
		# dimension, create a PyTorch tensor, and flash it to the
		# current device
		test_image = np.transpose(test_image, (2, 0, 1))
		test_image = np.expand_dims(test_image, 0)
		test_image = torch.from_numpy(test_image).to(DEVICE)
		# make the prediction, pass the results through the sigmoid
		# function, and convert the result to a NumPy array
		predMask = model(test_image).squeeze()
		predMask = torch.sigmoid(predMask)
		predMask = predMask.cpu().numpy()
		# filter out the weak predictions and convert them to integers
		predMask = (predMask > THRESHOLD) * 255
		predMask = predMask.astype(np.uint8)
		# prepare a plot for visualization
		prepare_plot(orig, gtMask, predMask)

In [None]:
# # get test_img ids
# test_img_names = list(test.index)
# print(test_img_names[:10])

# random_test_img_names = random.choice(test_img_names, size=10)

# # # create the train and validation datasets
# # trainDS = SegmentationDataset(wsi_names=random_test_img_names, pseudo_epoch_length=1024)
# # valDS = SegmentationDataset(wsi_names=val_img_names, pseudo_epoch_length=1024)
# # print(f"[INFO] found {len(trainDS)} samples in the training set...")
# # print(f"[INFO] found {len(valDS)} samples in the validation set...")

# # create the training and validation data loaders
# trainLoader = DataLoader(trainDS, shuffle=True,
# 	batch_size=BATCH_SIZE, num_workers=4)
# valLoader = DataLoader(valDS, shuffle=False,
# 	batch_size=BATCH_SIZE, num_workers=4)

# # load the image paths in our testing file and randomly select 10 image paths
# print("[INFO] loading up test images...")
# random_test_img_names = random.choice(test_img_names, size=10)
# test_imgs, test_masks = create_image_set(random_test_img_names)

# # load our model from disk and flash it to the current device
# print("[INFO] load up model...")
# unet = torch.load(MODEL_PATH).to(DEVICE)

# # iterate over the randomly selected test image paths
# for test_img, test_mask in zip(test_imgs, test_masks):
# 	# make predictions and visualize the results
# 	make_predictions(unet, test_img, test_mask)