This is a simple walk me through training for the paper published in ISBI2018 entiteld: "Real-time automatic fetal brain extraction in fetal MRI by deep learning".
Please do not forget to cite this work using following BibTex code:
```
@inproceedings{salehi2018real,
  title={Real-time automatic fetal brain extraction in fetal mri by deep learning},
  author={Salehi, Seyed Sadegh Mohseni and Hashemi, Seyed Raein and Velasco-Annis, Clemente and Ouaalam, Abdelhakim and Estroff, Judy A and Erdogmus, Deniz and Warfield, Simon K and Gholipour, Ali},
  booktitle={Biomedical Imaging (ISBI 2018), 2018 IEEE 15th International Symposium on},
  pages={720--724},
  year={2018},
  organization={IEEE}
}
```

Need more info? No problem! contact me at sadegh.msalehi@gmail.com

In [None]:
# importing needed libraries. 
# you can pip install all of them.
import os

from medpy.io import load
import numpy as np

import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d, upsample_2d

import tensorflow as tf

In [None]:
# Defining parameters

dataPath = './data/' # where data is
modelPath = './model/' # where to save the model

# what is loss function, the current options in this notebook are:
# cross_entropy: apply cross entropy on each pixle separately and avrage them on slice
# weighted_cross_entropy: apply cross entropy on each pixle separately and weighted average them on slice based on 
#                         the ratio of classes in each slice
# dice: apply dice coefficient on each slice and minimize 1-dice
# Tverskey: not implemented in this notebook. very useful for highly imblanced data (like 3d MS lesion detection)
loss_method = 'weighted_cross_entropy' # what is loss function, the 

batch_size = 16
display_step = 20

# Network Parameters
tf.reset_default_graph()
width = 256
height = 256
n_channels = 1
n_classes = 2 # total classes (brain, non-brain)

# x: place holder for the input image.
# y: place holder for the labels.
# lr : place holder for learning rate. to change the learning rate as we move forward. 
# weights: used in weighted_cross_entropy.
x = tf.placeholder(tf.float32, [None, width, height, n_channels])
y = tf.placeholder(tf.float32, [None, width, height, n_classes])
lr = tf.placeholder(tf.float32)
weights = tf.placeholder(tf.float32, [batch_size*width*height])

# total number of slices we are going to train on. Not the best implementation though.
NumberOfSamples = 1259

In [None]:
# generate batches during training. one can use keras and forget about this function.
def generate_batch():
    for samples in generate_samples():
        image_batch = images[samples]
        label_batch = labels[samples]
        for i in range(image_batch.shape[0]):
            image_batch[i], label_batch[i] = augment_sample(image_batch[i], label_batch[i])
        yield(image_batch, label_batch)

# choose random slices:
def generate_samples():
    n_samples = NumberOfSamples
    n_epochs = 1000
    n_batches = n_samples/batch_size
    for _ in range(n_epochs):
        sample_ids = np.random.permutation(n_samples)
        for i in range(int(n_batches)):
            inds = slice(i*batch_size, (i+1)*batch_size)
            yield sample_ids[inds]

# you want to add augmentation? (rotation, translation, etc). Do it on_fly! write your augmentation function here:
# right now: do nothing for augmentation! :)
def augment_sample(image, label):

    image = image
    label = label
    
    return(image, label)

In [None]:
# design you model here but first be sure to reset tensorflow graph.
tf.reset_default_graph()

# Unet:
conv1 = conv_2d(x, 32, 3, activation='relu', padding='same', regularizer="L2")
conv1 = conv_2d(conv1, 32, 3, activation='relu', padding='same', regularizer="L2")
pool1 = max_pool_2d(conv1, 2)

conv2 = conv_2d(pool1, 64, 3, activation='relu', padding='same', regularizer="L2")
conv2 = conv_2d(conv2, 64, 3, activation='relu', padding='same', regularizer="L2")
pool2 = max_pool_2d(conv2, 2)

conv3 = conv_2d(pool2, 128, 3, activation='relu', padding='same', regularizer="L2")
conv3 = conv_2d(conv3, 128, 3, activation='relu', padding='same', regularizer="L2")
pool3 = max_pool_2d(conv3, 2)

conv4 = conv_2d(pool3, 256, 3, activation='relu', padding='same', regularizer="L2")
conv4 = conv_2d(conv4, 256, 3, activation='relu', padding='same', regularizer="L2")
pool4 = max_pool_2d(conv4, 2)

conv5 = conv_2d(pool4, 512, 3, activation='relu', padding='same', regularizer="L2")
conv5 = conv_2d(conv5, 512, 3, activation='relu', padding='same', regularizer="L2")

up6 = upsample_2d(conv5,2)
up6 = tflearn.layers.merge_ops.merge([up6, conv4], 'concat',axis=3)
conv6 = conv_2d(up6, 256, 3, activation='relu', padding='same', regularizer="L2")
conv6 = conv_2d(conv6, 256, 3, activation='relu', padding='same', regularizer="L2")

up7 = upsample_2d(conv6,2)
up7 = tflearn.layers.merge_ops.merge([up7, conv3],'concat', axis=3)
conv7 = conv_2d(up7, 128, 3, activation='relu', padding='same', regularizer="L2")
conv7 = conv_2d(conv7, 128, 3, activation='relu', padding='same', regularizer="L2")

up8 = upsample_2d(conv7,2)
up8 = tflearn.layers.merge_ops.merge([up8, conv2],'concat', axis=3)
conv8 = conv_2d(up8, 64, 3, activation='relu', padding='same', regularizer="L2")
conv8 = conv_2d(conv8, 64, 3, activation='relu', padding='same', regularizer="L2")

up9 = upsample_2d(conv8,2)
up9 = tflearn.layers.merge_ops.merge([up9, conv1],'concat', axis=3)
conv9 = conv_2d(up9, 32, 3, activation='relu', padding='same', regularizer="L2")
conv9 = conv_2d(conv9, 32, 3, activation='relu', padding='same', regularizer="L2")

pred = conv_2d(conv9, 2, 1,  activation='linear', padding='valid')

pred_reshape = tf.reshape(pred, [-1, n_classes])

In [None]:
# load images and labels
images = np.zeros((NumberOfSamples, width, height, n_channels))
labels = np.zeros((NumberOfSamples, width, height, n_classes))

slice_counter = 0
for path, subdirs, files in os.walk(dataPath):
    for name in files:
        if not("mask") in name and "nii" in name:
            image_data, image_header = load(os.path.join(path, name)) # Load data
            mask_data, mask_header = load(os.path.join(path, 'mask' + name)) # Load mask
            image_data = np.moveaxis(image_data, -1, 0) # Bring the last dim to the first
            mask_data = np.moveaxis(mask_data, -1, 0) # Bring the last dim to the first
            
            images[slice_counter:slice_counter + image_data.shape[0], :, :, 0] = image_data / np.max(image_data)

            # Make the labels one-hot
            labels[slice_counter:slice_counter + image_data.shape[0], :, :, 0] = mask_data
            labels[slice_counter:slice_counter + image_data.shape[0], :, :, 1] = 1 - mask_data

In [None]:
######################################
# Define loss and optimizer
pred_reshape = tf.reshape(pred, [batch_size * width * height, n_classes])
y_reshape = tf.reshape(y, [batch_size * width * height, n_classes])

if loss_method == 'cross_entropy':
    error = tf.nn.softmax_cross_entropy_with_logits(labels = y_reshape , logits = pred_reshape)
    cost = tf.reduce_mean(error)

elif loss_method == 'weighted_cross_entropy':
    error = tf.nn.softmax_cross_entropy_with_logits(labels = y_reshape , logits = pred_reshape)
    cost = tf.reduce_mean(error)

elif loss_method == 'dice':
    intersection = tf.reduce_sum(pred_reshape * y_reshape)
    cost = -(2 * intersection + 1)/(tf.reduce_sum(pred_reshape) + tf.reduce_sum(y_reshape) + 1)
    
else:
    raise NotImplementedError

optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(cost)

# Evaluate model
correct_pred = tf.equal(tf.argmax(pred, -1), tf.argmax(y, -1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initializing the variables
init = tf.global_variables_initializer()
arg_labels = np.argmax(labels, axis = -1)
class_weights = np.zeros(n_classes)
for i in range(n_classes):
    class_weights[i] = 1 / np.mean(arg_labels == i) ** 0.3
class_weights /= np.sum(class_weights)

sess = tf.Session()
sess.run(init)

saver = tf.train.Saver()
model_path = os.path.join(modelPath, 'Unet.ckpt')

learning_rate = 0.00001

In [None]:
#######################Train###################################
for step2, (image_batch, label_batch) in enumerate(generate_batch()):            
    label_vect = np.reshape(np.argmax(label_batch, axis=-1), [batch_size * width * height])
    weight_vect = class_weights[label_vect]
    # Fit training using batch data
    feed_dict = {x: image_batch, y: label_batch, weights: weight_vect, lr:learning_rate}
    loss, acc, _ = sess.run([cost, accuracy, optimizer], feed_dict=feed_dict)
    if step2 % display_step == 0:
        print("Step %d, Minibatch Loss=%0.6f , Training Accuracy=%0.5f " 
              % (step2, loss, acc))

        # Save the variables to disk.
        saver.save(sess, model_path)
    if step2 % 2000 == 0:
        learning_rate *= 0.9