# Nucleus challenge using a CNN

In [None]:
import glob
import os.path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
MAX_TRAIN_INSTANCES = None  # for entire set: None
REDO_TRAINING = False
STEPS_PER_EPOCH = 10
N_EPOCHS = 1

## Load data

In [None]:
dataDir = 'data/data-science-bowl-2018/'

### Load labels

In [None]:
train_labels = pd.read_csv(os.path.join(dataDir, 'stage1_train_labels.csv/stage1_train_labels.csv'))
train_labels['EncodedPixels'] = train_labels['EncodedPixels'].map(lambda ep: [int(x) for x in ep.split(' ')])
train_labels.head()

### Load training paths and meta info

In [None]:
all_images = glob.glob(os.path.join(dataDir, 'stage1_*', '*', '*', '*.png'))
img_df = pd.DataFrame({'path': all_images})

print('An exemplary data path with indices of split:')
print(*map(lambda x: (x[0]-6, x[1]), enumerate(img_df['path'].iloc[0].split('/'))), sep='\n', end='\n\n')

img_id = lambda x: x.split('/')[-3]
img_type = lambda in_path: in_path.split('/')[-2]
img_group = lambda in_path: in_path.split('/')[-4].split('_')[1]
img_stage = lambda in_path: in_path.split('/')[-4].split('_')[0]
                           
img_df['ImageId'] = img_df['path'].map(img_id)
img_df['ImageType'] = img_df['path'].map(img_type)
img_df['TrainingSplit'] = img_df['path'].map(img_group)
img_df['Stage'] = img_df['path'].map(img_stage)

print(img_df.info())
img_df.head()

### Create dataframe with training data (image and mask paths)

In [None]:
%%time

train_df = img_df.query('TrainingSplit=="train"')
train_rows = []
group_cols = ['Stage', 'ImageId']

count = 0
for group, rows in train_df.groupby(group_cols):
    count += 1
    if MAX_TRAIN_INSTANCES is not None and count > MAX_TRAIN_INSTANCES:
        break
    #     print('group', group, 'contains', len(rows), 'rows')
    c_row = {col_name: col_value for col_name, col_value in zip(group_cols, group)}
    c_row['images'] = rows.query('ImageType == "images"')['path'].values.tolist()
    c_row['masks'] = rows.query('ImageType == "masks"')['path'].values.tolist()
    train_rows += [c_row]
    
train_img_df = pd.DataFrame(train_rows)    

In [None]:
train_img_df.head()

### Load training images

In [None]:
%%time

from skimage.io import imread


def read(in_img_list):
    assert (len(in_img_list) == 1), 'more than one image for this training instance. Shape: ' + str(in_img_list.shape)
    return imread(in_img_list[0])

IMG_CHANNELS = 3  # restrict pixels to RGB
train_img_df['images'] = train_img_df['images'].map(read).map(lambda x: x[:,:,:IMG_CHANNELS])

## Investigate images

### Analyze intensity distributions

In [None]:
train_img_df['Red'] = train_img_df['images'].map(lambda x: np.mean(x[:,:,0]))
train_img_df['Green'] = train_img_df['images'].map(lambda x: np.mean(x[:,:,1]))
train_img_df['Blue'] = train_img_df['images'].map(lambda x: np.mean(x[:,:,2]))
train_img_df['Gray'] = train_img_df['images'].map(lambda x: np.mean(x))
train_img_df['Red-Blue'] = train_img_df['images'].map(lambda x: np.mean(x[:,:,0]-x[:,:,2]))

In [None]:
sns.pairplot(train_img_df[['Gray', 'Red', 'Green', 'Blue', 'Red-Blue']])

### Image dimensions

In [None]:
train_img_df['images'].map(lambda x: x.shape).value_counts()

## Analysis using a single combined mask

### Get labels

#### Load the masks and save them in dataframe

In [None]:
%%time

def read_and_stack(in_img_list):
    return np.sum(np.stack([imread(c_img) for c_img in in_img_list], 0), 0) / 255.0

train_img_df['masks'] = train_img_df['masks'].map(read_and_stack).map(lambda x: x.astype(int))

In [None]:
train_img_df.head()

#### Show some of the pictures with their labels

In [None]:
n_img = 6
fig, m_axs = plt.subplots(2, n_img, figsize = (12, 4))
for (c_row_idx, c_row), (c_im, c_lab) in zip(train_img_df.sample(n_img).iterrows(), 
                                     m_axs.T):
    c_im.imshow(c_row['images'])
    c_im.axis('off')
    c_im.set_title('Microscope ' + str(c_row_idx))
    
    c_lab.imshow(c_row['masks'])
    c_lab.axis('off')
    c_lab.set_title('Labeled ' + str(c_row_idx))

### Build simple CNN

#### Set up CNN structure

In [None]:
from keras.models import Sequential
from keras.layers import BatchNormalization, Conv2D, UpSampling2D, Lambda
simple_cnn = Sequential()
simple_cnn.add(BatchNormalization(input_shape = (None, None, IMG_CHANNELS), 
                                  name = 'NormalizeInput'))
simple_cnn.add(Conv2D(8, kernel_size = (3,3), padding = 'same'))
simple_cnn.add(Conv2D(8, kernel_size = (3,3), padding = 'same'))
# use dilations to get a slightly larger field of view
simple_cnn.add(Conv2D(16, kernel_size = (3,3), dilation_rate = 2, padding = 'same'))
simple_cnn.add(Conv2D(16, kernel_size = (3,3), dilation_rate = 2, padding = 'same'))
simple_cnn.add(Conv2D(32, kernel_size = (3,3), dilation_rate = 3, padding = 'same'))

# the final processing
simple_cnn.add(Conv2D(16, kernel_size = (1,1), padding = 'same'))
simple_cnn.add(Conv2D(1, kernel_size = (1,1), padding = 'same', activation = 'sigmoid'))
simple_cnn.summary()

#### Define custom loss to match competition objective

Use Dice score, see [here](https://arxiv.org/pdf/1707.00478.pdf).

In [None]:
from keras import backend as K
smooth = 1.
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

simple_cnn.compile(optimizer = 'adam', loss = dice_coef_loss, metrics = [dice_coef, 'acc', 'mse'])

### Train the model

Use one image at a time for training (training step = one image, one epoch finished when all images processed)

In [None]:
def simple_gen():
    while True:
        for _, c_row in train_img_df.iterrows():
            yield np.expand_dims(c_row['images'],0), np.expand_dims(np.expand_dims(c_row['masks'],-1),0)

nxt = next(simple_gen())
print('Elements in each generated object:', len(nxt))
print('Shape of instance data: ', nxt[0].shape)
print('Shape of instance label:', nxt[1].shape)
# print(nxt[1][0][255][255][0])
# print(nxt[0][0][255][255])

In [None]:
%%time

from keras.models import load_model
import h5py

if REDO_TRAINING:

#     simple_cnn.fit_generator(simple_gen(), steps_per_epoch=train_img_df.shape[0], epochs = N_EPOCHS)
    simple_cnn.fit_generator(simple_gen(), STEPS_PER_EPOCH, epochs = N_EPOCHS)
    simple_cnn.save('simple_gen.h5')    
    
else:
    simple_cnn = load_model('simple_gen.h5', custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})

### Apply to test set

## Analysis using separate masks