### Import Library

In [1]:
from util import *
from model import *
import torch
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import StratifiedKFold, train_test_split

### Hyperparameter Setting

<table style="width: 100%">
    <colgroup>
       <col span="1" style="width: 60%;">
       <col span="1" style="width: 20%;">
       <col span="1" style="width: 20%;">
    </colgroup>
    <thead>
        <tr>
            <th>Description</th>
            <th>Code</th>
            <th>Value</th>
        </tr>
    </thead>
    <tbody>
        <tr>
            <td>Learning rate ($\eta_1$) for train multiple coefficient matrices ($V_{i,r}$)</td>
            <td>re_lr</td>
            <td>$1\mathrm{e}{-3}$</td>
        </tr>
    </tbody>
    <tbody>
        <tr>
            <td>Regularization rate ($\alpha$) for train multiple coefficient matrices ($V_{i,r}$)</td>
            <td>re_reg</td>
            <td>$1\mathrm{e}{-4}$</td>
        </tr>
    </tbody>
    <tbody>
        <tr>
            <td>Learning rate ($\eta_2$) for classifier ($f(\cdot)$) and a common basis matrix ($U$)</td>
            <td>clf_lr</td>
            <td>$1\mathrm{e}{-3}$</td>
        </tr>
    </tbody>
    <tbody>
        <tr>
            <td>Regularization rate ($\beta$) for classifier ($f(\cdot)$) and a common basis matrix ($U$)</td>
            <td>clf_reg</td>
            <td>$1\mathrm{e}{-3}$</td>
        </tr>
    </tbody>
    <tbody>
        <tr>
            <td>The dimension of multiple coefficient matrices as $[D_{i2}, D_{i3}]$, and define the dimension of the hidden layer of the classifier as $h$</td>
            <td>[du1, du2, du3]</td>
            <td>$[110, 90, 70]$</td>
        </tr>
    </tbody>
    <tbody>
        <tr>
            <td>certain patience ($T_{\text{p}}$)</td>
            <td>patience</td>
            <td>30</td>
        </tr>
    </tbody>
</table>

In [2]:
# Device Setting
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Hyperparameter Setting
hyper_dict = {'patience': 30, 'epoch': 50000,
              'device': device,
              'du1': 110, 'du2': 90, 'du3': 70,
              'clf_lr': 1e-3, 'clf_reg': 1e-4, 're_lr': 1e-3, 're_reg': 1e-3}

### Example 3 Modality & Label Data
- Modality 1: 1000 x 1000
- Modality 2: 1000 x 2000
- Modality 3: 1000 x 3000
- Label: Binary Class

In [3]:
# Prepare Toy Example
modality_1 = np.random.rand(1000, 1000)
modality_2 = np.random.rand(1000, 2000)
modality_3 = np.random.rand(1000, 3000)
label = np.random.randint(2, size = 1000)

# Split Toy Example Dataset
index = np.arange(1000)
train_index, val_index = train_test_split(index, test_size=0.2, random_state=3)
train_index, test_index = train_test_split(train_index, test_size=0.25, random_state=3)
index = [train_index, val_index, test_index]

### Model Train & Test

In [4]:
# Model Train with 3 Modality
def train(modality_1, modality_2, modality_3, label, index, hyper_dict):
    # Index
    train_index, val_index, test_index = index
    
    # For Early Stopping
    early_stopping = EarlyStopping(patience=300, delta=0)
    best_loss = np.Inf
    net_best_model_sd = None
    clf_best_model_sd = None
    
    # Define Model
    net = DMF(modality_1, modality_2, modality_3,
              hyper_dict['du1'], hyper_dict['du2'], hyper_dict['device']).to(hyper_dict['device'])

    clf = Softmax_Classifier(hyper_dict['du2'], hyper_dict['du3']).to(hyper_dict['device'])

    # Optimizer
    individual_param_list, common_param = net.get_param()
    mf_individual_optimizer = torch.optim.Adam(individual_param_list, lr=hyper_dict['re_lr'], weight_decay=hyper_dict['re_reg'])
    mf_common_optimizer = torch.optim.Adam(common_param, lr=hyper_dict['clf_lr'], weight_decay=hyper_dict['clf_reg'])
    clf_optimizer = torch.optim.Adam(clf.parameters(), lr=hyper_dict['clf_lr'], weight_decay=hyper_dict['clf_reg'])

    # Model Train // Validation
    for i in range(hyper_dict['epoch']):
        # For Train
        net.train()
        clf.train()

        # Optimzer.zero_grad()
        mf_individual_optimizer.zero_grad()
        mf_common_optimizer.zero_grad()
        clf_optimizer.zero_grad()

        # Reconstruction
        m1_loss, m2_loss, m3_loss = net(modality_1, modality_2, modality_3)

        # Reconstruction G1 Grad of U
        m1_loss.backward()
        g1 = net.u.grad.clone()
        net.u.grad.zero_()

        # Reconstruction G2 Grad of U
        m2_loss.backward()
        g2 = net.u.grad.clone()
        net.u.grad.zero_()

        # Reconstruction G3 Grad of U
        m3_loss.backward()
        g3 = net.u.grad.clone()
        net.u.grad.zero_()

        # Classification CLF Grad of U
        if i > hyper_dict['patience']:
            # Over Sampling
            smote = SMOTE(random_state=0)
            u_over, y_train_over = smote.fit_resample(net.u[train_index].cpu().detach().numpy(), y_train)

            u_over = torch.tensor(u_over, requires_grad=True).float()
            u_over.retain_grad()

            # Softmax Classifier
            clf_loss = clf(u_over.to(hyper_dict['device']),
                           torch.tensor(y_train_over).to(hyper_dict['device'])).to(hyper_dict['device'])
            clf_loss.backward()
            
            # Indexing for backward
            g4 = torch.zeros_like(net.u.grad).to(hyper_dict['device'])
            g4[train_index] = u_over.grad[:len(train_index)].clone().to(hyper_dict['device'])
            net.u.grad.zero_()

            # Projection Grad => U[train_index] -> g1, g2, g3, g4, U[~train_index] -> g1, g2, g3
            proj_grad = project_conflicting([g1.flatten(), g2.flatten(), g3.flatten()], g4.flatten())
            proj_grad = proj_grad.reshape_as(net.u.grad)
            net.u.grad = proj_grad

            # optimizer.step()
            mf_individual_optimizer.step()
            mf_common_optimizer.step()
            clf_optimizer.step()

        else:
            mean_grad = (g1 + g2 + g3) / 3
            net.u.grad = mean_grad
            # optimizer.step()
            mf_individual_optimizer.step()
            mf_common_optimizer.step()
        
        # Check Train Performance
        if i % 100 == 0:
            prob, prediction = clf.predict(net.u[train_index])
            prob = prob.detach().cpu().numpy()
            prediction = prediction.detach().cpu().numpy()
            y_train = label[train_index]
            ba, f1, auc, mcc = calculate_metric(y_train, prediction, prob)
            if i == 0:
                print('Training: Epoch [{}/{}]\n\
                M1 Reconstruction Loss: {:.4f}, M2 Reconstruction Loss: {:.4f}, M3 Reconstruction Loss: {:.4f}'.format(
                    i, hyper_dict['epoch'], m1_loss.item(), m2_loss.item(), m3_loss.item()))
            else:
                print('Training: Epoch [{}/{}]\n\
                M1 Reconstruction Loss: {:.4f}, M2 Reconstruction Loss: {:.4f}, M3 Reconstruction Loss: {:.4f}\n\
                Classification Loss: {:.4f}'.format(i, hyper_dict['epoch'], m1_loss.item(), m2_loss.item(), m3_loss.item(), clf_loss.item()))
                print('Ba: {:.4f}, F1: {:.4f}, AUC: {:.4f} MCC: {:.4f}'.format(ba, f1, auc, mcc))

        # Model Validation
        net.eval()
        clf.eval()
        
        m1_loss, m2_loss, m3_loss = net(modality_1, modality_2, modality_3)
        clf_loss = clf(net.u[val_index], torch.tensor(label[val_index]).to(device))

        # Check Early Stopping
        early_stopping(clf_loss)
        if clf_loss < best_loss:
            net_best_model_sd = copy.deepcopy(net.state_dict())
            clf_best_model_sd = copy.deepcopy(clf.state_dict())
            best_loss = clf_loss

        # Early Stopping
        if early_stopping.early_stop:
            print('Early Stopping... Epoch [{}/{}]'.format(i, hyper_dict['epoch']))
            print('Best CLassification Loss: {:.4f}'.format(best_loss))
            break
        
        # Validation
        if i % 100 == 0:
            # Check Validation Performance
            prob, prediction = clf.predict(net.u[val_index])
            prob = prob.detach().cpu().numpy()
            prediction = prediction.detach().cpu().numpy()
            y_val = label[val_index]
            ba, f1, auc, mcc = calculate_metric(y_val, prediction, prob)
            if i == 0:
                print('Validation: Epoch [{}/{}]\n\
                M1 Reconstruction Loss: {:.4f}, M2 Reconstruction Loss: {:.4f}, M3 Reconstruction Loss: {:.4f}\n'.format(
                    i, hyper_dict['epoch'], m1_loss.item(), m2_loss.item(), m3_loss.item()))
            else:
                print('Validation: Epoch [{}/{}]\n\
                M1 Reconstruction Loss: {:.4f}, M2 Reconstruction Loss: {:.4f}, M3 Reconstruction Loss: {:.4f}\n\
                Classification Loss: {:.4f}'.format(
                i, hyper_dict['epoch'], m1_loss.item(), m2_loss.item(), m3_loss.item(), clf_loss.item()))
                print('Ba: {:.4f}, F1: {:.4f}, AUC: {:.4f} MCC: {:.4f}\n'.format(ba, f1, auc, mcc))
    
    # Test Model Performance
    net.load_state_dict(net_best_model_sd)
    net.eval()
    clf.load_state_dict(clf_best_model_sd)
    clf.eval()
    
    # Test
    prob, prediction = clf.predict(net.u[test_index])
    prob = prob.detach().cpu().numpy()
    prediction = prediction.detach().cpu().numpy()
    ba, f1, auc, mcc = calculate_metric(label[test_index], prediction, prob)
    print('\nTest Performace\nBa: {:.4f}, F1: {:.4f}, AUC: {:.4f} MCC: {:.4f}'.format(ba, f1, auc, mcc))

    return net, clf

### Example Code Train & Test

In [5]:
net, clf = train(modality_1, modality_2, modality_3, label, index, hyper_dict)

Training: Epoch [0/50000]
                M1 Reconstruction Loss: 0.3333, M2 Reconstruction Loss: 0.3337, M3 Reconstruction Loss: 0.3337
Validation: Epoch [0/50000]
                M1 Reconstruction Loss: 0.3333, M2 Reconstruction Loss: 0.3337, M3 Reconstruction Loss: 0.3337

Training: Epoch [100/50000]
                M1 Reconstruction Loss: 0.3333, M2 Reconstruction Loss: 0.3337, M3 Reconstruction Loss: 0.3337
                Classification Loss: 0.3169
Ba: 1.0000, F1: 1.0000, AUC: 1.0000 MCC: 1.0000
Validation: Epoch [100/50000]
                M1 Reconstruction Loss: 0.3333, M2 Reconstruction Loss: 0.3337, M3 Reconstruction Loss: 0.3337
                Classification Loss: 0.6918
Ba: 0.4814, F1: 0.7063, AUC: 0.5054 MCC: -0.0720

Training: Epoch [200/50000]
                M1 Reconstruction Loss: 0.3333, M2 Reconstruction Loss: 0.3337, M3 Reconstruction Loss: 0.3337
                Classification Loss: 0.3167
Ba: 1.0000, F1: 1.0000, AUC: 1.0000 MCC: 1.0000
Validation: Epoch [200/500