## Multimodal Deep Neural Network (CNN + Attention + MLP) for ADHD Classification

### Motivation

Previous convolutional neural network (CNN) models trained solely on fMRI-derived features such as Functional Connectivity (FC), ReHo, and fALFF achieved modest classification accuracy (typically in the 50–60% range). However, traditional machine learning models (e.g., Random Forest, XGBoost) showed better performance when phenotypic information (age, sex, IQ, etc.) was incorporated.

This motivates a shift toward **multimodal learning**, where both imaging and non-imaging data are utilized jointly. The proposed architecture addresses this by combining CNNs for imaging data, attention mechanisms to emphasize important spatial features, and a multi-layer perceptron (MLP) for structured phenotypic input.

---

### Architecture Overview

1. **CNN Branches (for FC, ReHo, and fALFF)**
   - Each imaging modality (FC, ReHo, fALFF) is passed through multiple 2D convolutional layers.
   - Each branch includes layers of Conv2D → BatchNormalization → MaxPooling2D.
   - The output is flattened and processed through attention layers.

2. **Attention Mechanism**
   - Learns to weight the most informative spatial features from CNN outputs.
   - Helps improve model focus on brain regions contributing most to classification.

3. **MLP Branch (for Phenotypic Data)**
   - Handles non-image input such as age, sex, IQ, medication status, and site.
   - Structured as Dense → Dropout → Dense layers.

4. **Fusion Layer**
   - The outputs from all CNN-attention branches and the MLP are concatenated.
   - Followed by fully connected layers to perform final binary classification using sigmoid activation.

---

### Implementation Goals

- Align input shapes and normalization across modalities.
- Enable the CNN branches to learn distinct patterns per fMRI modality.
- Integrate attention to improve interpretability and classification.
- Train the model end-to-end to optimize for accuracy and generalization.


In [1]:
import pandas as pd
import os
import math
import numpy as np

from IPython.core.interactiveshell import InteractiveShell 
InteractiveShell.ast_node_interactivity = 'all'

import warnings
warnings.filterwarnings("ignore")

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score, f1_score
from sklearn.preprocessing import StandardScaler
from scipy.stats import zscore

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout, 
    BatchNormalization, concatenate, GlobalAveragePooling2D, 
    Reshape, Multiply
)

from keras_tuner.tuners import RandomSearch
import random

def set_seeds(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

set_seeds(42)  


### Importing Necessary files

In [2]:
# Loading the data
FC_data = pd.read_csv("C:/Users/prajw/Desktop/Undergrad Research/Datasets/Preprocessed FC/Merged_FC.csv", index_col = 0)
FC_data.index.name = "Subject ID"

Reho = pd.read_csv("C:/Users/prajw/Desktop/Undergrad Research/Datasets/Preprocessed ReHo/All_ReHo.csv", index_col = 0)
Reho.index.name = "Subject ID"

falff = pd.read_csv("C:/Users/prajw/Desktop/Undergrad Research/Datasets/Preprocessed fALFF/All_falff.csv", index_col = 0)
falff.index.name = "Subject ID"

FC_pheno_data = pd.read_csv("C:/Users/prajw/Desktop/Undergrad Research/Datasets/Preprocessed FC matrix with Pheno/FC_Merged.csv", index_col = 0)
FC_pheno_data.index.name = "Subject ID"

FC_data['DX'] = FC_data['DX'].apply(lambda x: 1 if x > 0 else 0)
Reho['DX'] = Reho['DX'].apply(lambda x: 1 if x > 0 else 0)
falff['DX'] = falff['DX'].apply(lambda x: 1 if x > 0 else 0)
Reho = Reho.set_index('ScanDir ID')

FC_data
Reho
falff

Unnamed: 0_level_0,DX,0,1,2,3,4,5,6,7,8,...,17945,17946,17947,17948,17949,17950,17951,17952,17953,17954
Subject ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1018959,0,0.082767,-0.202121,-0.253291,0.143162,-0.212533,0.572503,-0.346531,0.010921,0.014705,...,0.112687,0.171248,-0.073971,-0.413591,0.238855,0.300857,0.177778,0.658069,-0.052805,0.145006
1019436,1,0.216872,-0.055456,0.274632,0.057173,0.318357,0.334924,-0.285290,-0.093749,0.051842,...,-0.466976,-0.189028,0.048556,-0.476408,0.064047,-0.008339,0.424513,0.450524,-0.022595,0.231871
1043241,0,-0.060757,0.218841,-0.220541,-0.009787,-0.018797,0.102055,-0.207456,-0.332605,0.157679,...,-0.118490,-0.123154,-0.287799,-0.404065,0.111559,-0.233413,-0.120683,0.083335,-0.058685,0.343333
1266183,0,-0.063801,0.061519,-0.011792,0.016329,0.135984,0.164675,0.174189,-0.108558,-0.031854,...,-0.256891,-0.414658,-0.549495,0.543498,0.292311,0.164908,-0.031181,0.084472,-0.235999,-0.609832
1535233,0,0.022708,0.380467,0.404897,0.422225,0.546707,-0.022366,0.218834,0.274712,0.191838,...,-0.238173,-0.059068,-0.020537,-0.465405,0.143333,0.194637,0.410916,0.510831,-0.106011,0.137268
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5669389,0,-0.136454,-0.065583,0.017311,0.052200,0.460014,0.306935,0.176286,-0.345477,0.164890,...,0.206593,0.249444,-0.039177,0.034391,0.666300,0.386416,0.166788,0.555021,0.308751,0.346726
6383713,1,0.004203,-0.314136,-0.223007,0.475450,0.595427,0.529921,-0.154340,-0.193150,-0.176574,...,0.274156,0.141414,-0.183968,-0.383677,0.533130,0.236219,0.114466,0.385278,0.143373,0.421396
6477085,0,-0.619545,-0.075170,-0.074883,0.100222,0.329930,0.071147,-0.110990,0.133600,-0.494886,...,0.137766,0.480185,0.059669,0.074272,0.309700,0.376793,0.156937,0.225204,0.034719,0.042565
7994085,0,-0.298730,0.156669,-0.152175,0.284072,0.307652,0.079670,-0.070846,-0.309843,-0.055744,...,-0.386502,0.386439,-0.201634,-0.290110,-0.528532,-0.041827,0.313892,0.230130,-0.464864,-0.326013


Unnamed: 0_level_0,DX,ReHo_1,ReHo_2,ReHo_3,ReHo_4,ReHo_5,ReHo_6,ReHo_7,ReHo_8,ReHo_9,...,ReHo_181,ReHo_182,ReHo_183,ReHo_184,ReHo_185,ReHo_186,ReHo_187,ReHo_188,ReHo_189,ReHo_190
ScanDir ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1018959,0,-0.000693,-0.002784,0.005485,0.013734,0.009531,0.012442,0.015524,0.004018,0.013782,...,-0.000228,-0.004366,0.007430,0.014627,0.001317,-0.006228,0.001716,-0.000210,-0.012670,-0.006496
1019436,1,0.001641,0.000157,-0.000600,0.008507,-0.005215,0.021348,0.037630,-0.002221,-0.028085,...,0.005336,0.026657,-0.010433,0.010363,0.016745,0.020867,-0.026912,-0.010929,0.002073,-0.014229
1043241,0,-0.025526,-0.007973,-0.003846,-0.008930,0.002580,0.012220,0.027966,0.006901,0.004053,...,-0.004057,0.029532,0.017034,0.011103,-0.010745,-0.009088,0.008249,-0.008181,-0.001200,-0.007629
1266183,0,0.017338,-0.020910,0.026822,-0.028537,0.004145,0.018429,0.023084,0.018687,-0.043965,...,-0.009916,0.024002,0.001976,-0.016781,0.019819,0.008669,0.015358,-0.016745,-0.011874,0.012493
1535233,0,0.025408,-0.004181,0.019778,0.010639,0.042414,0.037103,0.037410,0.053010,-0.005686,...,-0.033017,-0.009099,-0.008167,-0.052603,0.034772,-0.018005,0.037081,0.002573,0.026957,0.030917
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5669389,0,0.041583,0.031368,0.018500,0.003853,0.005902,0.035916,0.007781,0.037332,-0.033296,...,0.029162,-0.003820,0.023286,-0.027044,0.047881,-0.004888,0.033479,0.043651,0.021794,0.022059
6383713,1,-0.011001,0.023553,0.045278,0.023125,-0.010543,0.009812,-0.002366,0.047632,0.009774,...,0.023663,-0.011299,0.034275,0.012848,0.053413,-0.014896,0.022261,0.042729,0.033602,0.029257
6477085,0,-0.090311,0.077690,0.036559,-0.020320,0.005736,-0.015036,-0.008798,0.056555,-0.011541,...,0.046824,0.049709,0.059777,0.007273,0.057467,0.073680,0.053626,0.080742,0.041598,0.056369
7994085,0,-0.009549,0.003129,0.000436,0.005846,0.003584,0.004464,0.017457,0.018368,-0.005317,...,0.002915,0.000692,-0.005344,-0.006689,0.002440,0.005350,0.000327,-0.005584,0.002570,-0.007553


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,181,182,183,184,185,186,187,188,189,DX
Subject ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1018959,0.440059,0.654038,0.604660,0.539188,0.427841,0.556644,0.456504,0.500982,0.631251,0.613161,...,0.375033,0.613263,0.645600,0.468351,0.381020,0.650361,0.684567,0.676759,0.612947,0
1019436,0.756351,0.527825,0.619720,0.466123,0.497030,0.384478,0.546645,0.465267,0.588696,0.561305,...,0.492485,0.475318,0.603202,0.554377,0.620290,0.637574,0.674873,0.568847,0.609281,1
1043241,0.730804,0.413692,0.591623,0.642186,0.434255,0.555075,0.479140,0.540595,0.621012,0.618000,...,0.584502,0.637769,0.727549,0.557828,0.637839,0.755363,0.547045,0.572328,0.555771,0
1266183,0.324953,0.566304,0.556289,0.538490,0.495830,0.416628,0.328429,0.372241,0.477356,0.531180,...,0.448431,0.631884,0.497536,0.458568,0.515995,0.474847,0.524475,0.535504,0.719011,0
1535233,0.544618,0.434999,0.646133,0.715492,0.669921,0.582833,0.459039,0.543931,0.610697,0.451788,...,0.539778,0.641483,0.671250,0.642956,0.653623,0.727645,0.405513,0.391241,0.657757,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5669389,0.256059,0.476082,0.668876,0.450693,0.398631,0.544864,0.429395,0.567268,0.646345,0.525099,...,0.316761,0.561301,0.515058,0.305818,0.312659,0.622423,0.618545,0.537619,0.563718,0
6383713,0.229875,0.445956,0.446567,0.515856,0.507962,0.443237,0.339974,0.519716,0.503483,0.485144,...,0.332223,0.643870,0.490933,0.395509,0.346755,0.377433,0.524141,0.334546,0.418346,1
6477085,0.340038,0.422289,0.542378,0.624692,0.546239,0.549973,0.443172,0.644051,0.535838,0.474425,...,0.383385,0.465567,0.499877,0.352168,0.394593,0.424113,0.480928,0.514339,0.584390,0
7994085,0.391632,0.510163,0.715460,0.562021,0.546323,0.440242,0.350638,0.617332,0.689878,0.572333,...,0.567139,0.657349,0.644709,0.422675,0.511961,0.565686,0.756029,0.552888,0.624066,0


In [3]:
phenotype_cols = ['Inattentive', 'Hyper/Impulsive', 'Verbal IQ', 
                  'Performance IQ', 'Full4 IQ', 'Med Status', 'DX']
pheno_data = FC_pheno_data[phenotype_cols].copy()
pheno_data.index.name = 'Subject ID'
pheno_data['DX'] = pheno_data['DX'].apply(lambda x: 1 if x > 0 else 0)
pheno_data

# Step: Find common subjects between phenotype and fMRI data
common_subjects = pheno_data.index.intersection(FC_data.index)

# Step: Filter all datasets to include only those subjects
FC_data = FC_data.loc[common_subjects]
Reho = Reho.loc[common_subjects]
falff = falff.loc[common_subjects]
pheno_data = pheno_data.loc[common_subjects]


Unnamed: 0_level_0,Inattentive,Hyper/Impulsive,Verbal IQ,Performance IQ,Full4 IQ,Med Status,DX
Subject ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1018959,47.0,44.0,99.0,115.0,103.0,1,0
1019436,60.0,66.0,124.0,108.0,122.0,1,1
1043241,40.0,43.0,128.0,106.0,120.0,1,0
1266183,44.0,43.0,136.0,96.0,120.0,1,0
1535233,41.0,43.0,106.0,135.0,122.0,1,0
...,...,...,...,...,...,...,...
5669389,15.0,9.0,120.0,97.0,110.0,1,0
6383713,29.0,32.0,115.0,91.0,104.0,1,1
6477085,13.0,12.0,115.0,112.0,115.0,1,0
7994085,23.0,15.0,89.0,86.0,86.0,1,0


### Shaping the Data into 2D with 3 channels with FC, ReHo and fALFF

In [5]:
X_FC = FC_data.drop(columns = "DX").values
y_FC = FC_data["DX"].values

X_Reho = Reho.drop(columns= "DX").values
X_Reho = X_Reho.astype(float)
y_Reho = Reho["DX"].values

X_falff = falff.drop(columns="DX").values
X_falff = X_falff.astype(float)
y_falff = falff["DX"].values



# Making flat FC data into symmetric Matrix
n_regions = 190 
triu_indices = np.triu_indices(n_regions, k = 1)

fc_matrices = []

for row in X_FC:
    mat = np.zeros((n_regions, n_regions))
    mat[triu_indices] = row
    mat += mat.T
    fc_matrices.append(mat)

X_fc_reshape = np.array(fc_matrices)

# ReHo --> Outer Product
reho_matrices = np.array([np.outer(row, row) for row in X_Reho])
X_reho_reshape = reho_matrices

# fALFF ---> Outer Product
falff_matrices = np.array([np.outer(row, row) for row in X_falff])
X_falff_reshape = falff_matrices

# Function to apply z-score normalization per matrix
def normalize_per_subject(matrices):
    normalized = []
    for mat in matrices:
        flat = mat.flatten()
        norm_flat = zscore(flat)
        norm_mat = norm_flat.reshape(mat.shape)
        normalized.append(norm_mat)
    return np.array(normalized)

# Apply to FC, ReHo, fALFF
X_fc_reshaped = normalize_per_subject(X_fc_reshape)
X_reho_reshaped = normalize_per_subject(X_reho_reshape)
X_falff_reshaped = normalize_per_subject(X_falff_reshape)


X_combined = np.stack([X_fc_reshaped,
                      X_reho_reshaped,
                      X_falff_reshaped], axis = -1)

y_combined = y_FC
X_combined.shape

X_pheno = pheno_data.drop(columns = "DX").values
scaler = StandardScaler()
X_pheno = scaler.fit_transform(X_pheno) 

(493, 190, 190, 3)

### Building the multi model

In [49]:
def build_model(hp):
    # CNN Branch
    cnn_input = Input(shape=(190, 190, 3), name="cnn_input")
    
    x = Conv2D(
        filters=hp.Choice('conv1_filters', [32, 64]),
        kernel_size=(3, 3),
        activation='relu',
        padding='same'
    )(cnn_input)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2))(x)

    x = Conv2D(
        filters=hp.Choice('conv2_filters', [64, 96, 128]),
        kernel_size=(3, 3),
        activation='relu',
        padding='same'
    )(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2))(x)

    x = Conv2D(
        filters=hp.Choice('conv3_filters', [128, 256]),
        kernel_size=(3, 3),
        activation='relu',
        padding='same'
    )(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2))(x)
    
    # Instead of Flatten immediately
    gap = GlobalAveragePooling2D()(x)             # shape (batch, channels)
    attention = Dense(gap.shape[-1], activation='sigmoid')(gap)
    attention = Reshape((1, 1, gap.shape[-1]))(attention)
    x = Multiply()([x, attention])                # Broadcast attention over spatial dimensions
       
    #x = Flatten()(x)
    x = GlobalAveragePooling2D()(x)

    # MLP Branch
    mlp_input = Input(shape=(6,), name="mlp_input")

    y = Dense(
        units=hp.Int('mlp_units1', 32, 128, step=32),
        activation='relu'
    )(mlp_input)
    y = Dropout(hp.Float('mlp_dropout', 0.3, 0.7, step=0.1))(y)

    y = Dense(
        units=hp.Int('mlp_units2', 16, 64, step=16),
        activation='relu'
    )(y)

    # Merge branches
    merged = concatenate([x, y])

    z = Dense(
        units=hp.Int('dense_units', 64, 256, step=64),
        activation='relu'
    )(merged)
    z = Dropout(hp.Float('final_dropout', 0.3, 0.7, step=0.1))(z)

    output = Dense(1, activation='sigmoid')(z)

    model = Model(inputs=[cnn_input, mlp_input], outputs=output)

    model.compile(
        optimizer=Adam(learning_rate=hp.Choice('lr', [1e-3, 5e-4, 1e-4])),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )

    return model

results = []

### Splitting Data and Predicting

In [52]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.utils.class_weight import compute_class_weight
from keras_tuner.tuners import RandomSearch
from tensorflow.keras.callbacks import EarlyStopping

# --- Step 0: Split Data ---
X_cnn = X_combined               # shape: (n_samples, 190, 190, 3)
X_mlp = X_pheno                  # standardized phenotypic features
y = y_FC                         # binary target

X_cnn_train, X_cnn_val, X_mlp_train, X_mlp_val, y_train, y_val = train_test_split(
    X_cnn, X_mlp, y, test_size=0.2, stratify=y, random_state=42
)

# --- Step 1: Define Hyperparameter Tuner ---
tuner = RandomSearch(
    build_model,                              # your model function
    objective='val_accuracy',
    max_trials=15,
    executions_per_trial=1,
    directory='tuner_multimodal',
    project_name='cnn_mlp_combo',
    overwrite=True                           # <== important to allow fresh rerun
)

# --- Step 2: Run Tuning ---
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

tuner.search(
    [X_cnn_train, X_mlp_train], y_train,
    validation_data=([X_cnn_val, X_mlp_val], y_val),
    epochs=20,
    batch_size=32,
    callbacks=[early_stop],
    verbose=1
)

# --- Step 3: Retrieve Best Model ---
best_model = tuner.get_best_models(1)[0]
best_hp = tuner.get_best_hyperparameters(1)[0]

print("\nBest Hyperparameters:")
for key in best_hp.values:
    print(f"{key}: {best_hp.get(key)}")

# --- Step 4: Evaluate or Retrain ---
#class_weights_array = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
#class_weight_dict = {i: class_weights_array[i] for i in range(len(class_weights_array))}
class_weight_dict = {0: 1.0, 1: 1.5}
#class_weight_dict = {0: 1.0, 1: 2.0}


best_model.fit(
    [X_cnn_train, X_mlp_train], y_train,
    validation_data=([X_cnn_val, X_mlp_val], y_val),
    epochs=20,
    batch_size=32,
    callbacks=[early_stop],
    class_weight=class_weight_dict,
    verbose=1
)

# --- Step 5: Final Evaluation ---
#y_pred = (best_model.predict([X_cnn_val, X_mlp_val]) > 0.5).astype(int)
y_pred_probs = best_model.predict([X_cnn_val, X_mlp_val])
y_pred = (y_pred_probs > 0.4).astype(int)  


acc = accuracy_score(y_val, y_pred)
f1 = f1_score(y_val, y_pred, average='weighted')

print(f"\nValidation Accuracy: {acc:.4f}")
print(f"Validation F1 Score: {f1:.4f}")
print("\nClassification Report:\n", classification_report(y_val, y_pred, digits=4))

# Save to results
results.append({
    'Dataset': 'CNN with Attention + MLP',
    'Best Params': {key: best_hp.get(key) for key in best_hp.values},
    'Val Accuracy': acc,
    'Val F1': f1
})


Trial 15 Complete [00h 00m 20s]
val_accuracy: 0.7777777910232544

Best val_accuracy So Far: 0.7979797720909119
Total elapsed time: 00h 03m 20s

Best Hyperparameters:
conv1_filters: 64
conv2_filters: 128
conv3_filters: 128
mlp_units1: 96
mlp_dropout: 0.3
mlp_units2: 32
dense_units: 256
final_dropout: 0.5
lr: 0.001
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20


<keras.callbacks.History at 0x1bd2a208700>


Validation Accuracy: 0.6667
Validation F1 Score: 0.6658

Classification Report:
               precision    recall  f1-score   support

           0     0.7619    0.5818    0.6598        55
           1     0.5965    0.7727    0.6733        44

    accuracy                         0.6667        99
   macro avg     0.6792    0.6773    0.6665        99
weighted avg     0.6884    0.6667    0.6658        99



#### Cross Validation

In [54]:
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.utils.class_weight import compute_class_weight

cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

cv_accuracies = []
cv_f1_scores = []

for fold, (train_idx, val_idx) in enumerate(cv.split(X_combined, y_FC)):
    print(f"\nFold {fold + 1}")

    # Split data for this fold
    X_cnn_train, X_cnn_val = X_combined[train_idx], X_combined[val_idx]
    X_mlp_train, X_mlp_val = X_pheno[train_idx], X_pheno[val_idx]
    y_train, y_val = y_FC[train_idx], y_FC[val_idx]

    # Compute class weights
    #class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
    #class_weight_dict = {i: class_weights[i] for i in range(len(class_weights))}
    #for w in [1.0, 1.5, 2.0, 3.0]:
        #class_weight_dict = {0: 1.0, 1: w}
    #class_weight_dict = {0: 1.0, 1: 1.5}
    class_weight_dict = {0: 1.0, 1: 2.0}

    # Build model with best hyperparameters
    model = build_model(best_hp)

    # Train model
    history = model.fit(
        [X_cnn_train, X_mlp_train], y_train,
        validation_data=([X_cnn_val, X_mlp_val], y_val),
        epochs=20,
        batch_size=32,
        class_weight=class_weight_dict,
        callbacks=[EarlyStopping(patience=5, restore_best_weights=True)],
        verbose=1
    )

    # Evaluate model
    # y_pred = (model.predict([X_cnn_val, X_mlp_val]) > 0.5).astype(int)
    y_pred_probs = best_model.predict([X_cnn_val, X_mlp_val])
    y_pred = (y_pred_probs > 0.4).astype(int)  

    acc = accuracy_score(y_val, y_pred)
    f1 = f1_score(y_val, y_pred, average='weighted')

    print(f"Fold {fold + 1} Accuracy: {acc:.4f} | F1 Score: {f1:.4f}")
    print(classification_report(y_val, y_pred, digits=4))

    cv_accuracies.append(acc)
    cv_f1_scores.append(f1)

# Final results
print(f"\nAverage CV Accuracy: {np.mean(cv_accuracies):.4f}")
print(f"Average CV F1 Score: {np.mean(cv_f1_scores):.4f}")

results.append({
    'Dataset': 'Multimodal CNN + MLP (CV + Attention + Tuned)',
    'Val Accuracy': np.mean(cv_accuracies),
    'Val F1': np.mean(cv_f1_scores),
    'Best Params': {key: best_hp.get(key) for key in best_hp.values}
})



Fold 1
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Fold 1 Accuracy: 0.8400 | F1 Score: 0.8390
              precision    recall  f1-score   support

           0     1.0000    0.7143    0.8333        28
           1     0.7333    1.0000    0.8462        22

    accuracy                         0.8400        50
   macro avg     0.8667    0.8571    0.8397        50
weighted avg     0.8827    0.8400    0.8390        50


Fold 2
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Fold 2 Accuracy: 0.8200 | F1 Score: 0.8195
              precision    recall  f1-score   support

           0     0.9524    0.7143    0.8163        28
           1

In [None]:
results_df = pd.DataFrame(results)
results_df


In [56]:
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.utils.class_weight import compute_class_weight
import numpy as np


# Step 1: Split off 20% test set
X_train_img, X_test_img, X_train_pheno, X_test_pheno, y_train_all, y_test = train_test_split(
    X_combined, X_pheno, y_FC, test_size=0.2, stratify=y_FC, random_state=42
)

# Step 2: Perform 10-fold CV on training set
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
cv_accuracies, cv_f1_scores = [], []

for fold, (train_idx, val_idx) in enumerate(cv.split(X_train_img, y_train_all)):
    print(f"\nFold {fold + 1}")

    X_cnn_train, X_cnn_val = X_train_img[train_idx], X_train_img[val_idx]
    X_mlp_train, X_mlp_val = X_train_pheno[train_idx], X_train_pheno[val_idx]
    y_train, y_val = y_train_all[train_idx], y_train_all[val_idx]

    # Class weights
    class_weight_dict = {0: 1, 1: 2}
    #class_weight_dict = {0: 1.0, 1: 2.0}

    # Build and train model
    model = build_model(best_hp)
    model.fit(
        [X_cnn_train, X_mlp_train], y_train,
        validation_data=([X_cnn_val, X_mlp_val], y_val),
        epochs=20,
        batch_size=32,
        class_weight=class_weight_dict,
        callbacks=[EarlyStopping(patience=5, restore_best_weights=True)],
        verbose=1
    )

    # Evaluate on validation fold
    y_pred_probs = model.predict([X_cnn_val, X_mlp_val])
    y_pred = (y_pred_probs > 0.7).astype(int)

    acc = accuracy_score(y_val, y_pred)
    f1 = f1_score(y_val, y_pred, average='weighted')

    print(f"Fold {fold + 1} Accuracy: {acc:.4f} | F1 Score: {f1:.4f}")
    print(classification_report(y_val, y_pred, digits=4))

    cv_accuracies.append(acc)
    cv_f1_scores.append(f1)

# Step 3: Evaluate final model on hold-out test set
y_test_probs = model.predict([X_test_img, X_test_pheno])
y_test_pred = (y_test_probs > 0.7).astype(int)

print("\nHold-Out Test Set Performance:")
print(classification_report(y_test, y_test_pred, digits=4))

# Step 4: Final CV results summary
print(f"\nAverage CV Accuracy: {np.mean(cv_accuracies):.4f}")
print(f"Average CV F1 Score: {np.mean(cv_f1_scores):.4f}")

# Step 5: Save to results list
results.append({
    'Dataset': 'Multimodal CNN + MLP (CV + Attention + Tuned)',
    'Val Accuracy': np.mean(cv_accuracies),
    'Val F1': np.mean(cv_f1_scores),
    'Best Params': {key: best_hp.get(key) for key in best_hp.values}
})



Fold 1
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20


<keras.callbacks.History at 0x1bd22ea7c10>

Fold 1 Accuracy: 0.8250 | F1 Score: 0.8159
              precision    recall  f1-score   support

           0     0.7586    1.0000    0.8627        22
           1     1.0000    0.6111    0.7586        18

    accuracy                         0.8250        40
   macro avg     0.8793    0.8056    0.8107        40
weighted avg     0.8672    0.8250    0.8159        40


Fold 2
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20


<keras.callbacks.History at 0x1bd2c65e470>

Fold 2 Accuracy: 0.7500 | F1 Score: 0.7335
              precision    recall  f1-score   support

           0     0.7000    0.9545    0.8077        22
           1     0.9000    0.5000    0.6429        18

    accuracy                         0.7500        40
   macro avg     0.8000    0.7273    0.7253        40
weighted avg     0.7900    0.7500    0.7335        40


Fold 3
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20


<keras.callbacks.History at 0x1bd30b8ff10>

Fold 3 Accuracy: 0.8250 | F1 Score: 0.8198
              precision    recall  f1-score   support

           0     0.7778    0.9545    0.8571        22
           1     0.9231    0.6667    0.7742        18

    accuracy                         0.8250        40
   macro avg     0.8504    0.8106    0.8157        40
weighted avg     0.8432    0.8250    0.8198        40


Fold 4
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x1bd29d3aa10>

Fold 4 Accuracy: 0.8000 | F1 Score: 0.7868
              precision    recall  f1-score   support

           0     0.7333    1.0000    0.8462        22
           1     1.0000    0.5556    0.7143        18

    accuracy                         0.8000        40
   macro avg     0.8667    0.7778    0.7802        40
weighted avg     0.8533    0.8000    0.7868        40


Fold 5
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20


<keras.callbacks.History at 0x1bd22987eb0>

Fold 5 Accuracy: 0.7692 | F1 Score: 0.7552
              precision    recall  f1-score   support

           0     0.7241    0.9545    0.8235        22
           1     0.9000    0.5294    0.6667        17

    accuracy                         0.7692        39
   macro avg     0.8121    0.7420    0.7451        39
weighted avg     0.8008    0.7692    0.7552        39


Fold 6
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20


<keras.callbacks.History at 0x1bd2d4b96c0>

Fold 6 Accuracy: 0.7179 | F1 Score: 0.6911
              precision    recall  f1-score   support

           0     0.6774    0.9545    0.7925        22
           1     0.8750    0.4118    0.5600        17

    accuracy                         0.7179        39
   macro avg     0.7762    0.6832    0.6762        39
weighted avg     0.7635    0.7179    0.6911        39


Fold 7
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x1bd2d4bbf10>

Fold 7 Accuracy: 0.8718 | F1 Score: 0.8673
              precision    recall  f1-score   support

           0     0.8148    1.0000    0.8980        22
           1     1.0000    0.7059    0.8276        17

    accuracy                         0.8718        39
   macro avg     0.9074    0.8529    0.8628        39
weighted avg     0.8955    0.8718    0.8673        39


Fold 8
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20


<keras.callbacks.History at 0x1bd2ec0fb20>

Fold 8 Accuracy: 0.7692 | F1 Score: 0.7552
              precision    recall  f1-score   support

           0     0.7241    0.9545    0.8235        22
           1     0.9000    0.5294    0.6667        17

    accuracy                         0.7692        39
   macro avg     0.8121    0.7420    0.7451        39
weighted avg     0.8008    0.7692    0.7552        39


Fold 9
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20


<keras.callbacks.History at 0x1bd2f3fa080>

Fold 9 Accuracy: 0.7179 | F1 Score: 0.6787
              precision    recall  f1-score   support

           0     0.6667    1.0000    0.8000        22
           1     1.0000    0.3529    0.5217        17

    accuracy                         0.7179        39
   macro avg     0.8333    0.6765    0.6609        39
weighted avg     0.8120    0.7179    0.6787        39


Fold 10
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20


<keras.callbacks.History at 0x1bd2cf834c0>

Fold 10 Accuracy: 0.7692 | F1 Score: 0.7552
              precision    recall  f1-score   support

           0     0.7241    0.9545    0.8235        22
           1     0.9000    0.5294    0.6667        17

    accuracy                         0.7692        39
   macro avg     0.8121    0.7420    0.7451        39
weighted avg     0.8008    0.7692    0.7552        39


Hold-Out Test Set Performance:
              precision    recall  f1-score   support

           0     0.7222    0.9455    0.8189        55
           1     0.8889    0.5455    0.6761        44

    accuracy                         0.7677        99
   macro avg     0.8056    0.7455    0.7475        99
weighted avg     0.7963    0.7677    0.7554        99


Average CV Accuracy: 0.7815
Average CV F1 Score: 0.7659


In [31]:
import numpy as np

print("Training set class distribution:")
unique_train, counts_train = np.unique(y_train_all, return_counts=True)
print(dict(zip(unique_train, counts_train)))

print("\nTest set class distribution:")
unique_test, counts_test = np.unique(y_test, return_counts=True)
print(dict(zip(unique_test, counts_test)))


Training set class distribution:
{0: 220, 1: 174}

Test set class distribution:
{0: 55, 1: 44}
