<a href="https://colab.research.google.com/github/MorningStarTM/skull-stripping/blob/main/Skull_Stripping-V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from glob import glob
from tqdm import tqdm
from tensorflow.keras.models import Model, load_model
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Conv2D, Flatten, Dense, Conv2DTranspose, MaxPooling2D, ReLU, Input, BatchNormalization, concatenate, Lambda, Activation, Dropout
from tensorflow.keras.metrics import Recall, Precision
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau

In [None]:
#Global variable
BATCH_SIZE = 16
W,H = 256, 256
CHANNEL = 3

# U-NET

In [None]:
#input
inputs = Input((256, 256, 3), 2)
#scale the image between 0 - 255
scale_img = Lambda(lambda x: x/255)(inputs)

#convolutional layer
c1 = Conv2D(16, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(scale_img)
#dropout 
c1 = Dropout(0.1)(c1)
c1 = Conv2D(16, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
#pooling
p1 = MaxPooling2D((2,2))(c1)

c2 = Conv2D(32, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
c2 = Dropout(0.1)(c2)
c2 = Conv2D(32, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
p2 = MaxPooling2D((2,2))(c2)

c3 = Conv2D(64, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
c3 = Dropout(0.1)(c3)
c3 = Conv2D(64, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
p3 = MaxPooling2D((2,2))(c3)

c4 = Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
c4 = Dropout(0.2)(c4)
c4 = Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
p4 = MaxPooling2D((2,2))(c4)

c5 = Conv2D(256, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
c5 = Dropout(0.2)(c5)
c5 = Conv2D(256, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
p5 = MaxPooling2D((2,2))(c5)

c6 = Conv2D(512, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p5)
c6 = Dropout(0.2)(c6)
c6 = Conv2D(512, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)
p6 = MaxPooling2D((2,2))(c6)

u6 = Conv2DTranspose(256, (2,2), strides=(2,2), padding='same')(c6)
u6 = concatenate([u6, c5])
c6 = Conv2D(256, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
c6 = Dropout(0.2)(c6)
c6 = Conv2D(256, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

u7 = Conv2DTranspose(128, (2,2), strides=(2,2), padding='same')(c6)
u7 = concatenate([u7, c4])
c7 = Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
c7 = Dropout(0.2)(c7)
c7 = Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

u8 = Conv2DTranspose(64, (2,2), strides=(2,2), padding='same')(c7)
u8 = concatenate([u8, c3])
c8 = Conv2D(64, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
c8 = Dropout(0.1)(c8)
c8 = Conv2D(64, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

u9 = Conv2DTranspose(32, (2,2), strides=(2,2), padding='same')(c8)
u9 = concatenate([u9, c2])
c9 = Conv2D(32, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
c9 = Dropout(0.1)(c9)
c9 = Conv2D(32, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

u10 = Conv2DTranspose(16, (2,2), strides=(2,2), padding='same')(c9)
u10 = concatenate([u10, c1])
c10 = Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u10)
c10 = Dropout(0.1)(c10)
c10 = Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c10)

outputs = Conv2D(1, (1,1), activation='sigmoid')(c10)

model = Model(inputs=[inputs], outputs=[outputs])

In [None]:
metrics_list = ['accuracy', Recall(), Precision()]

In [None]:
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=metrics_list)
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(2, 256, 256, 3)]   0           []                               
                                                                                                  
 lambda (Lambda)                (2, 256, 256, 3)     0           ['input_1[0][0]']                
                                                                                                  
 conv2d (Conv2D)                (2, 256, 256, 16)    448         ['lambda[0][0]']                 
                                                                                                  
 dropout (Dropout)              (2, 256, 256, 16)    0           ['conv2d[0][0]']                 
                                                                                              

# Preparing Dataset

In [None]:
model_path = os.path.join("/content/drive/MyDrive/CNN_Models","unet_for_skull_stripping.h5")
csv_path = os.path.join("/content/drive/MyDrive/DataSet/random","unet_for_skull_stripping.csv")
path = "/content/drive/MyDrive/DataSet/skull_strpping"
split = 0.1

In [None]:
train_path = "/content/drive/MyDrive/DataSet/skull_strpping/train"
test_path = "/content/drive/MyDrive/DataSet/skull_strpping/test"

In [None]:
#load dataset
def load_data(path, split=0.1):
  images = sorted(glob(os.path.join(path, "image_1/*")))
  masks = sorted(glob(os.path.join(path, "mask_1/*")))
  
  return images, masks

In [None]:
#process image function
def read_img(path):
  img = cv2.imread(path, cv2.IMREAD_COLOR)
  img = cv2.resize(img, (256, 256))
  img = img / 255
  img = img.astype(np.float32)
  return img

#process mask function
def read_mask(path):
  mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
  mask = cv2.resize(mask, (256, 256))
  mask = mask / 255
  mask = np.expand_dims(mask, axis=-1)
  mask = mask.astype(np.float32)
  return mask

In [None]:
def preprocess(x, y):
  def f(x, y):
    x = x.decode()
    y = y.decode()
    
    x = read_img(x)
    y = read_mask(y)
    return x, y

  image, mask = tf.numpy_function(f, [x, y], [tf.float32, tf.float32])
  image.set_shape([256, 256, 3])
  mask.set_shape([256, 256, 1])

  return image, mask

In [None]:
def tf_dataset(x, y, batch=8):
  dataset = tf.data.Dataset.from_tensor_slices((x, y))
  dataset = dataset.shuffle(buffer_size=1000)
  dataset = dataset.map(preprocess)
  dataset = dataset.batch(batch)
  dataset = dataset.prefetch(2)
  return dataset

In [None]:
train_images, train_masks = load_data(train_path)
test_images, test_masks = load_data(test_path)
print(f'image: {len(train_images)} - Masks: {len(train_masks)}')
print(f'image: {len(test_images)} - Masks: {len(test_masks)}')

image: 712 - Masks: 712
image: 72 - Masks: 72


In [None]:
train_dataset = tf_dataset(train_images, train_masks, batch=BATCH_SIZE)
test_dataset = tf_dataset(test_images, test_masks, batch=BATCH_SIZE)

In [None]:
Callbacks = [
    ModelCheckpoint(model_path, verbose=1, save_best_only=True), 
    CSVLogger(csv_path),
    ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-7, verbose=1)
]

In [None]:
model.fit(
    train_dataset,
    epochs=10,
    validation_data=test_dataset, 
    callbacks=Callbacks
    )

Epoch 1/10
Epoch 1: val_loss improved from inf to 0.54497, saving model to /content/drive/MyDrive/CNN_Models/unet_for_skull_stripping.h5
Epoch 2/10
Epoch 2: val_loss improved from 0.54497 to 0.18804, saving model to /content/drive/MyDrive/CNN_Models/unet_for_skull_stripping.h5
Epoch 3/10
Epoch 3: val_loss improved from 0.18804 to 0.16563, saving model to /content/drive/MyDrive/CNN_Models/unet_for_skull_stripping.h5
Epoch 4/10
Epoch 4: val_loss did not improve from 0.16563
Epoch 5/10
Epoch 5: val_loss did not improve from 0.16563
Epoch 6/10
Epoch 6: val_loss improved from 0.16563 to 0.16206, saving model to /content/drive/MyDrive/CNN_Models/unet_for_skull_stripping.h5
Epoch 7/10
Epoch 7: val_loss did not improve from 0.16206
Epoch 8/10
Epoch 8: val_loss did not improve from 0.16206
Epoch 9/10
Epoch 9: val_loss improved from 0.16206 to 0.15804, saving model to /content/drive/MyDrive/CNN_Models/unet_for_skull_stripping.h5
Epoch 10/10
Epoch 10: val_loss did not improve from 0.15804


<keras.callbacks.History at 0x7f418b499bd0>

In [18]:
unet = load_model("/content/drive/MyDrive/CNN_Models/unet_for_skull_stripping.h5")

In [47]:
for img, msk in tqdm(zip(test_images, test_masks), total=len(test_images)):
  image_name = img.split("/")[-1]

  ori_x = cv2.imread(img, cv2.IMREAD_COLOR)
  ori_x = cv2.resize(ori_x, (W,H))
  x = ori_x / 255.0
  x = x.astype(np.float32)
  x = np.expand_dims(x, axis=0)

  ori_y = cv2.imread(msk, cv2.IMREAD_GRAYSCALE)
  ori_y = cv2.resize(ori_y, (W,H))
  ori_y = np.expand_dims(ori_y, axis=-1)
  
  y_pred = model.predict(x)[0] > 0.5  
  y_pred = y_pred.astype(np.int32)
  y_pred = y_pred.reshape(256, 256, 1)

  print(ori_x.shape, ori_y.shape, y_pred.shape)
  save_image_path = f"/content/predicted/{image_name}"

  sep_line = np.ones((H, 10, 3)) * 255
  cat_image = np.concatenate([ori_y, y_pred*255], axis=1)
  cv2.imwrite(save_image_path, cat_image)

  0%|          | 0/72 [00:00<?, ?it/s]



  0%|          | 0/72 [00:00<?, ?it/s]

(256, 256, 3) (256, 256, 1) (256, 256, 1)





ValueError: ignored