In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']="1"
import glob
import openslide
import numpy as np
import pandas as pd
import imgaug as ia
import imgaug.augmenters as iaa
import matplotlib.pyplot as plt

from pathlib import Path
from collections import Counter
from collections import defaultdict
from multiprocessing import Queue, Process
from multiprocessing import Pool, cpu_count
from sklearn.model_selection import train_test_split

In [58]:
import tensorflow as tf
tf.enable_eager_execution()
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
session = tf.Session(config=config)
tf.keras.backend.set_session(session)

def chunker(seq, size):
        return (seq[pos:pos + size] for pos in range(0, len(seq), size))

def divide_round_up(n, d):
        return (n + (d - 1))//d    

def build_model():
    inputs = tf.keras.layers.Input((96,96,3))
    nasnet_model = tf.keras.applications.nasnet.NASNetMobile(include_top=False, input_tensor=inputs, weights='imagenet')
    nasnet_model.trainable=True

    x = nasnet_model(inputs)
    out1 = tf.keras.layers.GlobalMaxPooling2D()(x)
    out2 = tf.keras.layers.GlobalAveragePooling2D()(x)
    out3 = tf.keras.layers.Flatten()(x)
    out = tf.keras.layers.Concatenate(axis=-1)([out1, out2, out3])
    out = tf.keras.layers.Dropout(0.5)(out)
    out = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(out)

    model = tf.keras.models.Model(inputs, out)
    
    loss_fn = tf.keras.losses.binary_crossentropy
    metrics=[tf.keras.metrics.BinaryAccuracy(), 
            tf.keras.metrics.Precision(), 
            tf.keras.metrics.Recall()]
    
    model.compile(loss=loss_fn, metrics=metrics, optimizer=tf.keras.optimizers.Adam(lr=0.0001))
    
    return model
model = build_model()
    
def get_augmenter():
    sometimes = lambda aug: iaa.Sometimes(0.5, aug)
    seq = iaa.Sequential(
        [
            # apply the following augmenters to most images
            iaa.Fliplr(0.5),  # horizontally flip 50% of all images
            iaa.Flipud(0.2),  # vertically flip 20% of all images
            sometimes(iaa.Affine(
                scale={"x": (0.9, 1.1), "y": (0.9, 1.1)},
                # scale images to 80-120% of their size, individually per axis
                translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},  # translate by -20 to +20 percent (per axis)
                rotate=(-10, 10),  # rotate by -45 to +45 degrees
                shear=(-5, 5),  # shear by -16 to +16 degrees
                order=[0, 1],  # use nearest neighbour or bilinear interpolation (fast)
                cval=(0, 255),  # if mode is constant, use a cval between 0 and 255
                mode=ia.ALL  # use any of scikit-image's warping modes (see 2nd image from the top for examples)
            )),
            # execute 0 to 5 of the following (less important) augmenters per image
            # don't execute all of them, as that would often be way too strong
            iaa.SomeOf((0, 5),
                    [
                        sometimes(iaa.Superpixels(p_replace=(0, 1.0), n_segments=(20, 200))),
                        # convert images into their superpixel representation
                        iaa.OneOf([
                            iaa.GaussianBlur((0, 1.0)),  # blur images with a sigma between 0 and 3.0
                            iaa.AverageBlur(k=(3, 5)),
                            # blur image using local means with kernel sizes between 2 and 7
                            iaa.MedianBlur(k=(3, 5)),
                            # blur image using local medians with kernel sizes between 2 and 7
                        ]),
                        iaa.Sharpen(alpha=(0, 1.0), lightness=(0.9, 1.1)),  # sharpen images
                        iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0)),  # emboss images
                        # search either for all edges or for directed edges,
                        # blend the result with the original image using a blobby mask
                        iaa.SimplexNoiseAlpha(iaa.OneOf([
                            iaa.EdgeDetect(alpha=(0.5, 1.0)),
                            iaa.DirectedEdgeDetect(alpha=(0.5, 1.0), direction=(0.0, 1.0)),
                        ])),
                        iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.01 * 255), per_channel=0.5),
                        # add gaussian noise to images
                        iaa.OneOf([
                            iaa.Dropout((0.01, 0.05), per_channel=0.5),  # randomly remove up to 10% of the pixels
                            iaa.CoarseDropout((0.01, 0.03), size_percent=(0.01, 0.02), per_channel=0.2),
                        ]),
                        iaa.Invert(0.01, per_channel=True),  # invert color channels
                        iaa.Add((-2, 2), per_channel=0.5),
                        # change brightness of images (by -10 to 10 of original value)
                        iaa.AddToHueAndSaturation((-1, 1)),  # change hue and saturation
                        # either change the brightness of the whole image (sometimes
                        # per channel) or change the brightness of subareas
                        iaa.OneOf([
                            iaa.Multiply((0.9, 1.1), per_channel=0.5),
                            iaa.FrequencyNoiseAlpha(
                                exponent=(-1, 0),
                                first=iaa.Multiply((0.9, 1.1), per_channel=True),
                                second=iaa.ContrastNormalization((0.9, 1.1))
                            )
                        ]),
                        sometimes(iaa.ElasticTransformation(alpha=(0.5, 3.5), sigma=0.25)),
                        # move pixels locally around (with random strengths)
                        sometimes(iaa.PiecewiseAffine(scale=(0.01, 0.05))),
                        # sometimes move parts of the image around
                        sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1)))
                    ],
                    random_order=True
                    )
        ],
        random_order=True
    )
    return seq
seq = get_augmenter()

def sequential_batch_generator(pd_data, batch_size):
        while True:                
            for batch in chunker(pd_data, 32):
                data = [tf.keras.applications.nasnet.preprocess_input(x) for x in batch['image']]
                yield np.array(data), np.array(batch['class']).astype(np.int)

In [3]:
SLIDE_DIR = '/mnt/data/scans/AI scans/Mammy/'

In [None]:
def extract_slide_name_from_filename(filename, context=False):
    """ GET SLIDE NAME FROM PANDAS DATAFRAME FILENAME """
    parts = 4
    if context:
        parts = 5
    return os.path.splitext(os.path.basename(filename))[0].rsplit('-', parts)[0]

In [4]:
def get_slides(context_size):
    """ RETURN LIST OF PANDAS SLIDE COORDINATES """
    MAP_DIR = '/home/matejg/wsi_maps/Mammy/level0/'
    if not context_size or context_size in [0,1]:
        return glob.glob(MAP_DIR + 'normal/*.gz')
    else:
        return glob.glob(MAP_DIR + 'macro/*context{}.gz'.format(context_size))
    
def open_all_slides(slides, context=False):
    """ OPEN EVERY SLIDE ON THE DISK """
    open_slides = {}
    for slide in slides:
        slide_name = extract_slide_name_from_filename(slide, context)
        slide_fn = SLIDE_DIR + slide_name + '.mrxs'
        open_slides[slide_name] = openslide.open_slide(slide_fn)
    return open_slides

In [14]:
def split_slides(slides, ratio=0.15):
    """ SPLIT SLIDES INTO TRAIN, VALID, TEST"""
    test_size = int(len(slides) * ratio)
    train_val, test = train_test_split(slides, test_size=test_size)
    train, valid = train_test_split(train_val, test_size=test_size)
    return train, valid, test

def get_sampler(slides, context=False):
    """ PREPARE SAMPLER """
    sampler = {'normal': {}, 'cancer': {}}
    for slide in slides:
        slide_name = extract_slide_name_from_filename(slide, context)
        df = pd.read_pickle(slide)
        sampler['normal'][slide_name] = df[df['class'] == 0]['class']
        sampler['cancer'][slide_name] = df[df['class'] == 1]['class']
    return sampler

In [36]:
def generate_train_epoch_coords(sampler, size):
    """ GENERATE AMOUNT OF UNIQUE TRAIN SAMPLES """
    static_sample = defaultdict(set)
    for no in range(size):
        if (no+1) % 5000 == 0:
            print('{:,}\r'.format(no+1), end='')
        slide_name, coord, label = sample(sampler)
        while coord in static_sample[slide_name]:
            slide_name, coord, label = sample(sampler)
        static_sample[slide_name].add((coord, label))
    return static_sample

def sample(sampler):
    """ SAMPLE SINGLE VALID COORDINATE ACCORDING TO SAMPLING SCHEME"""
    label = np.random.choice(list(sampler.keys()))
    slide_name = np.random.choice(list(sampler[label].keys()))
    row_idx = np.random.randint(low=0, high=len(sampler[label][slide_name]))
    row = sampler[label][slide_name].iloc[[row_idx]]
    return slide_name, row.index[0], row.values[0]

In [12]:
def extract_epoch_tiles(static_sample):
    """ EXTRACT TILES FROM EACH SLIDE IN AN EFFICIENT MANNER """
    static_tiles = {'image': [], 'class': []}
    for idx, (k, v) in enumerate(static_sample.items()):
        print('{} {}/{}\r'.format(k, idx+1, len(static_sample)), end='')
        v = sorted(v)
        for coord, label in v:
            static_tiles['image'].append(np.array(extract_tile(k, coord).convert('RGB')))
            static_tiles['class'].append(label)
    return static_tiles

def extract_tile(slide_name, col_row):
    """ EXTRACT SINGLE TILE FROM A SLIDE """
    CENTER_SIZE = 32
    TILE_SIZE = 96
    
    col, row = col_row
    x_coord = (col-1)*CENTER_SIZE
    y_coord = (row-1)*CENTER_SIZE    
    return open_slides[slide_name].read_region(location=(x_coord, y_coord), level=(2), size=(TILE_SIZE,TILE_SIZE))

In [25]:
def get_epoch_pd(epoch_size, batch_size, queue):
    """ PROCESS TO BE RUN TO PREFETCH DATA """
    steps_per_epoch = divide_round_up(epoch_size, batch_size)
    train_epoch_coords = generate_train_epoch_coords(epoch_size)
    train_tiles_img = extract_epoch_tiles(train_epoch_coords)
    train_tiles_pd = pd.DataFrame.from_dict(train_tiles_img).sample(frac=1).reset_index(drop=True)
    if queue:
        queue.put((train_tiles_pd, steps_per_epoch))
        print('--- FINISHED QUEUE PACK ---')
    else:
        return train_tiles_pd, steps_per_epoch

In [52]:
def train(train_tiles_pd, batch_size, steps_per_epoch):
    """ RUN TRAINING """
    %time _ = model.fit_generator(sequential_batch_generator(train_tiles_pd, batch_size), steps_per_epoch=steps_per_epoch, verbose=1)

In [109]:
def main():
    """ PREPARING NEXT EPOCH DURING TRAINING - SINGLE TILE LEARNING """
    BATCH_SIZE = 32
    TOTAL_EPOCHS = 3
    EPOCH_SIZE = 20000

    q = Queue(2)
    p = Process(target=get_epoch_pd, args=(EPOCH_SIZE, BATCH_SIZE, q))
    p.start()

    for epoch_id in range(10):
        print('== EPOCH #{} =='.format(epoch_id))
        new_p = Process(target=get_epoch_pd, args=(EPOCH_SIZE, BATCH_SIZE, q))
        new_p.start()
        train_tiles_pd, steps_per_epoch = q.get()
        model.fit_generator(sequential_batch_generator(train_tiles_pd, batch_size), steps_per_epoch=steps_per_epoch, verbose=1)

== EPOCH #0 ==
--- FINISHED QUEUE PACK ---
--- FINISHED QUEUE PACK ---
== EPOCH #1 ==
--- FINISHED QUEUE PACK ---............] - ETA: 1:48 - loss: 0.0292 - binary_accuracy: 0.9906 - precision: 0.9896 - recall: 0.99165,000
== EPOCH #2 ==
--- FINISHED QUEUE PACK ---............] - ETA: 1:48 - loss: 0.0083 - binary_accuracy: 0.9974 - precision: 0.9983 - recall: 0.99655,000
== EPOCH #3 ==
--- FINISHED QUEUE PACK ---............] - ETA: 1:49 - loss: 0.0189 - binary_accuracy: 0.9923 - precision: 0.9917 - recall: 0.99275,000
== EPOCH #4 ==
--- FINISHED QUEUE PACK ---............] - ETA: 1:53 - loss: 0.0260 - binary_accuracy: 0.9908 - precision: 0.9896 - recall: 0.99195,000
== EPOCH #5 ==
--- FINISHED QUEUE PACK ---............] - ETA: 1:56 - loss: 0.0168 - binary_accuracy: 0.9935 - precision: 0.9935 - recall: 0.99335,000
== EPOCH #6 ==
--- FINISHED QUEUE PACK ---............] - ETA: 1:56 - loss: 0.0168 - binary_accuracy: 0.9940 - precision: 0.9943 - recall: 0.99355,000
== EPOCH #7 ==
--- FINI