In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
from keras import datasets, layers, models
from keras.layers import Input, Dense, Conv2D, Flatten, Activation, concatenate, Dropout, Conv2DTranspose, LeakyReLU, Add
from keras.layers import GlobalAveragePooling2D, Lambda, GlobalMaxPooling2D, MaxPooling2D
from keras.layers import BatchNormalization as BN
from keras.utils import Sequence
from keras.models import Model, Sequential
from keras.optimizers import Adam
from keras.models import load_model
from sklearn.metrics import classification_report
from keras.callbacks import ModelCheckpoint

import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import os, shutil, glob
import pydicom
import cv2
import math
import pandas as pd
import json

from sklearn.model_selection import train_test_split
import albumentations as albu

height = 512
width = 512

In [None]:
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
train_image = np.load("/kaggle/input/project/data/p_train_image.npy")
train_label = np.load("/kaggle/input/project/data/p_train_label.npy")
valid_image = np.load("/kaggle/input/project/data/p_valid_image.npy")
valid_label = np.load("/kaggle/input/project/data/p_valid_label.npy")
test_image = np.load("/kaggle/input/project/data/p_test_image.npy")
test_label = np.load("/kaggle/input/project/data/p_test_label.npy")
print(" train data: ({}, {}) \n valid data: ({}, {})\n test data: ({}, {})".format(train_image.shape, train_label.shape, 
                                                                                valid_image.shape, valid_label.shape,
                                                                                test_image.shape, test_label.shape,) )

In [None]:
train_image = train_image.astype(np.float16)
valid_image = valid_image.astype(np.float16)
test_image = test_image.astype(np.float16)

train_label = train_label.astype(np.bool)
valid_label = valid_label.astype(np.bool)
test_label = test_label.astype(np.bool)

train_image = (train_image-2048)/2048
valid_image = (valid_image-2048)/2048
test_image = (test_image-2048)/2048
print(train_image.shape, valid_image.shape, test_image.shape)
print(" train data: ({}, {}) \n valid data: ({}, {})\n test data: ({}, {})".format(train_image.shape, train_label.shape, 
                                                                                valid_image.shape, valid_label.shape,
                                                                                test_image.shape, test_label.shape,) )

In [None]:
from keras.layers import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K

class GroupNormalization(Layer):
    def __init__(self,
                 groups=16,
                 axis=-1,
                 epsilon=1e-5,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 **kwargs):
        super(GroupNormalization, self).__init__(**kwargs)
        self.supports_masking = True
        self.groups = groups
        self.axis = axis
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)

    def build(self, input_shape):
        dim = input_shape[self.axis]

        if dim is None:
            raise ValueError('Axis ' + str(self.axis) + ' of '
                             'input tensor should have a defined dimension '
                             'but the layer received an input with shape ' +
                             str(input_shape) + '.')

        if dim < self.groups:
            raise ValueError('Number of groups (' + str(self.groups) + ') cannot be '
                             'more than the number of channels (' +
                             str(dim) + ').')

        if dim % self.groups != 0:
            raise ValueError('Number of groups (' + str(self.groups) + ') must be a '
                             'multiple of the number of channels (' +
                             str(dim) + ').')

        self.input_spec = InputSpec(ndim=len(input_shape),
                                    axes={self.axis: dim})
        shape = (dim,)

        if self.scale:
            self.gamma = self.add_weight(shape=shape,
                                         name='gamma',
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(shape=shape,
                                        name='beta',
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint)
        else:
            self.beta = None
        self.built = True

    def call(self, inputs, **kwargs):
        input_shape = K.int_shape(inputs)
        tensor_input_shape = K.shape(inputs)

        # Prepare broadcasting shape.
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
        broadcast_shape.insert(1, self.groups)

        reshape_group_shape = K.shape(inputs)
        group_axes = [reshape_group_shape[i] for i in range(len(input_shape))]
        group_axes[self.axis] = input_shape[self.axis] // self.groups
        group_axes.insert(1, self.groups)

        # reshape inputs to new group shape
        group_shape = [group_axes[0], self.groups] + group_axes[2:]
        group_shape = K.stack(group_shape)
        inputs = K.reshape(inputs, group_shape)

        group_reduction_axes = list(range(len(group_axes)))
        
        mean = K.mean(inputs, axis=group_reduction_axes, keepdims=True)
        variance = K.var(inputs, axis=group_reduction_axes, keepdims=True)
        
        inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))

        # prepare broadcast shape
        inputs = K.reshape(inputs, group_shape)

        outputs = inputs

        # In this case we must explicitly broadcast all parameters.
        if self.scale:
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            outputs = outputs * broadcast_gamma

        if self.center:
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            outputs = outputs + broadcast_beta

        # finally we reshape the output back to the input shape
        outputs = K.reshape(outputs, tensor_input_shape)

        return outputs

    def get_config(self):
        config = {
            'groups': self.groups,
            'axis': self.axis,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
        }
        base_config = super(GroupNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
def channel_attention(n_channel, ratio):
    shared_dense1 = Dense(n_channel // ratio)
    shared_dense2 = Dense(n_channel)
    def call(x):
        if dim % self.groups != 0:
            raise ValueError('Number of ratio (' + str(ratio) + ') must be a '
            'multiple of the number of n_channel (' + str(n_channel) + ').')
        height = x._keras_shape[1]
        width = x._keras_shape[2]
        
        avg_pool = GlobalAveragePooling2D()(x)
        avg_pool = shared_dense1(avg_pool)
        avg_pool = LeakyReLU(alpha = 0.3)(avg_pool)
        avg_pool = shared_dense2(avg_pool)
        
        max_pool = GlobalMaxPooling2D()(x)
        max_pool = shared_dense1(max_pool)
        max_pool = LeakyReLU(alpha = 0.3)(max_pool)
        max_pool = shared_dense2(max_pool)
        
        merge_out = keras.layers.Add(avg_pool, max_pool)
        merge_act = Activation('sigmoid')(merge_out)
        merge_act = Lambda(lambda i: K.repeat_elements(i), height, axis=1)(merge_act)
        merge_act = Lambda(lambda i: K.repeat_elements(i), width, axis=2)(merge_act)
        
        output = keras.layers.multiply([x, merge_act])
        return output
    return call

In [None]:
def spatial_attention(kernel_size):
    def call(x):
        avg_pool = Lambda(lambda i: K.mean(i, axis = 3, keepdims=False))(x)
        max_pool = Lambda(lambda i: K.max(i, axis = 3, keepdims=False))(x)
        concate = concatenate([avg_pool, max_pool], axis =-1)
        con_out = Conv2D(filters=1, kernel_size=kernel_size, padding='same')
        con_act = Activationvation('sigmoid')(con_out)
        
        output = keras.layers.multiply([x, con_act])
        return output
    return call

In [None]:
def cbam_block(ratio = 8):
    def call(x):
        n_channel = x._keras_shape[-1]
        channel = channel_attention(n_channel, ratio)(x)
        spatial = spatial_attention(kernel_size=7)(channel)
        refined_feature = spatial
    return refined_feature

In [None]:
def convBlock(n_filter, kernel_size):
    conv1 = Conv2D(filters=n_filter, kernel_size=kernel_size, padding='same')
    conv2 = Conv2D(filters=n_filter, kernel_size=kernel_size, padding='same')
    def call(x):
        conv1_out = conv1(x)
        conv1_gn = GroupNormalization()(conv1_out)
        conv1_act = Activation('relu')(conv1_gn)
        conv2_out = conv2(conv1_act)
        conv2_gn = GroupNormalization()(conv2_out)
        conv2_act = Activation('relu')(conv2_gn)
        return conv2_act
    return call
    

In [None]:
def ResBlock(n_filter, kernel_size, with_conv_shortcut = False):
    def call(x):
        conv_out = convBlock(n_filter = n_filter, kernel_size = kernel_size)(x)
        cbam_out = cbam_block(ratio = 16)(x)
        
        if with_conv_shortcutcut:
            shortcut = Conv2D(filters=n_filter, kernel_size=kernel_size, padding='same')(x)
            shortcut = GroupNormalization()(shortcut)
            output =  keras.layers.Add(cbam_out, shortcut)
        else:
        output = keras.layers.Add(cbam_out, x)
        
        return output
    return call

In [None]:
def unet_gn(n_filter = 4, input_size = (512,512,1)):
    inputs = Input(input_size)
    # contracting path
    conv1 = convBlock(n_filter=2**n_filter, kernel_size=(3,3))(inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = convBlock(n_filter=2**(n_filter+1), kernel_size=(3,3))(pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = convBlock(n_filter=2**(n_filter+2), kernel_size=(3,3))(pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = convBlock(n_filter=2**(n_filter+3), kernel_size=(3,3))(pool3)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    conv5 = convBlock(n_filter=2**(n_filter+3), kernel_size=(3,3))(pool4)
    
    # expansive path
    u6 = Conv2DTranspose(filters=2**(n_filter+3), kernel_size=(2,2), strides=(2,2))(conv5)
    concate6 = concatenate([conv4, u6], axis = 3)
    conv6 = convBlock(n_filter=2**(n_filter+3), kernel_size=(3,3))(concate6)

    u7 = Conv2DTranspose(filters=2**(n_filter+2), kernel_size=(2,2), strides=(2,2))(conv6)
    concate7 = concatenate([conv3, u7], axis = 3)
    conv7 = convBlock(n_filter=2**(n_filter+2), kernel_size=(3,3))(concate7)
    
    u8 = Conv2DTranspose(filters=2**(n_filter+1), kernel_size=(2,2), strides=(2,2))(conv7)
    concate8 = concatenate([conv2, u8], axis =3)
    conv8 = convBlock(n_filter=2**(n_filter+1), kernel_size=(3,3))(concate8)

    
    u9 = Conv2DTranspose(filters=2**(n_filter), kernel_size=(2,2), strides=(2,2))(conv8)
    concate9 = concatenate([conv1, u9], axis =3)
    conv9 = convBlock(n_filter=2**(n_filter), kernel_size=(3,3))(concate9)

    # protonet
    conv9_out = Conv2D(filters=6, kernel_size=(1,1), activation = 'sigmoid', padding = 'same', name='prototype')(conv9)
    conv9_coe_out = Conv2D(filters=12, kernel_size=(3,3), padding='same')(conv9)
    conv9_coe_out = LeakyReLU(alpha=0.4)(conv9_coe_out)
    conv9_coe_out = Conv2D(filters=12, kernel_size=(3,3), padding='same')(conv9_coe_out)
    conv9_coe_out = LeakyReLU(alpha=0.4)(conv9_coe_out)
    conv9_coe_out = GlobalAveragePooling2D()(conv9_coe_out)
    conv9_coe_out = Activation('tanh', name='reg_proto')(conv9_coe_out)
    conv9_coe_out = Lambda(lambda x: K.reshape(x, shape=(-1, 2, 6)))(conv9_coe_out)
    repeat_coe = Lambda(lambda x: K.repeat_elements(K.expand_dims(x, axis=1), height, axis=1))(conv9_coe_out)
    repeat_coe = Lambda(lambda x: K.repeat_elements(K.expand_dims(x, axis=1), width, axis=1))(repeat_coe)
    conv9_out = Lambda(lambda x: K.repeat_elements(K.expand_dims(x, axis=3), 2, axis=3))(conv9_out)
    print(repeat_coe)
    
    assembly = Lambda(lambda x: K.sum(x[0] * x[1], axis=-1))([repeat_coe, conv9_out])
    assembly = Lambda(lambda x: K.softmax(x, axis=-1), name='seg')(assembly)
    print(assembly)
 
    model = Model(inputs=[inputs], outputs=[assembly])
    return model

In [None]:
import keras.backend as K

height = 512
width = 512

def bce(y_true, y_pred):
    y_sigmoid_pred = K.clip(y_pred, 1e-11, 0.99999999)
    return -K.mean(y_true * K.log(y_sigmoid_pred) + (1-y_true) * K.log(1-y_sigmoid_pred), axis=-1)

def dice_metric(y_true, y_pred):
    dice = lambda x: 2 * x[0] / (x[1] + x[2] + 1e-6)
    reshape = lambda x: K.reshape(x, (-1, height * width, 1))
    div_tmp_AiB = K.sum(reshape(y_true[:,:,:,0] * y_pred[:,:,:,0]), axis=1)
    div_tmp_A = K.sum(reshape(y_true[:,:,:,0]), axis=1)
    div_tmp_B = K.sum(reshape(y_pred[:,:,:,0]), axis=1)
    div_dice = dice([div_tmp_AiB, div_tmp_A, div_tmp_B])
    indicate = K.max(reshape(y_true[:,:,:,0]), axis=1)
    indicate_dice = K.sum(indicate * div_dice)
    indicate_dice /= (K.sum(indicate)+1e-6)
    return indicate_dice
    
def back_dice_metric(y_true, y_pred):
    dice = lambda x: 2 * x[0] / (x[1] + x[2] + 1e-6)
    reshape = lambda x: K.reshape(x, (-1, height * width, 1))
    div_tmp_AiB = K.sum(reshape(y_true[:,:,:,1] * y_pred[:,:,:,1]), axis=1)
    div_tmp_A = K.sum(reshape(y_true[:,:,:,1]), axis=1)
    div_tmp_B = K.sum(reshape(y_pred[:,:,:,1]), axis=1)
    div_dice = dice([div_tmp_AiB, div_tmp_A, div_tmp_B])
    indicate = K.max(reshape(y_true[:,:,:,1]), axis=1)
    indicate_dice = K.sum(indicate * div_dice)
    indicate_dice /= (K.sum(indicate)+1e-6)
    return indicate_dice

def dice(y_true, y_pred):
    dice = lambda x: 2 * x[0] / (x[1] + x[2] + 1e-6)
    reshape = lambda x: K.reshape(x, (-1, height * width, 1))
    div_tmp_AiB = K.sum(reshape(y_true[:,:,:,0] * y_pred[:,:,:,0]), axis=1)
    div_tmp_A = K.sum(reshape(y_true[:,:,:,0]), axis=1)
    div_tmp_B = K.sum(reshape(y_pred[:,:,:,0]), axis=1)
    div_dice = dice([div_tmp_AiB, div_tmp_A, div_tmp_B])
    indicate = K.max(reshape(y_true[:,:,:,0]), axis=1)
    indicate_dice = K.sum(indicate * div_dice)
    indicate_dice /= (K.sum(indicate)+1e-6)

    return 1.0 - indicate_dice

def ce_catogory(y_true, y_pred):
    y_sigmoid_pred = K.clip(y_pred, 1e-11, 0.99999999)
    return -K.mean(y_true * K.log(y_sigmoid_pred))

def weight_tar_back_ce(y_true, y_pred):
    reshape = lambda x: K.reshape(x, (-1, height * width, 1))
    target_true = reshape(y_true[:,:,:,0])
    target_pred = reshape(y_pred[:,:,:,0])
    back_true = reshape(y_true[:,:,:,1])
    back_pred = reshape(y_pred[:,:,:,1])
    target_loss = target_true * K.log(target_pred) + (1-target_true) * K.log(1-target_pred)
    back_loss = back_true * K.log(back_pred) + (1-back_true) * K.log(1-back_pred)
    
    loss = -(0.75*target_loss + 0.25*back_loss) 
#     return loss
    return -target_loss

In [None]:
import time

current_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())

num_epoch = 100
batch = 8
# model_fileName = str(current_time)+("_{}_{}_model.h5".format(num_epoch, batch))
self_callback = [
    ModelCheckpoint(str(time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()))+("_{}_{}_model.h5".format(num_epoch, batch)), verbose=0, save_best_only=True, save_weights_only=True)
]

In [None]:
K.clear_session()

In [None]:
unet_model = unet_gn(n_filter=5, input_size = (None, None,1))
unet_model.compile(optimizer=Adam(lr=1e-4), loss={'seg': dice}, metrics={'seg': [dice_metric, back_dice_metric]})

In [None]:
unet_model.summary()

In [None]:
history = unet_model.fit(
    x=train_image, y=train_label,
    batch_size=batch,
    epochs=num_epoch,
    callbacks=self_callback,
    validation_data = (valid_image, valid_label))

In [None]:
import time
def plot_history(histiry):
#     plt.figure(figsize=(15,15))
    # summarize history for accuracy
    plt.subplot(1,2,1)
    plt.plot(history.history['dice_metric'])
    plt.plot(history.history['val_dice_metric'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    # summarize history for loss
    plt.subplot(1,2,2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    current_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
    plt.savefig(current_time + "_history.png")
plot_history(history)

In [None]:
def construct_model():
    model = unet(n_filter=4,input_size = (512, 512, 1))
    return model
def load_model(weight_file):
    print("loading model")
    loaded_model = construct_model()
    loaded_model.load_weights(weight_file)
    print("loaded model")
    return loaded_model

In [None]:
import time
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches
def plot_exp(dataset, datalabel, upper_bound=-1):
    j=1
    plt.figure(figsize=(20,20))
    predict = unet_model.predict(dataset)
#     predict = np.where(predict>0.9, 1, 0)
    color1 = ["gray", "m" ]
    color2 = ["gray", "yellow" ]
    cmap1 = ListedColormap(color1)
    cmap2 = ListedColormap(color2)
    plt.subplot(5,5,1)
    plt.imshow(np.zeros((512,512)), cmap='gray')
    purple_patch = mpatches.Patch(color='m', alpha=0.6, label='y_true')
    yellow_patch = mpatches.Patch(color='y', alpha=0.4, label='y_pred')
    orange_patch = mpatches.Patch(color='orange', alpha=0.5, label='overlap')
    plt.legend(handles=[purple_patch, yellow_patch, orange_patch])
    for i in range(len(dataset[:upper_bound])):
        plt.subplot(5,5,j+1)
        plt.imshow(dataset[i].reshape(512, 512), cmap='gray')
#         plt.imshow(datalabel[i].reshape(512, 512), cmap=cmap1, alpha=0.8)
        plt.imshow(datalabel[i,:,:,0].reshape(512, 512), cmap=cmap1, alpha=0.8)
#         plt.imshow(predict[i].reshape(512, 512), cmap=cmap2, alpha=0.6)
        plt.imshow(predict[i,:,:,0].reshape(512, 512), cmap=cmap2, alpha=0.6)
        j += 1
        
    current_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
    plt.savefig(current_time + ".png")

In [None]:
plot_exp(test_image, test_label)

In [None]:
plt.figure(figsize=(15,15))
plt.subplot(2,2,1)
plt.imshow(valid_image[3].reshape(512,512), cmap='gray')
plt.subplot(2,2,2)
plt.imshow(valid_label[3].reshape(512,512), cmap='gray')

plt.subplot(2,2,3)
plt.imshow(valid_image[16].reshape(512, 512), cmap='gray')
plt.subplot(2,2,4)
plt.imshow(valid_label[16].reshape(512, 512), cmap='gray')