In [1]:
import os
import pprint

import keras
import numpy as np
import sklearn.metrics
import tensorflow as tf

from keras import backend as K
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from keras.layers import average, AveragePooling2D, concatenate, Conv2D, Conv3D, Dense, Flatten, Input, Reshape, MaxPooling2D, Dropout, maximum, Lambda, Activation
from keras.models import Model, Sequential, load_model
from keras.optimizers import Adam, SGD
from sklearn.model_selection import StratifiedKFold

PATCH_HEIGHT = 28
PATCH_WIDTH = 28

data_dir = 'data'
if not os.path.exists('checkpoints'):
    os.mkdir('checkpoints')
if not os.path.exists('logs'):
    os.mkdir('logs')

pp = pprint.PrettyPrinter(indent=4)

Using TensorFlow backend.


In [2]:
ct_train = np.load(os.path.join(data_dir, 'ct_train.npy'))
pet_train = np.load(os.path.join(data_dir, 'pet_train.npy'))
y_train = np.load(os.path.join(data_dir, 'y_patches_train.npy'))

ct_test = np.load(os.path.join(data_dir, 'ct_test.npy'))
pet_test = np.load(os.path.join(data_dir, 'pet_test.npy'))
y_test = np.load(os.path.join(data_dir, 'y_patches_test.npy'))

BIG_PATCH_HEIGHT = ct_train.shape[1]
BIG_PATCH_WIDTH = ct_train.shape[2]
BIG_LABEL_HEIGHT = y_train.shape[1]
BIG_LABEL_WIDTH = y_train.shape[2]

def get_center_x_window(height, width):
    return (BIG_PATCH_HEIGHT - height) // 2, (BIG_PATCH_WIDTH - width) // 2

def get_center_y_window(height, width):
    return (BIG_LABEL_HEIGHT - height) // 2, (BIG_LABEL_WIDTH - width) // 2

def get_train(mode=None, subshape=(PATCH_HEIGHT, PATCH_WIDTH)):
    if subshape is None:
        if mode == 'ct':
            return ct_train
        elif mode == 'pet':
            return pet_train
        else:
            return [ct_train, pet_train]
    
    trim = subshape[0] % 2 == 1
    
    w = get_center_x_window(subshape[0], subshape[1])
    if mode == 'ct':
        return ct_train[:, w[0]+trim:-w[0], w[1]:-w[1]-trim, :]
    elif mode == 'pet':
        return pet_train[:, w[0]+trim:-w[0], w[1]:-w[1]-trim, :]
    else:
        return [ct_train[:, w[0]+trim:-w[0], w[1]:-w[1]-trim, :], pet_train[:, w[0]+trim:-w[0], w[1]:-w[1]-trim, :]]

def get_test(mode=None, subshape=(PATCH_HEIGHT, PATCH_WIDTH)):
    if subshape is None:
        if mode == 'ct':
            return ct_test
        elif mode == 'pet':
            return pet_test
        else:
            return [ct_test, pet_test]
    
    trim = subshape[0] % 2 == 1

    w = get_center_x_window(subshape[0], subshape[1])
    if mode == 'ct':
        return ct_test[:, w[0]+trim:-w[0], w[1]:-w[1]-trim, :]
    elif mode == 'pet':
        return pet_test[:, w[0]+trim:-w[0], w[1]:-w[1]-trim, :]
    else:
        return [ct_test[:, w[0]+trim:-w[0], w[1]:-w[1]-trim, :], pet_test[:, w[0]+trim:-w[0], w[1]:-w[1]-trim, :]]

def get_labels(mode=None, subshape=(0, 0), flatten=False):
    if subshape is None:
        if mode == 'train':
            return y_train
        else:
            return y_test
        
    trim = subshape[0] % 2 == 1
    
    w = get_center_y_window(subshape[0], subshape[1])
    if mode == 'train':
        if subshape == (0, 0):
            return y_train[:, w[0], w[1], :]
        elif flatten:
            return np.reshape(y_train[:, w[0]+trim:-w[0], w[1]:-w[1]+trim, :], (y_train.shape[0], (subshape[0] - trim) * (subshape[1] - trim)))
        else:
            return y_train[:, w[0]+trim:-w[0], w[1]:-w[1]-trim, :]
    else:
        if subshape == (0, 0):
            return y_test[:, w[0], w[1], :]
        elif flatten:
            return np.reshape(y_test[:, w[0]+trim:-w[0], w[1]:-w[1]-trim, :], (y_test.shape[0], (subshape[0] - trim) * (subshape[1] - trim)))
        else:
            return y_test[:, w[0]+trim:-w[0], w[1]:-w[1]-trim, :]

In [3]:
def confusion_matrix(y_true, y_pred):
    num_targets = y_true.shape[0]
    y_true_targets = (y_true == 1.).reshape((num_targets, -1))
    y_pred_targets = (y_pred >= 0.5).reshape((num_targets, -1))
    return sklearn.metrics.confusion_matrix(y_true_targets, y_pred_targets)

def accuracy(y_true, y_pred):
    num_targets = y_true.shape[0]
    y_true_targets = (y_true == 1.).reshape((num_targets, -1))
    y_pred_targets = (y_pred >= 0.5).reshape((num_targets, -1))
    return sklearn.metrics.accuracy_score(y_true_targets, y_pred_targets)

def f1(y_true, y_pred):
    c_matrix = confusion_matrix(y_true, y_pred)
    if c_matrix.shape != (2, 2):
        raise NotImplementedError(f'F1 not available for confusion matrix of shape {c_matrix.shape}')
    tp = c_matrix[1][1]
    fp = c_matrix[0][1]
    fn = c_matrix[1][0]
    return 2 * tp / (2 * tp + fn + fp)

In [4]:
def train_model(model_fn, name, batch_size=32, epochs=8, patience=2, mode=None, save=False, val=True,
                x_subshape=(PATCH_HEIGHT, PATCH_WIDTH), y_subshape=(0, 0), return_f1=True, return_model=False):
    print('Train...')

    best_model_path = os.path.join('checkpoints', f'best_model_{name}.h5')
    log_dir = os.path.join('logs', f'{name}')

    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    callbacks = []
    
    if val:
        callbacks.append(EarlyStopping(monitor='val_acc', patience=patience))
    
    if save:
        callbacks.append(ModelCheckpoint(best_model_path, monitor='val_acc', save_best_only=True, save_weights_only=True))
        callbacks.append(TensorBoard(log_dir=log_dir, histogram_freq=1, batch_size=batch_size, write_graph=False, write_grads=True, write_images=True))
    
    model = model_fn()
    model.fit(get_train(mode, subshape=x_subshape),
              get_labels('train', subshape=y_subshape),
              batch_size=batch_size,
              epochs=epochs,
              validation_split=0.1 if val else 0.0,
              verbose=1,
              shuffle=True,
              callbacks=callbacks)
    preds = model.predict(get_test(mode, subshape=x_subshape))
    
    acc_score = accuracy(get_labels('test', subshape=y_subshape), preds)
    print(f'Acc: {acc_score}')
    
    if return_model:
        return model

    if return_f1:
        f1_score = f1(get_labels('test', subshape=y_subshape), preds)
        print(f'F1: {f1_score}')
        print('\n\n')
        return f1_score, acc_score
    
    return None, acc_score

def train_n_sessions(model_fn, name, n, mode=None, **kwargs):
    f1s = []
    accs = []
    
    for i in range(n):
        print(f'Round {i + 1} out of {n}')
        print('-' * 101)
        f1, acc = train_model(model_fn, name, mode=mode, return_model=False, **kwargs)
        f1s.append(f1)
        accs.append(acc)
    
    return f1s, accs

# Type 1: Feature-Level Fusion

In [46]:
def get_type_1_model(summary=False):
    K.clear_session()

    ct_input = Input(shape=(PATCH_HEIGHT, PATCH_WIDTH, 1))
    pet_input = Input(shape=(PATCH_HEIGHT, PATCH_WIDTH, 1))

    x = concatenate([ct_input, pet_input])
    x = Reshape((PATCH_HEIGHT, PATCH_WIDTH, 2, 1))(x)
    x = Conv3D(16, (2, 2, 2), activation='relu')(x)
    x = Reshape((27, 27, 16))(x)
    x = Conv2D(36, (2, 2), activation='relu')(x)
    x = Conv2D(64, (2, 2), activation='relu')(x)
    x = Conv2D(144, (2, 2), activation='relu')(x)
    x = AveragePooling2D((23, 23))(x)
    x = Flatten()(x)
    x = Dense(864, activation='relu')(x)
    x = Dense(288, activation='relu')(x)
    output = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=[ct_input, pet_input], outputs=output)

    model.compile(optimizer=SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    if summary:
        model.summary()

    return model

In [47]:
f1s, accs = train_n_sessions(get_type_1_model, 'type_I', 10, epochs=5, val=False)

Round 1 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.9303872653837016
Acc: 0.9304368471035138



Round 2 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.9353846153846154
Acc: 0.9351851851851852



Round 3 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.9305654974946314
Acc: 0.9309116809116809



Round 4 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.932858837485172
Acc: 0.9328110161443495



Round 5 out of 10
-------------------------------------------------------------------

In [48]:
pp.pprint(f1s)
pp.pprint(accs)

[   0.93038726538370164,
    0.93538461538461537,
    0.93056549749463136,
    0.93285883748517195,
    0.92370637785800236,
    0.9301990885104342,
    0.93955014058106845,
    0.92181069958847739,
    0.93612024424612494,
    0.92006761651774938]
[   0.93043684710351382,
    0.93518518518518523,
    0.93091168091168086,
    0.93281101614434947,
    0.92473884140550811,
    0.93091168091168086,
    0.93874643874643871,
    0.92331433998100665,
    0.93542260208926875,
    0.92141500474833804]


# Type 2: Classifier-Level Fusion

In [51]:
def get_type_2_model(summary=False):
    K.clear_session()

    ct_input = Input(shape=(PATCH_HEIGHT, PATCH_WIDTH, 1))
    pet_input = Input(shape=(PATCH_HEIGHT, PATCH_WIDTH, 1))

    ct_model = Conv2D(16, (2, 2), activation='relu')(ct_input)
    ct_model = Conv2D(36, (2, 2), activation='relu')(ct_model)
    ct_model = Conv2D(64, (2, 2), activation='relu')(ct_model)
    ct_model = Conv2D(144, (2, 2), activation='relu')(ct_model)
    ct_model = AveragePooling2D((23, 23))(ct_model)
    ct_model = Flatten()(ct_model)

    pet_model = Conv2D(16, (2, 2), activation='relu')(pet_input)
    pet_model = Conv2D(36, (2, 2), activation='relu')(pet_model)
    pet_model = Conv2D(64, (2, 2), activation='relu')(pet_model)
    pet_model = Conv2D(144, (2, 2), activation='relu')(pet_model)
    pet_model = AveragePooling2D((23, 23))(pet_model)
    pet_model = Flatten()(pet_model)

    x = concatenate([ct_model, pet_model])
    x = Dense(864, activation='relu')(x)
    x = Dense(288, activation='relu')(x)
    output = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=[ct_input, pet_input], outputs=output)

    model.compile(optimizer=SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    if summary:
        model.summary()
    
    return model

In [52]:
f1s_2, accs_2 = train_n_sessions(get_type_2_model, 'type_II', 10, epochs=5, val=False)

Round 1 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.924933622978518
Acc: 0.9261633428300095



Round 2 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.935361216730038
Acc: 0.9354226020892688



Round 3 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.9326350868840753
Acc: 0.9328110161443495



Round 4 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.9360795454545454
Acc: 0.9358974358974359



Round 5 out of 10
--------------------------------------------------------------------

In [53]:
pp.pprint(f1s_2)
pp.pprint(accs_2)

[   0.924933622978518,
    0.93536121673003803,
    0.93263508688407526,
    0.93607954545454541,
    0.94128654970760239,
    0.93257088396473675,
    0.9376476145488899,
    0.93529272339416925,
    0.9302103250478011,
    0.93497864261983865]
[   0.92616334283000945,
    0.93542260208926875,
    0.93281101614434947,
    0.9358974358974359,
    0.94040835707502379,
    0.93281101614434947,
    0.93732193732193736,
    0.93518518518518523,
    0.93067426400759734,
    0.9349477682811016]


# Type 3: Decision-Level Fusion

In [54]:
def get_type_3_model(summary=False):
    K.clear_session()

    ct_input = Input(shape=(PATCH_HEIGHT, PATCH_WIDTH, 1))
    pet_input = Input(shape=(PATCH_HEIGHT, PATCH_WIDTH, 1))

    ct_model = Conv2D(16, (2, 2), activation='relu')(ct_input)
    ct_model = Conv2D(36, (2, 2), activation='relu')(ct_model)
    ct_model = Conv2D(64, (2, 2), activation='relu')(ct_model)
    ct_model = Conv2D(144, (2, 2), activation='relu')(ct_model)
    ct_model = AveragePooling2D((23, 23))(ct_model)
    ct_model = Flatten()(ct_model)
    ct_model = Dense(864, activation='relu')(ct_model)
    ct_model = Dense(288, activation='relu')(ct_model)
    ct_model = Dense(1, activation='sigmoid')(ct_model)

    pet_model = Conv2D(16, (2, 2), activation='relu')(pet_input)
    pet_model = Conv2D(36, (2, 2), activation='relu')(pet_model)
    pet_model = Conv2D(64, (2, 2), activation='relu')(pet_model)
    pet_model = Conv2D(144, (2, 2), activation='relu')(pet_model)
    pet_model = AveragePooling2D((23, 23))(pet_model)
    pet_model = Flatten()(pet_model)
    pet_model = Dense(864, activation='relu')(pet_model)
    pet_model = Dense(288, activation='relu')(pet_model)
    pet_model = Dense(1, activation='sigmoid')(pet_model)

    predictions = average([ct_model, pet_model])

    model = Model(inputs=[ct_input, pet_input], outputs=predictions)

    model.compile(optimizer=SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    if summary:
        model.summary()
    
    return model

In [55]:
f1s_3, accs_3 = train_n_sessions(get_type_3_model, 'type_III', 10, epochs=5, val=False)

Round 1 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.9308446996889208
Acc: 0.931386514719848



Round 2 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.9124877089478859
Acc: 0.9154795821462488



Round 3 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.9346327549322557
Acc: 0.9347103513770181



Round 4 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.9142019066242972
Acc: 0.9166666666666666



Round 5 out of 10
-------------------------------------------------------------------

In [56]:
pp.pprint(f1s_3)
pp.pprint(accs_3)

[   0.93084469968892081,
    0.91248770894788589,
    0.93463275493225573,
    0.91420190662429723,
    0.93320656049441408,
    0.92534428605943464,
    0.93567251461988299,
    0.92318911035488571,
    0.93707282583078011,
    0.93520056966532161]
[   0.93138651471984801,
    0.9154795821462488,
    0.93471035137701808,
    0.91666666666666663,
    0.93328584995251662,
    0.9266381766381766,
    0.93471035137701808,
    0.92497625830959163,
    0.93660968660968658,
    0.93518518518518523]


# Baseline: Single-Modality CNNs

In [46]:
def get_single_modality_model(summary=False):
    print('Build model...')

    K.clear_session()
    
    model = Sequential()
    model.add(Conv2D(16, (2, 2), activation='relu', input_shape=(PATCH_HEIGHT, PATCH_WIDTH, 1)))
    model.add(Conv2D(36, (2, 2), activation='relu'))
    model.add(Conv2D(64, (2, 2), activation='relu'))
    model.add(Conv2D(144, (2, 2), activation='relu'))
    model.add(AveragePooling2D((23, 23)))
    model.add(Flatten())
    model.add(Dense(864, activation='relu'))
    model.add(Dense(288, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    if summary:
        model.summary()

    print('Model built.')
    
    return model

In [58]:
f1s_c, accs_c = train_n_sessions(get_single_modality_model, 'ct', 10, mode='ct', epochs=5, val=False)

Round 1 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Build model...
Model built.
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.8450125542113672
Acc: 0.8387939221272555



Round 2 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Build model...
Model built.
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.8489110707803993
Acc: 0.8418803418803419



Round 3 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Build model...
Model built.
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.8487894015532207
Acc: 0.8428300094966762



Round 4 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Build model...
Model built.
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
F1: 0.8499887209564629

In [59]:
pp.pprint(f1s_c)
pp.pprint(accs_c)

[   0.84501255421136723,
    0.84891107078039929,
    0.84878940155322069,
    0.84998872095646294,
    0.84585521808632103,
    0.84483153793261512,
    0.85380642263642492,
    0.84710178000912828,
    0.83165535003512059,
    0.8514582862310649]
[   0.83879392212725545,
    0.84188034188034189,
    0.84283000949667619,
    0.84211775878442541,
    0.83974358974358976,
    0.83926875593542261,
    0.84544159544159547,
    0.84093067426400758,
    0.82929724596391263,
    0.84401709401709402]


In [47]:
f1s_p, accs_p = train_n_sessions(get_single_modality_model, 'pet', 10, mode='pet', epochs=5, val=False)

Round 1 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Build model...
Model built.
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9344729344729344
F1: 0.9340659340659341



Round 2 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Build model...
Model built.
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.939696106362773
F1: 0.9398389388915206



Round 3 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Build model...
Model built.
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9401709401709402
F1: 0.9407616361071932



Round 4 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Build model...
Model built.
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9482431149097815

In [62]:
pp.pprint(f1s_p)
pp.pprint(accs_p)

[   0.92912246865959502,
    0.94142554693013414,
    0.93514544587505966,
    0.94545454545454544,
    0.93377958403060002,
    0.94518311173314673,
    0.93720266412940056,
    0.94651162790697674,
    0.94192334822478252,
    0.94175669328323153]
[   0.93019943019943019,
    0.94088319088319083,
    0.93542260208926875,
    0.94444444444444442,
    0.93423551756885093,
    0.9442070275403609,
    0.93732193732193736,
    0.94539411206077872,
    0.94135802469135799,
    0.94112060778727447]


In [36]:
k_1 = 5
k_2 = 2
k_2mf = 3

p_1 = 4
p_2 = 2

k_g = k_1 + k_2 + p_1 + p_2 - 3
k_gmf = k_1 + k_2mf + p_1 + p_2 - 3

d_l = k_1 + p_1 - 2
d_mf = k_1 + k_2mf + p_1 + p_2 - 4

def get_stream_model(big_patch_dim, small_patch_dim, n_feature_maps=2, mode=None, summary=False, maxout=False, dropout=False):
    K.clear_session()
        
    k_f = big_patch_dim - k_g + 1

    if mode in ['ct', 'pet']:
        x = Input(shape=(big_patch_dim, big_patch_dim, 1))
        model_input = x
    else:
        ct_input = Input(shape=(big_patch_dim, big_patch_dim, 1))
        pet_input = Input(shape=(big_patch_dim, big_patch_dim, 1))
        model_input = [ct_input, pet_input]
        x = concatenate(model_input, axis=-1)
    
    if maxout:
        conv1_local = maximum([Conv2D(64, (k_1, k_1))(x) for _ in range(n_feature_maps)])
    else:
        conv1_local = Conv2D(64, (k_1, k_1), activation='relu')(x)
    pool1_local = MaxPooling2D((p_1, p_1), strides=(1, 1))(conv1_local)
    if dropout:
        pool1_local = Dropout(0.2)(pool1_local)
    
    if maxout:
        conv2_local = maximum([Conv2D(64, (k_2, k_2))(pool1_local) for _ in range(n_feature_maps)])
    else:
        conv2_local = Conv2D(64, (k_2, k_2), activation='relu')(pool1_local)
    pool2_local = MaxPooling2D((p_2, p_2), strides=(1, 1))(conv2_local)
    if dropout:
        pool2_local = Dropout(0.2)(pool2_local)

    if maxout:
        conv1_global= maximum([Conv2D(160, (k_g, k_g))(x) for _ in range(n_feature_maps)])
    else:
        conv1_global = Conv2D(160, (k_g, k_g), activation='relu')(x)
    if dropout:
        conv1_global = Dropout(0.2)(conv1_global)
    
    #combine = Flatten()(concatenate([pool2_local, conv1_global], axis=-1))
    #output = Dense(small_patch_dim * small_patch_dim, activation='sigmoid')(combine)
    combine = concatenate([pool2_local, conv1_global], axis=-1)
    output = Conv2D(small_patch_dim * small_patch_dim, (k_f, k_f), activation='sigmoid')(combine)
    
    if small_patch_dim > 1:
        output = Reshape((small_patch_dim, small_patch_dim, 1))(output)
    else:
        output = Reshape((1,))(output)

    model = Model(inputs=model_input, outputs=output)

    model.compile(optimizer=SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    if summary:
        model.summary()
    return model

def get_two_path_cascade_input(stream_model, n_feature_maps=2, mode=None, summary=False, maxout=False, dropout=False):
    K.clear_session()
    
    k_f = PATCH_HEIGHT - k_g + 1

    stream_model = stream_model(2 * PATCH_HEIGHT, PATCH_HEIGHT,
                                mode=mode, n_feature_maps=n_feature_maps, maxout=maxout, dropout=dropout)
    stream_model.trainable = False
    stream_model.load_weights(f'input_stream_{mode}.h5' if mode is not None else 'input_stream.h5')
    
    if mode in ['ct', 'pet']:
        model_input = Input(shape=(2 * PATCH_HEIGHT, 2 * PATCH_WIDTH, 1))
        x = model_input
        stream_output = stream_model(x)
    else:
        ct_input = Input(shape=(2 * PATCH_HEIGHT, 2 * PATCH_WIDTH, 1))
        pet_input = Input(shape=(2 * PATCH_HEIGHT, 2 * PATCH_WIDTH, 1))
        model_input = [ct_input, pet_input]
        x = concatenate(model_input, axis=-1)
        stream_output = stream_model([ct_input, pet_input])
    
    h = PATCH_HEIGHT // 2
    w = PATCH_WIDTH // 2
        
    x = Lambda(lambda x: x[:, h:-h, w:-w, :])(x)
    x = concatenate([x, stream_output], axis=-1)
    
    if maxout:
        conv1_local = maximum([Conv2D(64, (k_1, k_1))(x) for _ in range(n_feature_maps)])
    else:
        conv1_local = Conv2D(64, (k_1, k_1), activation='relu')(x)
    pool1_local = MaxPooling2D((p_1, p_1), strides=(1, 1))(conv1_local)
    if dropout:
        pool1_local = Dropout(0.2)(pool1_local)
        
    if maxout:
        conv2_local = maximum([Conv2D(64, (k_2, k_2))(pool1_local) for _ in range(n_feature_maps)])
    else:
        conv2_local = Conv2D(64, (k_2, k_2), activation='relu')(pool1_local)
    pool2_local = MaxPooling2D((p_2, p_2), strides=(1, 1))(conv2_local)
    if dropout:
        pool2_local = Dropout(0.2)(pool2_local)

    if maxout:
        conv1_global = maximum([Conv2D(160, (k_g, k_g))(x) for _ in range(n_feature_maps)])
    else:
        conv1_global = Conv2D(160, (k_g, k_g), activation='relu')(x)
    if dropout:
        conv1_global = Dropout(0.2)(conv1_global)
    
    #combine = Flatten()(concatenate([pool2_local, conv1_global], axis=-1))
    #output = Dense(1, activation='sigmoid')(combine)
    combine = concatenate([pool2_local, conv1_global], axis=-1)
    output = Conv2D(1, (k_f, k_f), activation='sigmoid')(combine)
    output = Reshape((1,))(output)

    model = Model(inputs=model_input, outputs=output)

    model.compile(optimizer=SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    if summary:
        model.summary()
    return model

def get_two_path_cascade_local(stream_model, n_feature_maps=2, mode=None, summary=False, maxout=False, dropout=False):
    K.clear_session()
    
    k_f = PATCH_HEIGHT - k_g + 1
    
    stream_model = stream_model(2 * PATCH_HEIGHT - d_l, PATCH_HEIGHT - d_l,
                                mode=mode, n_feature_maps=n_feature_maps, maxout=maxout, dropout=dropout)
    stream_model.trainable = False
    stream_model.load_weights(f'local_stream_{mode}.h5' if mode is not None else 'local_stream.h5')
    
    if mode in ['ct', 'pet']:
        x = Input(shape=(2 * PATCH_HEIGHT - d_l, 2 * PATCH_WIDTH - d_l, 1))
        model_input = x
        stream_output = stream_model(x)
    else:
        ct_input = Input(shape=(2 * PATCH_HEIGHT - d_l, 2 * PATCH_WIDTH - d_l, 1))
        pet_input = Input(shape=(2 * PATCH_HEIGHT - d_l, 2 * PATCH_WIDTH - d_l, 1))
        model_input = [ct_input, pet_input]
        x = concatenate([ct_input, pet_input], axis=-1)
        stream_output = stream_model([ct_input, pet_input])
    
    h = (PATCH_HEIGHT - d_l) // 2
    w = (PATCH_WIDTH - d_l) // 2
    trim = d_l % 2 == 1
        
    x = Lambda(lambda x: x[:, h+trim:-h, w:-w-trim, :])(x)
    
    if maxout:
        conv1_local = maximum([Conv2D(64, (k_1, k_1))(x) for _ in range(n_feature_maps)])
    else:
        conv1_local = Conv2d(64, (k_1, k_1), activation='relu')(x)
    pool1_local = MaxPooling2D((p_1, p_1), strides=(1, 1))(conv1_local)
    if dropout:
        pool1_local = Dropout(0.2)(pool1_local)
    
    pool1_local = concatenate([pool1_local, stream_output], axis=-1)
    
    if maxout:
        conv2_local = maximum([Conv2D(64, (k_2, k_2))(pool1_local) for _ in range(n_feature_maps)])
    else:
        conv2_local = Conv2D(64, (k_2, k_2), activation='relu')(pool1_local)
    pool2_local = MaxPooling2D((p_2, p_2), strides=(1, 1))(conv2_local)
    if dropout:
        pool2_local = Dropout(0.2)(pool2_local)

    if maxout:
        conv1_global= maximum([Conv2D(160, (k_g, k_g))(x) for _ in range(n_feature_maps)])
    else:
        conv1_global = Conv2D(160, (k_g, k_g), activation='relu')(x)
    if dropout:
        conv1_global = Dropout(0.2)(conv1_global)
    
    combine = concatenate([pool2_local, conv1_global], axis=-1)
    output = Conv2D(1, (k_f, k_f), activation='sigmoid')(combine)
    output = Reshape((1,))(output)

    model = Model(inputs=model_input, outputs=output)

    model.compile(optimizer=SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    if summary:
        model.summary()
    return model

def get_two_path_cascade_mf(stream_model, n_feature_maps=2, mode=None, summary=False, maxout=False, dropout=False):
    K.clear_session()
    
    k_f = PATCH_HEIGHT - k_gmf + 1
    
    stream_model = stream_model(2 * PATCH_HEIGHT - d_mf, PATCH_HEIGHT - d_mf,
                                mode=mode, n_feature_maps=n_feature_maps, maxout=maxout, dropout=dropout)
    stream_model.trainable = False
    stream_model.load_weights(f'mf_stream_{mode}.h5' if mode is not None else 'mf_stream.h5')
    
    if mode in ['ct', 'pet']:
        x = Input(shape=(2 * PATCH_HEIGHT - d_mf, 2 * PATCH_WIDTH - d_mf, 1))
        model_input = x
        stream_output = stream_model(x)
    else:
        ct_input = Input(shape=(2 * PATCH_HEIGHT - d_mf, 2 * PATCH_WIDTH - d_mf, 1))
        pet_input = Input(shape=(2 * PATCH_HEIGHT - d_mf, 2 * PATCH_WIDTH - d_mf, 1))
        model_input = [ct_input, pet_input]
        x = concatenate(model_input, axis=-1)
        stream_output = stream_model([ct_input, pet_input])
    
    h = (PATCH_HEIGHT - d_mf) // 2
    w = (PATCH_WIDTH - d_mf) // 2
    trim = d_mf % 2 == 1
        
    x = Lambda(lambda x: x[:, h+trim:-h, w:-w-trim, :])(x)
    
    if maxout:
        conv1_local = maximum([Conv2D(64, (k_1, k_1))(x) for _ in range(n_feature_maps)])
    else:
        conv1_local = Conv2D(64, (k_1, k_1), activation='relu')(x)
    pool1_local = MaxPooling2D((p_1, p_1), strides=(1, 1))(conv1_local)
    if dropout:
        pool1_local = Dropout(0.2)(pool1_local)
    
    if maxout:
        conv2_local = maximum([Conv2D(64, (k_2mf, k_2mf))(pool1_local) for _ in range(n_feature_maps)])
    else:
        conv2_local = Conv2D(64, (k_2mf, k_2mf), activation='relu')(pool1_local)
    pool2_local = MaxPooling2D((p_2, p_2), strides=(1, 1))(conv2_local)
    if dropout:
        pool2_local = Dropout(0.2)(pool2_local)

    if maxout:
        conv1_global= maximum([Conv2D(160, (k_gmf, k_gmf))(x) for _ in range(n_feature_maps)])
    else:
        conv1_global = Conv2D(160, (k_gmf, k_gmf), activation='relu')(x)
    if dropout:
        conv1_global = Dropout(0.2)(conv1_global)

    combine = concatenate([pool2_local, conv1_global, stream_output], axis=-1)
    output = Conv2D(1, (k_f, k_f), activation='sigmoid')(combine)
    output = Reshape((1,))(output)

    model = Model(inputs=model_input, outputs=output)

    model.compile(optimizer=SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    if summary:
        model.summary()
    return model

In [28]:
input_stream = train_model(lambda: get_stream_model(2 * PATCH_HEIGHT, PATCH_HEIGHT, mode='pet', maxout=True, dropout=True),
                           'input_stream', epochs=5, mode='pet', val=False,
                           x_subshape=None, y_subshape=None, return_f1=False, return_model=True)
input_stream.save('input_stream_pet.h5')

Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.3969610636277303


In [29]:
local_stream = train_model(lambda: get_stream_model(2 * PATCH_HEIGHT - d_l, PATCH_HEIGHT - d_l, mode='pet', maxout=True, dropout=True),
                           'local_stream', epochs=5, mode='pet', val=False,
                           x_subshape=(2 * PATCH_HEIGHT - d_l, 2 * PATCH_WIDTH - d_l),
                           y_subshape=(PATCH_HEIGHT - d_l, PATCH_WIDTH - d_l),
                           return_f1=False, return_model=True)
local_stream.save('local_stream_pet.h5')

Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.41999050332383664


In [30]:
mf_stream = train_model(lambda: get_stream_model(2 * PATCH_HEIGHT - d_mf, PATCH_HEIGHT - d_mf, mode='pet', maxout=True, dropout=True),
                        'mf_stream', epochs=5, val=False, mode='pet',
                        x_subshape=(2 * PATCH_HEIGHT - d_mf, 2 * PATCH_WIDTH - d_mf),
                        y_subshape=(PATCH_HEIGHT - d_mf, PATCH_WIDTH - d_mf),
                        return_f1=False, return_model=True)
mf_stream.save('mf_stream_pet.h5')

Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.4301994301994302


In [31]:
f1s_input, accs_input = train_n_sessions(
    lambda: get_two_path_cascade_input(get_stream_model, mode='pet', maxout=True, dropout=True), 'mf', 10,
    epochs=5, x_subshape=None, mode='pet', val=False)

Round 1 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9434947768281101
F1: 0.942870859337494



Round 2 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9370845204178537
F1: 0.935064935064935



Round 3 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9461063627730294
F1: 0.9457067687156183



Round 4 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9387464387464387
F1: 0.9375



Round 5 out of 10
--------------------------------------------------------------------------------

In [32]:
pp.pprint(f1s_input)
pp.pprint(accs_input)

[   0.94287085933749404,
    0.93506493506493504,
    0.94570676871561832,
    0.9375,
    0.93189612934835864,
    0.94100788153809412,
    0.93241919686581787,
    0.94447103877453331,
    0.91338974614235935,
    0.93869638962927071]
[   0.94349477682811012,
    0.93708452041785373,
    0.9461063627730294,
    0.93874643874643871,
    0.93399810066476729,
    0.94135802469135799,
    0.93447293447293445,
    0.94491927825261157,
    0.91737891737891741,
    0.93993352326685664]


In [33]:
f1s_local, accs_local = train_n_sessions(
    lambda: get_two_path_cascade_local(get_stream_model, mode='pet', maxout=True, dropout=True), 'local', 10,
    epochs=5, x_subshape=(2 * PATCH_HEIGHT - d_l, 2 * PATCH_WIDTH - d_l), mode='pet', val=False)

Round 1 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9404083570750238
F1: 0.9392104625817389



Round 2 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9411206077872745
F1: 0.9401255432158377



Round 3 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9368471035137702
F1: 0.934931506849315



Round 4 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9411206077872745
F1: 0.9405275779376499



Round 5 out of 10
-------------------------------------------------------------------

In [34]:
pp.pprint(f1s_local)
pp.pprint(accs_local)

[   0.9392104625817389,
    0.94012554321583774,
    0.93493150684931503,
    0.94052757793764985,
    0.94714079769341664,
    0.9446993479835788,
    0.92151522654122309,
    0.93555501102670913,
    0.94503375120540023,
    0.92715231788079466]
[   0.94040835707502379,
    0.94112060778727447,
    0.93684710351377021,
    0.94112060778727447,
    0.94776828110161448,
    0.94563152896486224,
    0.92473884140550811,
    0.93755935422602088,
    0.94586894586894588,
    0.92948717948717952]


In [37]:
f1s_mf, accs_mf = train_n_sessions(
    lambda: get_two_path_cascade_mf(get_stream_model, mode='pet', maxout=True, dropout=True), 'mf', 10,
    epochs=5, x_subshape=(2 * PATCH_HEIGHT - d_mf, 2 * PATCH_WIDTH - d_mf), mode='pet', val=False)

Round 1 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9349477682811016
F1: 0.9329745596868885



Round 2 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9207027540360874
F1: 0.9168326693227091



Round 3 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.938034188034188
F1: 0.9363259331544279



Round 4 out of 10
-----------------------------------------------------------------------------------------------------
Train...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Acc: 0.9370845204178537
F1: 0.9354760165570977



Round 5 out of 10
-------------------------------------------------------------------

In [17]:
pp.pprint(f1s_mf)
pp.pprint(accs_mf)

[   0.93615984405458086,
    0.95070930512142338,
    0.94587442867452487,
    0.91465863453815266,
    0.93985877769661552,
    0.93997071742313321,
    0.93209574987787003,
    0.94305791131572569,
    0.94166061486322927,
    0.91672918229557387]
[   0.93779677113010451,
    0.95132953466286796,
    0.94658119658119655,
    0.91927825261158591,
    0.94135802469135799,
    0.94159544159544162,
    0.93399810066476729,
    0.9442070275403609,
    0.94278252611585944,
    0.92094017094017089]
