In [None]:
import os
import pprint

import numpy as np
import sklearn.metrics

from keras.callbacks import EarlyStopping

from models import *

PATCH_HEIGHT = 28
PATCH_WIDTH = 28

data_dir = 'data'
if not os.path.exists('checkpoints'):
    os.mkdir('checkpoints')
if not os.path.exists('logs'):
    os.mkdir('logs')

pp = pprint.PrettyPrinter(indent=4)

In [None]:
ct_train = np.load(os.path.join(data_dir, 'ct_train.npy'))
pet_train = np.load(os.path.join(data_dir, 'pet_train.npy'))
y_train = np.load(os.path.join(data_dir, 'y_patches_train.npy'))

ct_test = np.load(os.path.join(data_dir, 'ct_test.npy'))
pet_test = np.load(os.path.join(data_dir, 'pet_test.npy'))
y_test = np.load(os.path.join(data_dir, 'y_patches_test.npy'))

BIG_PATCH_HEIGHT = ct_train.shape[1]
BIG_PATCH_WIDTH = ct_train.shape[2]
BIG_LABEL_HEIGHT = y_train.shape[1]
BIG_LABEL_WIDTH = y_train.shape[2]

def get_center_x_window(height, width):
    return (BIG_PATCH_HEIGHT - height) // 2, (BIG_PATCH_WIDTH - width) // 2

def get_center_y_window(height, width):
    return (BIG_LABEL_HEIGHT - height) // 2, (BIG_LABEL_WIDTH - width) // 2

def get_train(mode=None, subshape=(PATCH_HEIGHT, PATCH_WIDTH)):
    if subshape is None:
        if mode == 'ct':
            return ct_train
        elif mode == 'pet':
            return pet_train
        else:
            return [ct_train, pet_train]
    
    trim = subshape[0] % 2 == 1
    
    w = get_center_x_window(subshape[0], subshape[1])
    if mode == 'ct':
        return ct_train[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :]
    elif mode == 'pet':
        return pet_train[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :]
    else:
        return [ct_train[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :], pet_train[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :]]

def get_test(mode=None, subshape=(PATCH_HEIGHT, PATCH_WIDTH)):
    if subshape is None:
        if mode == 'ct':
            return ct_test
        elif mode == 'pet':
            return pet_test
        else:
            return [ct_test, pet_test]
    
    trim = subshape[0] % 2 == 1

    w = get_center_x_window(subshape[0], subshape[1])
    if mode == 'ct':
        return ct_test[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :]
    elif mode == 'pet':
        return pet_test[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :]
    else:
        return [ct_test[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :], pet_test[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :]]

def get_labels(mode=None, subshape=(0, 0), flatten=False):
    if subshape is None:
        if mode == 'train':
            return y_train
        else:
            return y_test
        
    trim = subshape[0] % 2 == 1
    
    w = get_center_y_window(subshape[0], subshape[1])
    if mode == 'train':
        if subshape == (0, 0):
            return y_train[:, w[0], w[1], :]
        elif flatten:
            return np.reshape(y_train[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :], (y_train.shape[0], (subshape[0] - trim) * (subshape[1] - trim)))
        else:
            return y_train[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :]
    else:
        if subshape == (0, 0):
            return y_test[:, w[0], w[1], :]
        elif flatten:
            return np.reshape(y_test[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :], (y_test.shape[0], (subshape[0] - trim) * (subshape[1] - trim)))
        else:
            return y_test[:, w[0]+trim:-w[0], w[1]+trim:-w[1], :]

In [None]:
def confusion_matrix(y_true, y_pred):
    num_targets = y_true.shape[0]
    y_true_targets = (y_true == 1.).reshape((num_targets, -1))
    y_pred_targets = (y_pred >= 0.5).reshape((num_targets, -1))
    return sklearn.metrics.confusion_matrix(y_true_targets, y_pred_targets)

def accuracy(y_true, y_pred):
    num_targets = y_true.shape[0]
    y_true_targets = (y_true == 1.).reshape((num_targets, -1))
    y_pred_targets = (y_pred >= 0.5).reshape((num_targets, -1))
    return sklearn.metrics.accuracy_score(y_true_targets, y_pred_targets)

def f1(y_true, y_pred):
    c_matrix = confusion_matrix(y_true, y_pred)
    if c_matrix.shape != (2, 2):
        raise NotImplementedError(f'F1 not available for confusion matrix of shape {c_matrix.shape}')
    tp = c_matrix[1][1]
    fp = c_matrix[0][1]
    fn = c_matrix[1][0]
    return 2 * tp / (2 * tp + fn + fp)

In [None]:
def train_model(model_fn, name, batch_size=32, epochs=8, patience=2, mode=None, val=True,
                x_subshape=(PATCH_HEIGHT, PATCH_WIDTH), y_subshape=(0, 0), return_f1=True):
    print('Train...')

    best_model_path = os.path.join('checkpoints', f'best_model_{name}.h5')
    log_dir = os.path.join('logs', f'{name}')

    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    callbacks = []
    
    if val:
        callbacks.append(EarlyStopping(monitor='val_acc', patience=patience))
    
    model = model_fn()
    model.fit(get_train(mode, subshape=x_subshape),
              get_labels('train', subshape=y_subshape),
              batch_size=batch_size,
              epochs=epochs,
              validation_split=0.1 if val else 0.0,
              verbose=1,
              shuffle=True,
              callbacks=callbacks)
    preds = model.predict(get_test(mode, subshape=x_subshape))
    
    acc_score = accuracy(get_labels('test', subshape=y_subshape), preds)
    print(f'Acc: {acc_score}')
    
    if return_f1:
        f1_score = f1(get_labels('test', subshape=y_subshape), preds)
        print(f'F1: {f1_score}')
    else:
        f1_score = None

    print('\n\n')
    return model, f1_score, acc_score

def train_n_sessions(model_fn, name, n, mode=None, save_best_f1=True, **kwargs):
    f1s = []
    accs = []
    best_f1 = -1
    
    for i in range(n):
        print(f'Round {i + 1} out of {n}')
        print('-' * 101)
        model, f1, acc = train_model(model_fn, name, mode=mode, **kwargs)
        f1s.append(f1)
        accs.append(acc)
        
        if save_best_f1 and f1 > best_f1:
            best_f1 = f1
            model.save(f'best_{name}_model.h5')
    
    return f1s, accs

# Type 1: Feature-Level Fusion

In [None]:
f1s, accs = train_n_sessions(get_type_1_model, 'type_I', 10, epochs=5, val=False)

In [None]:
pp.pprint(f1s)
pp.pprint(accs)

# Type 2: Classifier-Level Fusion

In [None]:
f1s_2, accs_2 = train_n_sessions(get_type_2_model, 'type_II', 10, epochs=5, val=False)

In [None]:
pp.pprint(f1s_2)
pp.pprint(accs_2)

# Type 3: Decision-Level Fusion

In [None]:
f1s_3, accs_3 = train_n_sessions(get_type_3_model, 'type_III', 10, epochs=5, val=False)

In [None]:
pp.pprint(f1s_3)
pp.pprint(accs_3)

# Baseline: Single-Modality CNNs

In [None]:
f1s_c, accs_c = train_n_sessions(get_single_modality_model, 'ct', 10, mode='ct', epochs=5, val=False)

In [None]:
pp.pprint(f1s_c)
pp.pprint(accs_c)

In [None]:
f1s_p, accs_p = train_n_sessions(get_single_modality_model, 'pet', 10, mode='pet', epochs=5, val=False)

In [None]:
pp.pprint(f1s_p)
pp.pprint(accs_p)

# Cascaded CNNs

In [None]:
input_stream, _, _ = train_model(lambda: get_stream_model(2 * PATCH_HEIGHT - d_i, PATCH_HEIGHT, mode='pet', maxout=True, dropout=True),
                           'input_stream', epochs=5, mode='pet', val=False,
                           x_subshape=None, y_subshape=None, return_f1=False)
input_stream.save('input_stream_pet.h5')

In [None]:
local_stream, _, _ = train_model(lambda: get_stream_model(2 * PATCH_HEIGHT - d_l, PATCH_HEIGHT - d_l, mode='pet', maxout=True, dropout=True),
                           'local_stream', epochs=5, mode='pet', val=False,
                           x_subshape=(2 * PATCH_HEIGHT - d_l, 2 * PATCH_WIDTH - d_l),
                           y_subshape=(PATCH_HEIGHT - d_l, PATCH_WIDTH - d_l), return_f1=False)
local_stream.save('local_stream_pet.h5')

In [None]:
mf_stream, _, _ = train_model(lambda: get_stream_model(2 * PATCH_HEIGHT - d_mf, PATCH_HEIGHT - d_mf, mode='pet', maxout=True, dropout=True),
                        'mf_stream', epochs=5, val=False, mode='pet',
                        x_subshape=(2 * PATCH_HEIGHT - d_mf, 2 * PATCH_WIDTH - d_mf),
                        y_subshape=(PATCH_HEIGHT - d_mf, PATCH_WIDTH - d_mf), return_f1=False)
mf_stream.save('mf_stream_pet.h5')

In [None]:
f1s_input, accs_input = train_n_sessions(
    lambda: get_two_path_cascade_input(get_stream_model, mode='pet', maxout=True, dropout=True), 'mf', 10,
    epochs=5, x_subshape=None, mode='pet', val=False)

In [None]:
pp.pprint(f1s_input)
pp.pprint(accs_input)

In [None]:
f1s_local, accs_local = train_n_sessions(
    lambda: get_two_path_cascade_local(get_stream_model, mode='pet', maxout=True, dropout=True), 'local', 10,
    epochs=5, x_subshape=(2 * PATCH_HEIGHT - d_l, 2 * PATCH_WIDTH - d_l), mode='pet', val=False)

In [None]:
pp.pprint(f1s_local)
pp.pprint(accs_local)

In [None]:
f1s_mf, accs_mf = train_n_sessions(
    lambda: get_two_path_cascade_mf(get_stream_model, mode='pet', maxout=True, dropout=True), 'mf', 10,
    epochs=5, x_subshape=(2 * PATCH_HEIGHT - d_mf, 2 * PATCH_WIDTH - d_mf), mode='pet', val=False)

In [None]:
pp.pprint(f1s_mf)
pp.pprint(accs_mf)