In [5]:
import os
import random
import gc, numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.utils import compute_class_weight
import tensorflow as tf
from keras.models import Model
from keras import backend as K
from keras.layers import Input, Dense, Dropout,Flatten, BatchNormalization, Conv2D, MultiHeadAttention, concatenate
from sklearn.metrics import classification_report
from tensorflow.keras.optimizers import Adam
from keras.models import Sequential
from tensorflow.keras.utils import to_categorical
import seaborn as sns
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import precision_recall_curve


from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from tensorflow.keras.layers import Lambda
from tensorflow.keras.layers import Lambda
import tensorflow as tf

In [6]:
def make_img(t_img):
    img = pd.read_pickle(t_img)
    img_l = []
    for i in range(len(img)):
        img_l.append(img.values[i][0])
    
    return np.array(img_l)


def reset_random_seeds(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
   
               
def create_model_snp():
    model = Sequential()
    model.add(Dense(200,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(100, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.3))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))
    return model

def create_model_clinical():
    model = Sequential()
    model.add(Dense(128,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(128, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))    
    return model

def create_model_img():
    model = Sequential()
    model.add(Conv2D(72, (3, 3), activation='relu')) 
    model.add(Dropout(0.3))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(Dropout(0.3))
    model.add(Conv2D(32, (3, 3), activation='relu'))
    model.add(Dropout(0.3))
    model.add(Flatten())
    model.add(Dense(50, activation='relu'))   
    return model

def plot_classification_report(y_tru, y_prd, mode, learning_rate, batch_size,epochs, figsize=(7, 7), ax=None):

    plt.figure(figsize=figsize)

    xticks = ['precision', 'recall', 'f1-score', 'support']
    yticks = ["Control", "Moderate", "Alzheimer's" ] 
    yticks += ['avg']

    rep = np.array(precision_recall_fscore_support(y_tru, y_prd)).T
    avg = np.mean(rep, axis=0)
    avg[-1] = np.sum(rep[:, -1])
    rep = np.insert(rep, rep.shape[0], avg, axis=0)

    sns.heatmap(rep,
                annot=True, 
                cbar=False, 
                xticklabels=xticks, 
                yticklabels=yticks,
                ax=ax, cmap = "Blues")
    
    plt.savefig('report_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'_' + str(epochs)+'.png')
    

def calc_confusion_matrix(result, test_label,mode, learning_rate, batch_size, epochs):
    test_label = to_categorical(test_label,3)

    true_label= np.argmax(test_label, axis =1)

    predicted_label= np.argmax(result, axis =1)
    
    n_classes = 3
    precision = dict()
    recall = dict()
    thres = dict()
    for i in range(n_classes):
        precision[i], recall[i], thres[i] = precision_recall_curve(test_label[:, i],
                                                            result[:, i])


    print ("Classification Report :") 
    print (classification_report(true_label, predicted_label))
    cr = classification_report(true_label, predicted_label, output_dict=True)
    return cr, precision, recall, thres


def cross_modal_attention(x, y):
    x = tf.expand_dims(x, axis=1)
    y = tf.expand_dims(y, axis=1)
    a1 = MultiHeadAttention(num_heads = 4,key_dim=50)(x, y)
    a2 = MultiHeadAttention(num_heads = 4,key_dim=50)(y, x)
    a1 = a1[:,0,:]
    a2 = a2[:,0,:]
    return concatenate([a1, a2])


def self_attention(x):
    x = tf.expand_dims(x, axis=1)
    attention = MultiHeadAttention(num_heads = 4, key_dim=50)(x, x)
    attention = attention[:,0,:]
    return attention

In [7]:
train_clinical = pd.read_csv("../preprocess_overlap/X_train_clinical.csv").values
test_clinical= pd.read_csv("../preprocess_overlap/X_test_clinical.csv").values

train_snp = pd.read_csv("../preprocess_overlap/X_train_snp.csv").values
test_snp = pd.read_csv("../preprocess_overlap/X_test_snp.csv").values

train_img= make_img("../preprocess_overlap/X_train_img.pkl")
test_img= make_img("../preprocess_overlap/X_test_img.pkl")

train_label= pd.read_csv("../preprocess_overlap/y_train.csv").values.astype("int").flatten()
test_label= pd.read_csv("../preprocess_overlap/y_test.csv").values.astype("int").flatten()

train_clinical = train_clinical.astype("float32")
test_clinical = test_clinical.astype("float32")
# train_snp = train_snp.astype("float32")
# train_snp = test_snp.astype("float32")

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import classification_report, accuracy_score
from keras.models import Model
import numpy as np

def multi_modal_model_baseline(mode, train_clinical, train_snp, train_img):
    in_clinical = Input(shape=(train_clinical.shape[1]))
    in_snp = Input(shape=(train_snp.shape[1]))
    in_img = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))
    
    dense_clinical = create_model_clinical()(in_clinical)
    dense_snp = create_model_snp()(in_snp)
    dense_img = create_model_img()(in_img)
    
    if mode == 'MM_BA':
        vt_att = cross_modal_attention(dense_img, dense_clinical)
        av_att = cross_modal_attention(dense_snp, dense_img)
        ta_att = cross_modal_attention(dense_clinical, dense_snp)
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
    elif mode == 'MM_SA':
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
        merged = concatenate([aa_att, vv_att, tt_att, dense_img, dense_snp, dense_clinical])
    elif mode == 'MM_SA_BA':
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
        vt_att = cross_modal_attention(vv_att, tt_att)
        av_att = cross_modal_attention(aa_att, vv_att)
        ta_att = cross_modal_attention(tt_att, aa_att)
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
    elif mode == 'None':
        merged = concatenate([dense_img, dense_snp, dense_clinical])
    else:
        print("Invalid mode. Choose from 'MM_SA', 'MM_BA', 'MM_SA_BA', 'None'.")
        return
    
    features_model = Model(inputs=[in_clinical, in_snp, in_img], outputs=merged)
    return features_model


In [18]:
def train_baseline(mode, train_clinical, train_snp, train_img, train_label, test_clinical, test_snp, test_img, test_label):
    feature_model = multi_modal_model_baseline(mode, train_clinical, train_snp, train_img)
    train_features = feature_model.predict([train_clinical, train_snp, train_img])
    test_features = feature_model.predict([test_clinical, test_snp, test_img])
    
    classifier = LogisticRegression(max_iter=1000)
    classifier.fit(train_features, train_label)
    predictions = classifier.predict(test_features)
    
    accuracy = accuracy_score(test_label, predictions)
    print("Test Accuracy: ", accuracy)
    print(classification_report(test_label, predictions))
    
    return accuracy

In [19]:
accuracy = train_baseline(
    mode="MM_SA_BA",
    train_clinical=train_clinical,
    train_snp=train_snp,
    train_img=train_img,
    train_label=train_label,
    test_clinical=test_clinical,
    test_snp=test_snp,
    test_img=test_img,
    test_label=test_label
)
print(f"Baseline Model Accuracy: {accuracy:.4f}")


KerasTensor(type_spec=TensorSpec(shape=(None, 450), dtype=tf.float32, name=None), name='concatenate_15/concat:0', description="created by layer 'concatenate_15'")
Test Accuracy:  0.9142857142857143
              precision    recall  f1-score   support

           0       0.92      0.96      0.94        24
           1       0.67      0.50      0.57         4
           2       1.00      1.00      1.00         7

    accuracy                           0.91        35
   macro avg       0.86      0.82      0.84        35
weighted avg       0.91      0.91      0.91        35

Baseline Model Accuracy: 0.9143


In [3]:
from keras.models import load_model, Model
import numpy as np

best_model = load_model('best_model.h5')
best_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_3 (InputLayer)        [(None, 72, 72, 3)]          0         []                            
                                                                                                  
 input_1 (InputLayer)        [(None, 149)]                0         []                            
                                                                                                  
 input_2 (InputLayer)        [(None, 179666)]             0         []                            
                                                                                                  
 sequential_2 (Sequential)   (None, 50)                   7031666   ['input_3[0][0]']             
                                                                                              

In [8]:
feature_extractor = Model(inputs=best_model.input, outputs=best_model.get_layer('concatenate_3').output) # Get merged layer

train_features = feature_extractor.predict([train_clinical, train_snp, train_img])
test_features = feature_extractor.predict([test_clinical, test_snp, test_img])

print(f"Training feature shape: {train_features.shape}")
print(f"Testing feature shape: {test_features.shape}")

Training feature shape: (137, 450)
Testing feature shape: (35, 450)


In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.metrics import classification_report, accuracy_score

y_train = train_label
y_test = test_label

# Logistic Regression
log_reg = LogisticRegression()
# log_reg = LogisticRegression(penalty='l2', C=0.1, solver='lbfgs', max_iter=200)
log_reg.fit(train_features, y_train)
y_pred_log_reg = log_reg.predict(test_features)
print("Logistic Regression Report:")
print(classification_report(y_test, y_pred_log_reg))

# Random Forest
rf = RandomForestClassifier()
# rf = RandomForestClassifier(n_estimators=200, max_depth=10, min_samples_split=5, class_weight='balanced')
rf.fit(train_features, y_train)
y_pred_rf = rf.predict(test_features)
print("Random Forest Report:")
print(classification_report(y_test, y_pred_rf))

# SVM
svm = SVC()
# svm = SVC(C=0.5, kernel='rbf', gamma='scale', class_weight='balanced', probability=True)
svm.fit(train_features, y_train)
y_pred_svm = svm.predict(test_features)
print("SVM Report:")
print(classification_report(y_test, y_pred_svm))

# XGBoost
xgb = XGBClassifier()
# xgb = XGBClassifier(n_estimators=200, learning_rate=0.1, max_depth=5, subsample=0.8, colsample_bytree=0.8)
xgb.fit(train_features, y_train)
y_pred_xgb = xgb.predict(test_features)
print("XGBoost Report:")
print(classification_report(y_test, y_pred_xgb))

Logistic Regression Report:
              precision    recall  f1-score   support

           0       0.86      1.00      0.92        24
           1       0.00      0.00      0.00         4
           2       1.00      1.00      1.00         7

    accuracy                           0.89        35
   macro avg       0.62      0.67      0.64        35
weighted avg       0.79      0.89      0.83        35

Random Forest Report:
              precision    recall  f1-score   support

           0       0.88      0.96      0.92        24
           1       0.00      0.00      0.00         4
           2       0.78      1.00      0.88         7

    accuracy                           0.86        35
   macro avg       0.55      0.65      0.60        35
weighted avg       0.76      0.86      0.81        35

SVM Report:
              precision    recall  f1-score   support

           0       1.00      0.79      0.88        24
           1       0.43      0.75      0.55         4
           2 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


XGBoost Report:
              precision    recall  f1-score   support

           0       0.88      0.96      0.92        24
           1       0.00      0.00      0.00         4
           2       0.78      1.00      0.88         7

    accuracy                           0.86        35
   macro avg       0.55      0.65      0.60        35
weighted avg       0.76      0.86      0.81        35



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [17]:
from sklearn.model_selection import GridSearchCV

param_grid_lr = {
    'penalty': ['l1', 'l2'],
    'C': [0.01, 0.1, 1, 10],
    'solver': ['liblinear', 'saga']
}

grid_search_lr = GridSearchCV(
    estimator=LogisticRegression(max_iter=1000),
    param_grid=param_grid_lr,
    scoring='accuracy',
    cv=3
)
grid_search_lr.fit(train_features, y_train)

print("Best param:", grid_search_lr.best_params_)
print("Highest accuracy:", grid_search_lr.best_score_)

best_lr_model = grid_search_lr.best_estimator_
y_pred_lr = best_lr_model.predict(test_features)
print(classification_report(y_test, y_pred_lr))



Best param: {'C': 1, 'penalty': 'l2', 'solver': 'liblinear'}
Highest accuracy: 0.9706924315619968
              precision    recall  f1-score   support

           0       0.86      1.00      0.92        24
           1       0.00      0.00      0.00         4
           2       1.00      1.00      1.00         7

    accuracy                           0.89        35
   macro avg       0.62      0.67      0.64        35
weighted avg       0.79      0.89      0.83        35



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
param_grid = {
    'n_estimators': [100, 200, 300],
    'max_depth': [None, 10, 20],
    'min_samples_split': [2, 5, 10],
    'class_weight': ['balanced', None]
}

grid_search = GridSearchCV(estimator=RandomForestClassifier(), param_grid=param_grid, cv=3, scoring='accuracy')
grid_search.fit(train_features, y_train)

print("Best param:", grid_search.best_params_)
print("Highest accuracy:", grid_search.best_score_)

best_model = grid_search.best_estimator_
y_pred = best_model.predict(test_features)
print("Report:")
print(classification_report(y_test, y_pred))

Best param: {'class_weight': 'balanced', 'max_depth': None, 'min_samples_split': 10, 'n_estimators': 300}
Highest accuracy: 0.926892109500805
Report:
              precision    recall  f1-score   support

           0       0.88      0.96      0.92        24
           1       0.00      0.00      0.00         4
           2       0.78      1.00      0.88         7

    accuracy                           0.86        35
   macro avg       0.55      0.65      0.60        35
weighted avg       0.76      0.86      0.81        35



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [16]:
param_grid_svm = {
    'C': [0.1, 1, 10],
    'kernel': ['linear', 'rbf', 'poly'],
    'gamma': ['scale', 'auto']
}

grid_search_svm = GridSearchCV(
    estimator=SVC(probability=True),
    param_grid=param_grid_svm,
    scoring='accuracy',
    cv=3
)
grid_search_svm.fit(train_features, y_train)

print("Best param:", grid_search_svm.best_params_)
print("Highest accuracy:", grid_search_svm.best_score_)

best_svm_model = grid_search_svm.best_estimator_
y_pred_svm = best_svm_model.predict(test_features)
print(classification_report(y_test, y_pred_svm))

Best param: {'C': 1, 'gamma': 'scale', 'kernel': 'linear'}
Highest accuracy: 0.9561996779388084
              precision    recall  f1-score   support

           0       0.86      1.00      0.92        24
           1       0.00      0.00      0.00         4
           2       1.00      1.00      1.00         7

    accuracy                           0.89        35
   macro avg       0.62      0.67      0.64        35
weighted avg       0.79      0.89      0.83        35



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [19]:
param_grid_xgb = {
    'n_estimators': [100, 200, 300],
    'learning_rate': [0.01, 0.1, 0.2],
    'max_depth': [3, 5, 7],
    'subsample': [0.8, 1.0],
    'colsample_bytree': [0.8, 1.0]
}

grid_search_xgb = GridSearchCV(
    estimator=XGBClassifier(eval_metric='mlogloss'),
    param_grid=param_grid_xgb,
    scoring='accuracy',
    cv=3
)
grid_search_xgb.fit(train_features, y_train)

print("Best param:", grid_search_xgb.best_params_)
print("Highest accuracy:", grid_search_xgb.best_score_)

best_xgb_model = grid_search_xgb.best_estimator_
y_pred_xgb = best_xgb_model.predict(test_features)
print(classification_report(y_test, y_pred_xgb))

Best param: {'colsample_bytree': 1.0, 'learning_rate': 0.2, 'max_depth': 3, 'n_estimators': 100, 'subsample': 1.0}
Highest accuracy: 0.9194847020933977
              precision    recall  f1-score   support

           0       0.88      0.96      0.92        24
           1       0.00      0.00      0.00         4
           2       0.78      1.00      0.88         7

    accuracy                           0.86        35
   macro avg       0.55      0.65      0.60        35
weighted avg       0.76      0.86      0.81        35



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
