def multi_modal_model(mode, train_clinical, train_snp, train_img):
    
    in_clinical = Input(shape=(train_clinical.shape[1]))
    
    in_snp = Input(shape=(train_snp.shape[1]))
    
    in_img = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))
    
    dense_clinical = create_model_clinical()(in_clinical)
    dense_snp = create_model_snp()(in_snp) 
    dense_img = create_model_img()(in_img) 
        
    ########### Attention Layer ############
        
    ## Cross Modal Bi-directional Attention ##

    if mode == 'MM_BA':
            
        vt_att = cross_modal_attention(dense_img, dense_clinical)
        av_att = cross_modal_attention(dense_snp, dense_img)
        ta_att = cross_modal_attention(dense_clinical, dense_snp)
                
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
        
    ## Self Attention ##
    elif mode == 'MM_SA': 
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
            
        merged = concatenate([aa_att, vv_att, tt_att, dense_img, dense_snp, dense_clinical])
        
    ## Self Attention and Cross Modal Bi-directional Attention##
    elif mode == 'MM_SA_BA':
            
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
        
        vt_att = cross_modal_attention(vv_att, tt_att)
        av_att = cross_modal_attention(aa_att, vv_att)
        ta_att = cross_modal_attention(tt_att, aa_att)
            
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
            
        
    ## No Attention ##    
    elif mode == 'None':
        merged = concatenate([dense_img, dense_snp, dense_clinical])
                
    else:
        print ("Mode must be one of 'MM_SA', 'MM_BA', 'MU_SA_BA' or 'None'.")
        return
    
    
        
    ########### Output Layer ############
        
    output = Dense(3, activation='softmax')(merged)
    model = Model([in_clinical, in_snp, in_img], output)        
        
    return model



def train(mode, batch_size, epochs, learning_rate, seed):
    
    # train_img = train_img.astype("float32")

    reset_random_seeds(seed)
    class_weights = compute_class_weight(class_weight = 'balanced',classes = np.unique(train_label),y = train_label)
    d_class_weights = dict(enumerate(class_weights))
    
    # compile model #
    model = multi_modal_model(mode, train_clinical, train_snp, train_img)
    model.compile(optimizer=Adam(learning_rate = learning_rate), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
    

    # summarize results
    history = model.fit([train_clinical,
                         train_snp,
                         train_img],
                         train_label,
                        epochs=epochs,
                        batch_size=batch_size,
                        class_weight=d_class_weights,
                        validation_split=0.1,
                        verbose=1)
                        
                

    score = model.evaluate([test_clinical, test_snp, test_img], test_label)
    
    acc = score[1] 
    test_predictions = model.predict([test_clinical, test_snp, test_img])
    cr, precision_d, recall_d, thres = calc_confusion_matrix(test_predictions, test_label, mode, learning_rate, batch_size, epochs)
    
    
    """
    plt.clf()
    plt.plot(history.history['sparse_categorical_accuracy'])
    plt.plot(history.history['val_sparse_categorical_accuracy'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
    plt.savefig('accuracy_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'.png')
    plt.clf()
    # summarize history for loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
    plt.savefig('loss_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'.png')
    plt.clf()
    """
    
    # release gpu memory #
    K.clear_session()
    del model, history
    gc.collect()
        
        
    print ('Mode: ', mode)
    print ('Batch size:  ', batch_size)
    print ('Learning rate: ', learning_rate)
    print ('Epochs:  ', epochs)
    print ('Test Accuracy:', '{0:.4f}'.format(acc))
    print ('-'*55)
    
    return acc, batch_size, learning_rate, epochs, seed
我现在有Clinical data，SNP data(0, 1, 2, 3)和MRI data。我希望使用multimodal来进行阿尔兹海默症的三分类。我需要一些简单的multimodel来作为baseline

这个是我本身的模型，你需要在模型训练完后，提取得到的feature（merged），再连接上对应的传统ML算法就行。

In [1]:
import os
import random
import gc, numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.utils import compute_class_weight
import tensorflow as tf
from keras.models import Model
from keras import backend as K
from keras.layers import Input, Dense, Dropout,Flatten, BatchNormalization, Conv2D, MultiHeadAttention, concatenate
from sklearn.metrics import classification_report
from tensorflow.keras.optimizers import Adam
from keras.models import Sequential
from tensorflow.keras.utils import to_categorical
import seaborn as sns
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import precision_recall_curve

In [2]:
def make_img(t_img):
    img = pd.read_pickle(t_img)
    img_l = []
    for i in range(len(img)):
        img_l.append(img.values[i][0])
    
    return np.array(img_l)


def reset_random_seeds(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
   
               
def create_model_snp():
    model = Sequential()
    model.add(Dense(200,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(100, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.3))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))
    return model

def create_model_clinical():
    model = Sequential()
    model.add(Dense(128,  activation = "relu")) 
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(128, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    
    model.add(Dense(50, activation = "relu"))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))    
    return model

def create_model_img():
    model = Sequential()
    model.add(Conv2D(72, (3, 3), activation='relu')) 
    model.add(Dropout(0.3))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(Dropout(0.3))
    model.add(Conv2D(32, (3, 3), activation='relu'))
    model.add(Dropout(0.3))
    model.add(Flatten())
    model.add(Dense(50, activation='relu'))   
    return model

def plot_classification_report(y_tru, y_prd, mode, learning_rate, batch_size,epochs, figsize=(7, 7), ax=None):

    plt.figure(figsize=figsize)

    xticks = ['precision', 'recall', 'f1-score', 'support']
    yticks = ["Control", "Moderate", "Alzheimer's" ] 
    yticks += ['avg']

    rep = np.array(precision_recall_fscore_support(y_tru, y_prd)).T
    avg = np.mean(rep, axis=0)
    avg[-1] = np.sum(rep[:, -1])
    rep = np.insert(rep, rep.shape[0], avg, axis=0)

    sns.heatmap(rep,
                annot=True, 
                cbar=False, 
                xticklabels=xticks, 
                yticklabels=yticks,
                ax=ax, cmap = "Blues")
    
    plt.savefig('report_' + str(mode) + '_' + str(learning_rate) +'_' + str(batch_size)+'_' + str(epochs)+'.png')
    

def calc_confusion_matrix(result, test_label,mode, learning_rate, batch_size, epochs):
    test_label = to_categorical(test_label,3)

    true_label= np.argmax(test_label, axis =1)

    predicted_label= np.argmax(result, axis =1)
    
    n_classes = 3
    precision = dict()
    recall = dict()
    thres = dict()
    for i in range(n_classes):
        precision[i], recall[i], thres[i] = precision_recall_curve(test_label[:, i],
                                                            result[:, i])


    print ("Classification Report :") 
    print (classification_report(true_label, predicted_label))
    cr = classification_report(true_label, predicted_label, output_dict=True)
    return cr, precision, recall, thres


def cross_modal_attention(x, y):
    x = tf.expand_dims(x, axis=1)
    y = tf.expand_dims(y, axis=1)
    a1 = MultiHeadAttention(num_heads = 4,key_dim=50)(x, y)
    a2 = MultiHeadAttention(num_heads = 4,key_dim=50)(y, x)
    a1 = a1[:,0,:]
    a2 = a2[:,0,:]
    return concatenate([a1, a2])


def self_attention(x):
    x = tf.expand_dims(x, axis=1)
    attention = MultiHeadAttention(num_heads = 4, key_dim=50)(x, x)
    attention = attention[:,0,:]
    return attention

In [3]:
train_clinical = pd.read_csv("../preprocess_overlap/X_train_clinical.csv").values
test_clinical= pd.read_csv("../preprocess_overlap/X_test_clinical.csv").values

train_snp = pd.read_csv("../preprocess_overlap/X_train_snp.csv").values
test_snp = pd.read_csv("../preprocess_overlap/X_test_snp.csv").values

train_img= make_img("../preprocess_overlap/X_train_img.pkl")
test_img= make_img("../preprocess_overlap/X_test_img.pkl")

train_label= pd.read_csv("../preprocess_overlap/y_train.csv").values.astype("int").flatten()
test_label= pd.read_csv("../preprocess_overlap/y_test.csv").values.astype("int").flatten()

train_clinical = train_clinical.astype("float32")
test_clinical = test_clinical.astype("float32")
# train_snp = train_snp.astype("float32")
# train_snp = test_snp.astype("float32")

In [None]:
def multi_modal_model(mode, train_clinical, train_snp, train_img, ml_algorithm):
    """
    Multi-modal model with traditional ML layer for direct classification.
    """
    # Inputs
    in_clinical = Input(shape=(train_clinical.shape[1]))
    in_snp = Input(shape=(train_snp.shape[1]))
    in_img = Input(shape=(train_img.shape[1], train_img.shape[2], train_img.shape[3]))

    # Dense layers for each modality
    dense_clinical = create_model_clinical()(in_clinical)
    dense_snp = create_model_snp()(in_snp)
    dense_img = create_model_img()(in_img)

    # Attention mechanism based on mode
    if mode == 'MM_BA':
        vt_att = cross_modal_attention(dense_img, dense_clinical)
        av_att = cross_modal_attention(dense_snp, dense_img)
        ta_att = cross_modal_attention(dense_clinical, dense_snp)
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
    elif mode == 'MM_SA':
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
        merged = concatenate([aa_att, vv_att, tt_att, dense_img, dense_snp, dense_clinical])
    elif mode == 'MM_SA_BA':
        vv_att = self_attention(dense_img)
        tt_att = self_attention(dense_clinical)
        aa_att = self_attention(dense_snp)
        vt_att = cross_modal_attention(vv_att, tt_att)
        av_att = cross_modal_attention(aa_att, vv_att)
        ta_att = cross_modal_attention(tt_att, aa_att)
        print(vt_att.shape, av_att.shape, ta_att.shape, dense_img.shape, dense_snp.shape, dense_clinical.shape)
        
        merged = concatenate([vt_att, av_att, ta_att, dense_img, dense_snp, dense_clinical])
    elif mode == 'None':
        merged = concatenate([dense_img, dense_snp, dense_clinical])
    else:
        raise ValueError("Mode must be one of 'MM_SA', 'MM_BA', 'MM_SA_BA' or 'None'.")
    
    
    output = Dense(3, activation='softmax')(merged)
    model = Model([in_clinical, in_snp, in_img], output)

    return model




In [None]:
def train(mode, batch_size, epochs, learning_rate, seed):
    
    # train_img = train_img.astype("float32")

    reset_random_seeds(seed)
    class_weights = compute_class_weight(class_weight = 'balanced',classes = np.unique(train_label),y = train_label)
    d_class_weights = dict(enumerate(class_weights))
    
    # compile model #
    model = multi_modal_model(mode, train_clinical, train_snp, train_img)
    model.compile(optimizer=Adam(learning_rate = learning_rate), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
    

    # summarize results
    history = model.fit([train_clinical,
                         train_snp,
                         train_img],
                        train_label,
                        epochs=epochs,
                        batch_size=batch_size,
                        class_weight=d_class_weights,
                        validation_split=0.1,
                        verbose=1)
                        
                

    score = model.evaluate([test_clinical, test_snp, test_img], test_label)
    
    acc = score[1]
    test_predictions = model.predict([test_clinical, test_snp, test_img])
    cr, precision_d, recall_d, thres = calc_confusion_matrix(test_predictions, test_label, mode, learning_rate, batch_size, epochs)
 
    # release gpu memory #
    K.clear_session()
    del model, history
    gc.collect()
        
        
    print ('Mode: ', mode)
    print ('Batch size:  ', batch_size)
    print ('Learning rate: ', learning_rate)
    print ('Epochs:  ', epochs)
    print ('Test Accuracy:', '{0:.4f}'.format(acc))
    print ('-'*55)
    
    return acc, batch_size, learning_rate, epochs, seed

In [None]:
model = multi_modal_model('MM_SA_BA', train_clinical, train_snp, train_img, ml_algorithm='RandomForest')
model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

history = model.fit(
    [train_clinical, train_snp, train_img],
    train_label,
    epochs=10,
    batch_size=32,
    validation_split=0.1
)




(None, 100) (None, 100) (None, 100) (None, 50) (None, 50) (None, 50)
KerasTensor(type_spec=TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), name='lambda_3/EagerPyFunc:0', description="created by layer 'lambda_3'")
Epoch 1/10


ValueError: in user code:

    File "/Users/flynnzhang/anaconda3/envs/Intro2DL/lib/python3.8/site-packages/keras/src/engine/training.py", line 1338, in train_function  *
        return step_function(self, iterator)
    File "/Users/flynnzhang/anaconda3/envs/Intro2DL/lib/python3.8/site-packages/keras/src/engine/training.py", line 1322, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/Users/flynnzhang/anaconda3/envs/Intro2DL/lib/python3.8/site-packages/keras/src/engine/training.py", line 1303, in run_step  **
        outputs = model.train_step(data)
    File "/Users/flynnzhang/anaconda3/envs/Intro2DL/lib/python3.8/site-packages/keras/src/engine/training.py", line 1085, in train_step
        return self.compute_metrics(x, y, y_pred, sample_weight)
    File "/Users/flynnzhang/anaconda3/envs/Intro2DL/lib/python3.8/site-packages/keras/src/engine/training.py", line 1179, in compute_metrics
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
    File "/Users/flynnzhang/anaconda3/envs/Intro2DL/lib/python3.8/site-packages/keras/src/engine/compile_utils.py", line 577, in update_state
        self.build(y_pred, y_true)
    File "/Users/flynnzhang/anaconda3/envs/Intro2DL/lib/python3.8/site-packages/keras/src/engine/compile_utils.py", line 483, in build
        self._metrics = tf.__internal__.nest.map_structure_up_to(
    File "/Users/flynnzhang/anaconda3/envs/Intro2DL/lib/python3.8/site-packages/keras/src/engine/compile_utils.py", line 631, in _get_metric_objects
        return [self._get_metric_object(m, y_t, y_p) for m in metrics]
    File "/Users/flynnzhang/anaconda3/envs/Intro2DL/lib/python3.8/site-packages/keras/src/engine/compile_utils.py", line 631, in <listcomp>
        return [self._get_metric_object(m, y_t, y_p) for m in metrics]
    File "/Users/flynnzhang/anaconda3/envs/Intro2DL/lib/python3.8/site-packages/keras/src/engine/compile_utils.py", line 653, in _get_metric_object
        y_p_rank = len(y_p.shape.as_list())

    ValueError: as_list() is not defined on an unknown TensorShape.
