In [1]:
import sys
sys.path.append('../../resnet3d')
sys.path.append('../../Data')

In [2]:
from resnet3d import Resnet3DBuilder
from LIB import merge_data, split_flip, augmentate
import tensorflow as tf
from tensorflow import keras 
from keras.models import Model
from keras.layers import (
    Input,
    Activation,
    Dense,
    Concatenate,
    Flatten
)
from keras.regularizers import l2

import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix, log_loss
from matplotlib import pyplot as plt

In [3]:
def plot_metric(history, metric):
    plt.clf()
    train_metrics = history.history[metric]
    val_metrics = history.history['val_'+metric]
    epochs = range(1, len(train_metrics) + 1)
    plt.plot(epochs, train_metrics)
    plt.plot(epochs, val_metrics)
    plt.title('Training and validation '+ metric)
    plt.xlabel("Epochs")
    plt.ylabel(metric)
    plt.legend(["train_"+metric, 'val_'+metric])
    plt.savefig('Training_Validation_'+ metric+'.pdf')
    plt.show()
    
def performance(model, X, y_true):  
    
    y_pred = (model.predict(X) >= 0.5)    
    print("Accuracy: "+str(accuracy_score(y_true,y_pred)))
    print("Precision: "+str(precision_score(y_true,y_pred,average='weighted')))
    print("Recall: "+str(recall_score(y_true,y_pred)))

    cm = confusion_matrix(y_true,y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels = ["CN","AD"])
    disp.plot()
    plt.show()

In [4]:
X_train, Y_train, X_test,Y_test = merge_data('AD', 'CN', '../../Data/')

In [5]:
X_train = split_flip(X_train)

In [6]:
def build_model(inputs_shape = (121, 145, 61,1), num_outputs = 1, reg_factor=1e-4):
    
    # Create siamese model
    input1 = Input(inputs_shape)
    input2 = Input(inputs_shape)
    
    # Create base siamese model
    base_model = Resnet3DBuilder.build_resnet_34_block(inputs_shape, 256)
    
    # Create left and right twin models
    left_model = base_model(input1)
    right_model = base_model(input2)
    
    concatted = Concatenate()([left_model, right_model])
    
    dense = Dense(units=num_outputs,
                          kernel_initializer="he_normal",
                          activation="sigmoid",
                          kernel_regularizer=l2(reg_factor))(concatted)
    
    model = Model(inputs=[input1, input2], outputs=dense)
    return model

In [7]:
def training(model, X_train, Y_train, X_val, Y_val, epochs = 1, batch_size=16, show_summary = False):
    
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=8, restore_best_weights=True)
    model.compile(loss="binary_crossentropy", optimizer="sgd", metrics=['accuracy'])
    if show_summary:
        model.summary()
    
    history = model.fit(X_train, Y_train, validation_data = (X_val, Y_val), epochs=epochs, batch_size=batch_size, callbacks=[early_stopping])
    
    plot_metric(history, 'accuracy')
    plot_metric(history, 'loss')

In [8]:
model = build_model((60,64,128,1),1)

channels last


In [9]:
training(model, X_train, Y_train, X_test, Y_test, epochs = 100)

Epoch 1/100
 1/39 [..............................] - ETA: 27:52 - loss: 2.4182 - accuracy: 0.5625

KeyboardInterrupt: 

In [None]:
model.save("SN_R34")