# Nucleus challenge using a CNN

In [None]:
import glob
import os.path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
MAX_TRAIN_INSTANCES = None  # for entire set: None
REDO_TRAINING = True
BATCH_SIZE = 2
STEPS_PER_EPOCH = 100
N_EPOCHS = 2

## Load data

In [None]:
dataDir = 'data/data-science-bowl-2018/'

### Load labels

In [None]:
train_labels = pd.read_csv(os.path.join(dataDir, 'stage1_train_labels.csv/stage1_train_labels.csv'))
train_labels['EncodedPixels'] = train_labels['EncodedPixels'].map(lambda ep: [int(x) for x in ep.split(' ')])
train_labels.head()

### Load training paths and meta info

In [None]:
all_images = glob.glob(os.path.join(dataDir, 'stage1_*', '*', '*', '*.png'))
img_df = pd.DataFrame({'path': all_images})

print('An exemplary data path with indices of split:')
print(*map(lambda x: (x[0]-6, x[1]), enumerate(img_df['path'].iloc[0].split('/'))), sep='\n', end='\n\n')

img_id = lambda x: x.split('/')[-3]
img_type = lambda in_path: in_path.split('/')[-2]
img_group = lambda in_path: in_path.split('/')[-4].split('_')[1]
img_stage = lambda in_path: in_path.split('/')[-4].split('_')[0]
                           
img_df['ImageId'] = img_df['path'].map(img_id)
img_df['ImageType'] = img_df['path'].map(img_type)
img_df['TrainingSplit'] = img_df['path'].map(img_group)
img_df['Stage'] = img_df['path'].map(img_stage)

print(img_df.info())
img_df.head()

### Create dataframe with training data (image and mask paths)

In [None]:
#%%time

train_df = img_df.query('TrainingSplit=="train"')
train_rows = []
group_cols = ['Stage', 'ImageId']

count = 0
for group, rows in train_df.groupby(group_cols):
    count += 1
    if MAX_TRAIN_INSTANCES is not None and count > MAX_TRAIN_INSTANCES:
        break
    #     print('group', group, 'contains', len(rows), 'rows')
    c_row = {col_name: col_value for col_name, col_value in zip(group_cols, group)}
    c_row['images'] = rows.query('ImageType == "images"')['path'].values.tolist()
    c_row['masks'] = rows.query('ImageType == "masks"')['path'].values.tolist()
    train_rows += [c_row]
    
train_img_df = pd.DataFrame(train_rows)    

In [None]:
train_img_df.head()

### Load training images

In [None]:
%%time

from skimage.io import imread


def read(in_img_list):
    assert (len(in_img_list) == 1), 'more than one image for this training instance. Shape: ' + str(in_img_list.shape)
    return imread(in_img_list[0])

IMG_CHANNELS = 3  # restrict pixels to RGB
train_img_df['images'] = train_img_df['images'].map(read).map(lambda x: x[:,:,:IMG_CHANNELS])

## Investigate images

### Analyze intensity distributions

The instances form groups and could be handled separately.

In [None]:
train_img_df['Red'] = train_img_df['images'].map(lambda x: np.mean(x[:,:,0]))
train_img_df['Green'] = train_img_df['images'].map(lambda x: np.mean(x[:,:,1]))
train_img_df['Blue'] = train_img_df['images'].map(lambda x: np.mean(x[:,:,2]))
train_img_df['Gray'] = train_img_df['images'].map(lambda x: np.mean(x))
train_img_df['Red-Blue'] = train_img_df['images'].map(lambda x: np.mean(x[:,:,0]-x[:,:,2]))

In [None]:
sns.pairplot(train_img_df[['Gray', 'Red', 'Green', 'Blue', 'Red-Blue']])

### Image dimensions

In [None]:
train_img_df['images'].map(lambda x: x.shape).value_counts()

## Analysis using a single combined mask

The masks are simply superimposed for the training and the final individual masks are recovered from extracting connected pixels in the predicted mask. This is expected to come with some inaccuracies.

### Get labels

#### Load the masks and save them in dataframe

In [None]:
%%time

def read_and_stack(in_img_list):
    return np.sum(np.stack([imread(c_img) for c_img in in_img_list], 0), 0) / 255.0

train_img_df['masks'] = train_img_df['masks'].map(read_and_stack).map(lambda x: x.astype(int))

In [None]:
train_img_df.head(2)

#### Show some of the pictures with their labels

In [None]:
n_img = 6
fig, m_axs = plt.subplots(2, n_img, figsize = (12, 4))
for (c_row_idx, c_row), (c_im, c_lab) in zip(train_img_df.sample(n_img).iterrows(), 
                                     m_axs.T):
    c_im.imshow(c_row['images'])
    c_im.axis('off')
    c_im.set_title('Microscope ' + str(c_row_idx))
    
    c_lab.imshow(c_row['masks'])
    c_lab.axis('off')
    c_lab.set_title('Labeled ' + str(c_row_idx))

### Create RLE encoding

#### Create and test conversion function

In [None]:
from skimage.morphology import label # label regions


def rle_encoding(x):
    '''
    x: numpy array of shape (height, width), 1 - mask, 0 - background
    Returns run run length encoding as list
    '''
    dots = np.where(x.T.flatten()==1)[0] # .T sets order down-then-right
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b+1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

def prob_to_rles(x, cut_off = 0.5):
    lab_img = label(x>cut_off)
    if lab_img.max()<1:
        lab_img[0,0] = 1 # ensure at least one prediction per image
    for i in range(1, lab_img.max()+1):
        yield rle_encoding(lab_img==i)

#### Compare the true mask RLE with the one drawn from the true masks

In [None]:
from tqdm import tqdm

def check_match():
    match, mismatch = 0, 0
    perfect_masks, imperfect_masks = [], []
    count = 0
    
    for idx, row in tqdm(train_img_df.iterrows()): 
        isPerfect = True
        if idx > 100:
            break
        count += 1
        train_row_rles = list(prob_to_rles(row['masks']))
        tl_rles = train_labels.query('ImageId=="{ImageId}"'.format(**row))['EncodedPixels']
        for img_rle, train_rle in zip(sorted(train_row_rles, key = lambda x: x[0]), 
                                      sorted(tl_rles, key = lambda x: x[0])):
            for i_x, i_y in zip(img_rle, train_rle):
                if i_x == i_y:
                    match += 1
                else:
                    mismatch += 1
                    isPerfect = False
        if isPerfect:
            perfect_masks.append((idx, '{ImageId}'.format(**row)))
        else:
            imperfect_masks.append((idx, '{ImageId}'.format(**row)))

    print('Matches: %d, Mismatches: %d, Accuracy: %2.1f%%' % (match, mismatch, 100*match/(match+mismatch)))
    print('Fully correct masks: {} / {} = {:.1f}%'.format(len(perfect_masks), count, 
                                                       100*len(perfect_masks)/count))

    n_img_max = 5
    
    n_img = min(n_img_max, len(imperfect_masks))
    idxList = [i[0] for i in imperfect_masks][-n_img:]
    
    fig, m_axs = plt.subplots(2, n_img, figsize = (12, 6))
    print('Some failing and some successfull mask encodings:')
    for (_, d_row), (c_im, c_lab) in zip(train_img_df.iloc[idxList].iterrows(), m_axs.T):
        
        c_im.imshow(d_row['images'])
        c_im.axis('off')
        c_im.set_title('Img ' + d_row['ImageId'][:8])

        c_lab.imshow(d_row['masks'])
        c_lab.axis('off')
        c_lab.set_title('Bad ' + d_row['ImageId'][:8])
    
    n_img = min(n_img_max, len(perfect_masks))
    idxList = [i[0] for i in perfect_masks][-n_img:]
    
    fig, m_axs = plt.subplots(2, n_img, figsize = (12, 6))
    for (_, d_row), (c_im, c_lab) in zip(train_img_df.iloc[idxList].iterrows(), m_axs.T):
        
        c_im.imshow(d_row['images'])
        c_im.axis('off')
        c_im.set_title('Img ' + d_row['ImageId'][:8])

        c_lab.imshow(d_row['masks'])
        c_lab.axis('off')
        c_lab.set_title('Good ' + d_row['ImageId'][:8])
    
check_match()

### Build simple CNN

#### Set up CNN structure

In [None]:
from keras.models import Sequential
from keras.layers import BatchNormalization, Conv2D, UpSampling2D, Lambda
simple_cnn = Sequential()
simple_cnn.add(BatchNormalization(input_shape = (None, None, IMG_CHANNELS), 
                                  name = 'NormalizeInput'))
simple_cnn.add(Conv2D(8, kernel_size = (3,3), padding = 'same'))
simple_cnn.add(Conv2D(8, kernel_size = (3,3), padding = 'same'))
# use dilations to get a slightly larger field of view
simple_cnn.add(Conv2D(16, kernel_size = (3,3), dilation_rate = 2, padding = 'same'))
simple_cnn.add(Conv2D(16, kernel_size = (3,3), dilation_rate = 2, padding = 'same'))
simple_cnn.add(Conv2D(32, kernel_size = (3,3), dilation_rate = 3, padding = 'same'))

# the final processing
simple_cnn.add(Conv2D(16, kernel_size = (1,1), padding = 'same'))
simple_cnn.add(Conv2D(1, kernel_size = (1,1), padding = 'same', activation = 'sigmoid'))
simple_cnn.summary()

#### Define custom loss to match competition objective

Use Dice score, see [here](https://arxiv.org/pdf/1707.00478.pdf). Omit the factor 2 from the paper to have a IoU (Intersection over Unit) interpretation of the value. This formulation deviates from the one in the [kaggle evaluation description](https://www.kaggle.com/c/data-science-bowl-2018#evaluation) due to the fact that here, all masks are merged.

In [None]:
from keras import backend as K


smooth = 0.01


def dice(y_true, y_pred):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = sum(y_true_f * y_pred_f)
    return (intersection + smooth) / (sum(y_true_f) + sum(y_pred_f) + smooth)


def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)


simple_cnn.compile(optimizer = 'adam', loss = dice_coef_loss, metrics = [dice_coef, 'acc', 'mse'])

### Train the model

Use one image at a time for training (training step = one image, one epoch finished when all images processed)

In [None]:
def simple_gen():
    while True:
        for _, c_row in train_img_df.iterrows():
            yield np.expand_dims(c_row['images'],0), np.expand_dims(np.expand_dims(c_row['masks'],-1),0)

nxt = next(simple_gen())
print('Elements in each generated object:', len(nxt))
print('Shape of instance data: ', nxt[0].shape)
print('Shape of instance label:', nxt[1].shape)
# print(nxt[1][0][255][255][0])
# print(nxt[0][0][255][255])

In [None]:
%%time

import datetime
import subprocess
import time

from keras.models import load_model
import h5py

if REDO_TRAINING:
    simple_cnn.fit_generator(simple_gen(), min(STEPS_PER_EPOCH, train_img_df.shape[0]), epochs = N_EPOCHS)
    timeStamp = time.time()
    timeStamp = datetime.datetime.fromtimestamp(timeStamp).strftime('%Y-%m-%d_%H-%M-%S')
    saveName = 'simple_gen_' + timeStamp + '.h5'
    simple_cnn.save(saveName)
    subprocess.call(['cp', saveName, 'simple_gen.h5'])
    
else:
    simple_cnn = load_model('simple_gen.h5', custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})

### Evaluate model performance

#### Show predictions vs labels

In [None]:
n_img = 3

display_training_img_df = train_img_df.sample(n_img)

display_training_img_df['predictions'] = display_training_img_df['images'].map(lambda x: simple_cnn.predict(np.expand_dims(x, 0))[0, :, :, 0])

In [None]:
from skimage.morphology import closing, opening, disk


def clean_img(x):
    # remove cracks and small dots
    return opening(closing(x, disk(1)), disk(3))


fig, m_axs = plt.subplots(4, n_img, figsize = (12, 16))
for (_, d_row), (c_im, c_lab, c_dirty, c_clean) in zip(display_training_img_df.iterrows(), m_axs.T):
    c_im.imshow(d_row['images'])
    c_im.axis('off')
    c_im.set_title('Microscope')
    
    c_lab.imshow(d_row['masks'])
    c_lab.axis('off')
    c_lab.set_title('Label')
    
    dirty_im = d_row['predictions']
    dice_coeff = dice(d_row['masks'], dirty_im)
    
    c_dirty.imshow(dirty_im)
    c_dirty.axis('off')
    c_dirty.set_title('Prediction\n Dice {:.2f}'.format(dice_coeff))
    
    clean_im = clean_img(d_row['predictions'])
    dice_coeff = dice(d_row['masks'], clean_im)
    
    c_clean.imshow(clean_im)
    c_clean.axis('off')
    c_clean.set_title('Clean prediction\n Dice {:.2f}'.format(dice_coeff))

### Apply to test set

#### Load test images

In [None]:
%%time

test_df = img_df.query('TrainingSplit=="test"')
test_rows = []
group_cols = ['Stage', 'ImageId']
for group, rows in test_df.groupby(group_cols):
    c_row = {col_name: col_value for col_name, col_value in zip(group_cols, group)}
    c_row['images'] = rows.query('ImageType == "images"')['path'].values.tolist()
    test_rows += [c_row]
test_img_df = pd.DataFrame(test_rows)   

test_img_df['images'] = test_img_df['images'].map(read).map(lambda x: x[:,:,:IMG_CHANNELS])
print(test_img_df.shape[0], 'images to process')
print(test_img_df.sample(1))

#### Check test image dimensions

In [None]:
test_img_df['images'].map(lambda x: x.shape).value_counts()

#### Make predictions

In [None]:
%%time


test_img_df['masks'] = test_img_df['images'].map(lambda x: simple_cnn.predict(np.expand_dims(x, 0))[0, :, :, 0])

#### Show some predictions

In [None]:
n_img = 3

fig, m_axs = plt.subplots(3, n_img, figsize = (12, 10))
for (_, d_row), (c_im, c_lab, c_clean) in zip(test_img_df.sample(n_img).iterrows(), 
                                     m_axs.T):
    c_im.imshow(d_row['images'])
    c_im.axis('off')
    c_im.set_title('Microscope')
    
    c_lab.imshow(d_row['masks'])
    c_lab.axis('off')
    c_lab.set_title('Predicted')
    
    c_clean.imshow(clean_img(d_row['masks']))
    c_clean.axis('off')
    c_clean.set_title('Clean')

#### Convert predictions to RLEs

In [None]:
test_img_df['rles'] = test_img_df['masks'].map(clean_img).map(lambda x: list(prob_to_rles(x)))

In [None]:
out_pred_list = []

for _, c_row in test_img_df.iterrows():
    for c_rle in c_row['rles']:
        out_pred_list+=[dict(ImageId=c_row['ImageId'], EncodedPixels = ' '.join(np.array(c_rle).astype(str)))]

out_pred_df = pd.DataFrame(out_pred_list)
print(out_pred_df.shape[0], 'regions found for', test_img_df.shape[0], 'images')
out_pred_df.sample(3)

In [None]:
out_pred_df[['ImageId', 'EncodedPixels']].to_csv('results/result_cnn_single_mask.csv', index = False)