In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES']="0"

import tensorflow as tf
tf.enable_eager_execution()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

from tensorflow import keras
sess = tf.Session(config=config)
import numpy as np

from PIL import Image, ImageDraw
from collections import Counter, defaultdict, namedtuple
from sklearn.model_selection import train_test_split
from pathlib import Path
import openslide as ops
import pandas as pd
import glob
import cv2

DIR = '/home/matejg/Project/crc_ml/data/processed/'

In [31]:
def create_context_model(context_size=3):
    inputs = [keras.layers.Input(shape=(96,96,3)) for _ in range(context_size*context_size)]

    x = keras.applications.NASNetMobile(include_top=False, input_tensor=tf.keras.layers.Input(shape=(96,96,3)), weights='imagenet') 
    gmax = tf.keras.layers.GlobalMaxPooling2D()
    gavg = tf.keras.layers.GlobalAveragePooling2D()
    flat = tf.keras.layers.Flatten()
    con = tf.keras.layers.Concatenate(axis=-1)
    drop = tf.keras.layers.Dropout(0.5)
    out = tf.keras.layers.Dense(1, activation='sigmoid')
    
    outs = []
    for inp in inputs:
        f = x(inp)
        f1 = gmax(f)
        f2 = gavg(f)
        f3 = flat(f)
        f = con([f1,f2,f3])
        f = drop(f)
        f = out(f)
        outs.append(f)
        
    out = keras.layers.Concatenate(axis=-1)(outs)
    out = keras.layers.Dense(1, activation='sigmoid')(out)
    model = keras.Model(inputs, out)
    return model

def create_model():
    inputs = tf.keras.layers.Input((96,96,3))
    nasnet_model = tf.keras.applications.nasnet.NASNetMobile(include_top=False, input_tensor=inputs, weights='imagenet')
    nasnet_model.trainable=True

    x = nasnet_model(inputs)
    out1 = tf.keras.layers.GlobalMaxPooling2D()(x)
    out2 = tf.keras.layers.GlobalAveragePooling2D()(x)
    out3 = tf.keras.layers.Flatten()(x)
    out = tf.keras.layers.Concatenate(axis=-1)([out1, out2, out3])
    out = tf.keras.layers.Dropout(0.5)(out)
    out = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid, name="3_")(out)

    model = tf.keras.models.Model(inputs, out)
    return model

In [None]:
def create_macrotile_mask(slide_name):
    RAW_WSI_DIR = '/home/matejg/Project/crc_ml/data/raw/Prostata/'
    SCALE_FACTOR = 1
    OUTPUT_DIR = '/home/matejg/macrotiles/masks/'
    
    slide_file_no_ext = Path(RAW_WSI_DIR + slide_name)
    slide = ops.open_slide(str(slide_file_no_ext.with_suffix('.mrxs')))
    
    level = slide.get_best_level_for_downsample(SCALE_FACTOR)
    
    rectangle_image = Image.new('L', size=slide.level_dimensions[level], color='BLACK')
    rectangle_draw = ImageDraw.Draw(rectangle_image)
    
    tiles = get_tiles_for_slide_all(slide_name)
    tiles = set(['-'.join(path.rsplit('-', 7)[:4]) for path in tiles])
    tiles = list(filter(lambda x: check_context(x, 3, tiles), tiles))
    
    for tile in tiles:
        _, _, r, c = tile.rsplit('-', 3)
        r = int(r[1:])
        c = int(c[1:])
        x = ((c-1)*299+1)//SCALE_FACTOR
        y = ((r-1)*299+1)//SCALE_FACTOR
        size = 299 // SCALE_FACTOR
        rectangle_draw.rectangle([(x, y), (x+size, y+size)], fill=(255))
    
    fig_filepath = OUTPUT_DIR + slide_name + '-macrotile-mask.png'
    rectangle_image.save(fp=fig_filepath)

In [None]:
"""
Creates giga-pixel macrotile mask
"""
for slide in slide_name_iterator_all():
    create_macrotile_mask(slide)

In [None]:
def create_macrotile_image(slide_name):
    RAW_WSI_DIR = '/home/matejg/Project/crc_ml/data/raw/Prostata/'
    SCALE_FACTOR = 40
    OUTPUT_DIR = '/home/matejg/macrotiles/'

    slide_file_no_ext = Path(RAW_WSI_DIR + slide_name)
    slide = ops.open_slide(str(slide_file_no_ext.with_suffix('.mrxs')))

    level = slide.get_best_level_for_downsample(SCALE_FACTOR)

    large_w, large_h = slide.dimensions
    small_w = np.floor(large_w / SCALE_FACTOR).astype(int)
    small_h = np.floor(large_h / SCALE_FACTOR).astype(int)

    wsi_image = slide.read_region((0,0), level, slide.level_dimensions[level])
    wsi_image = wsi_image.convert('RGBA')
    wsi_image = wsi_image.resize((small_w, small_h))
    wsi_draw = ImageDraw.Draw(wsi_image)

    tiles = get_tiles_for_slide_all(slide_name)
    tiles = set(['-'.join(path.rsplit('-', 7)[:4]) for path in tiles])
    tiles = list(filter(lambda x: check_context(x, 3, tiles), tiles))

    for tile in tiles:
        _, _, r, c = tile.rsplit('-', 3)
        r = int(r[1:])
        c = int(c[1:])
        x = ((c-1)*299+1)//SCALE_FACTOR
        y = ((r-1)*299+1)//SCALE_FACTOR
        size = 299 // SCALE_FACTOR
        wsi_draw.rectangle([(x, y), (x+size, y+size)], fill=(255,255,0,175))

    fig_filepath = OUTPUT_DIR + slide_name + '-macrotile.png'
    wsi_image.save(fp=fig_filepath)

In [None]:
"""
Creates giga-pixel macrotile images
"""
for slide in slide_name_iterator_all():
    create_macrotile_image(slide)

In [18]:
import imgaug as ia
import imgaug.augmenters as iaa
def get_augmenter():
    sometimes = lambda aug: iaa.Sometimes(0.5, aug)
    seq = iaa.Sequential(
        [
            # apply the following augmenters to most images
            iaa.Fliplr(0.5),  # horizontally flip 50% of all images
            iaa.Flipud(0.2),  # vertically flip 20% of all images
            sometimes(iaa.Affine(
                scale={"x": (0.9, 1.1), "y": (0.9, 1.1)},
                # scale images to 80-120% of their size, individually per axis
                translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},  # translate by -20 to +20 percent (per axis)
                rotate=(-10, 10),  # rotate by -45 to +45 degrees
                shear=(-5, 5),  # shear by -16 to +16 degrees
                order=[0, 1],  # use nearest neighbour or bilinear interpolation (fast)
                cval=(0, 255),  # if mode is constant, use a cval between 0 and 255
                mode=ia.ALL  # use any of scikit-image's warping modes (see 2nd image from the top for examples)
            )),
            # execute 0 to 5 of the following (less important) augmenters per image
            # don't execute all of them, as that would often be way too strong
            iaa.SomeOf((0, 5),
                    [
                        sometimes(iaa.Superpixels(p_replace=(0, 1.0), n_segments=(20, 200))),
                        # convert images into their superpixel representation
                        iaa.OneOf([
                            iaa.GaussianBlur((0, 1.0)),  # blur images with a sigma between 0 and 3.0
                            iaa.AverageBlur(k=(3, 5)),
                            # blur image using local means with kernel sizes between 2 and 7
                            iaa.MedianBlur(k=(3, 5)),
                            # blur image using local medians with kernel sizes between 2 and 7
                        ]),
                        iaa.Sharpen(alpha=(0, 1.0), lightness=(0.9, 1.1)),  # sharpen images
                        iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0)),  # emboss images
                        # search either for all edges or for directed edges,
                        # blend the result with the original image using a blobby mask
                        iaa.SimplexNoiseAlpha(iaa.OneOf([
                            iaa.EdgeDetect(alpha=(0.5, 1.0)),
                            iaa.DirectedEdgeDetect(alpha=(0.5, 1.0), direction=(0.0, 1.0)),
                        ])),
                        iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.01 * 255), per_channel=0.5),
                        # add gaussian noise to images
                        iaa.OneOf([
                            iaa.Dropout((0.01, 0.05), per_channel=0.5),  # randomly remove up to 10% of the pixels
                            iaa.CoarseDropout((0.01, 0.03), size_percent=(0.01, 0.02), per_channel=0.2),
                        ]),
                        iaa.Invert(0.01, per_channel=True),  # invert color channels
                        iaa.Add((-2, 2), per_channel=0.5),
                        # change brightness of images (by -10 to 10 of original value)
                        iaa.AddToHueAndSaturation((-1, 1)),  # change hue and saturation
                        # either change the brightness of the whole image (sometimes
                        # per channel) or change the brightness of subareas
                        iaa.OneOf([
                            iaa.Multiply((0.9, 1.1), per_channel=0.5),
                            iaa.FrequencyNoiseAlpha(
                                exponent=(-1, 0),
                                first=iaa.Multiply((0.9, 1.1), per_channel=True),
                                second=iaa.ContrastNormalization((0.9, 1.1))
                            )
                        ]),
                        sometimes(iaa.ElasticTransformation(alpha=(0.5, 3.5), sigma=0.25)),
                        # move pixels locally around (with random strengths)
                        sometimes(iaa.PiecewiseAffine(scale=(0.01, 0.05))),
                        # sometimes move parts of the image around
                        sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1)))
                    ],
                    random_order=True
                    )
        ],
        random_order=True
    )
    return seq

def get_test_augmenter(augment_type):
    seq = None
    if augment_type == 'horizontal':
        print('Horizontal augment.')
        seq = iaa.Sequential([iaa.Fliplr(1)])
    elif augment_type == 'vertical':
        print('Vertical augment')
        seq = iaa.Sequential([iaa.Flipud(1)])
    elif augment_type == 'both':
        print('Horizontal+Vertical augment')
        seq = iaa.Sequential([iaa.Flipud(1), iaa.Fliplr(1)])
    return seq

In [19]:
def get_dataset_for_slide(tile_index, slide_name='*', split_type='train'):
    tiles = get_tiles_for_slide(slide_name, split_type)
    center_tiles = list(filter(lambda x: generate_context(x, 3, tile_index), tiles))
    center_labels = [1 if tile.split('/')[9] == 'cancer' else 0 for tile in center_tiles]
    
    proba = {}
    for k,v in Counter(center_labels).items():
        proba[k] = 0.5/v
    center_weights = [proba[label] for label in center_labels]
    pd_data = pd.DataFrame({'filename': center_tiles, 'class': center_labels, 'weight': center_weights})
    return pd_data

In [20]:
def get_tiles_for_slide(slide_name='*', split_type='train'):
    if split_type == 'train':
        part1 = glob.glob(DIR + 'train_slides/{}/*/*.png'.format(slide_name))
        part2 = glob.glob(DIR + 'valid_slides/{}/*/*.png'.format(slide_name))
    else:
        part1 = glob.glob(DIR + 'test_slides/{}/*/*.png'.format(slide_name))
        part2 = glob.glob(DIR + 'visual_slides/{}/*/*.png'.format(slide_name))
    tiles = part1 + part2
    return tiles

def create_inverted_index(split_type='train'):
    Tile = namedtuple('Tile', 'split_type slide_name label_name tile_name')
    tiles = get_tiles_for_slide(split_type=split_type)
    tile_index = defaultdict(dict)
    for tile in tiles:
        tile_parts = Tile(*Path(tile).parts[-4:])
        row_col_id = ''.join(tile_parts.tile_name.rsplit('-', 7)[2:4])
        tile_index[row_col_id][tile_parts.slide_name] = tile
    return tile_index

def generate_context(filename, context_size, tile_index):
    
    assert context_size % 2 == 1, 'context_size must be odd number'
    
    if context_size == 1:
        return [filename]
    
    slide_name, _, row, col, *_ = os.path.basename(filename).rsplit('-', 7)
    row = int(row[1:])
    col = int(col[1:])
    
    # TODO: Rename all tiles
    if not slide_name.startswith('P-'):
        slide_name = 'P-' + slide_name
        
    row_col_id = lambda r, c: 'r{}c{}'.format(r,c)

    context_limit = (context_size - 1) // 2
    context_range = range(-context_limit, context_limit+1)
    
    context = []
    try:
        for r_offset in context_range:
            for c_offset in context_range:
                new_r = row + r_offset
                new_c = col + c_offset
                context_tile = tile_index[row_col_id(new_r, new_c)][slide_name]
                context.append(context_tile)
    except KeyError:
        #print('Tiles {} for slide {} not found.'.format(row_col_id(new_r, new_c), slide_name))
        return []
                
    return context

In [21]:
def get_train_valid_split(tile_index, valid_ratio=0.03, slide_name='*', split_type='train'):
    pd_data = get_dataset_for_slide(tile_index, slide_name, split_type)
    train_df, valid_df = train_test_split(pd_data, test_size=valid_ratio)

    valid_df = valid_df.copy()
    valid_df['weight'] = valid_df['weight'].div(valid_df['weight'].sum(axis=0), axis=0)

    train_df = train_df.copy()
    train_df['weight'] = train_df['weight'].div(train_df['weight'].sum(axis=0), axis=0)
    
    return train_df, valid_df

In [23]:
""" TRAIN FEED """
def get_random_batch(tile_index, pd_data, batch_size=32):
    seq = get_augmenter()
    
    while True:
        batch = pd_data.sample(n=batch_size, replace=True, weights=pd_data['weight'])
        batch_context = [generate_context(fn, 1, tile_index) for fn in batch.filename]
        batch_context = np.array([np.array([cv2.resize(cv2.imread(im), dsize=(96,96), interpolation=cv2.INTER_CUBIC) for im in im_batch]) for im_batch in batch_context])
        batch_context = batch_context.transpose((1,0,2,3,4))
        batch_context = [tf.keras.applications.nasnet.preprocess_input(seq.augment_images(batch)) for batch in batch_context]
        yield batch_context, batch['class'].values

        
""" TEST FEED """
def _chunker(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))

def _divide_round_up(n, d):
    return (n + (d - 1))//d

def get_sequential_batch(tile_index, pd_data, augment=None, batch_size=32):
    
    if augment in ['horizontal','vertical','both']:
        seq = get_test_augmenter(augment)

    #pd_data = get_dataset_for_slide(tile_index, slide_name, split_type)
    
    while True:                
        for batch in _chunker(pd_data, batch_size):
            batch_context = [generate_context(fn, 3, tile_index) for fn in batch.filename]
            batch_context = np.array([np.array([cv2.resize(cv2.imread(im), dsize=(96,96), interpolation=cv2.INTER_CUBIC) for im in im_batch]) for im_batch in batch_context])
            batch_context = batch_context.transpose((1,0,2,3,4))
            if augment:
                batch_context = [tf.keras.applications.nasnet.preprocess_input(seq.augment_images(batch)) for batch in batch_context]
            else:
                batch_context = [tf.keras.applications.nasnet.preprocess_input(batch) for batch in batch_context]
            yield batch_context, batch['class'].values
            
def get_sequential_metadata(tile_index, slide_name='*', split_type='test', batch_size=32):
    pd_data = get_dataset_for_slide(tile_index, slide_name, split_type)
    return _divide_round_up(len(pd_data), batch_size), pd_data['class'].values

In [24]:
_train_tile_index = create_inverted_index()
_test_tile_index = create_inverted_index('test')

In [25]:
def calculate_metrics(y_pred, y_true):
    quality_metrics = {}

    # BINARY CROSSENTROPY
    m = tf.keras.losses.BinaryCrossentropy()
    quality_metrics['loss'] = m(y_true, y_pred)

    # BINARY ACCURACY
    m = tf.keras.metrics.BinaryAccuracy()
    m.update_state(y_true, y_pred)
    quality_metrics['accuracy'] = m.result().numpy()

    # PRECISION
    m = tf.keras.metrics.Precision()
    m.update_state(y_true, y_pred)
    quality_metrics['precision'] = m.result().numpy()

    # RECALL
    m = tf.keras.metrics.Recall()
    m.update_state(y_true, y_pred)
    quality_metrics['recall'] = m.result().numpy()

    return quality_metrics

In [None]:
def train(model, tile_index, steps, train_df, valid_df, batch_size=32, epochs=1, callbacks=[]):
    valid_steps = _divide_round_up(len(valid_df), batch_size)
    
    model.fit_generator(generator=get_random_batch(tile_index, train_df, batch_size),
                        validation_data=get_sequential_batch(tile_index, valid_df, batch_size=batch_size),
                        steps_per_epoch=steps, 
                        validation_steps=valid_steps,
                        callbacks=callbacks,
                        verbose=1, 
                        epochs=epochs)

def test(model, tile_index, batch_size=32):
    predicts = 1
    steps, labels = get_sequential_metadata(tile_index, batch_size=batch_size)
    test_df = get_dataset_for_slide(tile_index, '*', 'test')
    for augment_type in [None, 'horizontal', 'vertical', 'both']:
        predicts *= model.predict_generator(get_sequential_batch(tile_index, test_df, augment=augment_type, batch_size=batch_size), 
                                            steps=steps, verbose=1).ravel()

    predicts = predicts ** 0.25
    quality_metrics = calculate_metrics(predicts.ravel(), labels.ravel())
    return quality_metrics

In [None]:
train_df, valid_df = get_train_valid_split(_train_tile_index)

print('Creating a model')
model = create_model()
model.compile(optimizer=tf.keras.optimizers.Adam(0.0001), loss=tf.keras.losses.binary_crossentropy, metrics=[tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
model.summary()
model_ckpt = tf.keras.callbacks.ModelCheckpoint('/home/matejg/nasnet.ckpt', monitor='val_binary_accuracy', verbose=1, save_best_only=True, save_weights_only=True)

experiments = {}
for experiment in range(15):
    print('Beginning training...')
    train(model, _train_tile_index, 2000, train_df, valid_df, epochs=7, callbacks=[model_ckpt])
    
    print('Beginning testing...')
    experiments['{}_late'.format(experiment)] = test(model, _test_tile_index)
    
    model.load_weights('/home/matejg/nasnet.ckpt')
    print('Beginning testing...')
    experiments['{}_early'.format(experiment)] = test(model, _test_tile_index)

Creating a model
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_11 (InputLayer)           (None, 96, 96, 3)    0                                            
__________________________________________________________________________________________________
NASNet (Model)                  (None, 3, 3, 1056)   4269716     input_11[0][0]                   
__________________________________________________________________________________________________
global_max_pooling2d_1 (GlobalM (None, 1056)         0           NASNet[1][0]                     
__________________________________________________________________________________________________
global_average_pooling2d_1 (Glo (None, 1056)         0           NASNet[1][0]                     
____________________________________________________________________________________________

Beginning testing...
Horizontal augment.
Vertical augment
Horizontal+Vertical augment
Beginning training...
Epoch 1/7

Epoch 00001: val_binary_accuracy did not improve from 0.92604
Epoch 2/7

Epoch 00002: val_binary_accuracy did not improve from 0.92604
Epoch 3/7

Epoch 00003: val_binary_accuracy did not improve from 0.92604
Epoch 4/7

Epoch 00004: val_binary_accuracy did not improve from 0.92604
Epoch 5/7

Epoch 00005: val_binary_accuracy did not improve from 0.92604
Epoch 6/7

Epoch 00006: val_binary_accuracy improved from 0.92604 to 0.95690, saving model to /home/matejg/nasnet.ckpt

Consider using a TensorFlow optimizer from `tf.train`.
Epoch 7/7

Epoch 00007: val_binary_accuracy did not improve from 0.95690
Beginning testing...
Horizontal augment.
Vertical augment
Horizontal+Vertical augment
Beginning testing...
Horizontal augment.
Vertical augment
Horizontal+Vertical augment
Beginning training...
Epoch 1/7

Epoch 00001: val_binary_accuracy did not improve from 0.95690
Epoch 2/7

E

Beginning testing...
Horizontal augment.
Vertical augment
Horizontal+Vertical augment
Beginning testing...
Horizontal augment.
Vertical augment
Horizontal+Vertical augment
Beginning training...
Epoch 1/7

Epoch 00001: val_binary_accuracy did not improve from 0.95967
Epoch 2/7

Epoch 00002: val_binary_accuracy did not improve from 0.95967
Epoch 3/7

Epoch 00003: val_binary_accuracy did not improve from 0.95967
Epoch 4/7

Epoch 00004: val_binary_accuracy did not improve from 0.95967
Epoch 5/7

Epoch 00005: val_binary_accuracy did not improve from 0.95967
Epoch 6/7

Epoch 00006: val_binary_accuracy did not improve from 0.95967
Epoch 7/7

Epoch 00007: val_binary_accuracy did not improve from 0.95967
Beginning testing...
Horizontal augment.
Vertical augment
Horizontal+Vertical augment
Beginning testing...
Horizontal augment.
Vertical augment
Horizontal+Vertical augment
Beginning training...
Epoch 1/7

Epoch 00001: val_binary_accuracy did not improve from 0.95967
Epoch 2/7

Epoch 00002: val_

In [37]:
a_total, p_total, r_total = 0,0,0
for k,v in sorted(experiments.items()):
    if k.endswith('_early'):
        print('{:2}: A: {:.4f}% | P: {:.4f}% | R: {:.4f}%'.format(k, v['accuracy'], v['precision'], v['recall']))
        a_total += v['accuracy']
        p_total += v['precision']
        r_total += v['recall']
print('A_avg: {:.4f}% | P_avg: {:.4f} | R_avg: {:.4f}'.format(a_total/15, p_total/15, r_total/15))
 
print()

a_total, p_total, r_total = 0,0,0
for k,v in sorted(experiments.items()):
    if k.endswith('_late'):
        print('{:2}: A: {:.4f}% | P: {:.4f}% | R: {:.4f}%'.format(k, v['accuracy'], v['precision'], v['recall']))
        a_total += v['accuracy']
        p_total += v['precision']
        r_total += v['recall']
print('A_avg: {:.4f}% | P_avg: {:.4f} | R_avg: {:.4f}'.format(a_total/15, p_total/15, r_total/15))

0_early: A: 0.7848% | P: 0.8141% | R: 0.7657%
10_early: A: 0.7345% | P: 0.8760% | R: 0.5767%
11_early: A: 0.7345% | P: 0.8760% | R: 0.5767%
12_early: A: 0.7345% | P: 0.8760% | R: 0.5767%
13_early: A: 0.7345% | P: 0.8760% | R: 0.5767%
14_early: A: 0.7345% | P: 0.8760% | R: 0.5767%
1_early: A: 0.7640% | P: 0.8719% | R: 0.6461%
2_early: A: 0.7484% | P: 0.8534% | R: 0.6296%
3_early: A: 0.7484% | P: 0.8534% | R: 0.6296%
4_early: A: 0.7775% | P: 0.8406% | R: 0.7117%
5_early: A: 0.7775% | P: 0.8406% | R: 0.7117%
6_early: A: 0.7345% | P: 0.8760% | R: 0.5767%
7_early: A: 0.7345% | P: 0.8760% | R: 0.5767%
8_early: A: 0.7345% | P: 0.8760% | R: 0.5767%
9_early: A: 0.7345% | P: 0.8760% | R: 0.5767%
A_avg: 0.7474% | P_avg: 0.8639 | R_avg: 0.6190

0_late: A: 0.7783% | P: 0.7689% | R: 0.8270%
10_late: A: 0.7587% | P: 0.7345% | R: 0.8476%
11_late: A: 0.7674% | P: 0.8335% | R: 0.6968%
12_late: A: 0.7655% | P: 0.8044% | R: 0.7319%
13_late: A: 0.7555% | P: 0.7699% | R: 0.7631%
14_late: A: 0.7438% | P: 0.7

In [None]:
for k,v in history.items():
    print('{:2}: A: {:.4f}% | P: {:.4f}% | R: {:.4f}%'.format(k, v['accuracy'], v['precision'], v['recall']))

In [None]:
# A: 0.8027    P: 0.8672    R: 0.7379
# A: 0.7836    P: 0.8032    R: 0.7795
# A: 0.8027    P: 0.8473    R: 0.7379