### Import Library

In [1]:
from tqdm.auto import tqdm
from utils import *
from SDGCCA import SDGCCA_3_M
import torch.nn as nn

import warnings
warnings.simplefilter("ignore", UserWarning)

### Seed fix

In [2]:
# Seed Setting
random_seed = 100
set_seed(random_seed)

### Train SDGCCA

**Toy Dataset**  
- Label: Binary
- Modality1: n(376) x d1 (18164)
- Modality1: n(376) x d2 (19353)
- Modality1: n(376) x d3 (309)

In [3]:
def train_SDGCCA(hyper_dict):
    # Return List
    ensemble_list = {'ACC': [], 'F1': [], 'AUC': [], 'MCC': []}
    metric_list = ['ACC', 'F1', 'AUC', 'MCC']
    hyper_param_list = []
    best_hyper_param_list = []
    
    # Prepare Toy Dataset
    dataset = Toy_Dataset(hyper_dict['random_seed'])

    # 5 CV
    for cv in tqdm(range(5), desc='CV...'):
        # Prepare Dataset
        [x_train_1, x_val_1, x_test_1], [x_train_2, x_val_2, x_test_2], [x_train_3, x_val_3, x_test_3], \
        [y_train, y_val, y_test] = dataset(cv, tensor=True, device=hyper_dict['device'])
        
        # Define Deep neural network dimension of the each modality
        m1_embedding_list = [x_train_1.shape[1]] + hyper_dict['embedding_size']
        m2_embedding_list = [x_train_2.shape[1]] + hyper_dict['embedding_size']
        m3_embedding_list = [x_train_3.shape[1]] + hyper_dict['embedding_size'][1:]

        # Train Label -> One_Hot_Encoding
        y_train_onehot = torch.zeros(y_train.shape[0], 2).float().to(hyper_dict['device'])
        y_train_onehot[range(y_train.shape[0]), y_train.squeeze()] = 1

        # Find Best K by Validation MCC
        val_mcc_result_list = []
        test_ensemble_dict = {'ACC': [], 'F1': [], 'AUC': [], 'MCC': []}

        # Grid search for find best hyperparameter by Validation MCC
        for top_k in tqdm(range(1, hyper_dict['max_top_k']+1), desc='Grid seach for find best hyperparameter...'):
            for lr in hyper_dict['lr']:
                for reg in hyper_dict['reg']:
                    hyper_param_list.append([top_k, lr, reg])
                    early_stopping = EarlyStopping(patience=hyper_dict['patience'], delta=hyper_dict['delta'])
                    best_loss = np.Inf

                    # Define SDGCCA with 3 modality
                    model = SDGCCA_3_M(m1_embedding_list, m2_embedding_list, m3_embedding_list, top_k).to(hyper_dict['device'])

                    # Optimizer
                    clf_optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg)

                    # Cross Entropy Loss
                    criterion = nn.CrossEntropyLoss()

                    # Model Train
                    for i in range(hyper_dict['epoch']):
                        model.train()

                        # Calculate correlation loss
                        out1, out2, out3 = model(x_train_1, x_train_2, x_train_3)
                        cor_loss = model.cal_loss([out1, out2, out3, y_train_onehot])

                        # Calculate classification loss
                        clf_optimizer.zero_grad()

                        y_hat1, y_hat2, y_hat3, _ = model.predict(x_train_1, x_train_2, x_train_3)
                        clf_loss1 = criterion(y_hat1, y_train.squeeze())
                        clf_loss2 = criterion(y_hat2, y_train.squeeze())
                        clf_loss3 = criterion(y_hat3, y_train.squeeze())

                        clf_loss = clf_loss1 + clf_loss2 + clf_loss3

                        clf_loss.backward()
                        clf_optimizer.step()

                        # Model Validation
                        with torch.no_grad():
                            model.eval()
                            _, _, _, y_ensemble = model.predict(x_val_1, x_val_2, x_val_3)
                            val_loss = criterion(y_ensemble, y_val.squeeze())

                            early_stopping(val_loss)
                            if val_loss < best_loss:
                                best_loss = val_loss

                            if early_stopping.early_stop:
                                break

                    # Load Best Model
                    model.eval()

                    # Model Validation
                    _, _, _, ensembel_y_hat = model.predict(x_val_1, x_val_2, x_val_3)
                    y_pred_ensemble = torch.argmax(ensembel_y_hat, 1).cpu().detach().numpy()
                    y_pred_proba_ensemble = ensembel_y_hat[:, 1].cpu().detach().numpy()            
                    _, _, _, val_mcc = calculate_metric(y_val.cpu().detach().numpy(), y_pred_ensemble, y_pred_proba_ensemble)
                    val_mcc_result_list.append(val_mcc)

                    # Model Tset
                    _, _, _, ensembel_y_hat = model.predict(x_test_1, x_test_2, x_test_3)
                    y_pred_ensemble = torch.argmax(ensembel_y_hat, 1).cpu().detach().numpy()
                    y_pred_proba_ensemble = ensembel_y_hat[:, 1].cpu().detach().numpy()
                    test_acc, test_f1, test_auc, test_mcc = calculate_metric(y_test.cpu().detach().numpy(), y_pred_ensemble, y_pred_proba_ensemble)
                    ensemble_result = [test_acc, test_f1, test_auc, test_mcc]
                    for k, metric in enumerate(metric_list):
                        test_ensemble_dict[metric].append(ensemble_result[k])

        # Find best K
        best_k = np.argmax(val_mcc_result_list)
        
        # Find best hyperparameter
        best_hyper_param_list.append(hyper_param_list[best_k])
        
        # Append Best K Test Result
        for metric in metric_list:
            ensemble_list[metric].append(test_ensemble_dict[metric][best_k])

    return ensemble_list, best_hyper_param_list

### Setting Hyperparameter

In [4]:
hyper_dict = {'epoch': 1000, 'delta': 0, 'random_seed': random_seed,
              'device': torch.device("cuda:1" if torch.cuda.is_available() else "cpu"),
              'lr': [0.0001,0.00001], 'reg': [0, 0.01,0.0001], 
              'patience': 30, 'embedding_size': [256, 64, 16], 'max_top_k': 10}

### Model Training & Check Performance

In [5]:
ensemble_list, hyper = train_SDGCCA(hyper_dict)

# Check Performance
performance_result = check_mean_std_performance(ensemble_list)

print('Test Performance')
print('ACC: {} F1: {} AUC: {} MCC: {}'.format(performance_result[0], performance_result[1], performance_result[2], performance_result[3]))

print('\nBest Hyperparameter')
for i, h in enumerate(hyper):
    print('CV: {} Best k: {} Learning Rage: {} Regularization Term: {}'.format(i+1, h[0], h[1], h[2]))

CV...:   0%|          | 0/5 [00:00<?, ?it/s]

Grid seach for find best hyperparameter...:   0%|          | 0/10 [00:00<?, ?it/s]

Grid seach for find best hyperparameter...:   0%|          | 0/10 [00:00<?, ?it/s]

Grid seach for find best hyperparameter...:   0%|          | 0/10 [00:00<?, ?it/s]

Grid seach for find best hyperparameter...:   0%|          | 0/10 [00:00<?, ?it/s]

Grid seach for find best hyperparameter...:   0%|          | 0/10 [00:00<?, ?it/s]

Test Performance
ACC: 50.27+-2.58 F1: 38.35+-10.52 AUC: 48.03+-3.45 MCC: -0.96+-5.80

Best Hyperparameter
CV: 1 Best k: 5 Learning Rage: 1e-05 Regularization Term: 0.0001
CV: 2 Best k: 4 Learning Rage: 1e-05 Regularization Term: 0
CV: 3 Best k: 2 Learning Rage: 1e-05 Regularization Term: 0.01
CV: 4 Best k: 8 Learning Rage: 1e-05 Regularization Term: 0.0001
CV: 5 Best k: 4 Learning Rage: 1e-05 Regularization Term: 0
