In [7]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import os
import nibabel as nib
import pandas as pd
import tempfile
import cv2
import glob
from focal_loss import BinaryFocalLoss
from tensorflow.keras.losses import BinaryFocalCrossentropy

print(tf.__version__)
tf.test.gpu_device_name()
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
print(gpu_info)
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

2.17.0
zsh:1: command not found: nvidia-smi
Num GPUs Available:  0


In [31]:
Dropout = 0.1

f1 = 10
f2 = 20
f3 = 40
f4 = 80
f5 = 160
f6 = 320

HU_min = -500.
HU_max = 500.

HU_min2 = 0.
HU_max2 = 200.

# Normalization for ICH and SAH models
def normalize(image):
    image = (image - HU_min) / (HU_max - HU_min)
    image[image>1] = 1.
    image[image<0] = 0.
    return image

# Normalization for IVH model
def normalize2(image):
    image = (image - HU_min2) / (HU_max2 - HU_min2)
    image[image>1] = 1.
    image[image<0] = 0.
    return image


smooth = 0.01
alpha = 0.25
gamma = 2.0

def dice_coeff(y_true, y_pred):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return score

def dice_loss(y_true, y_pred):
    return 1 - dice_coeff(y_true, y_pred)

focal_loss = BinaryFocalCrossentropy(alpha=alpha, gamma=gamma)

def combined_loss(y_true, y_pred):
    dice = dice_loss(y_true, y_pred)
    focal = focal_loss(y_true, y_pred)
    return dice + focal


In [None]:

def get_model():
    inputs = tf.keras.layers.Input(shape = (512,512,1))
    c1 = tf.keras.layers.BatchNormalization()(inputs)

    #Contraction path

    c1 = tf.keras.layers.Conv2D(f1, (3,3), activation='relu',  padding='same')(c1)
    d1 = tf.keras.layers.Dropout(Dropout)(c1)
    r1 = tf.keras.layers.Conv2D(f1, (3,3), activation='relu',  padding='same')(d1)
    r1 = tf.keras.layers.Conv2D(f1, (3,3), activation='relu',  padding='same')(r1)
    r1 = tf.keras.layers.add([r1, c1])
    r1 = tf.keras.layers.Conv2D(f1, (3,3), activation='relu',  padding='same')(r1)

    p1 = tf.keras.layers.Conv2D(f1, (3,3), activation='relu',  padding='same',strides = (2,2))(r1)
    p1 = tf.keras.layers.BatchNormalization()(p1)

    c2 = tf.keras.layers.Conv2D(f2, (3,3), activation='relu',  padding='same')(p1)
    d2 = tf.keras.layers.Dropout(Dropout)(c2)
    r2 = tf.keras.layers.Conv2D(f2, (3,3), activation='relu',  padding='same')(d2)
    r2 = tf.keras.layers.Conv2D(f2, (3,3), activation='relu',  padding='same')(r2)
    r2 = tf.keras.layers.add([r2,c2])
    r2 = tf.keras.layers.Conv2D(f2, (3,3), activation='relu',  padding='same')(r2)


    p2 = tf.keras.layers.Conv2D(f2, (3,3), activation='relu',  padding='same',strides = (2,2))(r2)
    p2 = tf.keras.layers.BatchNormalization()(p2)

    c3 = tf.keras.layers.Conv2D(f3, (3,3), activation='relu',  padding='same')(p2)
    d3 = tf.keras.layers.Dropout(Dropout)(c3)
    r3 = tf.keras.layers.Conv2D(f3, (3,3), activation='relu',  padding='same')(d3)
    r3 = tf.keras.layers.Conv2D(f3, (3,3), activation='relu',  padding='same')(r3)
    r3 = tf.keras.layers.add([r3,c3])
    r3 = tf.keras.layers.Conv2D(f3, (3,3), activation='relu',  padding='same')(r3)

    p3 = tf.keras.layers.Conv2D(f3, (3,3), activation='relu',  padding='same',strides = (2,2))(r3)
    p3 = tf.keras.layers.BatchNormalization()(p3)

    c4 = tf.keras.layers.Conv2D(f4, (3,3), activation='relu',  padding='same')(p3)
    d4 = tf.keras.layers.Dropout(Dropout)(c4)
    r4 = tf.keras.layers.Conv2D(f4, (3,3), activation='relu',  padding='same')(d4)
    r4 = tf.keras.layers.Conv2D(f4, (3,3), activation='relu',  padding='same')(r4)
    r4 = tf.keras.layers.add([r4,c4])
    r4 = tf.keras.layers.Conv2D(f4, (3,3), activation='relu',  padding='same')(r4)

    p4 = tf.keras.layers.Conv2D(f4, (3,3), activation='relu',  padding='same',strides = (2,2))(r4)
    p4 = tf.keras.layers.BatchNormalization()(p4)

    c5 = tf.keras.layers.Conv2D(f5, (3,3), activation='relu',  padding='same')(p4)
    d5 = tf.keras.layers.Dropout(Dropout)(c5)
    r5 = tf.keras.layers.Conv2D(f5, (3,3), activation='relu',  padding='same')(d5)
    r5 = tf.keras.layers.Conv2D(f5, (3,3), activation='relu',  padding='same')(r5)
    r5 = tf.keras.layers.add([r5,c5])
    r5 = tf.keras.layers.Conv2D(f5, (3,3), activation='relu',  padding='same')(r5)

    p5 = tf.keras.layers.Conv2D(f5, (3,3), activation='relu',  padding='same',strides = (2,2))(r5)
    p5 = tf.keras.layers.BatchNormalization()(p5)

    c6 = tf.keras.layers.Conv2D(f6, (3,3), activation='relu',  padding='same')(p5)
    d6 = tf.keras.layers.Dropout(Dropout)(c6)
    r6 = tf.keras.layers.Conv2D(f6, (3,3), activation='relu',  padding='same')(d6)
    r6 = tf.keras.layers.Conv2D(f6, (3,3), activation='relu',  padding='same')(r6)
    r6 = tf.keras.layers.add([r6,c6])
    r6 = tf.keras.layers.Conv2D(f6, (3,3), activation='relu',  padding='same')(r6)
    
    u11 = tf.keras.layers.Conv2DTranspose(f5, (2,2), activation='relu', strides = (2,2), padding = 'same')(r6)
    u11 = tf.keras.layers.concatenate([u11, r5])
    u11 = tf.keras.layers.BatchNormalization()(u11)

    c11 = tf.keras.layers.Conv2D(f5, (3,3), activation='relu',  padding='same')(u11)
    d11 = tf.keras.layers.Dropout(Dropout)(c11)
    r11 = tf.keras.layers.Conv2D(f5, (3,3), activation='relu',  padding='same')(d11)
    r11 = tf.keras.layers.Conv2D(f5, (3,3), activation='relu',  padding='same')(r11)
    r11 = tf.keras.layers.add([r11,c11])
    r11 = tf.keras.layers.Conv2D(f5, (3,3), activation='relu',  padding='same')(r11)

    u12 = tf.keras.layers.Conv2DTranspose(f4, (2,2), activation='relu', strides = (2,2), padding = 'same')(r11)
    u12 = tf.keras.layers.concatenate([u12, r4])
    u12 = tf.keras.layers.BatchNormalization()(u12)

    c12 = tf.keras.layers.Conv2D(f4, (3,3), activation='relu',  padding='same')(u12)
    d12 = tf.keras.layers.Dropout(Dropout)(c12)
    r12 = tf.keras.layers.Conv2D(f4, (3,3), activation='relu',  padding='same')(d12)
    r12 = tf.keras.layers.Conv2D(f4, (3,3), activation='relu',  padding='same')(r12)
    r12 = tf.keras.layers.add([r12,c12])
    r12 = tf.keras.layers.Conv2D(f4, (3,3), activation='relu',  padding='same')(r12)

    u13 = tf.keras.layers.Conv2DTranspose(f3, (2,2), activation='relu', strides = (2,2), padding = 'same')(r12)
    u13 = tf.keras.layers.concatenate([u13, r3])
    u13 = tf.keras.layers.BatchNormalization()(u13)

    c13 = tf.keras.layers.Conv2D(f3, (3,3), activation='relu',  padding='same')(u13)
    d13 = tf.keras.layers.Dropout(Dropout)(c13)
    r13 = tf.keras.layers.Conv2D(f3, (3,3), activation='relu',  padding='same')(d13)
    r13 = tf.keras.layers.Conv2D(f3, (3,3), activation='relu',  padding='same')(r13)
    r13 = tf.keras.layers.add([r13,c13])
    r13 = tf.keras.layers.Conv2D(f3, (3,3), activation='relu',  padding='same')(r13)

    u14 = tf.keras.layers.Conv2DTranspose(f2, (2,2), activation='relu', strides = (2,2), padding = 'same')(r13)
    u14 = tf.keras.layers.concatenate([u14, r2])
    u14 = tf.keras.layers.BatchNormalization()(u14)

    c14 = tf.keras.layers.Conv2D(f2, (3,3), activation='relu',  padding='same')(u14)
    d14 = tf.keras.layers.Dropout(Dropout)(c14)
    r14 = tf.keras.layers.Conv2D(f2, (3,3), activation='relu',  padding='same')(d14)
    r14 = tf.keras.layers.Conv2D(f2, (3,3), activation='relu',  padding='same')(r14)
    r14 = tf.keras.layers.add([r14,c14])
    r14 = tf.keras.layers.Conv2D(f2, (3,3), activation='relu',  padding='same')(r14)

    u15 = tf.keras.layers.Conv2DTranspose(f1, (2,2), activation='relu', strides = (2,2), padding = 'same')(r14)
    u15 = tf.keras.layers.concatenate([u15, r1])
    u15 = tf.keras.layers.BatchNormalization()(u15)

    c15 = tf.keras.layers.Conv2D(f1, (3,3), activation='relu',  padding='same')(u15)
    d15 = tf.keras.layers.Dropout(Dropout)(c15)
    r15 = tf.keras.layers.Conv2D(f1, (3,3), activation='relu',  padding='same')(d15)
    r15 = tf.keras.layers.Conv2D(f1, (3,3), activation='relu',  padding='same')(r15)
    r15 = tf.keras.layers.add([r15,c15])
    r15 = tf.keras.layers.Conv2D(f1, (3,3), activation='relu',  padding='same')(r15)

    c16 = tf.keras.layers.Conv2D(f1, (3,3), activation='relu',  padding='same')(r15)
    c16 = tf.keras.layers.BatchNormalization()(c16)

    outputs = tf.keras.layers.Conv2D(1, (1,1), activation='sigmoid')(c16)

    model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
    return model

model = get_model()
    
# ICH and SAH models.
model.compile(optimizer='adam', loss = dice_loss, metrics=[dice_coeff])

# IVH model
# model.compile(optimizer='adam', loss = combined_loss, metrics=[dice_coeff]) # For IVH model training
model.summary()

model_name = 'ICH' # one of 'ICH','IVH','SAH'
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath= model_name + '.h5', monitor='val_loss', mode='min', verbose=1, save_best_only=True)

In [None]:
images = []
masks = []

patients = os.listdir('data/')
print(len(patients))

In [None]:
for i in range(len(patients)):
    image = nib.load('data/'+patients[i]+'/'+patients[i].split('.')[0]+'.nii')
    image = np.array(image.dataobj).astype(np.float32)
    image = np.rollaxis(image,2,0)
    image = np.expand_dims(image,3)
    empty = []
    if image[0].shape == (512,512,1):
        images.append(image)
    else:
        print('RESIZING')
        for a in range(len(image)):
            x = image[a,:,:,0].astype(np.float32)
            x = cv2.resize(x,(512,512))
            x = np.expand_dims(x,2)
            empty.append(np.expand_dims(x,0))
        image = np.concatenate(empty)
        print(image.shape)
        images.append(image)
            
    seg = nib.load('data/'+patients[i]+'/segmentation.nii')
    seg = np.array(seg.dataobj).astype(np.bool_)
    seg = np.rollaxis(seg,2,0)
    seg = np.expand_dims(seg,3)
    empty = []
    if seg[0].shape == (512,512,1):
        masks.append(seg)
    else:
        print('RESIZING')
        for a in range(len(seg)):
            x = seg[a,:,:,0].astype(np.float32)
            x = cv2.resize(x,(512,512))
            x = np.expand_dims(x,2)
            empty.append(np.expand_dims(x,0))
        seg = np.concatenate(empty).astype(np.bool_)
        print(seg.shape)
        masks.append(seg)     
    print(patients[i],' done')

In [None]:
x = np.concatenate(images)
y = np.concatenate(masks)

# for ICH and SAH models
x = normalize(x)

# for IVH model
# x = normalize2(x)

print(x.shape,y.shape)

In [None]:
images = []
masks = []

for i in range(len(x)):
    image = np.expand_dims(x[i],0)
    images.append(image)
    image = np.rot90(image,1,(1,2))
    images.append(image)
    image = np.rot90(image,1,(1,2))
    images.append(image)
    image = np.rot90(image,1,(1,2))
    images.append(image)
    mask = np.expand_dims(y[i],0)
    masks.append(mask)
    mask = np.rot90(mask,1,(1,2))
    masks.append(mask)
    mask = np.rot90(mask,1,(1,2))
    masks.append(mask)
    mask = np.rot90(mask,1,(1,2))
    masks.append(mask)

print(len(images),len(masks))

In [None]:
x = np.concatenate(images)
y = np.concatenate(masks)

In [None]:
history = model.fit(x,y,batch_size=32,epochs=200,validation_split=0.2,callbacks=[checkpoint])