<a href="https://colab.research.google.com/github/outofray/3D_Unet_CT/blob/main/3D_Unet_CT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import keras
from keras.models import Model
from keras.layers import Dense, Dropout, Activation, Flatten, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, Conv3D, MaxPooling3D, Conv3DTranspose
from keras.layers import Input, merge, UpSampling2D,BatchNormalization
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
import tensorflow as tf
from tensorflow.keras.optimizers import Adam

import matplotlib.pyplot as plt
import skimage.io as io
from glob import glob
import pandas as pd
import nibabel as nib
import numpy as np
import random as r
import cv2

In [None]:
from google.colab import drive
import os
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
!pip install simpleitk

Collecting simpleitk
  Downloading SimpleITK-2.1.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (48.4 MB)
[K     |████████████████████████████████| 48.4 MB 31 kB/s 
[?25hInstalling collected packages: simpleitk
Successfully installed simpleitk-2.1.1


In [None]:
import SimpleITK as sitk

# Convert CT image to Array 

In [None]:
def to_array(path, end, z_slice_number=4):
    
    # get locations
    files = glob(path+end, recursive=True)
    
    img_list = []
    
    r.seed(42)
    r.shuffle(files)
    
    for file in files:
        img = sitk.ReadImage(file)
        # img= sitk.Shrink(img, [1, 4, 4])  #optional: resize the image size if you want to save memory

        # to numpy array
        img = sitk.GetArrayFromImage(img)

        # standardization
        img = (img-img.mean())/img.std()
        img.astype("float32")
        
        for slice in range(0, img.shape[0], z_slice_number):  #pack slice into 3D array

          if slice+z_slice_number+1 < img.shape[0]:
            img_s = img[slice:slice+z_slice_number,:,:]
            img_s = np.expand_dims(img_s, axis=0)
            img_list.append(img_s)
          else:
            break
            
    return np.array(img_list, np.float32)

In [None]:
def seg_to_array(path, end, label, z_slice_number=4):
    
    # get locations
    files = glob(path+end, recursive=True)
    
    img_list = []
    
    r.seed(42)
    r.shuffle(files)
    
    for file in files:
        img = sitk.ReadImage(file)
        # img= sitk.Shrink(img, [1, 4, 4]) #optional: resize the image size if you want to save memory

        # to numpy array
        img = sitk.GetArrayFromImage(img)

        # current segmentation as label 1, keep if statement for future target
        if label == 1:
            img[img == 1]
            

        img.astype("float32")
        
        for slice in range(0, img.shape[0], z_slice_number):
          if slice+z_slice_number+1<img.shape[0]:

            img_s = img[slice:slice+z_slice_number,:,:]
            
            img_s = np.expand_dims(img_s, axis=0)
            img_list.append(img_s)
          else:
            break
            
    return np.array(img_list,np.float32)

In [None]:
def read_dataset(csv_file, folder, dataset):
  df = pd.read_csv(csv_file)
  df = df[df["DATASET"]==dataset]
  dataset_ID = df['ID'].squeeze().tolist()

  arr = np.empty((0,1,16,512,512))
  for ID in dataset_ID:
    path = folder+ID
    if "LABEL" in path:
      data = seg_to_array(path=path, z_slice_number=16, end="_label.nii.gz")
    data = to_array(path=path, z_slice_number=16, end=".nii.gz")
    arr = np.append(data, arr, axis=0)

  return arr

### Convert training set data to array
--same process to validation/test set

In [None]:
# read image to array
csv_file = "/content/gdrive/MyDrive/your classification file list.csv"
folder = "/content/gdrive/MyDrive/your image folder/"
dataset = 1 #train=1, val=2, test=3

train = read_dataset(csv_file, folder, dataset)

In [None]:
# read seg to array
csv_file = "/content/gdrive/MyDrive/your classification file list.csv"
folder = "/content/gdrive/MyDrive/your segmentation label folder/"

dataset = 1 #train=1, val=2, test=3

train_seg = read_dataset(csv_file, folder, dataset)

# 3D U-NET MODEL

In [None]:
def dice_coef(y_true, y_pred):
    smooth = 0.005 
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

In [None]:
K.set_image_data_format('channels_first')


def unet():
    
    inputs = Input((1, 16, 32, 32))
    
    conv1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same') (inputs)
    batch1 = BatchNormalization(axis=1)(conv1)
    conv1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same') (batch1)
    batch1 = BatchNormalization(axis=1)(conv1)
    pool1 = MaxPooling3D((2, 2, 2), padding='same')(batch1)

    conv2 = Conv3D(128, (3, 3, 3), activation='relu', padding='same') (pool1)
    batch2 = BatchNormalization(axis=1)(conv2)
    conv2 = Conv3D(128, (3, 3, 3), activation='relu', padding='same') (batch2)
    batch2 = BatchNormalization(axis=1)(conv2)
    pool2 = MaxPooling3D((2, 2, 2), padding='same')(batch2)

    conv3 = Conv3D(256, (3, 3, 3), activation='relu', padding='same') (pool2)
    batch3 = BatchNormalization(axis=1)(conv3)
    conv3 = Conv3D(256, (3, 3, 3), activation='relu', padding='same') (batch3)
    batch3 = BatchNormalization(axis=1)(conv3)
    pool3 = MaxPooling3D((2, 2, 2), padding='same')(batch3)

    conv4 = Conv3D(512, (3, 3, 3), activation='relu', padding='same') (pool3)
    batch4 = BatchNormalization(axis=1)(conv4)
    conv4 = Conv3D(512, (3, 3, 3), activation='relu', padding='same') (batch4)
    batch4 = BatchNormalization(axis=1)(conv4)
    pool4 = MaxPooling3D(pool_size=(2, 2, 2), padding='same')(batch4)

    conv5 = Conv3D(1024, (3, 3, 3), activation='relu', padding='same') (pool4)
    batch5 = BatchNormalization(axis=1)(conv5)
    conv5 = Conv3D(1024, (3, 3, 3), activation='relu', padding='same') (batch5)
    batch5 = BatchNormalization(axis=1)(conv5)

    up6 = Conv3DTranspose(512, (2, 2, 2), strides=(2, 2, 2), padding='same') (batch5)
    up6 = concatenate([up6, conv4], axis=1)
    conv6 = Conv3D(512, (3, 3, 3), activation='relu', padding='same') (up6)
    batch6 = BatchNormalization(axis=1)(conv6)
    conv6 = Conv3D(512, (3, 3, 3), activation='relu', padding='same') (batch6)
    batch6 = BatchNormalization(axis=1)(conv6)
    
    up7 = Conv3DTranspose(256, (2, 2, 2), strides=(2, 2, 2), padding='same') (batch6)
    up7 = concatenate([up7, conv3], axis=1)
    conv7 = Conv3D(256, (3, 3, 3), activation='relu', padding='same') (up7)
    batch7 = BatchNormalization(axis=1)(conv7)
    conv7 = Conv3D(256, (3, 3, 3), activation='relu', padding='same') (batch7)
    batch7 = BatchNormalization(axis=1)(conv7)
    
    up8 = Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same') (batch7)
    up8 = concatenate([up8, conv2], axis=1)
    conv8 = Conv3D(128, (3, 3, 3), activation='relu', padding='same') (up8)
    batch8 = BatchNormalization(axis=1)(conv8)
    conv8 = Conv3D(128, (3, 3, 3), activation='relu', padding='same') (batch8)
    batch8 = BatchNormalization(axis=1)(conv8)
    
    up9 = Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same') (batch8)
    up9 = concatenate([up9, conv1], axis=1)
    conv9 = Conv3D(64, (3, 3, 3), activation='relu', padding='same') (up9)
    batch9 = BatchNormalization(axis=1)(conv9)
    conv9 = Conv3D(64, (3, 3, 3), activation='relu', padding='same') (batch9)
    batch9 = BatchNormalization(axis=1)(conv9)

    conv10 = Conv3D(1, (1, 1, 1), activation='sigmoid')(batch9)

    model = Model(inputs=[inputs], outputs=[conv10])

    model.compile(optimizer=Adam(lr=1e-4), loss=dice_coef_loss, metrics=[dice_coef])

    return model

In [None]:
model = unet()

model.summary()

#Training

In [None]:
class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []
        self.val_losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))

loss_history = LossHistory()

In [None]:
checkpoint_filepath = "/content/drive/MyDrive/best_model.hdf5"

model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath, 
    save_best_only=True, 
    monitor='val_loss', #based on best val_loss
    mode='auto', verbose=1)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5,                                                  
                                              patience=10, min_lr=1e-8)

callbacks = [model_checkpoint, reduce_lr, loss_history]

In [None]:
history = model.fit(new_train, new_train_seg, validation_split=0.25, batch_size=16, epochs=20, shuffle=True, callbacks=callbacks)

Epoch 1/20

In [None]:
# plot train/validation dice_coef
plt.plot(history.history['dice_coef'])
plt.plot(history.history['val_dice_coef'])
plt.title('Model Dice')
plt.ylabel('Dice')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'], loc='upper left')
plt.show()

# plot train/validation loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Va.'], loc='upper left')
plt.show()


# Visualizing Model prediction

In [None]:
model.load_weights("/content/drive/MyDrive/best_model.hdf5")

In [None]:
# selecting inference data
i=200

In [None]:
for j in range(16):
  fig=plt.figure(figsize=(9,9))

  plt.subplot(1,3,1)   
  plt.imshow(new_train[i][0][j]);

  plt.subplot(1,3,2) 
  plt.imshow(new_train_seg[i][0][j]);

  expand_img = np.expand_dims(train[i], axis=0)
  pred = model.predict(expand_img)
  super_threshold_indices = pred < 0.99
  pred[super_threshold_indices] = 0

  plt.subplot(1,3,3) 
  plt.imshow(pred[0][0][j]);

  plt.show()

In [None]:
for j in range(16):
  fig=plt.figure(figsize=(9,9))

  plt.subplot(1,3,1)   
  plt.imshow(val[i][0][j]);

  plt.subplot(1,3,2) 
  plt.imshow(val_seg[i][0][j]);

  expand_img = np.expand_dims(val[i], axis=0)
  pred = model.predict(expand_img)
  super_threshold_indices = pred < 0.99
  pred[super_threshold_indices] = 0

  plt.subplot(1,3,3) 
  plt.imshow(pred[0][0][j]);

  plt.show()