### Import Library

In [1]:
from tqdm.auto import tqdm
from utils import *
from SDGCCA import SDGCCA_3_M
import torch.nn as nn
import pandas as pd
import shap
import warnings
warnings.simplefilter("ignore", UserWarning)

### Seed fix

In [2]:
# Seed Setting
random_seed = 100
set_seed(random_seed)

### Train SDGCCA

**Toy Dataset**  
- Label: Binary
- Modality1: n(376) x d1 (18164)
- Modality1: n(376) x d2 (19353)
- Modality1: n(376) x d3 (309)

In [3]:
def train_SDGCCA(hyper_dict):
    # Return List
    ensemble_list = {'ACC': [], 'F1': [], 'AUC': [], 'MCC': []}
    metric_list = ['ACC', 'F1', 'AUC', 'MCC']
    hyper_param_list = []
    best_hyper_param_list = []
    
    # Prepare Toy Dataset
    dataset = Toy_Dataset(hyper_dict['random_seed'])

    # 5 CV
    cv = 0
    # Prepare Dataset
    [x_train_1, x_val_1, x_test_1], [x_train_2, x_val_2, x_test_2], [x_train_3, x_val_3, x_test_3], \
    [y_train, y_val, y_test] = dataset(cv, tensor=True, device=hyper_dict['device'])

    # Define Deep neural network dimension of the each modality
    m1_embedding_list = [x_train_1.shape[1]] + hyper_dict['embedding_size']
    m2_embedding_list = [x_train_2.shape[1]] + hyper_dict['embedding_size']
    m3_embedding_list = [x_train_3.shape[1]] + hyper_dict['embedding_size'][1:]

    # Train Label -> One_Hot_Encoding
    y_train_onehot = torch.zeros(y_train.shape[0], 2).float().to(hyper_dict['device'])
    y_train_onehot[range(y_train.shape[0]), y_train.squeeze()] = 1

    # Find Best K by Validation MCC
    val_mcc_result_list = []
    test_ensemble_dict = {'ACC': [], 'F1': [], 'AUC': [], 'MCC': []}

    # Grid search for find best hyperparameter by Validation MCC
    for top_k in tqdm(range(1, hyper_dict['max_top_k']+1), desc='Grid seach for find best hyperparameter...'):
        for lr in hyper_dict['lr']:
            for reg in hyper_dict['reg']:
                hyper_param_list.append([top_k, lr, reg])
                early_stopping = EarlyStopping(patience=hyper_dict['patience'], delta=hyper_dict['delta'])
                best_loss = np.Inf

                # Define SDGCCA with 3 modality
                model = SDGCCA_3_M(m1_embedding_list, m2_embedding_list, m3_embedding_list, top_k).to(hyper_dict['device'])

                # Optimizer
                clf_optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg)

                # Cross Entropy Loss
                criterion = nn.CrossEntropyLoss()

                # Model Train
                for i in range(hyper_dict['epoch']):
                    model.train()

                    # Calculate correlation loss
                    out1, out2, out3 = model(x_train_1, x_train_2, x_train_3)
                    cor_loss = model.cal_loss([out1, out2, out3, y_train_onehot])

                    # Calculate classification loss
                    clf_optimizer.zero_grad()

                    y_hat1, y_hat2, y_hat3, _ = model.predict(x_train_1, x_train_2, x_train_3)
                    clf_loss1 = criterion(y_hat1, y_train.squeeze())
                    clf_loss2 = criterion(y_hat2, y_train.squeeze())
                    clf_loss3 = criterion(y_hat3, y_train.squeeze())

                    clf_loss = clf_loss1 + clf_loss2 + clf_loss3

                    clf_loss.backward()
                    clf_optimizer.step()

                    # Model Validation
                    with torch.no_grad():
                        model.eval()
                        _, _, _, y_ensemble = model.predict(x_val_1, x_val_2, x_val_3)
                        val_loss = criterion(y_ensemble, y_val.squeeze())

                        early_stopping(val_loss)
                        if val_loss < best_loss:
                            best_loss = val_loss

                        if early_stopping.early_stop:
                            break

                # Load Best Model
                model.eval()

                # Model Validation
                _, _, _, ensembel_y_hat = model.predict(x_val_1, x_val_2, x_val_3)
                y_pred_ensemble = torch.argmax(ensembel_y_hat, 1).cpu().detach().numpy()
                y_pred_proba_ensemble = ensembel_y_hat[:, 1].cpu().detach().numpy()            
                _, _, _, val_mcc = calculate_metric(y_val.cpu().detach().numpy(), y_pred_ensemble, y_pred_proba_ensemble)
                val_mcc_result_list.append(val_mcc)

                # Model Tset
                _, _, _, ensembel_y_hat = model.predict(x_test_1, x_test_2, x_test_3)
                y_pred_ensemble = torch.argmax(ensembel_y_hat, 1).cpu().detach().numpy()
                y_pred_proba_ensemble = ensembel_y_hat[:, 1].cpu().detach().numpy()
                test_acc, test_f1, test_auc, test_mcc = calculate_metric(y_test.cpu().detach().numpy(), y_pred_ensemble, y_pred_proba_ensemble)
                ensemble_result = [test_acc, test_f1, test_auc, test_mcc]
                for k, metric in enumerate(metric_list):
                    test_ensemble_dict[metric].append(ensemble_result[k])

        # Find best K
        best_k = np.argmax(val_mcc_result_list)
        
        # Find best hyperparameter
        best_hyper_param_list.append(hyper_param_list[best_k])
        
        # Append Best K Test Result
        for metric in metric_list:
            ensemble_list[metric].append(test_ensemble_dict[metric][best_k])

    return ensemble_list, best_hyper_param_list

In [4]:
class CLF_Predict(nn.Module):
    def __init__(self, embedding_list, top_k, mi_flip=False):
        super().__init__()
        if mi_flip:
            # Embedding List
            du0, du1, du2 = embedding_list
            self.encoder = nn.Sequential(
                nn.Linear(du0, du1), nn.Tanh(),
                nn.Linear(du1, du2), nn.Tanh())

            self.classifier = nn.Sequential(
                nn.Linear(du2, top_k, bias=False),
                nn.Linear(top_k, 2, bias=False)
            )
        else:
            # Embedding List
            du0, du1, du2, du3 = embedding_list

            self.encoder = nn.Sequential(
                nn.Linear(du0, du1), nn.Tanh(),
                nn.Linear(du1, du2), nn.Tanh(),
                nn.Linear(du2, du3), nn.Tanh())

            self.classifier = nn.Sequential(
                nn.Linear(du3, top_k, bias=False),
                nn.Linear(top_k, 2, bias=False)
            )

    def forward(self, x):
        x = self.encoder(x)
        x = self.classifier(x)

        return x

### Setting Hyperparameter

In [5]:
hyper_dict = {'epoch': 1000, 'delta': 0, 'random_seed': random_seed,
              'device': torch.device("cuda:1" if torch.cuda.is_available() else "cpu"),
              'lr': [0.0001,0.00001], 'reg': [0, 0.01,0.0001], 
              'patience': 30, 'embedding_size': [256, 64, 16], 'max_top_k': 10}

### Model Training & Check Performance

In [6]:
ensemble_list, hyper = train_SDGCCA(hyper_dict)

# Check Performance
performance_result = check_mean_std_performance(ensemble_list)

print('Test Performance')
print('ACC: {} F1: {} AUC: {} MCC: {}'.format(performance_result[0], performance_result[1], performance_result[2], performance_result[3]))

print('\nBest Hyperparameter')
for i, h in enumerate(hyper):
    print('CV: {} Best k: {} Learning Rage: {} Regularization Term: {}'.format(i+1, h[0], h[1], h[2]))

Grid seach for find best hyperparameter...:   0%|          | 0/10 [00:00<?, ?it/s]

Test Performance
ACC: 49.34+-3.01 F1: 48.85+-7.71 AUC: 45.98+-5.36 MCC: -0.91+-6.51

Best Hyperparameter
CV: 1 Best k: 1 Learning Rage: 0.0001 Regularization Term: 0
CV: 2 Best k: 2 Learning Rage: 1e-05 Regularization Term: 0.0001
CV: 3 Best k: 2 Learning Rage: 1e-05 Regularization Term: 0.0001
CV: 4 Best k: 2 Learning Rage: 1e-05 Regularization Term: 0.0001
CV: 5 Best k: 5 Learning Rage: 1e-05 Regularization Term: 0.0001
CV: 6 Best k: 5 Learning Rage: 1e-05 Regularization Term: 0.0001
CV: 7 Best k: 5 Learning Rage: 1e-05 Regularization Term: 0.0001
CV: 8 Best k: 5 Learning Rage: 1e-05 Regularization Term: 0.0001
CV: 9 Best k: 5 Learning Rage: 1e-05 Regularization Term: 0.0001
CV: 10 Best k: 5 Learning Rage: 1e-05 Regularization Term: 0.0001


## Important features (example on CV1)    
CV: 1 Best k: 1 Learning Rage: 0.0001 Regularization Term: 0

In [7]:

# Prepare Toy Dataset
dataset = Toy_Dataset(hyper_dict['random_seed'])
# Return List
important_result_pd = pd.DataFrame(
    columns=['CV', 'Correlation',
             'm1_clf_feature', 'm2_clf_feature', 'm3_clf_feature'])
# CV: 1 Best k: 5 Learning Rage: 1e-05 Regularization Term: 0.0001
cv = 0
top_k = 5
lr = 0.0001
reg = 0
# Prepare Dataset
[x_train_1, x_val_1, x_test_1], [x_train_2, x_val_2, x_test_2], [x_train_3, x_val_3, x_test_3], \
[y_train, y_val, y_test] = dataset(cv, tensor=True, device=hyper_dict['device'])

# Define Deep neural network dimension of the each modality
m1_embedding_list = [x_train_1.shape[1]] + hyper_dict['embedding_size']
m2_embedding_list = [x_train_2.shape[1]] + hyper_dict['embedding_size']
m3_embedding_list = [x_train_3.shape[1]] + hyper_dict['embedding_size'][1:]

# Feature Name
m1_name_list = []
m2_name_list = []
m3_name_list = []
for i in range(x_train_1.shape[1]):
    m1_name_list.append('Data1_'+str(i))

for i in range(x_train_2.shape[1]):
    m2_name_list.append('Data2_'+str(i))

for i in range(x_train_3.shape[1]):
    m3_name_list.append('Data3_'+str(i))
# Train Label -> One_Hot_Encoding
y_train_onehot = torch.zeros(y_train.shape[0], 2).float().to(hyper_dict['device'])
y_train_onehot[range(y_train.shape[0]), y_train.squeeze()] = 1

# Define SDGCCA with 3 modality
model = SDGCCA_3_M(m1_embedding_list, m2_embedding_list, m3_embedding_list, top_k).to(hyper_dict['device'])
early_stopping = EarlyStopping(patience=30, delta=0)
best_loss = np.Inf
# Optimizer
clf_optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg)

# Cross Entropy Loss
criterion = nn.CrossEntropyLoss()

# Model Train
for i in range(hyper_dict['epoch']):
    model.train()

    # Calculate correlation loss
    out1, out2, out3 = model(x_train_1, x_train_2, x_train_3)
    cor_loss = model.cal_loss([out1, out2, out3, y_train_onehot])

    # Calculate classification loss
    clf_optimizer.zero_grad()

    y_hat1, y_hat2, y_hat3, _ = model.predict(x_train_1, x_train_2, x_train_3)
    clf_loss1 = criterion(y_hat1, y_train.squeeze())
    clf_loss2 = criterion(y_hat2, y_train.squeeze())
    clf_loss3 = criterion(y_hat3, y_train.squeeze())

    clf_loss = clf_loss1 + clf_loss2 + clf_loss3

    clf_loss.backward()
    clf_optimizer.step()

    # Model Validation
    with torch.no_grad():
        model.eval()
        _, _, _, y_ensemble = model.predict(x_val_1, x_val_2, x_val_3)
        val_loss = criterion(y_ensemble, y_val.squeeze())

        early_stopping(val_loss)
        if val_loss < best_loss:
            best_loss = val_loss

        if early_stopping.early_stop:
            break

                # Load Best Model
model.eval()

# Load U

u =  [model.U[0].cpu().detach(), model.U[1].cpu().detach(),
                          model.U[2], torch.pinverse(model.U[3]).cpu().detach()]
for i in range(len(u)):
    u[i] = u[i].to(hyper_dict['device'])

"""
x1 = torch.cat((x_train_1, x_val_1, x_test_1), dim=0)
x2 = torch.cat((x_train_2, x_val_2, x_test_2), dim=0)
x3 = torch.cat((x_train_3, x_val_3, x_test_3), dim=0)
"""
x1 = x_train_1
x2 = x_train_2
x3 = x_train_3

output1, output2, output3 = model(x1, x2, x3)

# Raw -> Embedding -> PLS Embedding
pls_embedding1 = torch.mm(output1, u[0].to(hyper_dict['device']))
pls_embedding2 = torch.mm(output2, u[1].to(hyper_dict['device']))
pls_embedding3 = torch.mm(output3, u[2].to(hyper_dict['device']))

# Correlation - PLS Embedding Dimension=0 -> SVD max eigenvalue
correlation = \
    (np.corrcoef(pls_embedding1[:, 0].detach().cpu().numpy(), pls_embedding2[:, 0].detach().cpu().numpy())[0, 1] +
     np.corrcoef(pls_embedding1[:, 0].detach().cpu().numpy(), pls_embedding3[:, 0].detach().cpu().numpy())[0, 1] +
     np.corrcoef(pls_embedding2[:, 0].detach().cpu().numpy(), pls_embedding3[:, 0].detach().cpu().numpy())[0, 1])/3

# Define Modality_Classifier
m1_classifier = CLF_Predict(m1_embedding_list,top_k)
m2_classifier = CLF_Predict(m2_embedding_list, top_k)
m3_classifier = CLF_Predict(m3_embedding_list, top_k, mi_flip=True)

# Embedding & Classifier Weight
m1_classifier.encoder = model.model1
for i, p in enumerate(m1_classifier.classifier.parameters()):
    if i == 0:
        p.data = u[0].T
    else:
        p.data = u[3].T
m1_classifier = m1_classifier.to(hyper_dict['device'])
m1_classifier.eval()

# ME Embedding & Classifier Weight
m2_classifier.encoder = model.model2
for i, p in enumerate(m2_classifier.classifier.parameters()):
    if i == 0:
        p.data = u[1].T
    else:
        p.data = u[3].T
m2_classifier = m2_classifier.to(hyper_dict['device'])
m2_classifier.eval()

# ME Embedding & Classifier Weight
m3_classifier.encoder = model.model3
for i, p in enumerate(m3_classifier.classifier.parameters()):
    if i == 0:
        p.data = u[2].T
    else:
        p.data = u[3].T
m3_classifier = m3_classifier.to(hyper_dict['device'])
m3_classifier.eval()

# Modality1 Classification Important Feature
explainer_shap = shap.DeepExplainer(m1_classifier, x1)
shap_values = explainer_shap.shap_values(x1)
clf_m1_important_feature = np.argsort(abs(shap_values[0].mean(0)) + abs(shap_values[1].mean(0)))[::-1][:300]
m1_clf_feature = np.array(m1_name_list)[clf_m1_important_feature]

# Modality2 Classification Important Feature
explainer_shap = shap.DeepExplainer(m2_classifier, x2)
shap_values = explainer_shap.shap_values(x2)
clf_m2_important_feature = np.argsort(abs(shap_values[0].mean(0)) + abs(shap_values[1].mean(0)))[::-1][:300]
m2_clf_feature = np.array(m2_name_list)[clf_m2_important_feature]

# Modality3 Classification Important Feature
explainer_shap = shap.DeepExplainer(m3_classifier, x3)
shap_values = explainer_shap.shap_values(x3)
clf_m3_important_feature = np.argsort(abs(shap_values[0].mean(0)) + abs(shap_values[1].mean(0)))[::-1][:30]
m3_clf_feature = np.array(m3_name_list)[clf_m3_important_feature]

important_dict = {'CV': cv + 1, 'Correlation': correlation,
                  'm1_clf_feature': m1_clf_feature, 'm2_clf_feature': m2_clf_feature, 'm3_clf_feature': m3_clf_feature}
important_result_pd = important_result_pd.append(important_dict, ignore_index=True)




Important features (data 1)

In [8]:
m1_clf_feature

array(['Data1_15348', 'Data1_3288', 'Data1_7453', 'Data1_14143',
       'Data1_8334', 'Data1_15793', 'Data1_2745', 'Data1_3290',
       'Data1_10881', 'Data1_8440', 'Data1_17406', 'Data1_13185',
       'Data1_16885', 'Data1_10418', 'Data1_11008', 'Data1_15480',
       'Data1_3389', 'Data1_15811', 'Data1_2911', 'Data1_8826',
       'Data1_5446', 'Data1_6685', 'Data1_7821', 'Data1_11394',
       'Data1_3617', 'Data1_3240', 'Data1_10628', 'Data1_3962',
       'Data1_11800', 'Data1_2014', 'Data1_6199', 'Data1_3267',
       'Data1_11176', 'Data1_11232', 'Data1_7718', 'Data1_7414',
       'Data1_7506', 'Data1_8976', 'Data1_2843', 'Data1_9036',
       'Data1_6295', 'Data1_12579', 'Data1_1244', 'Data1_12091',
       'Data1_11251', 'Data1_9918', 'Data1_17066', 'Data1_11186',
       'Data1_17288', 'Data1_524', 'Data1_11520', 'Data1_10832',
       'Data1_10986', 'Data1_9668', 'Data1_5476', 'Data1_5586',
       'Data1_16309', 'Data1_2594', 'Data1_12084', 'Data1_10708',
       'Data1_16573', 'Data1