In [None]:
################################################################################
#Written by Sean Harris
#
# This notebook contains code to train a CNN on preprocessed gigapixel cancer
# biopsy tissue slides.
################################################################################

###############################################################################
#NOTES
#Looks like google drive doesn't like you rapidly sampling from publicly shared
#files. So instead I need to write code block that copies like 100 GB images, 
#trains on them for maybe 10 epochs, then deletes and loads new ones. Should 
#take a few hours to load each group. Or just do that manually and train normally.

#Also appears that its not learning so preprocessing that creates patch coords
#is probably broke. Seems to find ttissue OK so its probably the part that labels


In [None]:
#IMPORTS

!apt-get install openslide-tools
!pip install openslide-python
 
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from openslide import open_slide, __library_version__ as openslide_version
import os
from PIL import Image
from skimage.color import rgb2gray
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os
import gdown
from zipfile import ZipFile
import math
import random
import sklearn
from skimage.transform import resize
import cv2
from tensorflow.keras.applications.inception_v3 import preprocess_input
import random
import copy
import scipy
import torch
import torch.nn.functional as F
from skimage import feature
from keras import metrics
import time

 
from google.colab import drive
drive.mount("/content/drive")
slide_root='/content/drive/My Drive/ADL_Final/'

In [None]:
#USEFUL FUNCTIONS

def read_slide(slide, x, y, level, width, height, as_float=False):
    im = slide.read_region((x,y), level, (width, height))
    im = im.convert('RGB') # drop the alpha channel
    if as_float:
        im = np.asarray(im, dtype=np.float32)
    else:
        im = np.asarray(im)
    assert im.shape == (height, width, 3)
    return im

def find_tissue_pixels(image, intensity=0.8):
    im_gray = rgb2gray(image)
    assert im_gray.shape == (image.shape[0], image.shape[1])
    indices = np.where(im_gray <= intensity)
    return list(zip(indices[0], indices[1]))

def aug(image1,image2):
    num_rot = random.randint(1,4)
    image1 = np.rot90(image1,num_rot)
    image2 = np.rot90(image2,num_rot)

    if random.random() > .5:
        image1 = np.fliplr(image1)
        image2 = np.fliplr(image2)
    if random.random() > .5:
        image1 = np.flipud(image1)
        image2 = np.flipud(image2)
    return image1,image2

def context_coords(j,i,level,slide):
    input_shape = (299,299,3)
    x_level_0 = int(j * input_shape[0] * slide.level_downsamples[level])
    y_level_0 = int(i * input_shape[1] * slide.level_downsamples[level])

    x_width_0 = input_shape[0] * slide.level_downsamples[level]
    y_width_0 = input_shape[1] * slide.level_downsamples[level]

    x_diff = int(np.floor(.5 * x_width_0))
    y_diff = int(np.floor(.5 * y_width_0))

    x_pos = max(0,x_level_0 - x_diff) 
    y_pos = max(0,y_level_0 - y_diff)

    x_padded_border = slide.level_dimensions[0][0] - input_shape[0] * slide.level_downsamples[level + 1]
    y_padded_border = slide.level_dimensions[0][1] - input_shape[1] * slide.level_downsamples[level + 1]

    x_pos = int(np.floor(min(x_pos,x_padded_border)))
    y_pos = int(np.floor(min(y_pos,y_padded_border)))

    return x_pos,y_pos

def int_to_path(k):
    if k >= 100:
        slide_path = slide_root + 'Copy of tumor_' + str(k) + '.tif'
        tumor_mask_path = slide_root + 'Copy of tumor_' + str(k) + '_mask.tif'
    elif k >= 10:
        slide_path = slide_root + 'Copy of tumor_0' + str(k) + '.tif'
        tumor_mask_path = slide_root + 'Copy of tumor_0' + str(k) + '_mask.tif'
    elif k >= 0:
        slide_path = slide_root + 'Copy of tumor_00' + str(k) + '.tif'
        tumor_mask_path = slide_root + 'Copy of tumor_00' + str(k) + '_mask.tif'
    return slide_path,tumor_mask_path

In [None]:
#GENERATOR

def generate_patches(slide_root='/content/drive/My Drive/ADL_Final/data/',
                     patch_root='/content/drive/My Drive/ADL_Final/coords/2/',
                     level = 2,
                     batch_size=32):
    inputs      = []
    targets     = []
    contexts    = []
    batch_count = 0
    slides = os.listdir(slide_root)
    while True:
        for slide_name in slides: 
            
            slide_path = slide_root + slide_name
            slide = open_slide(slide_path)

            pos_patch_coords_path = patch_root + slide_name[0:-4] + '_tumor.npy'
            neg_patch_coords_path = patch_root + slide_name[0:-4] + '_normal.npy'

            if not os.path.isfile(pos_patch_coords_path) or not os.path.isfile(neg_patch_coords_path):
                continue

            pos_patches = np.load(pos_patch_coords_path).tolist()
            neg_patches = np.load(neg_patch_coords_path).tolist()
            np.random.shuffle(neg_patches)
            np.random.shuffle(pos_patches)

            input_shape = (299,299,3)

            while len(pos_patches) and len(neg_patches):
                if random.random() < .5:
                    coords = neg_patches.pop()
                    j = coords[0]
                    i = coords[1]
                    label = 0
                else:
                    coords = pos_patches.pop()
                    j = coords[0]
                    i = coords[1]
                    label = 1

                patch = read_slide(slide, 
                                x= int(j * input_shape[0] * slide.level_downsamples[level]), 
                                y= int(i * input_shape[1] * slide.level_downsamples[level]), 
                                level=level, 
                                width=input_shape[0], 
                                height=input_shape[1])

                x_pos,y_pos = context_coords(j,i,2,slide)
                patch_context = read_slide(slide, 
                                        x= x_pos, 
                                        y= y_pos, 
                                        level=level + 1, 
                                        width=input_shape[0], 
                                        height=input_shape[1])

                patch, patch_context = aug(patch,patch_context)

                inputs.append(patch)
                contexts.append(patch_context)
                targets.append(label)
                batch_count += 1
                if batch_count >= batch_size: 
                    yield ([np.array(inputs),np.array(contexts)], np.array(targets))
                    inputs = []
                    targets = []
                    contexts = []
                    batch_count = 0


In [None]:
#BUILD KERAS CNN MODEL
 
input1 = keras.Input(shape=(299, 299, 3))
input2 = keras.Input(shape=(299, 299, 3))
 
inception1 = keras.applications.InceptionV3(include_top=False, pooling='max')
inception1._name = 'inception1'
inception2 = keras.applications.InceptionV3(include_top=False, pooling='max')
inception2._name = 'inception2'
 
flat1 = (inception1(preprocess_input(input1)))
flat2 = (inception2(preprocess_input(input2))) 
 
concat = keras.layers.concatenate([flat1,flat2])
dropout = keras.layers.Dropout(0.5)(concat)
 
output = keras.layers.Dense(1, activation='sigmoid')(dropout)
 
model = keras.Model([input1,input2],output)
 
lr = keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate = .002,
    decay_steps = 6000,
    end_learning_rate = 0.0000001,
    power=1.0,
)

opt = tf.keras.optimizers.RMSprop(
    learning_rate=lr,
    rho=0.9,
    momentum=.1,
    #epsilon=1,
    

)
model.compile(
    optimizer=opt,
    loss='binary_crossentropy',
    metrics=[
             metrics.BinaryAccuracy(),
             metrics.AUC(),
            
    ],
)
model.summary()
#keras.utils.plot_model(model)

In [None]:
model.load_weights('/content/drive/My Drive/ADL_Final/checkpoint7/weights8.h5')

In [None]:
#TRAIN THE CNN

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='/content/drive/My Drive/ADL_Final/checkpoint7/weights9.h5',
    save_weights_only=True,
    monitor='auc',
    mode='max',
    save_best_only=True)

batch_size = 32
total_tumor_patches = 0
for path in os.listdir('/content/drive/My Drive/ADL_Final/coords/2/'):
    if path[18:-4] == 'tumor':
        pos_patches = np.load('/content/drive/My Drive/ADL_Final/coords/2/' + path).tolist()
        total_tumor_patches += len(pos_patches)

steps = np.floor((2 * total_tumor_patches)/batch_size)

model.fit( 
    x = generate_patches(batch_size=batch_size, level=2),
    epochs = 10, 
    steps_per_epoch = steps,
    callbacks=[model_checkpoint_callback],
    shuffle=False
)

In [None]:
a = open_slide('/content/drive/My Drive/ADL_Final/Copy of tumor_091.tif')
plt.imshow(a.get_thumbnail((300,300)))