In [1]:
import numpy as np
import nibabel as nib

from utils.extraction import extract_patches
from unet import generate_uresnet_model

Using TensorFlow backend.


In [2]:
scale = 1
patch_shape = (32, 32, 32)
input_shape = (1, ) + patch_shape
output_shape = (np.product(patch_shape), 4)
model = generate_uresnet_model(input_shape, output_shape, scale)
model.compile(loss='categorical_crossentropy', optimizer='Adam', metrics=['acc'])

In [3]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 1, 32, 32, 32 0                                            
__________________________________________________________________________________________________
conv3d_1 (Conv3D)               (None, 32, 32, 32, 3 896         input_1[0][0]                    
__________________________________________________________________________________________________
conv3d_2 (Conv3D)               (None, 32, 32, 32, 3 64          input_1[0][0]                    
__________________________________________________________________________________________________
add_1 (Add)                     (None, 32, 32, 32, 3 0           conv3d_1[0][0]                   
                                                                 conv3d_2[0][0]                   
__________

In [4]:
a_range = range(1, 100)
np.random.shuffle(a_range)

N = len(a_range)
N_train = np.uint8(np.ceil(N * 0.70))

range_train = a_range[:N_train]
range_val = a_range[N_train:]

In [5]:
import os
import nibabel as nib
from keras.utils import to_categorical

file_general_pattern = 'OAS2_0{0:03}_MR{1}'
dataset_skull_location = '/mnt/harddisk/datasets/OASIS/SKULL/{}.nii.gz'
dataset_histogram_location = '/mnt/harddisk/datasets/OASIS/MATCHED_HISTOGRAM/{}.nii.gz'

step = (16, 16, 16)
num_classes = 4
threshold = np.int32(0.40 * np.prod(patch_shape[:]))
ref_train = np.empty((0, 1, ) + patch_shape)
out_train = np.empty((0, np.prod(patch_shape), num_classes))
ref_val = np.empty((0, 1, ) + patch_shape)
out_val = np.empty((0, np.prod(patch_shape), num_classes))
for i in a_range :
    print '{} :'.format(i),
    
    for j in range(1, 5) :
        k = j + 1
        filename = dataset_histogram_location.format(file_general_pattern.format(i, j))
        seg_filename = dataset_skull_location.format(file_general_pattern.format(i, str(j)+ '_seg'))
        
        if not os.path.exists(filename) :
            continue

        volume_init = nib.load(filename).get_data()

        mask_patches = extract_patches(volume_init != 0, patch_shape, step)
        useful_patches = np.sum(mask_patches, axis=(1, 2, 3)) > threshold
        N = np.sum(useful_patches)
        
        del mask_patches

        mov_patches = extract_patches(volume_init, patch_shape, step)
        mov_patches = mov_patches[useful_patches].reshape((-1, 1, ) + patch_shape)
        if i in range_train :
            ref_train = np.vstack((mov_patches, ref_train)).astype('float32')
        else :
            ref_val = np.vstack((mov_patches, ref_val)).astype('float32')
        del mov_patches

        volume_init = nib.load(seg_filename).get_data()

        mov_prob_patches = extract_patches(volume_init, patch_shape, step)
        mov_prob_patches = mov_prob_patches[useful_patches].reshape((-1, 1, np.prod(patch_shape)))
        
        labels_train = np.empty((N, np.prod(patch_shape), 4))
        for l in range(N) :
            labels_train[l] = to_categorical(mov_prob_patches[l].flatten(), 4)
    
        if i in range_train :
            out_train = np.vstack((labels_train, out_train)).astype('float32')
        else :
            out_val = np.vstack((labels_train, out_val)).astype('float32')
        del labels_train, mov_prob_patches
        ######################################################################################
        print '{}->{}'.format(j, k),
print

4 : 1->2 2->3 1 : 1->2 2->3 35 : 1->2 2->3 75 : 1->2 2->3 86 : 1->2 2->3 92 : 1->2 2->3 36 : 1->2 3->4 4->5 34 : 1->2 2->3 3->4 4->5 89 : 1->2 3->4 14 : 1->2 2->3 9 : 1->2 2->3 16 : 1->2 2->3 73 : 1->2 2->3 3->4 4->5 43 : 1->2 2->3 19 : 94 : 1->2 2->3 32 : 1->2 2->3 99 : 1->2 2->3 74 : 61 : 1->2 2->3 3->4 20 : 1->2 2->3 3->4 84 : 53 : 1->2 2->3 8 : 1->2 2->3 6 : 17 : 1->2 3->4 4->5 90 : 1->2 2->3 3->4 87 : 1->2 2->3 31 : 1->2 2->3 3->4 47 : 1->2 2->3 82 : 60 : 1->2 2->3 88 : 1->2 2->3 42 : 1->2 2->3 77 : 1->2 2->3 22 : 1->2 2->3 72 : 49 : 1->2 2->3 3->4 97 : 1->2 2->3 96 : 1->2 2->3 40 : 1->2 2->3 3->4 95 : 1->2 2->3 3->4 76 : 1->2 2->3 3->4 91 : 1->2 2->3 85 : 1->2 2->3 46 : 1->2 2->3 57 : 1->2 2->3 3->4 25 : 62 : 1->2 2->3 3->4 13 : 1->2 2->3 3->4 37 : 1->2 2->3 3->4 4->5 38 : 21 : 1->2 2->3 66 : 1->2 2->3 78 : 1->2 2->3 3->4 56 : 1->2 2->3 2 : 1->2 2->3 3->4 59 : 98 : 1->2 2->3 15 : 67 : 1->2 2->3 3->4 4->5 3 : 71 : 1->2 2->3 33 : 26 : 1->2 2->3 45 : 1->2 2->3 64 : 1->2 2->3 3->4 55

In [6]:
train_mean = ref_train.mean()
train_std = ref_train.std()
params = {'train_mean' : train_mean, 'train_std' : train_std}

np.save('models/ag_segmenter_o1o2.npy', params)

In [7]:
ref_train = (ref_train - train_mean) / train_std
ref_val = (ref_val - train_mean) / train_std

In [8]:
print train_mean, train_std

789.9806 587.54895


In [9]:
from keras.callbacks import EarlyStopping, ModelCheckpoint

patience = 10

stopper = EarlyStopping(patience=patience)
checkpointer = ModelCheckpoint('models/ag_segmenter_o1o2.h5', save_best_only=True, save_weights_only=True)

N = len(ref_train)
model.fit(
    ref_train, out_train,
    validation_data=(ref_val, out_val), epochs=100,
    callbacks=[checkpointer, stopper])

Train on 38738 samples, validate on 17931 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100


<keras.callbacks.History at 0x7f35102e9e50>