# Import libraries

In [None]:
import os
import sys
import warnings
import scipy.misc
import numpy as np
import math
import matplotlib.pyplot as plt
from tqdm import tqdm
from itertools import chain
from skimage.io import imread, imshow, imread_collection, concatenate_images, imsave
from skimage.transform import resize
import cv2
import matplotlib
from scipy import ndimage
from keras.models import Model, load_model
from keras.layers import *
from keras import backend as K
from keras import optimizers
from keras.utils import multi_gpu_model 
import tensorflow as tf
warnings.filterwarnings('ignore', category=UserWarning, module='skimage')

# define constants

In [None]:
IMG_WIDTH = 256
IMG_HEIGHT = 256
IMG_CHANNELS = 3
TRAIN_IM = './train_im/'
TRAIN_MASK = './train_mask/'
TEST_IM = './test_im/'
TEST_MASK = './test_mask/'
NUM_CLASSES = 4
NUM_GPUS = 8

# Load training and test images

In [None]:
num_train = len(os.listdir(TRAIN_IM))
X_train = np.zeros((num_train, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
Y_train = np.zeros((num_train, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.bool)
sys.stdout.flush()
#load training images
for count, filename in tqdm(enumerate(os.listdir(TRAIN_IM)), total=num_train):
    img = imread(os.path.join(TRAIN_IM, filename))[:,:,:IMG_CHANNELS]
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    X_train[count] = img
    name, ext = os.path.splitext(filename)
    mask_name = name + '_mask' + ext    
    mask = imread(os.path.join(TRAIN_MASK, mask_name))[:,:,:NUM_CLASSES]
    mask = resize(mask, (IMG_HEIGHT, IMG_WIDTH))
    Y_train[count] = mask
    
#load test images
num_test = len(os.listdir(TEST_IM))
X_test = np.zeros((num_test, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
Y_test = np.zeros((num_test, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.bool)
sys.stdout.flush()
for count, filename in tqdm(enumerate(os.listdir(TEST_IM)), total=num_test):
    img = imread(os.path.join(TEST_IM, filename))[:,:,:IMG_CHANNELS]    
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    X_test[count] = img
    name, ext = os.path.splitext(filename)
    mask_name = name + '_mask' + ext    
    mask = imread(os.path.join(TEST_MASK, mask_name))[:,:,:NUM_CLASSES]
    mask = resize(mask, (IMG_HEIGHT, IMG_WIDTH))
    Y_test[count] = mask

# compute weight for each class

In [None]:
back_count = 0
ec_count = 0
chrom_count = 0
nuc_count = 0
alpha = 1 #used for exponential scaling
for x in Y_train:
    back_count = back_count + x[:,:,0].sum()
    nuc_count = nuc_count + x[:,:,1].sum() 
    chrom_count = chrom_count + x[:,:,2].sum()
    ec_count = ec_count + x[:,:,3].sum()
print("number of pixels for background, nuclei, chromosomes, ecDNA: ", 
      back_count, nuc_count, chrom_count, ec_count)
tot = back_count + nuc_count + chrom_count + ec_count
back_w = 1
nuc_w = (nuc_count)**alpha /  (nuc_count)**alpha
chrom_w = (nuc_count)**alpha /(chrom_count)**alpha
ec_w = (nuc_count)**alpha / (ec_count)**alpha
weights = [back_w, nuc_w, chrom_w, ec_w]
print(weights)

# define loss function and other metrics

In [None]:
# Custom IoU metric
from keras.metrics import binary_crossentropy

def mIoU(y_true, y_pred):
    prec = []
    for t in np.arange(0.05, 0.1, 0.5):
        y_pred_ = tf.to_int32(y_pred > t)
        score, up_opt = tf.metrics.mean_iou(y_true, y_pred_, num_classes=NUM_CLASSES)
        K.get_session().run(tf.local_variables_initializer())
        with tf.control_dependencies([up_opt]):
            score = tf.identity(score)
        prec.append(score)
    return K.mean(K.stack(prec), axis=-1)

def weighted_loss(original_loss, weights_list):
    def lossFunc(true, pred):     
        select_class = [K.equal(tf.cast(i, tf.int64), 
                                tf.cast(K.argmax(true, axis=-1), 
                                        tf.int64)) for i in range(len(weights_list))]
        select_class = [K.cast(x, K.floatx()) for x in select_class]
        weights = [sel * w for sel, w in zip(select_class, weights_list)] 
        
        scalar = weights[0]
        for i in range(1, len(weights)):
            scalar = scalar + weights[i]

        loss = original_loss(true,pred)
        loss = loss * scalar
        return loss
    return lossFunc

smooth = 1.
# Custom loss function
def dice_coef(y_true, y_pred):    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return ((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))

def dice_loss(y_true,y_pred):
    return 1-dice_coef(y_true,y_pred)

def BCE_loss(y_true, y_pred):
    return (binary_crossentropy(y_true, y_pred))

def bce_dice(y_true, y_pred):
    return BCE_loss(y_true, y_pred) + dice_loss(y_true, y_pred)

print('Weight functions compiled')

# build model function

In [None]:
def build_model(width=32, NUM_CLASSES=4):
    inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
    activation = 'sigmoid'
    s = Lambda(lambda x: x / 255) (inputs)
    c1 = Conv2D(width, (3, 3), activation='elu', padding='same') (s)
    c1 = Conv2D(width, (3, 3), activation='elu', padding='same') (c1)
    p1 = MaxPooling2D((2, 2)) (c1)

    c2 = Conv2D(width*2, (3, 3), activation='elu', padding='same') (p1)
    c2 = Conv2D(width*2, (3, 3), activation='elu', padding='same', 
                kernel_regularizer=regularizers.l2(0.001)) (c2)
    p2 = MaxPooling2D((2, 2)) (c2)

    c3 = Conv2D(width*4, (3, 3), activation='elu', padding='same') (p2)
    c3 = Conv2D(width*4, (3, 3), activation='elu', padding='same', 
                kernel_regularizer=regularizers.l2(0.001)) (c3)
    p3 = MaxPooling2D((2, 2)) (c3)

    c4 = Conv2D(width*8, (3, 3), activation='elu', padding='same') (p3)
    c4 = Conv2D(width*8, (3, 3), activation='elu', padding='same', 
                kernel_regularizer=regularizers.l2(0.001)) (c4)
    p4 = MaxPooling2D(pool_size=(2, 2)) (c4)

    c5 = Conv2D(width*16, (3, 3), activation='elu', padding='same') (p4)
    c5 = Conv2D(width*16, (3, 3), activation='elu', padding='same') (c5)

    u6 = Conv2DTranspose(width*8, (2, 2), strides=(2, 2), padding='same') (c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(width*8, (3, 3), activation='elu', padding='same') (u6)
    c6 = Conv2D(width*8, (3, 3), activation='elu', padding='same') (c6)

    u7 = Conv2DTranspose(width*4, (2, 2), strides=(2, 2), padding='same') (c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(width*4, (3, 3), activation='elu', padding='same') (u7)
    c7 = Conv2D(width*4, (3, 3), activation='elu', padding='same') (c7)

    u8 = Conv2DTranspose(width*2, (2, 2), strides=(2, 2), padding='same') (c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(width*2, (3, 3), activation='elu', padding='same') (u8)
    c8 = Conv2D(width*2, (3, 3), activation='elu', padding='same') (c8)

    u9 = Conv2DTranspose(width, (2, 2), strides=(2, 2), padding='same') (c8)
    u9 = concatenate([u9, c1], axis=3)
    c9 = Conv2D(width, (3, 3), activation='elu', padding='same') (u9)
    c9 = Conv2D(width, (3, 3), activation='elu', padding='same') (c9)

    outputs = Conv2D(NUM_CLASSES, (1, 1), activation=activation) (c9)
    model = Model(inputs=[inputs], outputs=[outputs])
    return model

# build and compile model (multi-GPU support)

In [None]:
model = build_model()
if(NUM_GPUS > 1):
    model = multi_gpu_model(model, gpus=NUM_GPUS)
model.compile(optimizer='Adamax', loss = weighted_loss(bce_dice, weights), metrics = [mIoU])
model.summary()

# train model

In [None]:
earlystopper = EarlyStopping(patience=7, verbose=1)
history = parallel_model.fit(X_train, Y_train, validation_split=0.25, batch_size = 16, 
                             verbose=1, epochs=45, callbacks=[earlystopper])
model_out = parallel_model.layers[-2]
model_out.save_weights(filepath="./ecDNA_model.hdf5")

# plot training results

In [None]:
print(history.history.keys())

plt.plot(history.history['mIoU'])
plt.plot(history.history['val_mIoU'])
plt.title('ecDNA IoU score')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('IoU.png')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('loss')
plt.title('ecDNA loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('loss.png')
plt.show()

# save as also model rather than just as a weight file

In [None]:
model = build_model()
model.load_weights("./ecDNA_model.hdf5")
model.save('ecDNA_model.h5')

# predict on holdout set

In [None]:
from os import listdir
from os.path import isfile, join

onlyfiles = [f for f in listdir(TEST_IM) if isfile(join(TEST_IM, f))]
for i in onlyfiles:
    img = TEST_IM + i
    img = imread(img)
    name = './results/' +i
    x = np.expand_dims(img, axis=0)
    comb_pred = np.squeeze(model.predict(x, verbose=0))
    np.save(name, comb_pred)