In [2]:
import nibabel as nib
import numpy as np
#import tensorflow
import keras

# Fix random seed for reproducibility?
# Better to follow the advice in Keras FAQ:
#  "How can I obtain reproducible results using Keras during development?"
seed = 7
np.random.seed(seed)

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


# Problem configuration

In [3]:
num_classes = 3

patience = 1
model_filename = 'models/iSeg2017/outrun_step_{}.h5'
csv_filename = 'log/iSeg2017/outrun_step_{}.cvs'

nb_epoch = 20
validation_split = 0.25

class_mapper = {0 : 0, 10 : 0, 150 : 1, 250 : 2}
class_mapper_inv = {0 : 0, 1 : 10, 2 : 150, 3 : 250}

# Utils

In [4]:
# General utils for reading and saving data
def get_filename(set_name, case_idx, input_name, loc='datasets') :
    pattern = '{0}/iSeg2017/iSeg-2017-{1}/subject-{2}-{3}.hdr'
    return pattern.format(loc, set_name, case_idx, input_name)

def get_set_name(case_idx) :
    return 'Training' if case_idx < 11 else 'Testing'

def read_data(case_idx, input_name, loc='datasets') :
    set_name = get_set_name(case_idx)

    image_path = get_filename(set_name, case_idx, input_name, loc)

    return nib.load(image_path)

def read_vol(case_idx, input_name, loc='datasets') :
    image_data = read_data(case_idx, input_name, loc)

    return image_data.get_data()[:, :, :, 0]

def save_vol(segmentation, case_idx, loc='results') :
    set_name = get_set_name(case_idx)
    input_image_data = read_data(case_idx, 'T1')

    segmentation_vol = np.empty(input_image_data.shape)
    segmentation_vol[:144, :192, :256, 0] = segmentation
    
    filename = get_filename(set_name, case_idx, 'label', loc)
    nib.save(nib.analyze.AnalyzeImage(
        segmentation_vol.astype('uint8'), input_image_data.affine), filename)


# Data preparation utils
from keras.utils import np_utils
from sklearn.feature_extraction.image import extract_patches as sk_extract_patches

def extract_patches(volume, patch_shape, extraction_step) :
    patches = sk_extract_patches(
        volume,
        patch_shape=patch_shape,
        extraction_step=extraction_step)

    ndim = len(volume.shape)
    npatches = np.prod(patches.shape[:ndim])
    return patches.reshape((npatches, ) + patch_shape)

def build_set(T1_vols, T2_vols, label_vols, extraction_step=(9, 9, 9)) :
    patch_shape = (27, 27, 27)
    label_selector = [slice(None)] + [slice(9, 18) for i in range(3)]

    # Extract patches from input volumes and ground truth
    x = np.zeros((0, 2, 27, 27, 27))
    y = np.zeros((0, 9 * 9 * 9, num_classes))
    for idx in range(len(T1_vols)) :
        y_length = len(y)

        label_patches = extract_patches(label_vols[idx], patch_shape, extraction_step)
        label_patches = label_patches[label_selector]

        # Select only those who are important for processing
        valid_idxs = np.where(np.sum(label_patches, axis=(1, 2, 3)) != 0)

        # Filtering extracted patches
        label_patches = label_patches[valid_idxs]

        x = np.vstack((x, np.zeros((len(label_patches), 2, 27, 27, 27))))
        y = np.vstack((y, np.zeros((len(label_patches), 9 * 9 * 9, num_classes))))

        for i in range(len(label_patches)) :
            y[i+y_length, :, :] = np_utils.to_categorical(label_patches[i].flatten(), num_classes)

        del label_patches

        # Sampling strategy: reject samples which labels are only zeros
        T1_train = extract_patches(T1_vols[idx], patch_shape, extraction_step)
        x[y_length:, 0, :, :, :] = T1_train[valid_idxs]
        del T1_train

        # Sampling strategy: reject samples which labels are only zeros
        T2_train = extract_patches(T2_vols[idx], patch_shape, extraction_step)
        x[y_length:, 1, :, :, :] = T2_train[valid_idxs]
        del T2_train
    return x, y

# Reconstruction utils
import itertools

def generate_indexes(patch_shape, expected_shape) :
    ndims = len(patch_shape)

    poss_shape = [patch_shape[i+1] * (expected_shape[i] // patch_shape[i+1]) for i in range(ndims-1)]

    idxs = [range(patch_shape[i+1], poss_shape[i] - patch_shape[i+1], patch_shape[i+1]) for i in range(ndims-1)]

    return itertools.product(*idxs)

def reconstruct_volume(patches, expected_shape) :
    patch_shape = patches.shape

    assert len(patch_shape) - 1 == len(expected_shape)

    reconstructed_img = np.zeros(expected_shape)

    for count, coord in enumerate(generate_indexes(patch_shape, expected_shape)) :
        selection = [slice(coord[i], coord[i] + patch_shape[i+1]) for i in range(len(coord))]
        reconstructed_img[selection] = patches[count]

    return reconstructed_img

# Architecture

In [5]:
from keras import backend as K
from keras.layers import Activation
from keras.layers import Input
from keras.layers.advanced_activations import PReLU
from keras.layers.convolutional import Conv3D
from keras.layers.convolutional import Cropping3D
from keras.layers.core import Permute
from keras.layers.core import Reshape
from keras.layers.merge import concatenate
from keras.models import Model

K.set_image_dim_ordering('th')

# For understanding the architecture itself, I recommend checking the following article
# Dolz, J. et al. 3D fully convolutional networks for subcortical segmentation in MRI :
# A large-scale study. Neuroimage, 2017.
def generate_model(num_classes) :
    init_input = Input((2, 27, 27, 27))

    x = Conv3D(5, kernel_size=(3, 3, 3))(init_input)
    x = PReLU()(x)
    x = Conv3D(5, kernel_size=(3, 3, 3))(x)
    x = PReLU()(x)
    x = Conv3D(5, kernel_size=(3, 3, 3))(x)
    x = PReLU()(x)

    y = Conv3D(10, kernel_size=(3, 3, 3))(x)
    y = PReLU()(y)
    y = Conv3D(10, kernel_size=(3, 3, 3))(y)
    y = PReLU()(y)
    y = Conv3D(10, kernel_size=(3, 3, 3))(y)
    y = PReLU()(y)

    z = Conv3D(15, kernel_size=(3, 3, 3))(y)
    z = PReLU()(z)
    z = Conv3D(15, kernel_size=(3, 3, 3))(z)
    z = PReLU()(z)
    z = Conv3D(15, kernel_size=(3, 3, 3))(z)
    z = PReLU()(z)

    x_crop = Cropping3D(cropping=((6, 6), (6, 6), (6, 6)))(x)
    y_crop = Cropping3D(cropping=((3, 3), (3, 3), (3, 3)))(y)

    concat = concatenate([x_crop, y_crop, z], axis=1)

    fc = Conv3D(40, kernel_size=(1, 1, 1))(concat)
    fc = PReLU()(fc)
    fc = Conv3D(20, kernel_size=(1, 1, 1))(fc)
    fc = PReLU()(fc)
    fc = Conv3D(10, kernel_size=(1, 1, 1))(fc)
    fc = PReLU()(fc)

    pred = Conv3D(num_classes, kernel_size=(1, 1, 1))(fc)
    pred = PReLU()(pred)
    pred = Reshape((num_classes, 9 * 9 * 9))(pred)
    pred = Permute((2, 1))(pred)
    pred = Activation('softmax')(pred)

    model = Model(inputs=init_input, outputs=pred)
    model.compile(
        loss='categorical_crossentropy',
        optimizer='adam',
        metrics=['categorical_accuracy'])
    return model

# 1. Initial segmentation

## 1.1 Read data

In [6]:
T1_vols = np.empty((10, 144, 192, 256))
T2_vols = np.empty((10, 144, 192, 256))
label_vols = np.empty((10, 144, 192, 256))
for case_idx in range(1, 11) :
    T1_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'T1')
    T2_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'T2')
    label_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'label')

## 1.2 Pre-processing

In [7]:
## Intensity normalisation (zero mean and unit variance)
T1_mean = T1_vols.mean()
T1_std = T1_vols.std()
T1_vols = (T1_vols - T1_mean) / T1_std
T2_mean = T2_vols.mean()
T2_std = T2_vols.std()
T2_vols = (T2_vols - T2_mean) / T2_std

# Combine labels of BG and CSF
for class_idx in class_mapper :
    label_vols[label_vols == class_idx] = class_mapper[class_idx]

## 1.3 Data preparation

In [8]:
x_train, y_train = build_set(T1_vols, T2_vols, label_vols, (6, 16, 6))

In [9]:

abc=generate_model(3)

#What we have removed so far: images 4 -> 11
#patch size that splits into training and test set was doubled: less precision
#values from training set after 100
x_train=x_train[:100]
y_train=y_train[:100]

abc.fit(x_train,y_train,verbose=1,validation_split=0.1,epochs=20)
#How to use the above NN:

Train on 90 samples, validate on 10 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x23a8293c048>

## 1.4 Configure callbacks

In [9]:
from keras.callbacks import ModelCheckpoint
from keras.callbacks import CSVLogger
from keras.callbacks import EarlyStopping

# Early stopping for reducing over-fitting risk
stopper = EarlyStopping(patience=patience)

# Model checkpoint to save the training results
checkpointer = ModelCheckpoint(
    filepath=model_filename.format(1),
    verbose=0,
    save_best_only=True,
    save_weights_only=True)

# CSVLogger to save the training results in a csv file
csv_logger = CSVLogger(csv_filename.format(1), separator=';')

callbacks = [checkpointer, csv_logger, stopper]

## 1.5 Training

In [40]:
# Build model

model = generate_model(num_classes)

model.fit(
    x_train,
    y_train,
    epochs=nb_epoch,
    validation_split=validation_split,
    verbose=2,
    callbacks=callbacks)

# freeing space
del x_train
del y_train

Train on 12690 samples, validate on 4231 samples
Epoch 1/20
 - 785s - loss: 0.9239 - categorical_accuracy: 0.5531 - val_loss: 0.7795 - val_categorical_accuracy: 0.6649
Epoch 2/20
 - 804s - loss: 0.7565 - categorical_accuracy: 0.6534 - val_loss: 0.7211 - val_categorical_accuracy: 0.6630
Epoch 3/20
 - 802s - loss: 0.6996 - categorical_accuracy: 0.6549 - val_loss: 0.6603 - val_categorical_accuracy: 0.6737
Epoch 4/20
 - 804s - loss: 0.6585 - categorical_accuracy: 0.6733 - val_loss: 0.6246 - val_categorical_accuracy: 0.7008
Epoch 5/20
 - 796s - loss: 0.6296 - categorical_accuracy: 0.6940 - val_loss: 0.6028 - val_categorical_accuracy: 0.6900
Epoch 6/20
 - 794s - loss: 0.6046 - categorical_accuracy: 0.7077 - val_loss: 0.5877 - val_categorical_accuracy: 0.6923
Epoch 7/20
 - 804s - loss: 0.5842 - categorical_accuracy: 0.7206 - val_loss: 0.5608 - val_categorical_accuracy: 0.7372
Epoch 8/20
 - 812s - loss: 0.5709 - categorical_accuracy: 0.7275 - val_loss: 0.5528 - val_categorical_accuracy: 0.7327

Train on 12690 samples, validate on 4231 samples
Epoch 1/20
 - 785s - loss: 0.9239 - categorical_accuracy: 0.5531 - val_loss: 0.7795 - val_categorical_accuracy: 0.6649
Epoch 2/20
 - 804s - loss: 0.7565 - categorical_accuracy: 0.6534 - val_loss: 0.7211 - val_categorical_accuracy: 0.6630
Epoch 3/20
 - 802s - loss: 0.6996 - categorical_accuracy: 0.6549 - val_loss: 0.6603 - val_categorical_accuracy: 0.6737
Epoch 4/20
 - 804s - loss: 0.6585 - categorical_accuracy: 0.6733 - val_loss: 0.6246 - val_categorical_accuracy: 0.7008
Epoch 5/20
 - 796s - loss: 0.6296 - categorical_accuracy: 0.6940 - val_loss: 0.6028 - val_categorical_accuracy: 0.6900
Epoch 6/20
 - 794s - loss: 0.6046 - categorical_accuracy: 0.7077 - val_loss: 0.5877 - val_categorical_accuracy: 0.6923
Epoch 7/20
 - 804s - loss: 0.5842 - categorical_accuracy: 0.7206 - val_loss: 0.5608 - val_categorical_accuracy: 0.7372
Epoch 8/20
 - 812s - loss: 0.5709 - categorical_accuracy: 0.7275 - val_loss: 0.5528 - val_categorical_accuracy: 0.7327
Epoch 9/20
 - 851s - loss: 0.5598 - categorical_accuracy: 0.7341 - val_loss: 0.5477 - val_categorical_accuracy: 0.7284
Epoch 10/20
 - 5835s - loss: 0.5520 - categorical_accuracy: 0.7383 - val_loss: 0.5425 - val_categorical_accuracy: 0.7312
Epoch 11/20
 - 799s - loss: 0.5432 - categorical_accuracy: 0.7437 - val_loss: 0.5283 - val_categorical_accuracy: 0.7534
Epoch 12/20
 - 845s - loss: 0.5353 - categorical_accuracy: 0.7489 - val_loss: 0.5261 - val_categorical_accuracy: 0.7599
Epoch 13/20
 - 861s - loss: 0.5303 - categorical_accuracy: 0.7511 - val_loss: 0.5148 - val_categorical_accuracy: 0.7570
Epoch 14/20
 - 858s - loss: 0.5227 - categorical_accuracy: 0.7558 - val_loss: 0.5061 - val_categorical_accuracy: 0.7658
Epoch 15/20
 - 857s - loss: 0.5162 - categorical_accuracy: 0.7603 - val_loss: 0.5083 - val_categorical_accuracy: 0.7576

In [10]:


'''#Training large data by breaking it into batches

model = generate_model(num_classes)

for e in range(10) :
    print ("epoch %d" %e)
    for step in range(10) :
        x_train, train_y = LoadTrainBatch(32)  #batch_size = 32
        x_train = np.asarray(x_train)
        y_train = np.asarray(y_train) 
        for train_X , train_Y in zip(x_train ,train_y) :
            train_X = train_X.reshape(1, row, col, 3)
            model.fit(train_X, train_Y, nb_epoch=1, callbacks=callbacks,verbose=2,validation_split=validation_split)
            
del train_X            
del train_Y
'''
# Build model

model = generate_model(num_classes)

model.fit(
    x_train,
    y_train,
    epochs=nb_epoch,
    validation_split=validation_split,
    verbose=2,
    batch_size=15,
    callbacks=callbacks) 

# freeing space
del x_train
del y_train 

Train on 12690 samples, validate on 4231 samples
Epoch 1/20
 - 2945s - loss: 0.5372 - categorical_accuracy: 0.7358 - val_loss: 0.4076 - val_categorical_accuracy: 0.8119
Epoch 2/20
 - 2978s - loss: 0.3781 - categorical_accuracy: 0.8286 - val_loss: 0.3377 - val_categorical_accuracy: 0.8493
Epoch 3/20
 - 2982s - loss: 0.3280 - categorical_accuracy: 0.8540 - val_loss: 0.3111 - val_categorical_accuracy: 0.8639
Epoch 4/20
 - 2989s - loss: 0.3017 - categorical_accuracy: 0.8665 - val_loss: 0.3081 - val_categorical_accuracy: 0.8626
Epoch 5/20
 - 2979s - loss: 0.2742 - categorical_accuracy: 0.8795 - val_loss: 0.2818 - val_categorical_accuracy: 0.8744
Epoch 6/20
 - 3088s - loss: 0.2615 - categorical_accuracy: 0.8854 - val_loss: 0.2582 - val_categorical_accuracy: 0.8862
Epoch 7/20
 - 3168s - loss: 0.2437 - categorical_accuracy: 0.8938 - val_loss: 0.2598 - val_categorical_accuracy: 0.8869


def generate_model(num_classes) :
    init_input = Input((2, 27, 27, 27))

    x = Conv3D(5, kernel_size=(3, 3, 3))(init_input)
    x = PReLU()(x)
    x = Conv3D(5, kernel_size=(3, 3, 3))(x)
    x = PReLU()(x)
    x = Conv3D(5, kernel_size=(3, 3, 3))(x)
    x = PReLU()(x)

    y = Conv3D(10, kernel_size=(3, 3, 3))(x)
    y = PReLU()(y)
    y = Conv3D(10, kernel_size=(3, 3, 3))(y)
    y = PReLU()(y)
    y = Conv3D(10, kernel_size=(3, 3, 3))(y)
    y = PReLU()(y)

    z = Conv3D(15, kernel_size=(3, 3, 3))(y)
    z = PReLU()(z)
    z = Conv3D(15, kernel_size=(3, 3, 3))(z)
    z = PReLU()(z)
    z = Conv3D(15, kernel_size=(3, 3, 3))(z)
    z = PReLU()(z)

    x_crop = Cropping3D(cropping=((6, 6), (6, 6), (6, 6)))(x)
    y_crop = Cropping3D(cropping=((3, 3), (3, 3), (3, 3)))(y)

    concat = concatenate([x_crop, y_crop, z], axis=1)

    fc = Conv3D(40, kernel_size=(1, 1, 1))(concat)
    fc = PReLU()(fc)
    fc = Conv3D(20, kernel_size=(1, 1, 1))(fc)
    fc = PReLU()(fc)
    fc = Conv3D(10, kernel_size=(1, 1, 1))(fc)
    fc = PReLU()(fc)

Train on 12690 samples, validate on 4231 samples
Epoch 1/20
 - 2945s - loss: 0.5372 - categorical_accuracy: 0.7358 - val_loss: 0.4076 - val_categorical_accuracy: 0.8119
Epoch 2/20
 - 2978s - loss: 0.3781 - categorical_accuracy: 0.8286 - val_loss: 0.3377 - val_categorical_accuracy: 0.8493
Epoch 3/20
 - 2982s - loss: 0.3280 - categorical_accuracy: 0.8540 - val_loss: 0.3111 - val_categorical_accuracy: 0.8639
Epoch 4/20
 - 2989s - loss: 0.3017 - categorical_accuracy: 0.8665 - val_loss: 0.3081 - val_categorical_accuracy: 0.8626
Epoch 5/20
 - 2979s - loss: 0.2742 - categorical_accuracy: 0.8795 - val_loss: 0.2818 - val_categorical_accuracy: 0.8744
Epoch 6/20
 - 3088s - loss: 0.2615 - categorical_accuracy: 0.8854 - val_loss: 0.2582 - val_categorical_accuracy: 0.8862
Epoch 7/20
 - 3168s - loss: 0.2437 - categorical_accuracy: 0.8938 - val_loss: 0.2598 - val_categorical_accuracy: 0.8869

In [11]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 2, 27, 27, 27 0                                            
__________________________________________________________________________________________________
conv3d_1 (Conv3D)               (None, 5, 25, 25, 25 275         input_1[0][0]                    
__________________________________________________________________________________________________
p_re_lu_1 (PReLU)               (None, 5, 25, 25, 25 78125       conv3d_1[0][0]                   
__________________________________________________________________________________________________
conv3d_2 (Conv3D)               (None, 5, 23, 23, 23 680         p_re_lu_1[0][0]                  
__________________________________________________________________________________________________
p_re_lu_2 

In [54]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_8 (InputLayer)            (None, 2, 27, 27, 27 0                                            
__________________________________________________________________________________________________
conv3d_92 (Conv3D)              (None, 1, 25, 25, 25 55          input_8[0][0]                    
__________________________________________________________________________________________________
p_re_lu_92 (PReLU)              (None, 1, 25, 25, 25 15625       conv3d_92[0][0]                  
__________________________________________________________________________________________________
conv3d_93 (Conv3D)              (None, 1, 23, 23, 23 28          p_re_lu_92[0][0]                 
__________________________________________________________________________________________________
p_re_lu_93

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_8 (InputLayer)            (None, 2, 27, 27, 27 0                                            
__________________________________________________________________________________________________
conv3d_92 (Conv3D)              (None, 1, 25, 25, 25 55          input_8[0][0]                    
__________________________________________________________________________________________________
p_re_lu_92 (PReLU)              (None, 1, 25, 25, 25 15625       conv3d_92[0][0]                  
__________________________________________________________________________________________________
conv3d_93 (Conv3D)              (None, 1, 23, 23, 23 28          p_re_lu_92[0][0]                 
__________________________________________________________________________________________________
p_re_lu_93 (PReLU)              (None, 1, 23, 23, 23 12167       conv3d_93[0][0]                  
__________________________________________________________________________________________________
conv3d_94 (Conv3D)              (None, 1, 21, 21, 21 28          p_re_lu_93[0][0]                 
__________________________________________________________________________________________________
p_re_lu_94 (PReLU)              (None, 1, 21, 21, 21 9261        conv3d_94[0][0]                  
__________________________________________________________________________________________________
conv3d_95 (Conv3D)              (None, 1, 19, 19, 19 28          p_re_lu_94[0][0]                 
__________________________________________________________________________________________________
p_re_lu_95 (PReLU)              (None, 1, 19, 19, 19 6859        conv3d_95[0][0]                  
__________________________________________________________________________________________________
conv3d_96 (Conv3D)              (None, 1, 17, 17, 17 28          p_re_lu_95[0][0]                 
__________________________________________________________________________________________________
p_re_lu_96 (PReLU)              (None, 1, 17, 17, 17 4913        conv3d_96[0][0]                  
__________________________________________________________________________________________________
conv3d_97 (Conv3D)              (None, 1, 15, 15, 15 28          p_re_lu_96[0][0]                 
__________________________________________________________________________________________________
p_re_lu_97 (PReLU)              (None, 1, 15, 15, 15 3375        conv3d_97[0][0]                  
__________________________________________________________________________________________________
conv3d_98 (Conv3D)              (None, 1, 13, 13, 13 28          p_re_lu_97[0][0]                 
__________________________________________________________________________________________________
p_re_lu_98 (PReLU)              (None, 1, 13, 13, 13 2197        conv3d_98[0][0]                  
__________________________________________________________________________________________________
conv3d_99 (Conv3D)              (None, 1, 11, 11, 11 28          p_re_lu_98[0][0]                 
__________________________________________________________________________________________________
p_re_lu_99 (PReLU)              (None, 1, 11, 11, 11 1331        conv3d_99[0][0]                  
__________________________________________________________________________________________________
conv3d_100 (Conv3D)             (None, 1, 9, 9, 9)   28          p_re_lu_99[0][0]                 
__________________________________________________________________________________________________
cropping3d_15 (Cropping3D)      (None, 1, 9, 9, 9)   0           p_re_lu_94[0][0]                 
__________________________________________________________________________________________________
cropping3d_16 (Cropping3D)      (None, 1, 9, 9, 9)   0           p_re_lu_97[0][0]                 
__________________________________________________________________________________________________
p_re_lu_100 (PReLU)             (None, 1, 9, 9, 9)   729         conv3d_100[0][0]                 
__________________________________________________________________________________________________
concatenate_8 (Concatenate)     (None, 3, 9, 9, 9)   0           cropping3d_15[0][0]              
                                                                 cropping3d_16[0][0]              
                                                                 p_re_lu_100[0][0]                
__________________________________________________________________________________________________
conv3d_101 (Conv3D)             (None, 4, 9, 9, 9)   16          concatenate_8[0][0]              
__________________________________________________________________________________________________
p_re_lu_101 (PReLU)             (None, 4, 9, 9, 9)   2916        conv3d_101[0][0]                 
__________________________________________________________________________________________________
conv3d_102 (Conv3D)             (None, 2, 9, 9, 9)   10          p_re_lu_101[0][0]                
__________________________________________________________________________________________________
p_re_lu_102 (PReLU)             (None, 2, 9, 9, 9)   1458        conv3d_102[0][0]                 
__________________________________________________________________________________________________
conv3d_103 (Conv3D)             (None, 1, 9, 9, 9)   3           p_re_lu_102[0][0]                
__________________________________________________________________________________________________
p_re_lu_103 (PReLU)             (None, 1, 9, 9, 9)   729         conv3d_103[0][0]                 
__________________________________________________________________________________________________
conv3d_104 (Conv3D)             (None, 3, 9, 9, 9)   6           p_re_lu_103[0][0]                
__________________________________________________________________________________________________
p_re_lu_104 (PReLU)             (None, 3, 9, 9, 9)   2187        conv3d_104[0][0]                 
__________________________________________________________________________________________________
reshape_8 (Reshape)             (None, 3, 729)       0           p_re_lu_104[0][0]                
__________________________________________________________________________________________________
permute_8 (Permute)             (None, 729, 3)       0           reshape_8[0][0]                  
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 729, 3)       0           permute_8[0][0]                  
==================================================================================================
Total params: 64,061
Trainable params: 64,061
Non-trainable params: 0
__________________________________________________________________________________________________

## 1.6 Classification

In [12]:
from keras.models import load_model

# Load best model
model = generate_model(num_classes)
model.load_weights(model_filename.format(1))

In [13]:
for case_idx in range(11, 24) :
    T1_test_vol = read_vol(case_idx, 'T1')[:144, :192, :256]
    T2_test_vol = read_vol(case_idx, 'T2')[:144, :192, :256]
    
    x_test = np.zeros((6916, 2, 27, 27, 27))
    x_test[:, 0, :, :, :] = extract_patches(T1_test_vol, patch_shape=(27, 27, 27), extraction_step=(9, 9, 9))
    x_test[:, 1, :, :, :] = extract_patches(T2_test_vol, patch_shape=(27, 27, 27), extraction_step=(9, 9, 9))
    
    x_test[:, 0, :, :, :] = (x_test[:, 0, :, :, :] - T1_mean) / T1_std
    x_test[:, 1, :, :, :] = (x_test[:, 1, :, :, :] - T2_mean) / T2_std

    pred = model.predict(x_test, verbose=2)
    pred_classes = np.argmax(pred, axis=2)
    pred_classes = pred_classes.reshape((len(pred_classes), 9, 9, 9))
    segmentation = reconstruct_volume(pred_classes, (144, 192, 256))
    
    csf = np.logical_and(segmentation == 0, T1_test_vol != 0)
    segmentation[segmentation == 2] = 250
    segmentation[segmentation == 1] = 150
    segmentation[csf] = 10
    
    save_vol(segmentation, case_idx)
    
    print( "Finished segmentation of case # {}" .format(case_idx)) #

print( "Done with Step 1")

Finished segmentation of case # 11
Finished segmentation of case # 12
Finished segmentation of case # 13
Finished segmentation of case # 14
Finished segmentation of case # 15
Finished segmentation of case # 16
Finished segmentation of case # 17
Finished segmentation of case # 18
Finished segmentation of case # 19
Finished segmentation of case # 20
Finished segmentation of case # 21
Finished segmentation of case # 22
Finished segmentation of case # 23
Done with Step 1


# 2. Pseudo-labelling step

## 2.1 Read data

In [14]:
sure = range(0, 10)
unsure = range(11, 23)

T1_vols = np.empty((23, 144, 192, 256))
T2_vols = np.empty((23, 144, 192, 256))
label_vols = np.empty((23, 144, 192, 256))
for case_idx in range(1, 24) :
    loc = 'datasets' if case_idx < 11 else 'results'

    T1_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'T1')[:144, :192, :256]
    T2_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'T2')[:144, :192, :256]
    label_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'label', loc)[:144, :192, :256]

## 2.2 Pre-processing

In [15]:
## Intensity normalisation (zero mean and unit variance)
T1_mean = T1_vols.mean()
T1_std = T1_vols.std()
T1_vols = (T1_vols - T1_mean) / T1_std
T2_mean = T2_vols.mean()
T2_std = T2_vols.std()
T2_vols = (T2_vols - T2_mean) / T2_std

# Combine labels of BG and CSF
for class_idx in class_mapper :
    label_vols[label_vols == class_idx] = class_mapper[class_idx]

## 2.3 Data preparation

In [16]:


x_sure, y_sure = build_set(T1_vols[sure], T2_vols[sure], label_vols[sure], (12,32,12))
x_unsure, y_unsure = build_set(T1_vols[unsure], T2_vols[unsure], label_vols[unsure])

#x_train=x_train[:100]
#y_train=y_train[:100]

x_train = np.vstack((x_sure, x_unsure))
y_train = np.vstack((y_sure, y_unsure))

del x_sure
del x_unsure
del y_sure
del y_unsure

## 2.4 Configure callbacks

In [18]:
from keras.callbacks import ModelCheckpoint
from keras.callbacks import CSVLogger
from keras.callbacks import EarlyStopping

# Early stopping for reducing over-fitting risk
stopper = EarlyStopping(patience=patience)

# Model checkpoint to save the training results
checkpointer = ModelCheckpoint(
    filepath=model_filename.format(2),
    verbose=0,
    save_best_only=True,
    save_weights_only=True)

# CSVLogger to save the training results in a csv file
csv_logger = CSVLogger(csv_filename.format(2), separator=';')

callbacks = [checkpointer, csv_logger, stopper]

## 2.5 Training

In [17]:
# Build model
model = generate_model(num_classes)

model.fit(
    x_train,
    y_train,
    epochs=nb_epoch,
    validation_split=validation_split,
    verbose=2,
    callbacks=callbacks)

# freeing space
del x_train
del y_train

Train on 13983 samples, validate on 4662 samples
Epoch 1/20
 - 845s - loss: 0.9262 - categorical_accuracy: 0.5298 - val_loss: 0.7076 - val_categorical_accuracy: 0.6720
Epoch 2/20
 - 862s - loss: 0.7001 - categorical_accuracy: 0.6453 - val_loss: 0.6428 - val_categorical_accuracy: 0.6802
Epoch 3/20
 - 861s - loss: 0.6492 - categorical_accuracy: 0.6766 - val_loss: 0.6083 - val_categorical_accuracy: 0.7065
Epoch 4/20
 - 862s - loss: 0.5927 - categorical_accuracy: 0.7192 - val_loss: 0.5290 - val_categorical_accuracy: 0.7628
Epoch 5/20
 - 862s - loss: 0.5018 - categorical_accuracy: 0.7797 - val_loss: 0.4419 - val_categorical_accuracy: 0.8181
Epoch 6/20
 - 859s - loss: 0.4466 - categorical_accuracy: 0.8102 - val_loss: 0.4190 - val_categorical_accuracy: 0.8216
Epoch 7/20
 - 830s - loss: 0.4199 - categorical_accuracy: 0.8222 - val_loss: 0.3892 - val_categorical_accuracy: 0.8340
Epoch 8/20
 - 834s - loss: 0.4000 - categorical_accuracy: 0.8312 - val_loss: 0.3747 - val_categorical_accuracy: 0.8490

# 2.5 Results 
Train on 13983 samples, validate on 4662 samples
Epoch 1/20
 - 845s - loss: 0.9262 - categorical_accuracy: 0.5298 - val_loss: 0.7076 - val_categorical_accuracy: 0.6720
Epoch 2/20
 - 862s - loss: 0.7001 - categorical_accuracy: 0.6453 - val_loss: 0.6428 - val_categorical_accuracy: 0.6802
Epoch 3/20
 - 861s - loss: 0.6492 - categorical_accuracy: 0.6766 - val_loss: 0.6083 - val_categorical_accuracy: 0.7065
Epoch 4/20
 - 862s - loss: 0.5927 - categorical_accuracy: 0.7192 - val_loss: 0.5290 - val_categorical_accuracy: 0.7628
Epoch 5/20
 - 862s - loss: 0.5018 - categorical_accuracy: 0.7797 - val_loss: 0.4419 - val_categorical_accuracy: 0.8181
Epoch 6/20
 - 859s - loss: 0.4466 - categorical_accuracy: 0.8102 - val_loss: 0.4190 - val_categorical_accuracy: 0.8216
Epoch 7/20
 - 830s - loss: 0.4199 - categorical_accuracy: 0.8222 - val_loss: 0.3892 - val_categorical_accuracy: 0.8340
Epoch 8/20
 - 834s - loss: 0.4000 - categorical_accuracy: 0.8312 - val_loss: 0.3747 - val_categorical_accuracy: 0.8490
Epoch 9/20
 - 842s - loss: 0.3860 - categorical_accuracy: 0.8378 - val_loss: 0.3521 - val_categorical_accuracy: 0.8616
Epoch 10/20
 - 835s - loss: 0.3740 - categorical_accuracy: 0.8432 - val_loss: 0.3317 - val_categorical_accuracy: 0.8649
Epoch 11/20
 - 837s - loss: 0.3649 - categorical_accuracy: 0.8468 - val_loss: 0.3864 - val_categorical_accuracy: 0.8322

In [21]:
from sklearn.metrics import classification_report
import numpy as np

In [22]:
# Build model
model = generate_model(num_classes)

model.fit(
    x_train,
    y_train,
    epochs=nb_epoch,
    validation_split=validation_split,
    verbose=2,
    batch_size=32,
    callbacks=callbacks)


from sklearn.metrics import classification_report
import numpy as np
preds=np.flatten(model.predict(x_train))
y_train=np.flatten(y_train)
print(classification_report(y_train,preds))


# freeing space
del x_train
del y_train

Train on 13845 samples, validate on 4615 samples
Epoch 1/20
 - 3445s - loss: 0.5010 - categorical_accuracy: 0.7540 - val_loss: 0.4341 - val_categorical_accuracy: 0.7892
Epoch 2/20
 - 3412s - loss: 0.3368 - categorical_accuracy: 0.8483 - val_loss: 0.2886 - val_categorical_accuracy: 0.8730
Epoch 3/20
 - 3418s - loss: 0.2813 - categorical_accuracy: 0.8776 - val_loss: 0.2503 - val_categorical_accuracy: 0.8913
Epoch 4/20
 - 3406s - loss: 0.2426 - categorical_accuracy: 0.8957 - val_loss: 0.2463 - val_categorical_accuracy: 0.8939
Epoch 5/20
 - 3395s - loss: 0.2182 - categorical_accuracy: 0.9065 - val_loss: 0.2161 - val_categorical_accuracy: 0.9064
Epoch 6/20
 - 3400s - loss: 0.2037 - categorical_accuracy: 0.9129 - val_loss: 0.1802 - val_categorical_accuracy: 0.9229
Epoch 7/20
 - 3395s - loss: 0.1968 - categorical_accuracy: 0.9160 - val_loss: 0.1776 - val_categorical_accuracy: 0.9244
Epoch 8/20
 - 3359s - loss: 0.1861 - categorical_accuracy: 0.9206 - val_loss: 0.1680 - val_categorical_accuracy

AttributeError: module 'numpy' has no attribute 'flatten'

2.5 results : 

Train on 13845 samples, validate on 4615 samples
Epoch 1/20
 - 3445s - loss: 0.5010 - categorical_accuracy: 0.7540 - val_loss: 0.4341 - val_categorical_accuracy: 0.7892
Epoch 2/20
 - 3412s - loss: 0.3368 - categorical_accuracy: 0.8483 - val_loss: 0.2886 - val_categorical_accuracy: 0.8730
Epoch 3/20
 - 3418s - loss: 0.2813 - categorical_accuracy: 0.8776 - val_loss: 0.2503 - val_categorical_accuracy: 0.8913
Epoch 4/20
 - 3406s - loss: 0.2426 - categorical_accuracy: 0.8957 - val_loss: 0.2463 - val_categorical_accuracy: 0.8939
Epoch 5/20
 - 3395s - loss: 0.2182 - categorical_accuracy: 0.9065 - val_loss: 0.2161 - val_categorical_accuracy: 0.9064
Epoch 6/20
 - 3400s - loss: 0.2037 - categorical_accuracy: 0.9129 - val_loss: 0.1802 - val_categorical_accuracy: 0.9229
Epoch 7/20
 - 3395s - loss: 0.1968 - categorical_accuracy: 0.9160 - val_loss: 0.1776 - val_categorical_accuracy: 0.9244
Epoch 8/20
 - 3359s - loss: 0.1861 - categorical_accuracy: 0.9206 - val_loss: 0.1680 - val_categorical_accuracy: 0.9279
Epoch 9/20
 - 3353s - loss: 0.1794 - categorical_accuracy: 0.9236 - val_loss: 0.1614 - val_categorical_accuracy: 0.9305
Epoch 10/20
 - 3348s - loss: 0.1741 - categorical_accuracy: 0.9259 - val_loss: 0.1613 - val_categorical_accuracy: 0.9309
Epoch 11/20
 - 3347s - loss: 0.1710 - categorical_accuracy: 0.9273 - val_loss: 0.1643 - val_categorical_accuracy: 0.9302

In [27]:
from sklearn.metrics import classification_report
import numpy as np
preds=np.ndarray.flatten(model.predict(x_train))
#preds=preds.astype(float)
y_train=np.ndarray.flatten(y_train)

y_train=(y_train.astype(int)).astype(str)
preds=(preds.astype(int)).astype(str)
print(classification_report(y_train,preds))

             precision    recall  f1-score   support

          0       0.72      1.00      0.84  26914680
          1       1.00      0.21      0.35  13457340

avg / total       0.81      0.74      0.67  40372020



In [45]:
print(y_train[:10])

print(preds[:10])

['1' '0' '0' '1' '0' '0' '1' '0' '0' '1']
['1' '0' '0' '1' '0' '0' '1' '0' '0' '1']


In [38]:
#test_eval = model.evaluate(y_train,preds, verbose=0)
print(y_train.shape)
print(preds.shape)

(40372020,)
(40372020,)


# Report 
from sklearn.metrics import classification_report

import numpy as np

preds=np.ndarray.flatten(model.predict(x_train))

#preds=preds.astype(float)
y_train=np.ndarray.flatten(y_train)

y_train=(y_train.astype(int)).astype(str)

preds=(preds.astype(int)).astype(str)

print(classification_report(y_train,preds))



         precision    recall  f1-score   support

          0       0.72      1.00      0.84  26914680
          1       1.00      0.21      0.35  13457340

avg / total       0.81      0.74      0.67  40372020

## 2.6 Clasification

In [21]:
from keras.models import load_model

# Load best model
model = generate_model(num_classes)
model.load_weights(model_filename.format(2))

In [22]:
for case_idx in range(11, 24) :
    T1_test_vol = read_vol(case_idx, 'T1')[:144, :192, :256]
    T2_test_vol = read_vol(case_idx, 'T2')[:144, :192, :256]
    
    x_test = np.zeros((6916, 2, 27, 27, 27))
    x_test[:, 0, :, :, :] = extract_patches(T1_test_vol, patch_shape=(27, 27, 27), extraction_step=(9, 9, 9))
    x_test[:, 1, :, :, :] = extract_patches(T2_test_vol, patch_shape=(27, 27, 27), extraction_step=(9, 9, 9))
    
    x_test[:, 0, :, :, :] = (x_test[:, 0, :, :, :] - T1_mean) / T1_std
    x_test[:, 1, :, :, :] = (x_test[:, 1, :, :, :] - T2_mean) / T2_std

    pred = model.predict(x_test, verbose=2)
    pred_classes = np.argmax(pred, axis=2)
    pred_classes = pred_classes.reshape((len(pred_classes), 9, 9, 9))
    segmentation = reconstruct_volume(pred_classes, (144, 192, 256))
    
    csf = np.logical_and(segmentation == 0, T1_test_vol != 0)
    segmentation[segmentation == 2] = 250
    segmentation[segmentation == 1] = 150
    segmentation[csf] = 10
    
    save_vol(segmentation, case_idx, 'refined-results')
    
    print( "Finished segmentation of case # {}".format(case_idx))

print( "Done with Step 2")

Finished segmentation of case # 11
Finished segmentation of case # 12
Finished segmentation of case # 13
Finished segmentation of case # 14
Finished segmentation of case # 15
Finished segmentation of case # 16
Finished segmentation of case # 17
Finished segmentation of case # 18
Finished segmentation of case # 19
Finished segmentation of case # 20
Finished segmentation of case # 21
Finished segmentation of case # 22
Finished segmentation of case # 23
Done with Step 2


In [None]:
"""Skip to content
 
Search or jump to…

Pull requests
Issues
Marketplace
Explore
 @GKKhan21 Sign out
0
0 0 Rickey985/ISeg
 Code  Issues 0  Pull requests 0  Projects 0  Wiki  Insights
ISeg/Data Loader2.py
bf46afd  2 hours ago
@Rickey985 Rickey985 Add files via upload
     
241 lines (183 sloc)  7.8 KB"""


import nibabel as nib
import numpy as np

# Fix random seed for reproducibility?
# Better to follow the advice in Keras FAQ:
#  "How can I obtain reproducible results using Keras during development?"
seed = 7
np.random.seed(seed)

num_classes = 3

patience = 1
model_filename = 'models/iSeg2017/outrun_step_{}.h5'
csv_filename = 'log/iSeg2017/outrun_step_{}.cvs'

nb_epoch = 20
validation_split = 0.25

class_mapper = {0 : 0, 10 : 0, 150 : 1, 250 : 2}
class_mapper_inv = {0 : 0, 1 : 10, 2 : 150, 3 : 250}


# General utils for reading and saving data
def get_filename(set_name, case_idx, input_name, loc='datasets'):
    pattern = '{0}/iSeg2017/iSeg-2017-{1}/subject-{2}-{3}.hdr'
    return pattern.format(loc, set_name, case_idx, input_name)
def get_set_name(case_idx):
    return 'Training' if case_idx < 4 else 'Testing'
def read_data(case_idx, input_name, loc='datasets'):
    set_name = get_set_name(case_idx)
    image_path = get_filename(set_name, case_idx, input_name, loc)
    return nib.load(image_path)
def read_vol(case_idx, input_name, loc='datasets'):
    image_data = read_data(case_idx, input_name, loc)
    return image_data.get_data()[:, :, :, 0]

def save_vol(segmentation, case_idx, loc='results'):
    set_name = get_set_name(case_idx)
    input_image_data = read_data(case_idx, 'T1')
    segmentation_vol = np.empty(input_image_data.shape)
    segmentation_vol[:144, :192, :256, 0] = segmentation
    filename = get_filename(set_name, case_idx, 'label', loc)
    nib.save(nib.analyze.AnalyzeImage(
        segmentation_vol.astype('uint8'), input_image_data.affine), filename)


# Data preparation utils
from keras.utils import np_utils
from sklearn.feature_extraction.image import extract_patches as sk_extract_patches


def extract_patches(volume, patch_shape, extraction_step):
    patches = sk_extract_patches(
        volume,
        patch_shape=patch_shape,
        extraction_step=extraction_step)

    ndim = len(volume.shape)
    npatches = np.prod(patches.shape[:ndim])
    return patches.reshape((npatches,) + patch_shape)


def build_set(T1_vols, T2_vols, label_vols, extraction_step):
    patch_shape = (27, 27, 27)
    label_selector = [slice(None)] + [slice(9, 18) for i in range(3)]

    # Extract patches from input volumes and ground truth
    x = np.zeros((0, 2, 27, 27, 27))
    y = np.zeros((0, 9 * 9 * 9, num_classes))
    for idx in range(len(T1_vols)):
        y_length = len(y)

        label_patches = extract_patches(label_vols[idx], patch_shape, extraction_step)
        label_patches = label_patches[label_selector]

        # Select only those who are important for processing
        valid_idxs = np.where(np.sum(label_patches, axis=(1, 2, 3)) != 0)

        # Filtering extracted patches
        label_patches = label_patches[valid_idxs]

        x = np.vstack((x, np.zeros((len(label_patches), 2, 27, 27, 27))))
        y = np.vstack((y, np.zeros((len(label_patches), 9 * 9 * 9, num_classes))))

        for i in range(len(label_patches)):
            y[i + y_length, :, :] = np_utils.to_categorical(label_patches[i].flatten(), num_classes)

        del label_patches

        # Sampling strategy: reject samples which labels are only zeros
        T1_train = extract_patches(T1_vols[idx], patch_shape, extraction_step)
        x[y_length:, 0, :, :, :] = T1_train[valid_idxs]
        del T1_train

        # Sampling strategy: reject samples which labels are only zeros
        T2_train = extract_patches(T2_vols[idx], patch_shape, extraction_step)
        x[y_length:, 1, :, :, :] = T2_train[valid_idxs]
        del T2_train
    return x, y


# Reconstruction utils
import itertools


def generate_indexes(patch_shape, expected_shape):
    ndims = len(patch_shape)

    poss_shape = [patch_shape[i + 1] * (expected_shape[i] // patch_shape[i + 1]) for i in range(ndims - 1)]

    idxs = [range(patch_shape[i + 1], poss_shape[i] - patch_shape[i + 1], patch_shape[i + 1]) for i in range(ndims - 1)]

    return itertools.product(*idxs)


def reconstruct_volume(patches, expected_shape):
    patch_shape = patches.shape

    assert len(patch_shape) - 1 == len(expected_shape)

    reconstructed_img = np.zeros(expected_shape)

    for count, coord in enumerate(generate_indexes(patch_shape, expected_shape)):
        selection = [slice(coord[i], coord[i] + patch_shape[i + 1]) for i in range(len(coord))]
        reconstructed_img[selection] = patches[count]

    return reconstructed_img


T1_vols = np.empty((3, 144, 192, 256))
T2_vols = np.empty((3, 144, 192, 256))
label_vols = np.empty((3, 144, 192, 256))
for case_idx in range(1, 4) :
    T1_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'T1')
    T2_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'T2')
    label_vols[(case_idx - 1), :, :, :] = read_vol(case_idx, 'label')

## Intensity normalisation (zero mean and unit variance)
T1_mean = T1_vols.mean()
T1_std = T1_vols.std()
T1_vols = (T1_vols - T1_mean) / T1_std
T2_mean = T2_vols.mean()
T2_std = T2_vols.std()
T2_vols = (T2_vols - T2_mean) / T2_std

# Combine labels of BG and CSF
for class_idx in class_mapper :
    label_vols[label_vols == class_idx] = class_mapper[class_idx]

x_train, y_train = build_set(T1_vols, T2_vols, label_vols, (6, 16, 6))

del T1_vols
del T2_vols

from keras.layers import Activation
from keras.layers import Input
from keras.layers.advanced_activations import PReLU
from keras.layers.convolutional import Conv3D
from keras.layers.convolutional import Cropping3D
from keras.layers.core import Permute
from keras.layers.core import Reshape
from keras.layers.merge import concatenate
from keras.models import Model
print(np.shape(x_train))
print(np.shape(y_train))

from keras import backend as K
from keras.layers import Activation
from keras.layers import Input
from keras.layers.advanced_activations import PReLU
from keras.layers.convolutional import Conv3D
from keras.layers.convolutional import Cropping3D
from keras.layers.core import Permute
from keras.layers.core import Reshape
from keras.layers.merge import concatenate
from keras.models import Model

K.set_image_dim_ordering('th')

def generate_model(num_classes) :
    init_input = Input((2, 27, 27, 27))

    x = Conv3D(1, kernel_size=(3, 3, 3))(init_input)
    x = PReLU()(x)
    x = Conv3D(1, kernel_size=(3, 3, 3))(x)
    x = PReLU()(x)
    x = Conv3D(1, kernel_size=(3, 3, 3))(x)
    x = PReLU()(x)

    y = Conv3D(1, kernel_size=(3, 3, 3))(x)
    y = PReLU()(y)
    y = Conv3D(1, kernel_size=(3, 3, 3))(y)
    y = PReLU()(y)
    y = Conv3D(1, kernel_size=(3, 3, 3))(y)
    y = PReLU()(y)

    z = Conv3D(1, kernel_size=(3, 3, 3))(y)
    z = PReLU()(z)
    z = Conv3D(1, kernel_size=(3, 3, 3))(z)
    z = PReLU()(z)
    z = Conv3D(1, kernel_size=(3, 3, 3))(z)
    z = PReLU()(z)

    x_crop = Cropping3D(cropping=((6, 6), (6, 6), (6, 6)))(x)
    y_crop = Cropping3D(cropping=((3, 3), (3, 3), (3, 3)))(y)

    concat = concatenate([x_crop, y_crop, z], axis=1)

    fc = Conv3D(4, kernel_size=(1, 1, 1))(concat)
    fc = PReLU()(fc)
    fc = Conv3D(2, kernel_size=(1, 1, 1))(fc)
    fc = PReLU()(fc)
    fc = Conv3D(3, kernel_size=(1, 1, 1))(fc)
    fc = PReLU()(fc)

    pred = Conv3D(num_classes, kernel_size=(1, 1, 1))(fc)
    pred = PReLU()(pred)
    pred = Reshape((num_classes, 9 * 9 * 9))(pred)
    pred = Permute((2, 1))(pred)
    pred = Activation('softmax')(pred)

    model = Model(inputs=init_input, outputs=pred)
    model.compile(
        loss='categorical_crossentropy',
        optimizer='adam',
        metrics=['categorical_accuracy'])
    return model

abc=generate_model(3)

#What we have removed so far: images 4 -> 11
#patch size that splits into training and test set was doubled: less precision
#values from training set after 100
x_train=x_train[:100]
y_train=y_train[:100]

abc.fit(x_train,y_train,verbose=1,validation_split=0.1,epochs=10)
#How to use the above NN:
#increase number of epochs
#increase the number you see right after Conv3D
#change optimizer to Nadam, or from the e-mail I sent you
'''© 2018 GitHub, Inc.
Terms
Privacy
Security
Status
Help
Contact GitHub
API
Training
Shop
Blog
About
Press h to open a hovercard with more details.'''