This is a neural network based on the modified U-Net architecture found in the DeepHarmony paper (as pictured below). In addition, it features batch normalization layers integrated within the network and a compound loss function made up of MS-SSIM and L1. 


Importing necessary modules

In [16]:
import tensorflow as tf
from os import scandir
import skimage.io as io 
import skimage.transform as trans
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.activations import *
from keras.metrics import mean_squared_error as mse
from keras import backend as keras 
import matplotlib.pyplot as plt
import numpy as np
from random import random
from mriPrep import Preprocessing


Definition of the Network Class, including loss function definition and convolution definitions

In [None]:
class Unet:

    # initialization, including declaration of input sizes, coefficients for loss function
    def __init__(self):
        self.input_size = (128, 128, 1)
        self.coe = [random(), random()]

    # compound loss function
    def comp_loss(self):
        ssim = 1 - tf.reduce_mean(tf.image.ssim(self.y_true, self.y_pred, 1.0))
        compound_loss = (self.coe[0] * ssim) + (self.coe[1] * mse(self.y_true, self.y_pred)) 
        return compound_loss
    
    # side convolution
    def convolution(self, inp, n_filters):
        conv = Conv2D(n_filters, 3, activation="relu", padding="same", kernel_initializer="he_normal")(inp)
        bn = BatchNormalization(axis=1, momentum=0.99, epsilon=0.0001)(conv)
        return bn

    # down convolution
    def down_convolution(self, inp, n_filters):
        conv = Conv2D(n_filters, (4, 4), activation="relu", padding="same", strides=(2,2) kernel_initializer="he_normal")(inp)
        bn = bn = BatchNormalization(axis=1, momentum=0.99, epsilon=0.0001)(conv)
        return bn

    # up convolution
    def up_convolution(self, inp, n_filters, conv_features):
        deconvolution = Conv2DTranspose(n_filters, (4, 4), activation = 'relu', padding = 'same', strides=(0.5, 0.5), kernel_initializer = 'he_normal')(inp)
        concatenation = Concatenate([conv_features, deconvolution], axis=3)
        return concatenation 

    # definition of neural network
    def unet(self, pretrained_weights=None):

        inputs = Input(self.input_size)
        
        # side
        conv1 = self.convolution(inputs, 16)

        # down
        conv2 = self.down_convolution(conv1, 16)
        
        # side
        conv3 = self.convolution(conv2, 32)
        
        # down
        conv4 = self.down_convolution(conv3, 32)

        # side
        conv5 = self.convolution(conv4, 64)

        # down
        conv6 = self.down_convolution(conv5, 64)

        # side
        conv7 = self.convolution(conv6, 128)

        # down 
        conv8 = self.down_convolution(conv7, 128)
        
        # side
        conv9 = self.convolution(conv8, 256)

        # up and merge
        conv10 = self.up_convolution(conv9, 128, conv7)

        # side
        conv11 = self.convolution(conv10, 128)

        # up and merge
        conv12 = self.up_convolution(conv11, 64, conv5)

        # side 
        conv13 = self.convolution(conv12, 64)

        # up and merge
        conv14 = self.up_convolution(conv13, 32, conv3)

        # side 
        conv15 = self.convolution(conv14, 32)

        # up and merge
        conv16 = self.up_convolution(conv15, 16, conv1)

        # side and merge 
        conv17 = self.convolution(conv16, 16)
        merge = Concatenate([inputs, conv17], axis=3)

        # final side
        conv18 = Conv2D(1, 1, activation="relu", padding="same", kernel_initializer="he_normal")(merge)

        model = Model(input = inputs, output = conv18)
        model.compile(optimizer= Adam(lr = 1e-4),loss=self.comp_loss,metrics=["accuracy"])
        model.summary()


Importing Data, running preprocessing programs, and Splitting into training and validation

In [None]:
preprocessing = Preprocessing("./data/modified/", "./data/preprocessed/")
preprocessing.correctBias("data/modified/scanner2_sub-CC110069_T1w.nii")


Training Model

In [None]:
callbacks = [
    keras.callbacks.ModelCheckpoint("oxford_segmentation.h5", save_best_only=True)
]

# Train the model, doing validation at the end of each epoch.
epochs = 20
model.fit()