In [1]:
import numpy as np
import nibabel as nib
%matplotlib inline
import matplotlib.pyplot as plt

from utils.roi_measures import mad, ssim
from utils.extraction import extract_patches
from utils.reconstruction import perform_voting

  from ._conv import register_converters as _register_converters


In [2]:
import numpy as np

from keras import backend as K
from keras.layers import Activation, Input, PReLU, Flatten, Dense, Cropping3D, Dropout
from keras.layers.convolutional import Conv3D, MaxPooling3D
from keras.layers.convolutional import Conv3DTranspose as Deconv3D
from keras.layers.core import Permute, Reshape
from keras.layers.merge import add, concatenate
from keras.models import Model

K.set_image_dim_ordering('th')

def generate_uresnet_model(input_shape, output_shape, num_classes=4, scale=1):
    input = Input(shape=input_shape)

    conv1 = get_res_conv_core(input, np.int32(scale*32))
    pool1 = get_max_pooling_layer(conv1)

    conv2 = get_res_conv_core(pool1, np.int32(scale*64))
    pool2 = get_max_pooling_layer(conv2)

    conv3 = get_res_conv_core(pool2, np.int32(scale*128))
    pool3 = get_max_pooling_layer(conv3)

    conv4 = get_res_conv_core(pool3, np.int32(scale*256))
    
    up1 = get_deconv_layer(conv4, np.int32(scale*128))
    conv5 = get_res_conv_core(up1, np.int32(scale*128))

    add35 = merge_add(conv3, conv5)
    conv6 = get_res_conv_core(add35, np.int32(scale*128))
    up2 = get_deconv_layer(conv6, np.int32(scale*64))

    add22 = merge_add(conv2, up2)
    conv7 = get_res_conv_core(add22, np.int32(scale*64))
    up3 = get_deconv_layer(conv7, np.int32(scale*32))

    add13 = merge_add(conv1, up3)
    conv8 = get_res_conv_core(add13, np.int32(scale*32))

    pred = get_conv_fc(conv8)
    pred = organise_output(pred, output_shape)

    return Model(inputs=[input], outputs=[pred])

def merge_add(a, b) :
    c = add([a, b])
    return Activation('relu')(c)

def get_res_conv_core(input, num_filters) :
    a = Conv3D(num_filters, kernel_size=(3, 3, 3), padding='same')(input)
    b = Conv3D(num_filters, kernel_size=(1, 1, 1), padding='same')(input)
    return merge_add(a, b)

def get_max_pooling_layer(input) :
    return MaxPooling3D(pool_size=(2, 2, 2))(input)

def get_deconv_layer(input, num_filters) :
    return Deconv3D(num_filters, kernel_size=(2, 2, 2), strides=(2, 2, 2))(input)

def get_conv_fc(input, num_filters=4) :
    fc = Conv3D(num_filters, kernel_size=(1, 1, 1))(input)

    return Activation('relu')(fc)

def organise_output(input, output_shape) :
    pred = Reshape((4, 32*32*32))(input)
    pred = Permute((2, 1))(pred)
    return Activation('softmax')(pred)

Using TensorFlow backend.


In [17]:
scale=0.5
curr_patch_shape = (32, 32, 32)
patch_shape = (1, ) + curr_patch_shape
output_shape = (np.product(curr_patch_shape), 4)
model = generate_uresnet_model(patch_shape, output_shape, scale)
model.compile(loss='categorical_crossentropy', optimizer='Adam', metrics=['acc'])

In [13]:
import os
import nibabel as nib
from keras.utils import to_categorical

from medical_data import cdr_info, nwbv_info, diff_info

file_general_pattern = 'OAS2_0{0:03}_MR{1}_{3}_OAS2_0{0:03}_MR{2}'
dataset_location = 'datasets/OASIS/OASIS2/REG/{}/{}.nii.gz'

step = (32, 32, 32)
num_classes = 4
threshold = np.int32(0.30 * np.prod(curr_patch_shape[:]))
ref_train = np.empty((0, 1, ) + curr_patch_shape)
out_train = np.empty((0, np.prod(curr_patch_shape), num_classes))
for i in range(1, 100) :
    print '{} :'.format(i),
    
    for j in range(1, 5) :
        k = j + 1
        mov_filename = dataset_location.format(
            file_general_pattern.format(i, j, k, 'to'),
            file_general_pattern.format(i, j, k, 'halfwayto'))
        mov_prob_filename = dataset_location.format(
            file_general_pattern.format(i, j, k, 'to'),
            file_general_pattern.format(i, j, k, 'halfwayto') + '_brain_seg')

        if not os.path.exists(mov_filename) :
            continue

        volume_init = nib.load(mov_filename).get_data()
        volume_init = volume_init / volume_init.max()

        mask_patches = extract_patches(volume_init != 0, curr_patch_shape, step)
        useful_patches = np.sum(mask_patches, axis=(1, 2, 3)) > threshold
        N = np.sum(useful_patches)
        
        del mask_patches

        mov_patches = extract_patches(volume_init, curr_patch_shape, step)
        mov_patches = mov_patches[useful_patches].reshape((-1, 1, ) + curr_patch_shape)
        ref_train = np.vstack((mov_patches, ref_train)).astype('float32')
        del mov_patches

        volume_init = nib.load(mov_prob_filename).get_data()

        mov_prob_patches = extract_patches(volume_init, curr_patch_shape, step)
        mov_prob_patches = mov_prob_patches[useful_patches].reshape((-1, 1, np.prod(curr_patch_shape)))
        
        labels_train = np.empty((N, np.prod(curr_patch_shape), 4))
        for l in range(N) :
            labels_train[l] = to_categorical(mov_prob_patches[l].flatten(), 4)

        out_train = np.vstack((labels_train, out_train)).astype('float32')
        del labels_train, mov_prob_patches
        ######################################################################################
        print '{}->{}'.format(j, k),
print

1 : 1->2 2 : 1->2 2->3 3 : 4 : 1->2 5 : 1->2 2->3 6 : 7 : 3->4 8 : 1->2 9 : 1->2 10 : 1->2 11 : 12 : 1->2 2->3 13 : 1->2 2->3 14 : 1->2 15 : 16 : 1->2 17 : 3->4 18 : 3->4 19 : 20 : 1->2 2->3 21 : 1->2 22 : 1->2 23 : 1->2 24 : 25 : 26 : 1->2 27 : 1->2 2->3 3->4 28 : 1->2 29 : 1->2 30 : 1->2 31 : 1->2 2->3 32 : 1->2 33 : 34 : 1->2 2->3 3->4 35 : 1->2 36 : 3->4 37 : 1->2 2->3 3->4 38 : 39 : 1->2 40 : 1->2 2->3 41 : 1->2 2->3 42 : 1->2 43 : 1->2 44 : 1->2 2->3 45 : 1->2 46 : 1->2 47 : 1->2 48 : 1->2 2->3 3->4 49 : 1->2 2->3 50 : 1->2 51 : 1->2 2->3 52 : 1->2 53 : 1->2 54 : 1->2 55 : 1->2 56 : 1->2 57 : 1->2 2->3 58 : 1->2 2->3 59 : 60 : 1->2 61 : 1->2 2->3 62 : 1->2 2->3 63 : 1->2 64 : 1->2 2->3 65 : 66 : 1->2 67 : 1->2 2->3 3->4 68 : 1->2 69 : 1->2 70 : 1->2 2->3 3->4 71 : 1->2 72 : 73 : 1->2 2->3 3->4 74 : 75 : 1->2 76 : 1->2 2->3 77 : 1->2 78 : 1->2 2->3 79 : 1->2 2->3 80 : 1->2 2->3 81 : 1->2 82 : 83 : 84 : 85 : 1->2 86 : 1->2 87 : 1->2 88 : 1->2 89 : 90 : 1->2 2->3 91 : 1->2 92 : 1->2

In [14]:
train_mean = ref_train.mean()
train_std = ref_train.std()

ref_train = (ref_train - train_mean) / train_std

In [15]:
print train_mean, train_std

0.23081103 0.18944699


In [18]:
from keras.callbacks import EarlyStopping, ModelCheckpoint

patience = 3

stopper = EarlyStopping(patience=patience)
checkpointer = ModelCheckpoint('models/ag_segmenter.h5', save_best_only=True, save_weights_only=True)

N = len(ref_train)
model.fit(
    ref_train, out_train,
    validation_split=0.3, epochs=40,
    callbacks=[checkpointer, stopper])

Train on 4274 samples, validate on 1832 samples
Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40


<keras.callbacks.History at 0x7fca40ec3d90>

In [19]:
model.load_weights('models/ag_segmenter.h5')