In [1]:
import os
import random
import gc, numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.utils import compute_class_weight
import tensorflow as tf
from keras.models import Model
from keras import backend as K
from keras.layers import Input, Dense, Dropout,Flatten, BatchNormalization, Conv2D, MultiHeadAttention, concatenate
from sklearn.metrics import classification_report
from tensorflow.keras.optimizers import Adam
from keras.models import Sequential
from tensorflow.keras.utils import to_categorical
import seaborn as sns
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import precision_recall_curve

In [38]:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.compat.v1.Session(config=config)

In [2]:
def make_img(t_img):
    img = pd.read_pickle(t_img)
    img_l = []
    for i in range(len(img)):
        img_l.append(img.values[i][0])
    
    return np.array(img_l)


def reset_random_seeds(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
   
               
def create_model_snp():
    
    model = Sequential()
    model.add(Dense(200,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(100, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.3))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))
    return model

def create_model_clinical():
    
    model = Sequential()
    model.add(Dense(128,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(128, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))    
    return model

def create_model_img():
    
    
    
    model = Sequential()
    model.add(Conv2D(72, (3, 3), activation='relu')) 
    model.add(Dropout(0.3))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(Dropout(0.3))
    model.add(Conv2D(32, (3, 3), activation='relu'))
    model.add(Dropout(0.3))
    model.add(Flatten())
    model.add(Dense(50, activation='relu'))   
    return model


def plot_classification_report(y_tru, y_prd, mode, learning_rate, batch_size,epochs, figsize=(7, 7), ax=None):

    plt.figure(figsize=figsize)

    xticks = ['precision', 'recall', 'f1-score', 'support']
    yticks = ["Control", "Moderate", "Alzheimer's" ] 
    yticks += ['avg']

    rep = np.array(precision_recall_fscore_support(y_tru, y_prd)).T
    avg = np.mean(rep, axis=0)
    avg[-1] = np.sum(rep[:, -1])
    rep = np.insert(rep, rep.shape[0], avg, axis=0)

    sns.heatmap(rep,
                annot=True, 
                cbar=False, 
                xticklabels=xticks, 
                yticklabels=yticks,
                ax=ax, cmap = "Blues")
    
    plt.savefig('report_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'_' + str(epochs)+'.png')
    


def calc_confusion_matrix(result, test_label,mode, learning_rate, batch_size, epochs):
    test_label = to_categorical(test_label,3)

    true_label= np.argmax(test_label, axis =1)

    predicted_label= np.argmax(result, axis =1)
    
    n_classes = 3
    precision = dict()
    recall = dict()
    thres = dict()
    for i in range(n_classes):
        precision[i], recall[i], thres[i] = precision_recall_curve(test_label[:, i],
                                                            result[:, i])


    print ("Classification Report :") 
    print (classification_report(true_label, predicted_label))
    cr = classification_report(true_label, predicted_label, output_dict=True)
    return cr, precision, recall, thres



def cross_modal_attention(x, y):
    x = tf.expand_dims(x, axis=1)
    y = tf.expand_dims(y, axis=1)
    a1 = MultiHeadAttention(num_heads = 4,key_dim=50)(x, y)
    a2 = MultiHeadAttention(num_heads = 4,key_dim=50)(y, x)
    a1 = a1[:,0,:]
    a2 = a2[:,0,:]
    return concatenate([a1, a2])


def self_attention(x):
    x = tf.expand_dims(x, axis=1)
    attention = MultiHeadAttention(num_heads = 4, key_dim=50)(x, x)
    attention = attention[:,0,:]
    return attention
    

def multi_modal_model(mode, train_clinical, train_snp, train_img):
    
    in_clinical = Input(shape=(train_clinical.shape[1]))
    
    in_snp = Input(shape=(train_snp.shape[1]))
    
    in_img = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))
    
    dense_clinical = create_model_clinical()(in_clinical)
    dense_snp = create_model_snp()(in_snp) 
    dense_img = create_model_img()(in_img) 
    
 
        
    ########### Attention Layer ############
        
    ## Cross Modal Bi-directional Attention ##

    if mode == 'MM_BA':
            
        vt_att = cross_modal_attention(dense_img, dense_clinical)
        av_att = cross_modal_attention(dense_snp, dense_img)
        ta_att = cross_modal_attention(dense_clinical, dense_snp)
                
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
                 
   
        
        
    ## Self Attention ##
    elif mode == 'MM_SA':
            
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
            
        merged = concatenate([aa_att, vv_att, tt_att, dense_img, dense_snp, dense_clinical])
        
    ## Self Attention and Cross Modal Bi-directional Attention##
    elif mode == 'MM_SA_BA':
            
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
        
        vt_att = cross_modal_attention(vv_att, tt_att)
        av_att = cross_modal_attention(aa_att, vv_att)
        ta_att = cross_modal_attention(tt_att, aa_att)
            
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
            
        
    ## No Attention ##    
    elif mode == 'None':
            
        merged = concatenate([dense_img, dense_snp, dense_clinical])
                
    else:
        print ("Mode must be one of 'MM_SA', 'MM_BA', 'MU_SA_BA' or 'None'.")
        return
                
        
    ########### Output Layer ############
        
    output = Dense(3, activation='softmax')(merged)
    model = Model([in_clinical, in_snp, in_img], output)        
        
    return model



def train(mode, batch_size, epochs, learning_rate, seed):
    
    # train_img = train_img.astype("float32")

    reset_random_seeds(seed)
    class_weights = compute_class_weight(class_weight = 'balanced',classes = np.unique(train_label),y = train_label)
    d_class_weights = dict(enumerate(class_weights))
    
    # compile model #
    model = multi_modal_model(mode, train_clinical, train_snp, train_img)
    model.compile(optimizer=Adam(learning_rate = learning_rate), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
    

    # summarize results
    history = model.fit([train_clinical,
                         train_snp,
                         train_img],
                        train_label,
                        epochs=epochs,
                        batch_size=batch_size,
                        class_weight=d_class_weights,
                        validation_split=0.1,
                        verbose=1)
                        
                

    score = model.evaluate([test_clinical, test_snp, test_img], test_label)
    
    acc = score[1] 
    test_predictions = model.predict([test_clinical, test_snp, test_img])
    cr, precision_d, recall_d, thres = calc_confusion_matrix(test_predictions, test_label, mode, learning_rate, batch_size, epochs)
    
    
    """
    plt.clf()
    plt.plot(history.history['sparse_categorical_accuracy'])
    plt.plot(history.history['val_sparse_categorical_accuracy'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
    plt.savefig('accuracy_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'.png')
    plt.clf()
    # summarize history for loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
    plt.savefig('loss_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'.png')
    plt.clf()
    """
    
 
    
    # release gpu memory #
    K.clear_session()
    del model, history
    gc.collect()
        
        
    print ('Mode: ', mode)
    print ('Batch size:  ', batch_size)
    print ('Learning rate: ', learning_rate)
    print ('Epochs:  ', epochs)
    print ('Test Accuracy:', '{0:.4f}'.format(acc))
    print ('-'*55)
    
    return acc, batch_size, learning_rate, epochs, seed
    

In [40]:
train_clinical = pd.read_csv("../preprocess_overlap/X_train_clinical.csv").values
test_clinical= pd.read_csv("../preprocess_overlap/X_test_clinical.csv").values


In [41]:
train_clinical

array([[True, False, False, ..., False, False, False],
       [True, False, False, ..., True, False, False],
       [True, False, False, ..., True, False, False],
       ...,
       [True, False, False, ..., False, False, False],
       [True, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]], dtype=object)

In [6]:

train_clinical = pd.read_csv("../preprocess_overlap/X_train_clinical.csv").values
test_clinical= pd.read_csv("../preprocess_overlap/X_test_clinical.csv").values


train_snp = pd.read_csv("../preprocess_overlap/X_train_snp.csv").values
test_snp = pd.read_csv("../preprocess_overlap/X_test_snp.csv").values


train_img= make_img("../preprocess_overlap/X_train_img.pkl")
test_img= make_img("../preprocess_overlap/X_test_img.pkl")


train_label= pd.read_csv("../preprocess_overlap/y_train.csv").values.astype("int").flatten()
test_label= pd.read_csv("../preprocess_overlap/y_test.csv").values.astype("int").flatten()

train_clinical = train_clinical.astype("float32")
test_clinical = test_clinical.astype("float32")
# train_snp = train_snp.astype("float32")
# train_snp = test_snp.astype("float32")

In [7]:
def train(mode, batch_size, epochs, learning_rate, seed):
    reset_random_seeds(seed)
    class_weights = compute_class_weight(class_weight = 'balanced',classes = np.unique(train_label),y = train_label)
    d_class_weights = dict(enumerate(class_weights))
    
    # compile model #
    model = multi_modal_model(mode, train_clinical, train_snp, train_img)
    model.compile(optimizer=Adam(learning_rate = learning_rate), 
                 loss='sparse_categorical_crossentropy', 
                 metrics=['sparse_categorical_accuracy'])

    # Model Checkpoint
    checkpoint_dir = 'best_model'
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
        
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(checkpoint_dir, 'best_model.h5'),
        monitor='val_sparse_categorical_accuracy',
        save_best_only=True,
        save_weights_only=False,
        mode='max',
        verbose=1
    )

    # summarize results
    history = model.fit([train_clinical,
                        train_snp,
                        train_img],
                        train_label,
                        epochs=epochs,
                        batch_size=batch_size,
                        class_weight=d_class_weights,
                        validation_split=0.1,
                        callbacks=[checkpoint_callback],
                        verbose=1)
                        
    score = model.evaluate([test_clinical, test_snp, test_img], test_label)
    acc = score[1] 
    test_predictions = model.predict([test_clinical, test_snp, test_img])
    cr, precision_d, recall_d, thres = calc_confusion_matrix(test_predictions, test_label, mode, learning_rate, batch_size, epochs)
    
    # release gpu memory #
    K.clear_session()
    del model, history
    gc.collect()
        
    print ('Mode: ', mode)
    print ('Batch size:  ', batch_size)
    print ('Learning rate: ', learning_rate)
    print ('Epochs:  ', epochs)
    print ('Test Accuracy:', '{0:.4f}'.format(acc))
    print ('-'*55)
    
    return acc, batch_size, learning_rate, epochs, seed

# Save best model
def find_best_model():
    m_a = {}
    seeds = random.sample(range(1, 200), 1)
    best_acc = 0
    best_model = None
    
    for s in seeds:
        acc, bs_, lr_, e_, seed = train('MM_SA_BA', 8, 10, 0.001, s)
        m_a[acc] = ('MM_SA_BA', acc, bs_, lr_, e_, seed)
        
        if acc > best_acc:
            best_acc = acc
            # Save the better model
            if os.path.exists('best_model/best_model.h5'):
                best_model = tf.keras.models.load_model('best_model/best_model.h5')
                best_model.save(f'best_model/final_best_model_acc_{acc:.4f}.h5')
    
    print(m_a)
    print('-'*55)
    
    max_acc = max(m_a, key=float)
    print("Highest accuracy of: " + str(max_acc) + " with parameters: " + str(m_a[max_acc]))
    
    return max_acc, m_a[max_acc]


# Train and find the best model
best_acc, best_params = find_best_model()

# Load the best model
best_model = tf.keras.models.load_model(f'best_model/final_best_model_acc_{best_acc:.4f}.h5')




Epoch 1/10
Epoch 1: val_sparse_categorical_accuracy improved from -inf to 0.42857, saving model to best_model/best_model.h5


  saving_api.save_model(


Epoch 2/10
Epoch 2: val_sparse_categorical_accuracy improved from 0.42857 to 0.50000, saving model to best_model/best_model.h5
Epoch 3/10
Epoch 3: val_sparse_categorical_accuracy did not improve from 0.50000
Epoch 4/10
Epoch 4: val_sparse_categorical_accuracy did not improve from 0.50000
Epoch 5/10
Epoch 5: val_sparse_categorical_accuracy did not improve from 0.50000
Epoch 6/10
Epoch 6: val_sparse_categorical_accuracy improved from 0.50000 to 0.57143, saving model to best_model/best_model.h5
Epoch 7/10
Epoch 7: val_sparse_categorical_accuracy improved from 0.57143 to 0.78571, saving model to best_model/best_model.h5
Epoch 8/10
Epoch 8: val_sparse_categorical_accuracy did not improve from 0.78571
Epoch 9/10
Epoch 9: val_sparse_categorical_accuracy did not improve from 0.78571
Epoch 10/10
Epoch 10: val_sparse_categorical_accuracy did not improve from 0.78571
Classification Report :
              precision    recall  f1-score   support

           0       1.00      0.79      0.88        2

  saving_api.save_model(


{0.8571428656578064: ('MM_SA_BA', 0.8571428656578064, 8, 0.001, 10, 36)}
-------------------------------------------------------
Highest accuracy of: 0.8571428656578064 with parameters: ('MM_SA_BA', 0.8571428656578064, 8, 0.001, 10, 36)


In [61]:
m_a = {}
seeds = random.sample(range(1, 200), 1)
for s in seeds:
    acc, bs_, lr_, e_ , seed= train('MM_SA_BA', 8, 10, 0.001, s)
    m_a[acc] = ('MM_SA_BA', acc, bs_, lr_, e_, seed)
print(m_a)
print ('-'*55)
max_acc = max(m_a, key=float)
print("Highest accuracy of: " + str(max_acc) + " with parameters: " + str(m_a[max_acc]))





Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Classification Report :
              precision    recall  f1-score   support

           0       1.00      0.79      0.88        24
           1       0.43      0.75      0.55         4
           2       0.78      1.00      0.88         7

    accuracy                           0.83        35
   macro avg       0.74      0.85      0.77        35
weighted avg       0.89      0.83      0.84        35

Mode:  MM_SA_BA
Batch size:   8
Learning rate:  0.001
Epochs:   10
Test Accuracy: 0.8286
-------------------------------------------------------
{0.8285714387893677: ('MM_SA_BA', 0.8285714387893677, 8, 0.001, 10, 146)}
-------------------------------------------------------
Highest accuracy of: 0.8285714387893677 with parameters: ('MM_SA_BA', 0.8285714387893677, 8, 0.001, 10, 146)


In [4]:
import tensorflow as tf
from tensorflow.keras.models import load_model
import os
import json

def train(mode, batch_size, epochs, learning_rate, seed):
    reset_random_seeds(seed)
    
    # Calculate weights
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(train_label),
        y=train_label
    )
    d_class_weights = dict(enumerate(class_weights))
    
    # Create model
    model = multi_modal_model(mode, train_clinical, train_snp, train_img)
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='sparse_categorical_crossentropy',
        metrics=['sparse_categorical_accuracy']
    )
    
    # Callback
    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            mode='min'
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-6,
            mode='min'
        )
    ]
    
    # Train Model
    history = model.fit(
        [train_clinical, train_snp, train_img],
        train_label,
        epochs=epochs,
        batch_size=batch_size,
        class_weight=d_class_weights,
        validation_split=0.2,
        callbacks=callbacks,
        verbose=1
    )
    
    # Evaluation
    score = model.evaluate([test_clinical, test_snp, test_img], test_label)
    acc = score[1]
    
    # Save model and para
    model_params = {
        'mode': mode,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'epochs': epochs,
        'seed': seed,
        'accuracy': acc
    }
    
    # Create directory
    save_dir = 'saved_models'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Save
    model_name = f'model_acc_{acc:.4f}_seed_{seed}'
    model_path = os.path.join(save_dir, model_name)
    os.makedirs(model_path, exist_ok=True)
    
    # Save
    model.save(os.path.join(model_path, 'model.h5'))
    
    # Save
    with open(os.path.join(model_path, 'params.json'), 'w') as f:
        json.dump(model_params, f)
    
    # Clean Cache
    K.clear_session()
    del model, history
    gc.collect()
    
    return acc, batch_size, learning_rate, epochs, seed

# Best Model
def find_best_model():
    m_a = {}
    seeds = random.sample(range(1, 200), 1)
    
    for s in seeds:
        acc, bs_, lr_, e_, seed = train('MM_SA_BA', 8, 10, 0.001, s)
        m_a[acc] = ('MM_SA_BA', acc, bs_, lr_, e_, seed)
    
    print(m_a)
    print('-'*55)
    
    max_acc = max(m_a, key=float)
    print("Highest accuracy of: " + str(max_acc) + " with parameters: " + str(m_a[max_acc]))
    
    return max_acc, m_a[max_acc]

# Save
def load_best_model(acc, seed):
    model_name = f'model_acc_{acc:.4f}_seed_{seed}'
    model_path = os.path.join('saved_models', model_name)
    
    # Load
    model = load_model(os.path.join(model_path, 'model.h5'))
    
    # Load
    with open(os.path.join(model_path, 'params.json'), 'r') as f:
        params = json.load(f)
    
    return model, params


# Train and save the best
best_acc, best_params = find_best_model()

# Load
best_model, model_params = load_best_model(best_acc, best_params[-1])

# Predict
predictions = best_model.predict([test_clinical, test_snp, test_img])


NameError: name 'train_label' is not defined

In [29]:
train_snp = pd.read_csv("../preprocess_overlap/X_train_snp.csv").values
test_snp = pd.read_csv("../preprocess_overlap/X_test_snp.csv").values
train_snp

array([[0, 1, 0, ..., 0, 1, 1],
       [0, 2, 0, ..., 0, 2, 2],
       [0, 2, 0, ..., 0, 2, 2],
       ...,
       [0, 2, 0, ..., 0, 2, 2],
       [0, 2, 0, ..., 0, 2, 2],
       [0, 2, 0, ..., 0, 2, 2]])

In [30]:


print("train_clinical shape:", train_clinical.shape)
print("train_snp shape:", train_snp.shape)
print("train_img shape:", train_img.shape)
print("train_label shape:", train_label.shape)

train_clinical shape: (71, 149)
train_snp shape: (71, 179666)
train_img shape: (71, 72, 72, 3)
train_label shape: (71,)


In [36]:
import numpy as np
print("Train Dataset:", np.bincount(train_label))
print("Valid Dataset:", np.bincount(test_label))

Train Dataset: [41 30]
Valid Dataset: [5 3]


In [62]:
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Input, concatenate
from tensorflow.keras.layers import Conv2D, Flatten, MultiHeadAttention
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, precision_recall_fscore_support, precision_recall_curve
from tensorflow.keras import backend as K
import numpy as np
import gc
import os
import random

def reset_random_seeds(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def create_model_snp():
    model = Sequential([
        Dense(100, activation="relu"),
        BatchNormalization(),
        Dropout(0.5),
        Dense(50, activation="relu"),
        BatchNormalization(),
        Dropout(0.3)
    ])
    return model

def create_model_clinical():
    model = Sequential([
        Dense(128, activation="relu"),
        BatchNormalization(),
        Dropout(0.2),
        Dense(128, activation="relu"),
        BatchNormalization(),
        Dropout(0.2),
        Dense(50, activation="relu"),
        BatchNormalization(),
        Dropout(0.2)
    ])
    return model

def create_model_img():
    model = Sequential([
        Conv2D(64, (3, 3), activation='relu'),
        BatchNormalization(),
        Dropout(0.4),
        Conv2D(32, (3, 3), activation='relu'),
        BatchNormalization(),
        Dropout(0.4),
        Flatten(),
        Dense(50, activation='relu'),
        BatchNormalization(),
        Dropout(0.3)
    ])
    return model

def simplified_attention(x, y):
    """简化的注意力机制"""
    x = tf.expand_dims(x, axis=1)
    y = tf.expand_dims(y, axis=1)
    attention = MultiHeadAttention(num_heads=2, key_dim=25)(x, y)
    return attention[:,0,:]

def multi_modal_model(mode, train_clinical, train_snp, train_img):
    in_clinical = Input(shape=(train_clinical.shape[1]))
    in_snp = Input(shape=(train_snp.shape[1]))
    in_img = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))
    
    dense_clinical = create_model_clinical()(in_clinical)
    dense_snp = create_model_snp()(in_snp)
    dense_img = create_model_img()(in_img)
    
    if mode == 'MM_SA_BA':
        # Only save this layer
        img_clinical_att = simplified_attention(dense_img, dense_clinical)
        snp_clinical_att = simplified_attention(dense_snp, dense_clinical)
        merged = concatenate([img_clinical_att, snp_clinical_att, dense_img, dense_snp, dense_clinical])
    else:
        merged = concatenate([dense_img, dense_snp, dense_clinical])
    
    # Add additional layer after concatenation 
    merged = Dense(100, activation='relu')(merged)
    merged = BatchNormalization()(merged)
    merged = Dropout(0.5)(merged)
    merged = Dense(50, activation='relu')(merged)
    merged = BatchNormalization()(merged)
    merged = Dropout(0.3)(merged)
    
    # Output
    output = Dense(3, activation='softmax')(merged)
    model = Model([in_clinical, in_snp, in_img], output)
    
    return model

def train(mode, batch_size, epochs, learning_rate, seed):
    reset_random_seeds(seed)
    
    # 
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(train_label),
        y=train_label
    )
    d_class_weights = dict(enumerate(class_weights))
    
    # 
    model = multi_modal_model(mode, train_clinical, train_snp, train_img)
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='sparse_categorical_crossentropy',
        metrics=['sparse_categorical_accuracy']
    )
    
    # Callback
    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            mode='min'
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-6,
            mode='min'
        )
    ]
    
    # Train
    history = model.fit(
        [train_clinical, train_snp, train_img],
        train_label,
        epochs=epochs,
        batch_size=batch_size,
        class_weight=d_class_weights,
        validation_split=0.2,  # Valid
        callbacks=callbacks,
        verbose=1
    )
    
    # Evaluation
    score = model.evaluate([test_clinical, test_snp, test_img], test_label)
    acc = score[1]
    test_predictions = model.predict([test_clinical, test_snp, test_img])
    
    # 
    pred_labels = np.argmax(test_predictions, axis=1)
    print("\nClassification Report:")
    print(classification_report(test_label, pred_labels))
    
    # 
    K.clear_session()
    del model, history
    gc.collect()
    
    print(f'Mode: {mode}')
    print(f'Batch size: {batch_size}')
    print(f'Learning rate: {learning_rate}')
    print(f'Epochs: {epochs}')
    print(f'Test Accuracy: {acc:.4f}')
    print('-'*55)
    
    return acc, batch_size, learning_rate, epochs, seed

# 
"""
results = train(
    mode='MM_SA_BA',
    batch_size=16,  # maximize batch size
    epochs=100,
    learning_rate=0.001,
    seed=42
)
"""

"\nresults = train(\n    mode='MM_SA_BA',\n    batch_size=16,  # 增大batch size\n    epochs=100,\n    learning_rate=0.001,\n    seed=42\n)\n"

In [63]:
m_a = {}
seeds = random.sample(range(1, 200), 1)
for s in seeds:
    acc, bs_, lr_, e_ , seed= train('MM_SA_BA', 8, 50, 0.001, s)
    m_a[acc] = ('MM_SA_BA', acc, bs_, lr_, e_, seed)
print(m_a)
print ('-'*55)
max_acc = max(m_a, key=float)
print("Highest accuracy of: " + str(max_acc) + " with parameters: " + str(m_a[max_acc]))





Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50

Classification Report:
              precision    recall  f1-score   support

           0       0.95      0.88      0.91        24
           1       0.40      0.50      0.44         4
           2       0.88      1.00      0.93         7

    accuracy                           0.86        35
   macro avg       0.74      0.79      0.76        35
weighted avg       0.88      0.86      0.86        35

Mode: MM_SA_BA
Batch size: 8
Learning rate: 0.001
Epochs: 50
Test Accuracy: 0.8571
-------------------------------------------------------
{0.8571428656578064: ('MM_SA_BA', 0.8571428656578064, 8, 0.001, 50, 100)}
-------------------------------------------------------
Highest accuracy of: 0.8571428656578064 with parameters: ('MM_SA_BA', 0.85714286565