In [12]:
import os
import random
import gc, numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.utils import compute_class_weight
import tensorflow as tf
from keras.models import Model
from keras import backend as K
from keras.layers import Input, Dense, Dropout,Flatten, BatchNormalization, Conv2D, MultiHeadAttention, concatenate
from sklearn.metrics import classification_report
from tensorflow.keras.optimizers import Adam
from keras.models import Sequential
from tensorflow.keras.utils import to_categorical
import seaborn as sns
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import precision_recall_curve
from tensorflow.keras.layers import (
    Conv2D, 
    BatchNormalization, 
    Activation, 
    MaxPooling2D, 
    GlobalAveragePooling2D, 
    Dense, 
    Input, 
    Add, 
    Dropout
)
from tensorflow.keras.models import Model

In [3]:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.compat.v1.Session(config=config)

## IMG - Resnet - acc: 0.83

In [None]:
def make_img(t_img):
    img = pd.read_pickle(t_img)
    img_l = []
    for i in range(len(img)):
        img_l.append(img.values[i][0])
    
    return np.array(img_l)


def reset_random_seeds(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
   
               
def create_model_snp():
    
    model = Sequential()
    model.add(Dense(200,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.3))
    model.add(Dense(100, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))
    return model

def create_model_clinical():
    
    model = Sequential()
    model.add(Dense(128,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.2))
    model.add(Dense(128, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))    
    return model

def residual_block(x, filters, kernel_size=3, stride=1, conv_shortcut=True):
    shortcut = x
    
    if conv_shortcut:
        shortcut = Conv2D(filters, 1, strides=stride)(shortcut)
        shortcut = BatchNormalization()(shortcut)
    
    x = Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(filters, kernel_size, padding='same')(x)
    x = BatchNormalization()(x)
    
    x = Add()([shortcut, x])
    x = Activation('relu')(x)
    
    return x

def create_model_img():
    inputs = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))
    
    # 初始卷积层
    x = Conv2D(64, 7, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(3, strides=2, padding='same')(x)
    
    # ResNet blocks
    x = residual_block(x, 64)
    x = Dropout(0.3)(x)
    x = residual_block(x, 64)
    
    x = residual_block(x, 128, stride=2)
    x = Dropout(0.3)(x)
    x = residual_block(x, 128)
    
    x = residual_block(x, 256, stride=2)
    x = Dropout(0.3)(x)
    x = residual_block(x, 256)
    
    # 全局平均池化
    x = GlobalAveragePooling2D()(x)
    
    # 全连接层
    x = Dense(512, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)
    x = Dense(50, activation='relu')(x)
    
    model = Model(inputs=inputs, outputs=x)
    return model


def plot_classification_report(y_tru, y_prd, mode, learning_rate, batch_size,epochs, figsize=(7, 7), ax=None):

    plt.figure(figsize=figsize)

    xticks = ['precision', 'recall', 'f1-score', 'support']
    yticks = ["Control", "Moderate", "Alzheimer's" ] 
    yticks += ['avg']

    rep = np.array(precision_recall_fscore_support(y_tru, y_prd)).T
    avg = np.mean(rep, axis=0)
    avg[-1] = np.sum(rep[:, -1])
    rep = np.insert(rep, rep.shape[0], avg, axis=0)

    sns.heatmap(rep,
                annot=True, 
                cbar=False, 
                xticklabels=xticks, 
                yticklabels=yticks,
                ax=ax, cmap = "Blues")
    
    plt.savefig('report_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'_' + str(epochs)+'.png')
    


def calc_confusion_matrix(result, test_label,mode, learning_rate, batch_size, epochs):
    test_label = to_categorical(test_label,3)

    true_label= np.argmax(test_label, axis =1)

    predicted_label= np.argmax(result, axis =1)
    
    n_classes = 3
    precision = dict()
    recall = dict()
    thres = dict()
    for i in range(n_classes):
        precision[i], recall[i], thres[i] = precision_recall_curve(test_label[:, i],
                                                            result[:, i])


    print ("Classification Report :") 
    print (classification_report(true_label, predicted_label))
    cr = classification_report(true_label, predicted_label, output_dict=True)
    return cr, precision, recall, thres



def cross_modal_attention(x, y):
    x = tf.expand_dims(x, axis=1)
    y = tf.expand_dims(y, axis=1)
    a1 = MultiHeadAttention(num_heads = 4,key_dim=50)(x, y)
    a2 = MultiHeadAttention(num_heads = 4,key_dim=50)(y, x)
    a1 = a1[:,0,:]
    a2 = a2[:,0,:]
    return concatenate([a1, a2])


def self_attention(x):
    x = tf.expand_dims(x, axis=1)
    attention = MultiHeadAttention(num_heads = 4, key_dim=50)(x, x)
    attention = attention[:,0,:]
    return attention
    

def multi_modal_model(mode, train_clinical, train_snp, train_img):
    
    in_clinical = Input(shape=(train_clinical.shape[1]))
    
    in_snp = Input(shape=(train_snp.shape[1]))
    
    in_img = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))
    
    dense_clinical = create_model_clinical()(in_clinical)
    dense_snp = create_model_snp()(in_snp) 
    dense_img = create_model_img()(in_img) 
    
 
        
    ########### Attention Layer ############
        
    ## Cross Modal Bi-directional Attention ##

    if mode == 'MM_BA':
            
        vt_att = cross_modal_attention(dense_img, dense_clinical)
        av_att = cross_modal_attention(dense_snp, dense_img)
        ta_att = cross_modal_attention(dense_clinical, dense_snp)
                
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
                 
   
        
        
    ## Self Attention ##
    elif mode == 'MM_SA':
            
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
            
        merged = concatenate([aa_att, vv_att, tt_att, dense_img, dense_snp, dense_clinical])
        
    ## Self Attention and Cross Modal Bi-directional Attention##
    elif mode == 'MM_SA_BA':
            
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
        
        vt_att = cross_modal_attention(vv_att, tt_att)
        av_att = cross_modal_attention(aa_att, vv_att)
        ta_att = cross_modal_attention(tt_att, aa_att)
            
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
            
        
    ## No Attention ##    
    elif mode == 'None':
            
        merged = concatenate([dense_img, dense_snp, dense_clinical])
                
    else:
        print ("Mode must be one of 'MM_SA', 'MM_BA', 'MU_SA_BA' or 'None'.")
        return
                
        
    ########### Output Layer ############
        
    output = Dense(3, activation='softmax')(merged)
    model = Model([in_clinical, in_snp, in_img], output)        
        
    return model



def train(mode, batch_size, epochs, learning_rate, seed):
    
    # train_img = train_img.astype("float32")

    reset_random_seeds(seed)
    class_weights = compute_class_weight(class_weight = 'balanced',classes = np.unique(train_label),y = train_label)
    d_class_weights = dict(enumerate(class_weights))
    
    # compile model #
    model = multi_modal_model(mode, train_clinical, train_snp, train_img)
    model.compile(optimizer=Adam(learning_rate = learning_rate), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
    

    # summarize results
    history = model.fit([train_clinical,
                         train_snp,
                         train_img],
                        train_label,
                        epochs=epochs,
                        batch_size=batch_size,
                        class_weight=d_class_weights,
                        validation_split=0.1,
                        verbose=1)
                        
                

    score = model.evaluate([test_clinical, test_snp, test_img], test_label)
    
    acc = score[1] 
    test_predictions = model.predict([test_clinical, test_snp, test_img])
    cr, precision_d, recall_d, thres = calc_confusion_matrix(test_predictions, test_label, mode, learning_rate, batch_size, epochs)
    
    
    """
    plt.clf()
    plt.plot(history.history['sparse_categorical_accuracy'])
    plt.plot(history.history['val_sparse_categorical_accuracy'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
    plt.savefig('accuracy_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'.png')
    plt.clf()
    # summarize history for loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
    plt.savefig('loss_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'.png')
    plt.clf()
    """
    
 
    
    # release gpu memory #
    K.clear_session()
    del model, history
    gc.collect()
        
        
    print ('Mode: ', mode)
    print ('Batch size:  ', batch_size)
    print ('Learning rate: ', learning_rate)
    print ('Epochs:  ', epochs)
    print ('Test Accuracy:', '{0:.4f}'.format(acc))
    print ('-'*55)
    
    return acc, batch_size, learning_rate, epochs, seed
    

In [5]:
train_clinical = pd.read_csv("../preprocess_overlap/X_train_clinical.csv").values
test_clinical= pd.read_csv("../preprocess_overlap/X_test_clinical.csv").values


In [6]:
train_clinical

array([[True, False, False, ..., False, False, False],
       [True, False, False, ..., True, False, False],
       [True, False, False, ..., True, False, False],
       ...,
       [True, False, False, ..., False, False, False],
       [True, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]], dtype=object)

In [13]:

train_clinical = pd.read_csv("../preprocess_overlap/X_train_clinical.csv").values
test_clinical= pd.read_csv("../preprocess_overlap/X_test_clinical.csv").values


train_snp = pd.read_csv("../preprocess_overlap/X_train_snp.csv").values
test_snp = pd.read_csv("../preprocess_overlap/X_test_snp.csv").values


train_img= make_img("../preprocess_overlap/X_train_img.pkl")
test_img= make_img("../preprocess_overlap/X_test_img.pkl")


train_label= pd.read_csv("../preprocess_overlap/y_train.csv").values.astype("int").flatten()
test_label= pd.read_csv("../preprocess_overlap/y_test.csv").values.astype("int").flatten()

train_clinical = train_clinical.astype("float32")
test_clinical = test_clinical.astype("float32")
# train_snp = train_snp.astype("float32")
# train_snp = test_snp.astype("float32")

In [31]:
m_a = {}
seeds = random.sample(range(1, 200), 1)
for s in seeds:
    acc, bs_, lr_, e_ , seed= train('MM_SA_BA', 8, 18, 0.001, 45)
    m_a[acc] = ('MM_SA_BA', acc, bs_, lr_, e_, seed)
print(m_a)
print ('-'*55)
max_acc = max(m_a, key=float)
print("Highest accuracy of: " + str(max_acc) + " with parameters: " + str(m_a[max_acc]))





Epoch 1/18
Epoch 2/18
Epoch 3/18
Epoch 4/18
Epoch 5/18
Epoch 6/18
Epoch 7/18
Epoch 8/18
Epoch 9/18
Epoch 10/18
Epoch 11/18
Epoch 12/18
Epoch 13/18
Epoch 14/18
Epoch 15/18
Epoch 16/18
Epoch 17/18
Epoch 18/18
Classification Report :
              precision    recall  f1-score   support

           0       0.88      0.96      0.92        24
           1       0.00      0.00      0.00         4
           2       0.78      1.00      0.88         7

    accuracy                           0.86        35
   macro avg       0.55      0.65      0.60        35
weighted avg       0.76      0.86      0.81        35



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Mode:  MM_SA_BA
Batch size:   8
Learning rate:  0.001
Epochs:   18
Test Accuracy: 0.8571
-------------------------------------------------------
{0.8571428656578064: ('MM_SA_BA', 0.8571428656578064, 8, 0.001, 18, 45)}
-------------------------------------------------------
Highest accuracy of: 0.8571428656578064 with parameters: ('MM_SA_BA', 0.8571428656578064, 8, 0.001, 18, 45)


In [29]:
train_snp = pd.read_csv("../preprocess_overlap/X_train_snp.csv").values
test_snp = pd.read_csv("../preprocess_overlap/X_test_snp.csv").values
train_snp

array([[0, 1, 0, ..., 0, 1, 1],
       [0, 2, 0, ..., 0, 2, 2],
       [0, 2, 0, ..., 0, 2, 2],
       ...,
       [0, 2, 0, ..., 0, 2, 2],
       [0, 2, 0, ..., 0, 2, 2],
       [0, 2, 0, ..., 0, 2, 2]])

In [30]:


print("train_clinical shape:", train_clinical.shape)
print("train_snp shape:", train_snp.shape)
print("train_img shape:", train_img.shape)
print("train_label shape:", train_label.shape)

train_clinical shape: (71, 149)
train_snp shape: (71, 179666)
train_img shape: (71, 72, 72, 3)
train_label shape: (71,)


In [36]:
import numpy as np
print("训练集类别分布:", np.bincount(train_label))
print("验证集类别分布:", np.bincount(test_label))

训练集类别分布: [41 30]
验证集类别分布: [5 3]


## GoogleNet 

In [25]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

def make_img(t_img):
    img = pd.read_pickle(t_img)
    img_l = []
    for i in range(len(img)):
        img_l.append(img.values[i][0])
    
    return np.array(img_l)


def reset_random_seeds(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
   
               
def create_model_snp():
    
    model = Sequential()
    model.add(Dense(200,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(100, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.3))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))
    return model

def create_model_clinical():
    
    model = Sequential()
    model.add(Dense(200,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(100, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))    
    return model

def inception_module(x, filters_1x1, filters_3x3_reduce, filters_3x3, filters_5x5_reduce, filters_5x5, filters_pool):
    # 1x1 convolution branch
    conv1x1 = Conv2D(filters_1x1, (1, 1), padding='same')(x)
    conv1x1 = BatchNormalization()(conv1x1)
    conv1x1 = Activation('relu')(conv1x1)
    
    # 3x3 convolution branch
    conv3x3 = Conv2D(filters_3x3_reduce, (1, 1), padding='same')(x)
    conv3x3 = BatchNormalization()(conv3x3)
    conv3x3 = Activation('relu')(conv3x3)
    conv3x3 = Conv2D(filters_3x3, (3, 3), padding='same')(conv3x3)
    conv3x3 = BatchNormalization()(conv3x3)
    conv3x3 = Activation('relu')(conv3x3)
    
    # 5x5 convolution branch
    conv5x5 = Conv2D(filters_5x5_reduce, (1, 1), padding='same')(x)
    conv5x5 = BatchNormalization()(conv5x5)
    conv5x5 = Activation('relu')(conv5x5)
    conv5x5 = Conv2D(filters_5x5, (5, 5), padding='same')(conv5x5)
    conv5x5 = BatchNormalization()(conv5x5)
    conv5x5 = Activation('relu')(conv5x5)
    
    # Max pooling branch
    pool = MaxPooling2D((3, 3), strides=(1, 1), padding='same')(x)
    pool = Conv2D(filters_pool, (1, 1), padding='same')(pool)
    pool = BatchNormalization()(pool)
    pool = Activation('relu')(pool)
    
    # Concatenate all branches
    output = concatenate([conv1x1, conv3x3, conv5x5, pool], axis=-1)
    return output

def create_model_img():
    inputs = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))
    
    # Initial convolution
    x = Conv2D(64, 7, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(3, strides=2, padding='same')(x)
    
    # First Inception modules
    x = inception_module(x, 64, 96, 128, 16, 32, 32)
    x = Dropout(0.3)(x)
    x = inception_module(x, 128, 128, 192, 32, 96, 64)
    x = MaxPooling2D(3, strides=2, padding='same')(x)
    
    # Middle Inception modules
    x = inception_module(x, 192, 96, 208, 16, 48, 64)
    x = Dropout(0.3)(x)
    x = inception_module(x, 160, 112, 224, 24, 64, 64)
    x = inception_module(x, 128, 128, 256, 24, 64, 64)
    x = Dropout(0.3)(x)
    x = inception_module(x, 112, 144, 288, 32, 64, 64)
    
    # Global average pooling
    x = GlobalAveragePooling2D()(x)
    
    # Dense layers
    x = Dense(512, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)
    x = Dense(50, activation='relu')(x)
    
    model = Model(inputs=inputs, outputs=x)
    return model

def plot_classification_report(y_tru, y_prd, mode, learning_rate, batch_size,epochs, figsize=(7, 7), ax=None):

    plt.figure(figsize=figsize)

    xticks = ['precision', 'recall', 'f1-score', 'support']
    yticks = ["Control", "Moderate", "Alzheimer's" ] 
    yticks += ['avg']

    rep = np.array(precision_recall_fscore_support(y_tru, y_prd)).T
    avg = np.mean(rep, axis=0)
    avg[-1] = np.sum(rep[:, -1])
    rep = np.insert(rep, rep.shape[0], avg, axis=0)

    sns.heatmap(rep,
                annot=True, 
                cbar=False, 
                xticklabels=xticks, 
                yticklabels=yticks,
                ax=ax, cmap = "Blues")
    
    plt.savefig('report_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'_' + str(epochs)+'.png')
    


def calc_confusion_matrix(result, test_label,mode, learning_rate, batch_size, epochs):
    test_label = to_categorical(test_label,3)

    true_label= np.argmax(test_label, axis =1)

    predicted_label= np.argmax(result, axis =1)
    
    n_classes = 3
    precision = dict()
    recall = dict()
    thres = dict()
    for i in range(n_classes):
        precision[i], recall[i], thres[i] = precision_recall_curve(test_label[:, i],
                                                            result[:, i])


    print ("Classification Report :") 
    print (classification_report(true_label, predicted_label))
    cr = classification_report(true_label, predicted_label, output_dict=True)
    return cr, precision, recall, thres



def cross_modal_attention(x, y):
    x = tf.expand_dims(x, axis=1)
    y = tf.expand_dims(y, axis=1)
    a1 = MultiHeadAttention(num_heads = 4,key_dim=50)(x, y)
    a2 = MultiHeadAttention(num_heads = 4,key_dim=50)(y, x)
    a1 = a1[:,0,:]
    a2 = a2[:,0,:]
    return concatenate([a1, a2])


def self_attention(x):
    x = tf.expand_dims(x, axis=1)
    attention = MultiHeadAttention(num_heads = 4, key_dim=50)(x, x)
    attention = attention[:,0,:]
    return attention
    

def multi_modal_model(mode, train_clinical, train_snp, train_img):
    
    in_clinical = Input(shape=(train_clinical.shape[1]))
    
    in_snp = Input(shape=(train_snp.shape[1]))
    
    in_img = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))
    
    dense_clinical = create_model_clinical()(in_clinical)
    dense_snp = create_model_snp()(in_snp) 
    dense_img = create_model_img()(in_img) 
    
 
        
    ########### Attention Layer ############
        
    ## Cross Modal Bi-directional Attention ##

    if mode == 'MM_BA':
            
        vt_att = cross_modal_attention(dense_img, dense_clinical)
        av_att = cross_modal_attention(dense_snp, dense_img)
        ta_att = cross_modal_attention(dense_clinical, dense_snp)
                
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
                 
   
        
        
    ## Self Attention ##
    elif mode == 'MM_SA':
            
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
            
        merged = concatenate([aa_att, vv_att, tt_att, dense_img, dense_snp, dense_clinical])
        
    ## Self Attention and Cross Modal Bi-directional Attention##
    elif mode == 'MM_SA_BA':
            
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
        
        vt_att = cross_modal_attention(vv_att, tt_att)
        av_att = cross_modal_attention(aa_att, vv_att)
        ta_att = cross_modal_attention(tt_att, aa_att)
            
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
            
        
    ## No Attention ##    
    elif mode == 'None':
            
        merged = concatenate([dense_img, dense_snp, dense_clinical])
                
    else:
        print ("Mode must be one of 'MM_SA', 'MM_BA', 'MU_SA_BA' or 'None'.")
        return
                
        
    ########### Output Layer ############
        
    output = Dense(3, activation='softmax')(merged)
    model = Model([in_clinical, in_snp, in_img], output)        
        
    return model



def train(mode, batch_size, epochs, learning_rate, seed):
    
    # train_img = train_img.astype("float32")

    reset_random_seeds(seed)
    class_weights = compute_class_weight(class_weight = 'balanced',classes = np.unique(train_label),y = train_label)
    d_class_weights = dict(enumerate(class_weights))
    
    # compile model #
    model = multi_modal_model(mode, train_clinical, train_snp, train_img)
    model.compile(optimizer=Adam(learning_rate = learning_rate), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            mode='min'
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=10,
            min_lr=1e-6,
            mode='min'
        )
    ]

    # summarize results
    history = model.fit([train_clinical,
                         train_snp,
                         train_img],
                        train_label,
                        epochs=epochs,
                        batch_size=batch_size,
                        class_weight=d_class_weights,
                        validation_split=0.1,
                        verbose=1,
                        callbacks = callbacks)
                        
                

    score = model.evaluate([test_clinical, test_snp, test_img], test_label)
    
    acc = score[1] 
    test_predictions = model.predict([test_clinical, test_snp, test_img])
    cr, precision_d, recall_d, thres = calc_confusion_matrix(test_predictions, test_label, mode, learning_rate, batch_size, epochs)
    
    
    """
    plt.clf()
    plt.plot(history.history['sparse_categorical_accuracy'])
    plt.plot(history.history['val_sparse_categorical_accuracy'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
    plt.savefig('accuracy_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'.png')
    plt.clf()
    # summarize history for loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
    plt.savefig('loss_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'.png')
    plt.clf()
    """
    
 
    
    # release gpu memory #
    K.clear_session()
    del model, history
    gc.collect()
        
        
    print ('Mode: ', mode)
    print ('Batch size:  ', batch_size)
    print ('Learning rate: ', learning_rate)
    print ('Epochs:  ', epochs)
    print ('Test Accuracy:', '{0:.4f}'.format(acc))
    print ('-'*55)
    
    return acc, batch_size, learning_rate, epochs, seed
    

In [26]:
m_a = {}
seeds = random.sample(range(1, 200), 1)
for s in seeds:
    acc, bs_, lr_, e_ , seed= train('MM_SA_BA', 8, 30, 0.001, s)
    m_a[acc] = ('MM_SA_BA', acc, bs_, lr_, e_, seed)
print(m_a)
print ('-'*55)
max_acc = max(m_a, key=float)
print("Highest accuracy of: " + str(max_acc) + " with parameters: " + str(m_a[max_acc]))





Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Classification Report :
              precision    recall  f1-score   support

           0       0.95      0.88      0.91        24
           1       0.50      0.50      0.50         4
           2       0.78      1.00      0.88         7

    accuracy                           0.86        35
   macro avg       0.74      0.79      0.76        35
weighted avg       0.87      0.86      0.86        35

Mode:  MM_SA_BA
Batch size:   8
Learning rate:  0.001
Epochs:   30
Test Accuracy: 0.8571
-------------------------------------------------------
{0.8571428656578064: ('MM_SA_BA', 0.8571428656578064, 8, 0.001, 30, 3)}
-------------------------------------------------------
Highest accuracy of: 0.85714286

## SNP - CNN

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Input, concatenate
from tensorflow.keras.layers import Conv2D, Conv1D,Conv2D, Flatten, MultiHeadAttention
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, precision_recall_fscore_support, precision_recall_curve
from tensorflow.keras import backend as K
import numpy as np
import gc
import os
import random

def reset_random_seeds(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def create_model_snp():
    model = Sequential([
        Conv1D(64, kernel_size=5, activation="relu", input_shape=(input_shape, 1)),
        BatchNormalization(),
        MaxPooling1D(pool_size=2),
        Dropout(config["dropout"]),

        Conv1D(128, kernel_size=3, activation="relu"),
        BatchNormalization(),
        MaxPooling1D(pool_size=2),
        Dropout(config["dropout"]),

        Flatten(),
        Dense(128, activation="relu"),
        BatchNormalization(),
        Dropout(config["dropout"]),

        Dense(50, activation="softmax")
    ])
    return model

def create_model_clinical():
    model = Sequential([
        Dense(128, activation="relu"),
        BatchNormalization(),
        Dropout(0.2),
        Dense(128, activation="relu"),
        BatchNormalization(),
        Dropout(0.2),
        Dense(50, activation="relu"),
        BatchNormalization(),
        Dropout(0.2)
    ])
    return model

def create_model_img():
    model = Sequential([
        Conv2D(64, (3, 3), activation='relu'),
        BatchNormalization(),
        Dropout(0.4),
        Conv2D(32, (3, 3), activation='relu'),
        BatchNormalization(),
        Dropout(0.4),
        Flatten(),
        Dense(50, activation='relu'),
        BatchNormalization(),
        Dropout(0.3)
    ])
    return model

def simplified_attention(x, y):
    """简化的注意力机制"""
    x = tf.expand_dims(x, axis=1)
    y = tf.expand_dims(y, axis=1)
    attention = MultiHeadAttention(num_heads=2, key_dim=25)(x, y)
    return attention[:,0,:]

def multi_modal_model(mode, train_clinical, train_snp, train_img):
    # 输入层
    in_clinical = Input(shape=(train_clinical.shape[1]))
    in_snp = Input(shape=(train_snp.shape[1]))
    in_img = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))
    
    # 特征提取
    dense_clinical = create_model_clinical()(in_clinical)
    dense_snp = create_model_snp()(in_snp)
    dense_img = create_model_img()(in_img)
    
    # 简化的注意力机制
    if mode == 'MM_SA_BA':
        # 只保留最重要的跨模态注意力
        img_clinical_att = simplified_attention(dense_img, dense_clinical)
        snp_clinical_att = simplified_attention(dense_snp, dense_clinical)
        merged = concatenate([img_clinical_att, snp_clinical_att, dense_img, dense_snp, dense_clinical])
    else:
        merged = concatenate([dense_img, dense_snp, dense_clinical])
    
    # 添加额外的整合层
    merged = Dense(100, activation='relu')(merged)
    merged = BatchNormalization()(merged)
    merged = Dropout(0.5)(merged)
    merged = Dense(50, activation='relu')(merged)
    merged = BatchNormalization()(merged)
    merged = Dropout(0.3)(merged)
    
    # 输出层
    output = Dense(3, activation='softmax')(merged)
    model = Model([in_clinical, in_snp, in_img], output)
    
    return model

def train(mode, batch_size, epochs, learning_rate, seed):
    reset_random_seeds(seed)
    
    # 计算类别权重
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(train_label),
        y=train_label
    )
    d_class_weights = dict(enumerate(class_weights))
    
    # 创建和编译模型
    model = multi_modal_model(mode, train_clinical, train_snp, train_img)
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='sparse_categorical_crossentropy',
        metrics=['sparse_categorical_accuracy']
    )
    
    # 回调函数
    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            mode='min'
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-6,
            mode='min'
        )
    ]
    
    # 训练模型
    history = model.fit(
        [train_clinical, train_snp, train_img],
        train_label,
        epochs=epochs,
        batch_size=batch_size,
        class_weight=d_class_weights,
        validation_split=0.2,  # 增加验证集比例
        callbacks=callbacks,
        verbose=1
    )
    
    # 评估模型
    score = model.evaluate([test_clinical, test_snp, test_img], test_label)
    acc = score[1]
    test_predictions = model.predict([test_clinical, test_snp, test_img])
    
    # 输出分类报告
    pred_labels = np.argmax(test_predictions, axis=1)
    print("\nClassification Report:")
    print(classification_report(test_label, pred_labels))
    
    # 清理内存
    K.clear_session()
    del model, history
    gc.collect()
    
    print(f'Mode: {mode}')
    print(f'Batch size: {batch_size}')
    print(f'Learning rate: {learning_rate}')
    print(f'Epochs: {epochs}')
    print(f'Test Accuracy: {acc:.4f}')
    print('-'*55)
    
    return acc, batch_size, learning_rate, epochs, seed

# 使用示例:
"""
results = train(
    mode='MM_SA_BA',
    batch_size=16,  # 增大batch size
    epochs=100,
    learning_rate=0.001,
    seed=42
)
"""

In [32]:
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import (
    Dense, 
    Dropout, 
    BatchNormalization, 
    Input, 
    concatenate,
    Conv1D,
    MaxPooling1D,
    GlobalAveragePooling1D,
    MultiHeadAttention,
    LayerNormalization,
    Reshape,
    Add,
    Activation
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K

import numpy as np
import pandas as pd
import random
import os
import gc

In [35]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

def make_img(t_img):
    img = pd.read_pickle(t_img)
    img_l = []
    for i in range(len(img)):
        img_l.append(img.values[i][0])
    
    return np.array(img_l)


def reset_random_seeds(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
   
               
def create_model_snp():
    
    model = Sequential()
    model.add(Dense(200,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(100, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.3))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))
    return model

def create_model_clinical():
    
    model = Sequential()
    model.add(Dense(200,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(100, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))    
    return model

# DeepSNP版本
def create_model_snp_deep():
    model = Sequential([
        # 第一层卷积块
        Conv1D(64, 3, activation='relu', input_shape=(train_snp.shape[1], 1)),
        BatchNormalization(),
        MaxPooling1D(2),
        Dropout(0.2),
        
        # 第二层卷积块
        Conv1D(128, 3, activation='relu'),
        BatchNormalization(),
        MaxPooling1D(2),
        Dropout(0.2),
        
        # 第三层卷积块
        Conv1D(256, 3, activation='relu'),
        BatchNormalization(),
        GlobalAveragePooling1D(),
        
        # 全连接层
        Dense(128, activation='relu'),
        BatchNormalization(),
        Dropout(0.3),
        Dense(50, activation='relu'),
        BatchNormalization(),
        Dropout(0.2)
    ])
    return model

# Transformer版本
def positional_encoding(length, depth):
    positions = np.arange(length)[:, np.newaxis]
    depths = np.arange(depth)[np.newaxis, :]/depth
    angle_rates = 1 / (10000**depths)
    angle_rads = positions * angle_rates
    
    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1)
    
    return tf.cast(pos_encoding, dtype=tf.float32)

def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
    # Multi-head self attention
    x = LayerNormalization(epsilon=1e-6)(inputs)
    x = MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout
    )(x, x)
    x = Dropout(dropout)(x)
    res = x + inputs

    # Feed Forward
    x = LayerNormalization(epsilon=1e-6)(res)
    x = Dense(ff_dim, activation="relu")(x)
    x = Dropout(dropout)(x)
    x = Dense(inputs.shape[-1])(x)
    return x + res

def create_model_snp_transformer():
    input_shape = train_snp.shape[1]
    
    inputs = Input(shape=(input_shape,))
    
    # Reshape and add positional encoding
    x = Dense(64)(inputs)  # 投影到嵌入空间
    x = Reshape((input_shape, 64))(x)
    
    # Add positional encoding
    pos_encoding = positional_encoding(input_shape, 64)
    x = x + pos_encoding
    
    # Transformer blocks
    for _ in range(4):  # 4个Transformer块
        x = transformer_encoder(
            x,
            head_size=32,
            num_heads=4,
            ff_dim=128,
            dropout=0.1
        )
    
    # Global pooling
    x = GlobalAveragePooling1D()(x)
    
    # Final dense layers
    x = Dense(128, activation="relu")(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    x = Dense(50, activation="relu")(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    
    model = Model(inputs=inputs, outputs=x)
    return model


def plot_classification_report(y_tru, y_prd, mode, learning_rate, batch_size,epochs, figsize=(7, 7), ax=None):

    plt.figure(figsize=figsize)

    xticks = ['precision', 'recall', 'f1-score', 'support']
    yticks = ["Control", "Moderate", "Alzheimer's" ] 
    yticks += ['avg']

    rep = np.array(precision_recall_fscore_support(y_tru, y_prd)).T
    avg = np.mean(rep, axis=0)
    avg[-1] = np.sum(rep[:, -1])
    rep = np.insert(rep, rep.shape[0], avg, axis=0)

    sns.heatmap(rep,
                annot=True, 
                cbar=False, 
                xticklabels=xticks, 
                yticklabels=yticks,
                ax=ax, cmap = "Blues")
    
    plt.savefig('report_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'_' + str(epochs)+'.png')
    


def calc_confusion_matrix(result, test_label,mode, learning_rate, batch_size, epochs):
    test_label = to_categorical(test_label,3)

    true_label= np.argmax(test_label, axis =1)

    predicted_label= np.argmax(result, axis =1)
    
    n_classes = 3
    precision = dict()
    recall = dict()
    thres = dict()
    for i in range(n_classes):
        precision[i], recall[i], thres[i] = precision_recall_curve(test_label[:, i],
                                                            result[:, i])


    print ("Classification Report :") 
    print (classification_report(true_label, predicted_label))
    cr = classification_report(true_label, predicted_label, output_dict=True)
    return cr, precision, recall, thres



def cross_modal_attention(x, y):
    x = tf.expand_dims(x, axis=1)
    y = tf.expand_dims(y, axis=1)
    a1 = MultiHeadAttention(num_heads = 4,key_dim=50)(x, y)
    a2 = MultiHeadAttention(num_heads = 4,key_dim=50)(y, x)
    a1 = a1[:,0,:]
    a2 = a2[:,0,:]
    return concatenate([a1, a2])


def self_attention(x):
    x = tf.expand_dims(x, axis=1)
    attention = MultiHeadAttention(num_heads = 4, key_dim=50)(x, x)
    attention = attention[:,0,:]
    return attention

def simplified_attention(x, y):
    """简化的注意力机制"""
    x = tf.expand_dims(x, axis=1)
    y = tf.expand_dims(y, axis=1)
    attention = MultiHeadAttention(num_heads=2, key_dim=25)(x, y)
    return attention[:,0,:]
    

def multi_modal_model(mode, train_clinical, train_snp, train_img, snp_model_type='deepsnp'):
    # 输入层
    in_clinical = Input(shape=(train_clinical.shape[1]))
    in_snp = Input(shape=(train_snp.shape[1]))
    in_img = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))
    
    # 特征提取
    dense_clinical = create_model_clinical()(in_clinical)
    # 选择SNP模型类型
    if snp_model_type == 'deepsnp':
        dense_snp = create_model_snp_deep()(in_snp)
    else:  # transformer
        dense_snp = create_model_snp_transformer()(in_snp)
    dense_img = create_model_img()(in_img)
    
    # 简化的注意力机制
    if mode == 'MM_SA_BA':
        # 只保留最重要的跨模态注意力
        img_clinical_att = simplified_attention(dense_img, dense_clinical)
        snp_clinical_att = simplified_attention(dense_snp, dense_clinical)
        merged = concatenate([img_clinical_att, snp_clinical_att, dense_img, dense_snp, dense_clinical])
    else:
        merged = concatenate([dense_img, dense_snp, dense_clinical])
    
    # 添加额外的整合层
    merged = Dense(100, activation='relu')(merged)
    merged = BatchNormalization()(merged)
    merged = Dropout(0.5)(merged)
    merged = Dense(50, activation='relu')(merged)
    merged = BatchNormalization()(merged)
    merged = Dropout(0.3)(merged)
    
    # 输出层
    output = Dense(3, activation='softmax')(merged)
    model = Model([in_clinical, in_snp, in_img], output)
    
    return model


def train(mode, batch_size, epochs, learning_rate, seed):
    
    # train_img = train_img.astype("float32")

    reset_random_seeds(seed)
    class_weights = compute_class_weight(class_weight = 'balanced',classes = np.unique(train_label),y = train_label)
    d_class_weights = dict(enumerate(class_weights))
    
    # compile model #
    model = multi_modal_model(mode, train_clinical, train_snp, train_img)
    model.compile(optimizer=Adam(learning_rate = learning_rate), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            mode='min'
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=10,
            min_lr=1e-6,
            mode='min'
        )
    ]

    # summarize results
    history = model.fit([train_clinical,
                         train_snp,
                         train_img],
                        train_label,
                        epochs=epochs,
                        batch_size=batch_size,
                        class_weight=d_class_weights,
                        validation_split=0.1,
                        verbose=1,
                        callbacks = callbacks)
                        
                

    score = model.evaluate([test_clinical, test_snp, test_img], test_label)
    
    acc = score[1] 
    test_predictions = model.predict([test_clinical, test_snp, test_img])
    cr, precision_d, recall_d, thres = calc_confusion_matrix(test_predictions, test_label, mode, learning_rate, batch_size, epochs)
    
    
    """
    plt.clf()
    plt.plot(history.history['sparse_categorical_accuracy'])
    plt.plot(history.history['val_sparse_categorical_accuracy'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
    plt.savefig('accuracy_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'.png')
    plt.clf()
    # summarize history for loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
    plt.savefig('loss_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'.png')
    plt.clf()
    """
    
 
    
    # release gpu memory #
    K.clear_session()
    del model, history
    gc.collect()
        
        
    print ('Mode: ', mode)
    print ('Batch size:  ', batch_size)
    print ('Learning rate: ', learning_rate)
    print ('Epochs:  ', epochs)
    print ('Test Accuracy:', '{0:.4f}'.format(acc))
    print ('-'*55)
    
    return acc, batch_size, learning_rate, epochs, seed
    

In [36]:
m_a = {}
seeds = random.sample(range(1, 200), 1)
for s in seeds:
    acc, bs_, lr_, e_ , seed= train('MM_SA_BA', 8, 30, 0.001, s)
    m_a[acc] = ('MM_SA_BA', acc, bs_, lr_, e_, seed)
print(m_a)
print ('-'*55)
max_acc = max(m_a, key=float)
print("Highest accuracy of: " + str(max_acc) + " with parameters: " + str(m_a[max_acc]))





Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30

KeyboardInterrupt: 

In [None]:
def create_model_snp_deep():
    model = Sequential([
        # 第一层卷积块
        Conv1D(64, 3, activation='relu', input_shape=(train_snp.shape[1], 1)),
        BatchNormalization(),
        MaxPooling1D(2),
        Dropout(0.2),
        
        # 第二层卷积块
        Conv1D(128, 3, activation='relu'),
        BatchNormalization(),
        MaxPooling1D(2),
        Dropout(0.2),
        
        # 第三层卷积块
        Conv1D(256, 3, activation='relu'),
        BatchNormalization(),
        GlobalAveragePooling1D(),
        
        # 全连接层
        Dense(128, activation='relu'),
        BatchNormalization(),
        Dropout(0.3),
        Dense(50, activation='relu'),
        BatchNormalization(),
        Dropout(0.2)
    ])
    return model

# Transformer版本
def positional_encoding(length, depth):
    positions = np.arange(length)[:, np.newaxis]
    depths = np.arange(depth)[np.newaxis, :]/depth
    angle_rates = 1 / (10000**depths)
    angle_rads = positions * angle_rates
    
    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1)
    
    return tf.cast(pos_encoding, dtype=tf.float32)

def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
    # Multi-head self attention
    x = LayerNormalization(epsilon=1e-6)(inputs)
    x = MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout
    )(x, x)
    x = Dropout(dropout)(x)
    res = x + inputs

    # Feed Forward
    x = LayerNormalization(epsilon=1e-6)(res)
    x = Dense(ff_dim, activation="relu")(x)
    x = Dropout(dropout)(x)
    x = Dense(inputs.shape[-1])(x)
    return x + res

def create_model_snp_transformer():
    input_shape = train_snp.shape[1]
    
    inputs = Input(shape=(input_shape,))
    
    # Reshape and add positional encoding
    x = Dense(64)(inputs)  # 投影到嵌入空间
    x = Reshape((input_shape, 64))(x)
    
    # Add positional encoding
    pos_encoding = positional_encoding(input_shape, 64)
    x = x + pos_encoding
    
    # Transformer blocks
    for _ in range(4):  # 4个Transformer块
        x = transformer_encoder(
            x,
            head_size=32,
            num_heads=4,
            ff_dim=128,
            dropout=0.1
        )
    
    # Global pooling
    x = GlobalAveragePooling1D()(x)
    
    # Final dense layers
    x = Dense(128, activation="relu")(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    x = Dense(50, activation="relu")(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    
    model = Model(inputs=inputs, outputs=x)
    return model

In [44]:
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Input, concatenate
from tensorflow.keras.layers import Conv2D, Flatten, MultiHeadAttention
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, precision_recall_fscore_support, precision_recall_curve
from tensorflow.keras import backend as K
import numpy as np
import gc
import os
import random

def reset_random_seeds(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def create_model_snp():
    model = Sequential([
        Dense(100, activation="relu"),
        BatchNormalization(),
        Dropout(0.5),
        Dense(50, activation="relu"),
        BatchNormalization(),
        Dropout(0.3)
    ])
    return model

def create_model_clinical():
    model = Sequential([
        Dense(100, activation="relu"),
        BatchNormalization(),
        Dropout(0.5),
        Dense(50, activation="relu"),
        BatchNormalization(),
        Dropout(0.4)
    ])
    return model

def create_model_img():
    model = Sequential([
        Conv2D(64, (3, 3), activation='relu'),
        BatchNormalization(),
        Dropout(0.4),
        Conv2D(32, (3, 3), activation='relu'),
        BatchNormalization(),
        Dropout(0.4),
        Flatten(),
        Dense(50, activation='relu'),
        BatchNormalization(),
        Dropout(0.3)
    ])
    return model

def simplified_attention(x, y):
    """简化的注意力机制"""
    x = tf.expand_dims(x, axis=1)
    y = tf.expand_dims(y, axis=1)
    attention = MultiHeadAttention(num_heads=2, key_dim=25)(x, y)
    return attention[:,0,:]

def multi_modal_model(mode, train_clinical, train_snp, train_img):
    # 输入层
    in_clinical = Input(shape=(train_clinical.shape[1]))
    in_snp = Input(shape=(train_snp.shape[1]))
    in_img = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))
    
    # 特征提取
    dense_clinical = create_model_clinical()(in_clinical)
    dense_snp = create_model_snp()(in_snp)
    dense_img = create_model_img()(in_img)
    
    # 简化的注意力机制
    if mode == 'MM_SA_BA':
        # 只保留最重要的跨模态注意力
        img_clinical_att = simplified_attention(dense_img, dense_clinical)
        snp_clinical_att = simplified_attention(dense_snp, dense_clinical)
        merged = concatenate([img_clinical_att, snp_clinical_att, dense_img, dense_snp, dense_clinical])
    else:
        merged = concatenate([dense_img, dense_snp, dense_clinical])
    
    # 添加额外的整合层
    merged = Dense(100, activation='relu')(merged)
    merged = BatchNormalization()(merged)
    merged = Dropout(0.5)(merged)
    merged = Dense(50, activation='relu')(merged)
    merged = BatchNormalization()(merged)
    merged = Dropout(0.3)(merged)
    
    # 输出层
    output = Dense(3, activation='softmax')(merged)
    model = Model([in_clinical, in_snp, in_img], output)
    
    return model

def train(mode, batch_size, epochs, learning_rate, seed):
    reset_random_seeds(seed)
    
    # 计算类别权重
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(train_label),
        y=train_label
    )
    d_class_weights = dict(enumerate(class_weights))
    
    # 创建和编译模型
    model = multi_modal_model(mode, train_clinical, train_snp, train_img)
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='sparse_categorical_crossentropy',
        metrics=['sparse_categorical_accuracy']
    )
    
    # 回调函数
    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            mode='min'
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-6,
            mode='min'
        )
    ]
    
    # 训练模型
    history = model.fit(
        [train_clinical, train_snp, train_img],
        train_label,
        epochs=epochs,
        batch_size=batch_size,
        class_weight=d_class_weights,
        validation_split=0.2,  # 增加验证集比例
        callbacks=callbacks,
        verbose=1
    )
    
    # 评估模型
    score = model.evaluate([test_clinical, test_snp, test_img], test_label)
    acc = score[1]
    test_predictions = model.predict([test_clinical, test_snp, test_img])
    
    # 输出分类报告
    pred_labels = np.argmax(test_predictions, axis=1)
    print("\nClassification Report:")
    print(classification_report(test_label, pred_labels))
    
    # 清理内存
    K.clear_session()
    del model, history
    gc.collect()
    
    print(f'Mode: {mode}')
    print(f'Batch size: {batch_size}')
    print(f'Learning rate: {learning_rate}')
    print(f'Epochs: {epochs}')
    print(f'Test Accuracy: {acc:.4f}')
    print('-'*55)
    
    return acc, batch_size, learning_rate, epochs, seed

# 使用示例:
"""
results = train(
    mode='MM_SA_BA',
    batch_size=16,  # 增大batch size
    epochs=100,
    learning_rate=0.001,
    seed=42
)
"""

"\nresults = train(\n    mode='MM_SA_BA',\n    batch_size=16,  # 增大batch size\n    epochs=100,\n    learning_rate=0.001,\n    seed=42\n)\n"

In [45]:
m_a = {}
seeds = random.sample(range(1, 200), 1)
for s in seeds:
    acc, bs_, lr_, e_ , seed= train('MM_SA_BA', 8, 50, 0.001, s)
    m_a[acc] = ('MM_SA_BA', acc, bs_, lr_, e_, seed)
print(m_a)
print ('-'*55)
max_acc = max(m_a, key=float)
print("Highest accuracy of: " + str(max_acc) + " with parameters: " + str(m_a[max_acc]))





Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50





Classification Report:
              precision    recall  f1-score   support

           0       1.00      0.67      0.80        12
           1       0.29      1.00      0.44         2
           2       1.00      0.75      0.86         4

    accuracy                           0.72        18
   macro avg       0.76      0.81      0.70        18
weighted avg       0.92      0.72      0.77        18

Mode: MM_SA_BA
Batch size: 8
Learning rate: 0.001
Epochs: 50
Test Accuracy: 0.7222
-------------------------------------------------------
{0.7222222089767456: ('MM_SA_BA', 0.7222222089767456, 8, 0.001, 50, 53)}
-------------------------------------------------------
Highest accuracy of: 0.7222222089767456 with parameters: ('MM_SA_BA', 0.7222222089767456, 8, 0.001, 50, 53)
