In [1]:
import numpy as np
import nibabel as nib
%matplotlib inline
import matplotlib.pyplot as plt

from utils.roi_measures import mad, ssim
from utils.extraction import extract_patches
from utils.reconstruction import perform_voting

  from ._conv import register_converters as _register_converters


In [9]:
import numpy as np

from keras import backend as K
from keras.layers import Activation, Input, PReLU, Flatten, Dense, Cropping3D, Dropout
from keras.layers.convolutional import Conv3D, MaxPooling3D
from keras.layers.convolutional import Conv3DTranspose as Deconv3D
from keras.layers.core import Permute, Reshape
from keras.layers.merge import add, concatenate
from keras.layers.normalization import BatchNormalization
from keras.models import Model

K.set_image_dim_ordering('th')

def generate_uresnet_model(input_shape, output_shape, num_classes=4, scale=1):
    input = Input(shape=input_shape)

    conv1 = get_res_conv_core(input, np.int32(scale*32))
    pool1 = get_max_pooling_layer(conv1)

    conv2 = get_res_conv_core(pool1, np.int32(scale*64))
    pool2 = get_max_pooling_layer(conv2)

    conv3 = get_res_conv_core(pool2, np.int32(scale*128))
    pool3 = get_max_pooling_layer(conv3)

    conv4 = get_res_conv_core(pool3, np.int32(scale*256))
    
    up1 = get_deconv_layer(conv4, np.int32(scale*128))
    conv5 = get_res_conv_core(up1, np.int32(scale*128))

    add35 = merge_add(conv3, conv5)
    conv6 = get_res_conv_core(add35, np.int32(scale*128))
    up2 = get_deconv_layer(conv6, np.int32(scale*64))

    add22 = merge_add(conv2, up2)
    conv7 = get_res_conv_core(add22, np.int32(scale*64))
    up3 = get_deconv_layer(conv7, np.int32(scale*32))

    add13 = merge_add(conv1, up3)
    conv8 = get_res_conv_core(add13, np.int32(scale*32))

    pred = get_conv_fc(conv8)
    pred = organise_output(pred, output_shape)

    return Model(inputs=[input], outputs=[pred])

def merge_add(a, b) :
    c = add([a, b])
    c = BatchNormalization(axis=1)(c)
    return PReLU()(c)

def get_res_conv_core(input, num_filters) :
    a = Conv3D(num_filters, kernel_size=(3, 3, 3), padding='same')(input)
    b = Conv3D(num_filters, kernel_size=(1, 1, 1), padding='same')(input)
    return merge_add(a, b)

def get_max_pooling_layer(input) :
    return MaxPooling3D(pool_size=(2, 2, 2))(input)

def get_deconv_layer(input, num_filters) :
    return Deconv3D(num_filters, kernel_size=(2, 2, 2), strides=(2, 2, 2))(input)

def get_conv_fc(input, num_filters=4) :
    fc = Conv3D(num_filters, kernel_size=(1, 1, 1))(input)

    return PReLU()(fc)

def organise_output(input, output_shape) :
    pred = Reshape((4, 32*32*32))(input)
    pred = Permute((2, 1))(pred)
    return Activation('softmax')(pred)

In [14]:
scale=1
curr_patch_shape = (32, 32, 32)
patch_shape = (1, ) + curr_patch_shape
output_shape = (np.product(curr_patch_shape), 2)
model = generate_uresnet_model(patch_shape, output_shape, scale)
model.compile(loss='categorical_crossentropy', optimizer='Adam', metrics=['acc'])

In [12]:
a_range = range(1, 100)
np.random.shuffle(a_range)

N = len(a_range)
N_train = np.uint8(np.ceil(N * 0.70))

range_train = a_range[:N_train]
range_val = a_range[N_train:]

In [16]:
import os
import nibabel as nib
from keras.utils import to_categorical

file_general_pattern = 'OAS2_0{0:03}_MR{1}'
dataset_skull_location = 'datasets/OASIS/OASIS2/SKULL/{}.nii.gz'
dataset_histogram_location = 'datasets/OASIS/OASIS2/MATCHED_HISTOGRAM/{}.nii.gz'

step = (16, 16, 16)
num_classes = 4
threshold = np.int32(0.40 * np.prod(curr_patch_shape[:]))
ref_train = np.empty((0, 1, ) + curr_patch_shape)
out_train = np.empty((0, np.prod(curr_patch_shape), num_classes))
ref_val = np.empty((0, 1, ) + curr_patch_shape)
out_val = np.empty((0, np.prod(curr_patch_shape), num_classes))
for i in range(1, 100) :
    print '{} :'.format(i),
    
    for j in range(1, 5) :
        k = j + 1
        filename = dataset_histogram_location.format(file_general_pattern.format(i, j))
        seg_filename = dataset_skull_location.format(file_general_pattern.format(i, str(j)+ '_seg'))
        
        if not os.path.exists(filename) :
            continue

        volume_init = nib.load(filename).get_data()

        mask_patches = extract_patches(volume_init != 0, curr_patch_shape, step)
        useful_patches = np.sum(mask_patches, axis=(1, 2, 3)) > threshold
        N = np.sum(useful_patches)
        
        del mask_patches

        mov_patches = extract_patches(volume_init, curr_patch_shape, step)
        mov_patches = mov_patches[useful_patches].reshape((-1, 1, ) + curr_patch_shape)
        if i in range_train :
            ref_train = np.vstack((mov_patches, ref_train)).astype('float32')
        else :
            ref_val = np.vstack((mov_patches, ref_val)).astype('float32')
        del mov_patches

        volume_init = nib.load(seg_filename).get_data()

        mov_prob_patches = extract_patches(volume_init, curr_patch_shape, step)
        mov_prob_patches = mov_prob_patches[useful_patches].reshape((-1, 1, np.prod(curr_patch_shape)))
        
        labels_train = np.empty((N, np.prod(curr_patch_shape), 4))
        for l in range(N) :
            labels_train[l] = to_categorical(mov_prob_patches[l].flatten(), 4)
    
        if i in range_train :
            out_train = np.vstack((labels_train, out_train)).astype('float32')
        else :
            out_val = np.vstack((labels_train, out_val)).astype('float32')
        del labels_train, mov_prob_patches
        ######################################################################################
        print '{}->{}'.format(j, k),
print

1 : 1->2 2->3 2 : 1->2 2->3 3->4 3 : 4 : 1->2 2->3 5 : 1->2 2->3 3->4 6 : 7 : 1->2 3->4 4->5 8 : 1->2 2->3 9 : 1->2 2->3 10 : 1->2 2->3 11 : 12 : 1->2 2->3 3->4 13 : 1->2 2->3 3->4 14 : 1->2 2->3 15 : 16 : 1->2 2->3 17 : 1->2 3->4 4->5 18 : 1->2 3->4 4->5 19 : 20 : 1->2 2->3 3->4 21 : 1->2 2->3 22 : 1->2 2->3 23 : 1->2 2->3 24 : 25 : 26 : 1->2 2->3 27 : 1->2 2->3 3->4 4->5 28 : 1->2 2->3 29 : 1->2 2->3 30 : 1->2 2->3 31 : 1->2 2->3 3->4 32 : 1->2 2->3 33 : 34 : 1->2 2->3 3->4 4->5 35 : 1->2 2->3 36 : 1->2 3->4 4->5 37 : 1->2 2->3 3->4 4->5 38 : 39 : 1->2 2->3 40 : 1->2 2->3 3->4 41 : 1->2 2->3 3->4 42 : 1->2 2->3 43 : 1->2 2->3 44 : 1->2 2->3 3->4 45 : 1->2 2->3 46 : 1->2 2->3 47 : 1->2 2->3 48 : 1->2 2->3 3->4 4->5 49 : 1->2 2->3 3->4 50 : 1->2 2->3 51 : 1->2 2->3 3->4 52 : 1->2 2->3 53 : 1->2 2->3 54 : 1->2 2->3 55 : 1->2 2->3 56 : 1->2 2->3 57 : 1->2 2->3 3->4 58 : 1->2 2->3 3->4 59 : 60 : 1->2 2->3 61 : 1->2 2->3 3->4 62 : 1->2 2->3 3->4 63 : 1->2 2->3 64 : 1->2 2->3 3->4 65 : 66 :

In [17]:
train_mean = ref_train.mean()
train_std = ref_train.std()

ref_train = (ref_train - train_mean) / train_std
ref_val = (ref_val - train_mean) / train_std

In [18]:
print train_mean, train_std

786.1171 587.41235


In [19]:
from keras.callbacks import EarlyStopping, ModelCheckpoint

patience = 10

stopper = EarlyStopping(patience=patience)
checkpointer = ModelCheckpoint('models/ag_segmenter.h5', save_best_only=True, save_weights_only=True)

N = len(ref_train)
model.fit(
    ref_train, out_train,
    validation_data=(ref_val, out_val), epochs=100,
    callbacks=[checkpointer, stopper])

Train on 41651 samples, validate on 15018 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100


<keras.callbacks.History at 0x7f60e4012c50>

In [20]:
model.load_weights('models/ag_segmenter.h5')