# Scanner-Generalization Neural Networks (SGNN)

In [None]:
import os
import gc
import time
import math
import pickle
import random
import itertools
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
from decimal import Decimal
from datetime import datetime as dt
from pytz import timezone
from sklearn.model_selection import ParameterGrid
from sklearn.metrics import balanced_accuracy_score

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable, Function
import torch.optim as optim
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
import warnings
warnings.filterwarnings('ignore')

## Assigning GPU for training

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

## Controling seed 

In [None]:
def seed_everything(seed=seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(seed)

## Fixed parameters for leave-one-site-out cross-validation (LOSOCV; $n_{test} = 18$)

In [None]:
tot_outer_cv = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
n_inner_valid = 5
output_folder = "SGNN"

## Loading dataset

In [None]:
data = np.load("RSFC_p_factor_ABCD.npz")
X = data["X"]
y = data["y"]

In [None]:
p_factor_idx = 0
site_idx = 1
scanner_idx = 2

## Preparing LOSOCV by spliting data 

In [None]:
# Spliting subject indices for leave-one-site-out validation set from two sites
site_unq = np.unique(y[:, site_idx])
data_idx = np.arange(y.shape[0])

outer_train_folds_idx = []
outer_test_folds_idx = []
inner_folds_idx = []

split_seed = 1000
n_outer_repeat = len(site_unq)
n_inner_repeat_list = []

# Outer loop
for n_outer, outer_test_site in enumerate(site_unq):
    outer_train_idx = np.where(y[:, site_idx] != outer_test_site)[0]
    outer_test_idx = np.where(y[:, site_idx] == outer_test_site)[0]
    outer_train_folds_idx.append(outer_train_idx)
    outer_test_folds_idx.append(outer_test_idx)
    
    outer_test_df = y_df.iloc[outer_test_idx]
    outer_train_df = y_df.iloc[outer_train_idx]
    valid_0_df = outer_train_df[outer_train_df["scanner"] == 0]
    valid_1_df = outer_train_df[outer_train_df["scanner"] == 1]
    valid_0_site_unq = pd.Series(np.unique(valid_0_df["site"]))
    valid_1_site_unq = pd.Series(np.unique(valid_1_df["site"]))

    inner_train_folds_idx = []
    inner_valid_folds_idx = []
    
    sel_val_seed = n_outer + split_seed
    inner_valid_site_list = list(
        itertools.product(valid_0_site_unq.values, valid_1_site_unq.values)
    )
    if n_inner_valid is None:
        n_inner_repeat = len(inner_valid_site_list) # Full combination
    else:
        n_inner_repeat = n_inner_valid
    n_inner_repeat_list.append(n_inner_repeat)
    
    inner_valid_site_df = pd.DataFrame(inner_valid_site_list, columns=["SI", "GE"]).sample(
        n=n_inner_repeat, replace=False, random_state=sel_val_seed
    )    
    # Inner loop
    for n_inner in range(n_inner_repeat):
        inner_valid_site_0 = inner_valid_site_df["SI"].values[n_inner]
        inner_valid_site_1 = inner_valid_site_df["GE"].values[n_inner]
        inner_valid_site = [inner_valid_site_0, inner_valid_site_1]
        inner_valid_cond_0 = (outer_train_df["site"] == inner_valid_site_0)
        inner_valid_cond_1 = (outer_train_df["site"] == inner_valid_site_1)
        inner_valid_df = outer_train_df[inner_valid_cond_0 | inner_valid_cond_1]

        inner_train_idx = np.setdiff1d(
            outer_train_df.index.values, inner_valid_df.index.values)
        inner_valid_idx = inner_valid_df.index.values
        print("[{}/{}] inner fold: train: {}, valid: {}".
              format(n_inner + 1, n_inner_repeat, len(inner_train_idx), len(inner_valid_idx)),
              end=", ")
        print("valid site: {}, {}".format(int(inner_valid_site[0]), int(inner_valid_site[1])))
        inner_train_folds_idx.append(inner_train_idx)
        inner_valid_folds_idx.append(inner_valid_idx)
        
    inner_folds_idx.append([inner_train_folds_idx, inner_valid_folds_idx])
    
    outer_test_scnr_label = np.unique(outer_test_df["scanner"])
    outer_train_scnr_label = np.unique(outer_train_df["scanner"])
    inner_valid_scnr_label = np.unique(inner_valid_df["scanner"])
    
    print("[{}/{}] outer fold: train: {}, test: {}"
          .format(n_outer + 1, len(site_unq), len(outer_train_idx), len(outer_test_idx)), 
          end=" --> ")
    print("outer test site: {}\n".format(int(outer_test_site)))

In [None]:
mode = "max"
lr_patience = 5
min_lr = 1e-08
lr_alpha = -1.5
lr_beta = 1.7

swa_lr = 5e-03
momentum = 0.90
l1_param = 0
early_stopping_patience = 150

input_dim = 61776
n_classes = len(np.unique(y[:, scanner_idx]))
output_prd_dim = 1
output_dsc_dim = n_classes

wsc_flag = [1, 1, 1]
beta_lr = [1e-04, 1e-03, 1e-03]
max_beta = [1e-02, 5e-02, 5e-02]
n_wsc = wsc_flag.count(1)

outer_n_splits = n_outer_repeat
inner_n_splits = n_inner_repeat

In [None]:
# Training dataset
class TrainDataset(Dataset): 
    def __init__(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train
        
    def __len__(self):
        return len(self.X_train)
    
    def __getitem__(self, idx): 
        X_train = torch.from_numpy(self.X_train[idx]).type(torch.FloatTensor)
        y_train = torch.from_numpy(self.y_train[idx]).type(torch.FloatTensor)

        return X_train, y_train

In [None]:
# Test dataset
class ValidDataset(Dataset): 
    def __init__(self, X_valid, y_valid):
        self.X_valid = X_valid
        self.y_valid = y_valid
        
    def __len__(self):
        return len(self.X_valid)
    
    def __getitem__(self, idx): 
        X_valid = torch.from_numpy(self.X_valid[idx]).type(torch.FloatTensor)
        y_valid = torch.from_numpy(self.y_valid[idx]).type(torch.FloatTensor)
        
        return X_valid, y_valid

In [None]:
# Test dataset
class TestDataset(Dataset): 
    def __init__(self, X_test, y_test):
        self.X_test = X_test
        self.y_test = y_test
        
    def __len__(self):
        return len(self.X_test)
    
    def __getitem__(self, idx): 
        X_test = torch.from_numpy(self.X_test[idx]).type(torch.FloatTensor)
        y_test = torch.from_numpy(self.y_test[idx]).type(torch.FloatTensor)
        
        return X_test, y_test

# Function for gradient reversal layer (Ganin et al., 2015)

In [None]:
class GradRevFunc(Function):

    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()

    @staticmethod
    def backward(ctx, grads):
        lambda_ = ctx.lambda_
        lambda_ = grads.new_tensor(lambda_)
        dx = lambda_ * grads.neg()
        return dx, None

In [None]:
class GradRev(torch.nn.Module):
    def __init__(self, lambda_=0.0):
        super(GradRev, self).__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        return GradRevFunc.apply(x, self.lambda_)

# Defining SGNN

In [None]:
class SGNN(nn.Module):
    def __init__(self, fe_hidden, prd_hidden, dsc_hidden,
                 dropout_fe, dropout_prd, dropout_dsc, act_func_name, lambda_):
        super(SGNN, self).__init__()
        self.fe_1 = nn.Linear(input_dim, fe_hidden)
        self.fe_bn_1 = nn.BatchNorm1d(fe_hidden)
        
        self.prd_1 = nn.Linear(fe_hidden, prd_hidden)
        self.prd_bn_1 = nn.BatchNorm1d(prd_hidden)
        self.prd_2 = nn.Linear(prd_hidden, output_prd_dim)
        
        self.dsc_1 = nn.Linear(fe_hidden, dsc_hidden)
        self.dsc_bn_1 = nn.BatchNorm1d(dsc_hidden)
        self.dsc_2 = nn.Linear(dsc_hidden, output_dsc_dim)

        self.dropout_fe = nn.Dropout(p=dropout_fe)
        self.dropout_prd = nn.Dropout(p=dropout_prd)
        self.dropout_dsc = nn.Dropout(p=dropout_dsc)
        
        self.act_func = get_act_func(act_func_name)
        self.GradRev = GradRev(lambda_)
        self.weights_init()
    
    def forward(self, x):
        x_ftr = self.fe_1(x)
        x_ftr = self.fe_bn_1(x_ftr)
        x_ftr = self.act_func(x_ftr)
        x_ftr = self.dropout_fe(x_ftr)
        
        x_prd = self.prd_1(x_ftr)
        x_prd = self.prd_bn_1(x_prd)
        x_prd = self.act_func(x_prd)
        x_prd = self.dropout_prd(x_prd)
        x_prd = self.prd_2(x_prd)
        
        x_rev = self.GradRev(x_ftr)
        x_dsc = self.dsc_1(x_rev)
        x_dsc = self.dsc_bn_1(x_dsc)
        x_dsc = self.act_func(x_dsc)
        x_dsc = self.dropout_dsc(x_dsc)
        x_dsc = self.dsc_2(x_dsc)
        
        return x_prd, x_dsc
    
    def weights_init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
                nn.init.normal_(m.bias, std=0.01)

In [None]:
def get_optimizer(model, opt_name, learning_rate=None, l2_param=None):
    lower_opt_name = opt_name.lower()
    if lower_opt_name == 'momentum':
        return optim.SGD(model.parameters(), lr=learning_rate, 
                         momentum=momentum, weight_decay=l2_param)
    elif lower_opt_name == 'nag':
        return optim.SGD(model.parameters(), lr=learning_rate, 
                         momentum=momentum, weight_decay=l2_param, nesterov=True)
    elif lower_opt_name == 'adam':
        return optim.Adam(model.parameters(), lr=learning_rate, 
                          weight_decay=l2_param)
    else:
        sys.exit("Illegal arguement for optimizer type")

In [None]:
def get_act_func(act_func_name):
    act_func_name = act_func_name.lower()
    if act_func_name == 'relu':
        return nn.ReLU()
    elif act_func_name == 'prelu':
        return nn.PReLU()
    elif act_func_name == 'elu':
        return nn.ELU()
    elif act_func_name == 'silu':
        return nn.SiLU()
    elif act_func_name == 'leakyrelu':
        return nn.LeakyReLU()
    elif act_func_name == 'tanh':
        return nn.Tanh()
    else:
        sys.exit("Illegal arguement for activation function type")

# Functions for weight sparsity control with Hoyer's sparsness

In [None]:
def init_hsp(n_wsc, epochs):
    hsp_val = torch.zeros(n_wsc)
    beta_val = torch.clone(hsp_val)
    hsp_list = torch.zeros((n_wsc, epochs))
    beta_list = torch.zeros((n_wsc, epochs))
    
    return hsp_val, beta_val, hsp_list, beta_list

In [None]:
# Weight sparsity control with Hoyer's sparsness (Layer wise)
def calc_hsp(w, beta, max_beta, beta_lr, tg_hsp):
    
    # Get value of weight
    [dim, n_nodes] = w.shape
    num_elements = dim * n_nodes
    norm_ratio = torch.norm(w.detach(), 1) / torch.norm(w.detach(), 2)

    # Calculate hoyer's sparsity level
    num = math.sqrt(num_elements) - norm_ratio
    den = math.sqrt(num_elements) - 1
    hsp = torch.tensor(num / den).to(device)

    # Update beta
    beta = beta.clone() + beta_lr * torch.sign(torch.tensor(tg_hsp).to(device) - hsp)
    
    # Trim value
    beta = 0 if beta < 0 else beta
    beta = max_beta if beta > max_beta else beta

    return [hsp, beta]

In [None]:
def calc_l1(model, epoch, hsp_val, beta_val, hsp_list, beta_list, tg_hsp):
    l1_reg = None
    layer_idx = 0
    wsc_idx = 0

    for name, param in model.named_parameters():
        if "weight" in name and "bn" not in name:
            if "fe" in name or "prd_1" in name or "dsc_1" in name:
                temp_w = param
                
                if wsc_flag[layer_idx] != 0:
                    hsp_val[wsc_idx], beta_val[wsc_idx] = calc_hsp(
                        temp_w, beta_val[wsc_idx], max_beta[wsc_idx], 
                        beta_lr[wsc_idx], tg_hsp[wsc_idx]
                    )
                    hsp_list[wsc_idx, epoch - 1] = hsp_val[wsc_idx]
                    beta_list[wsc_idx, epoch - 1] = beta_val[wsc_idx]
                    layer_reg = torch.norm(temp_w, 1) * beta_val[wsc_idx].clone()
                    wsc_idx += 1
                else:
                    layer_reg = torch.norm(temp_w, 1) * l1_param

                if l1_reg is None:
                    l1_reg = layer_reg
                else:
                    l1_reg = l1_reg + layer_reg
                layer_idx += 1
        
    return l1_reg

In [None]:
def calc_pearsonr(x, y):
    x_mean = torch.mean(x)
    y_mean = torch.mean(y)
    xx = x.sub(x_mean)
    yy = y.sub(y_mean)
    num = xx.dot(yy)
    den = torch.norm(xx, 2) * torch.norm(yy, 2)
    corr = num / den
    return corr

In [None]:
def calc_mae(x, y):
    return torch.abs(x - y).mean().data

In [None]:
def train(model, epoch, train_loader, optimizer, criterion_prd, criterion_dsc, 
          hsp_val, beta_val, hsp_list, beta_list, tg_hsp, lambda_, X_train, y_train):
    model.train()
    prd_loss = 0
    dsc_loss = 0
    dsc_acc = 0
    cost = 0
    total = 0
    correct = 0
    y_train_true = []
    y_train_pred = []
    
    for batch_idx, (input, target) in enumerate(train_loader):
        optimizer.zero_grad(set_to_none=True)
        input, target = input.to(device), target.to(device)
        output_prd, output_dsc = model(input)
        
        # 1. p-factor predictor loss
        target_prd = target[:, p_factor_idx].view(-1, 1)
        running_prd_loss = criterion_prd(output_prd, target_prd)
        l1_norm = calc_l1(model, epoch, hsp_val, beta_val, hsp_list, beta_list, tg_hsp)

        if epoch > pretrain_epoch:
            # Undersampling for scanner-generalization
            target_dsc = target[:, scanner_idx].long().view(-1)
            scnr_smp = target[:, scanner_idx].detach().cpu().numpy()
            n_minor = (scnr_smp == 0).sum()
            n_major = len(scnr_smp) - n_minor
            if n_minor != 0 and n_major != 0:
                minor_idx = np.where(scnr_smp == 0)[0]
                major_idx = np.where(scnr_smp != 0)[0]
                major_smp_idx = np.random.choice(major_idx, size=n_minor, replace=True)
                smp_idx = np.concatenate((minor_idx.astype(np.int), major_smp_idx.astype(np.int)))
                running_dsc_loss = criterion_dsc(output_dsc[smp_idx], target_dsc[smp_idx])
                dsc_loss += running_dsc_loss.detach()
            else:
                running_dsc_loss = 0
                dsc_loss += 0

            # Total Loss
            running_loss = running_dsc_loss + running_prd_loss + l1_norm.clone()
            
        else:
            running_loss = running_prd_loss + l1_norm.clone()

        cost = running_loss
        cost.backward()
        optimizer.step()
        
        prd_loss += running_prd_loss.detach()
        total += output_prd.size(0)
        true_batch = torch.flatten(target_prd.detach())
        pred_batch = torch.flatten(output_prd.detach())
        y_train_true.append(true_batch)
        y_train_pred.append(pred_batch)
        
    X_train = torch.from_numpy(X_train).type(torch.FloatTensor).to(device)
    _, output_dsc = model(X_train)
    _, scnr_pred = torch.max(output_dsc.data, 1)
    scnr_pred = scnr_pred.detach().cpu().numpy().ravel()
    scnr_true = y_train[:, scanner_idx].ravel()
    dsc_acc = balanced_accuracy_score(scnr_true, scnr_pred)
    
    prd_loss /= total
    dsc_loss /= total
    y_train_true = torch.flatten(torch.stack(y_train_true))
    y_train_pred = torch.flatten(torch.stack(y_train_pred))
    train_corr = calc_pearsonr(y_train_true, y_train_pred)
    train_mae = calc_mae(y_train_true, y_train_pred)
    torch.cuda.empty_cache()

    return prd_loss, dsc_loss, dsc_acc, train_corr, train_mae

In [None]:
def valid(model, epoch, valid_loader, criterion_prd, criterion_dsc, X_valid, y_valid):
    model.eval()
    prd_loss = 0
    dsc_loss = 0
    dsc_acc = 0
    correct = 0
    total = 0
    y_valid_true = []
    y_valid_pred = []
    
    with torch.no_grad():
        for input, target in valid_loader:
            input, target = input.to(device), target.to(device)
            output_prd, output_dsc = model(input)
            target_dsc = target[:, scanner_idx].long().view(-1)
            target_prd = target[:, p_factor_idx].view(-1, 1)
            running_dsc_loss = criterion_dsc(output_dsc, target_dsc)
            running_prd_loss = criterion_prd(output_prd, target_prd)
            dsc_loss += running_dsc_loss.detach()
            prd_loss += running_prd_loss.detach()
            total += output_prd.size(0)
            _, pred = torch.max(output_dsc.data, 1)
            correct += (pred.view(-1, 1) == target).sum().detach()
            true_batch = torch.flatten(target_prd.detach())
            pred_batch = torch.flatten(output_prd.detach())
            y_valid_true.append(true_batch)
            y_valid_pred.append(pred_batch)

    X_valid = torch.from_numpy(X_valid).type(torch.FloatTensor).to(device)
    _, output_dsc = model(X_valid)
    _, scnr_pred = torch.max(output_dsc.data, 1)
    scnr_pred = scnr_pred.detach().cpu().numpy().ravel()
    scnr_true = y_valid[:, scanner_idx].ravel()
    dsc_acc = balanced_accuracy_score(scnr_true, scnr_pred)

    y_valid_true = torch.flatten(torch.stack(y_valid_true))
    y_valid_pred = torch.flatten(torch.stack(y_valid_pred))
    valid_corr = calc_pearsonr(y_valid_true, y_valid_pred)
    valid_mae = calc_mae(y_valid_true, y_valid_pred)
    torch.cuda.empty_cache()

    return prd_loss, dsc_loss, dsc_acc, valid_corr, valid_mae

In [None]:
def test(model, epoch, test_loader, criterion_prd, criterion_dsc, X_test, y_test):
    model.eval()
    prd_loss = 0
    total = 0
    y_test_true = []
    y_test_pred = []
    
    with torch.no_grad():
        for input, target in test_loader:
            input, target = input.to(device), target.to(device)
            output_prd, output_clf = model(input)
            target_prd = target[:, p_factor_idx].view(-1, 1)
            running_prd_loss = criterion_prd(output_prd, target_prd)
            prd_loss += running_prd_loss.detach()
            total += output_prd.size(0)
            true_batch = torch.flatten(target_prd.detach())
            pred_batch = torch.flatten(output_prd.detach())
            y_test_true.append(true_batch)
            y_test_pred.append(pred_batch)

    X_test = torch.from_numpy(X_test).type(torch.FloatTensor).to(device)
    _, output_dsc = model(X_test)
    _, scnr_pred = torch.max(output_dsc.data, 1)
    scnr_pred = scnr_pred.detach().cpu().numpy().ravel()
    scnr_true = y_test[:, scanner_idx].ravel()
    dsc_acc = balanced_accuracy_score(scnr_true, scnr_pred)

    y_test_true = torch.flatten(torch.stack(y_test_true))
    y_test_pred = torch.flatten(torch.stack(y_test_pred))
    test_corr = calc_pearsonr(y_test_true, y_test_pred)
    test_mae = calc_mae(y_test_true, y_test_pred)
    torch.cuda.empty_cache()

    return prd_loss, test_corr, test_mae, dsc_acc

In [None]:
def plot_learning_curves(
    save_dir, epochs, train_loss, valid_loss,  
    train_corr, valid_corr, train_acc, valid_acc, lr,
    plot_hsp_list, plot_beta_list, tg_hsp):
    
    sns.set(style="whitegrid", font_scale=2)
    fig, ax = plt.subplots(2, 3, figsize=(28, 10))
    ax = ax.flat
    lw = 3.5
    last_epoch = epochs
    
    train_loss, valid_loss = np.array(train_loss), np.array(valid_loss)
    
    ax[0].plot(train_loss[:last_epoch, 0], label='train prd loss', lw=lw, color="r")
    ax[0].plot(valid_loss[:last_epoch, 0], label='valid prd loss', lw=lw, color="g")
    ax[0].legend()
    ax[0].set_title("Predictor Loss Plot", pad=20)

    ax[1].plot(train_loss[:last_epoch, 1], label='train dsc loss', lw=lw, color="r")
    ax[1].legend()
    ax[1].set_title("Discriminator Loss Plot", pad=20)

    ax[2].plot(train_acc[:last_epoch], label='train dsc acc', lw=lw, color="r")
    ax[2].legend()
    ax[2].set_title("Discriminator Accuracy Plot", pad=20)

    ax[3].plot(train_corr[:last_epoch], label='train corr', lw=lw, color="r")
    ax[3].plot(valid_corr[:last_epoch], label='valid corr', lw=lw, color="g")
    ax[3].legend()
    ax[3].set_title("Correlation Plot ($r$={:.4f})".format(valid_corr[-1]), pad=20)

    plot_hsp_list, plot_beta_list = np.array(plot_hsp_list).T, np.array(plot_beta_list).T
    
    for idx, n_layer in enumerate(indices):
        ax[4].plot(plot_hsp_list[idx], label='layer{}'.format(n_layer), lw=lw)
        ax[5].plot(plot_beta_list[idx], 
                   label='layer{}'.format(n_layer), lw=lw)
        ax[4].legend(); ax[5].legend()
        ax[4].set_title("HSP plot [{:.3f}/{:.3f}]"
                        .format(plot_hsp_list[0, -1], tg_hsp[0][0]), pad=20)
        ax[5].set_title("Beta plot", pad=20)
    
    fig.tight_layout()
    fig.savefig("{}/Learning_curves.png".format(save_dir))
    
    plt.close(fig)

In [None]:
def run_inner_fold(
    n_outer_cv, output_save_dir=None, cur_tg_hsp=None, cur_lambda=None):
    inner_cv_list = []
    
    inner_train_folds_idx = inner_folds_idx[n_outer_cv][0]
    inner_valid_folds_idx = inner_folds_idx[n_outer_cv][1]    
    
    for n_inner_cv in range(inner_n_splits):
        
        print("\n===================================", end=" ")
        print("Inner Fold [{}/{}]".format(n_inner_cv + 1, inner_n_splits), end=" ")
        print("===================================")

        inner_start_fold_time = time.time()
        inner_save_dir = "{}/Inner_fold_{}".format(output_save_dir, n_inner_cv + 1)
        os.makedirs(inner_save_dir, exist_ok=True)

        inner_train_idx = inner_train_folds_idx[n_inner_cv]
        inner_valid_idx = inner_valid_folds_idx[n_inner_cv]

        X_train, y_train = X[inner_train_idx], y[inner_train_idx]
        X_valid, y_valid = X[inner_valid_idx], y[inner_valid_idx]

        X_train = stats.zscore(X_train, axis=1)
        X_valid = stats.zscore(X_valid, axis=1)

        inner_train_dataset = TrainDataset(X_train, y_train)
        inner_valid_dataset = ValidDataset(X_valid, y_valid)

        inner_train_loader = DataLoader(
            inner_train_dataset, batch_size=batch_size, pin_memory=True,
            shuffle=True, num_workers=num_workers, drop_last=True)
        inner_valid_loader = DataLoader(
            inner_valid_dataset, batch_size=len(y_valid), pin_memory=True,
            shuffle=True, num_workers=num_workers, drop_last=True)

        # Assign model
        model = SGNN(
            fe_hidden, dsc_hidden, prd_hidden, 
            dropout_fe, dropout_prd, dropout_dsc, act_func_name, cur_lambda
        ).to(device)
        optimizer = get_optimizer(model, optimizer_name, learning_rate, l2_param)
        lr_factor = lr_alpha * cur_tg_hsp[0][0] + lr_beta
        scheduler = ReduceLROnPlateau(
            optimizer, mode=mode, patience=lr_patience, min_lr=min_lr, factor=lr_factor
        )
        criterion_prd = nn.MSELoss()
        criterion_dsc = nn.CrossEntropyLoss()

        # list to save learning parameters
        inner_train_loss = []
        inner_valid_loss = []
        inner_train_corr = []
        inner_valid_corr = []
        inner_train_acc = []
        inner_valid_acc = []
        inner_lr = []
        inner_hsp_list = []
        inner_beta_list = []

        hsp_val, beta_val, hsp_list, beta_list = init_hsp(n_wsc, epochs)

        for epoch in range(1, epochs + 1):
            train_prd_loss, train_dsc_loss, train_acc, train_corr, train_mae = train(
                model, epoch, inner_train_loader, 
                optimizer, criterion_prd, criterion_dsc, 
                hsp_val, beta_val, hsp_list, beta_list, cur_tg_hsp, cur_lambda,
                X_train, y_train
            )
            valid_prd_loss, valid_dsc_loss, valid_acc, valid_corr, valid_mae = valid(
                model, epoch, inner_valid_loader, criterion_prd, criterion_dsc,
                X_valid, y_valid
            )

            scheduler.step(hsp_val[0])
            lr = optimizer.param_groups[0]['lr']

            inner_train_loss.append([train_prd_loss, train_dsc_loss])
            inner_train_corr.append(train_corr)
            inner_train_acc.append(train_acc)
            inner_valid_loss.append([valid_prd_loss, valid_dsc_loss])
            inner_valid_corr.append(valid_corr)
            inner_valid_acc.append(valid_acc)
            inner_lr.append(lr)
            inner_hsp_list.append(list(hsp_val.clone()))
            inner_beta_list.append(list(beta_val.clone()))

            if epoch % print_epoch == 0:
                print("\nEpoch [{:d}/{:d}]".format(epoch, epochs), end=" ")
                print("Train corr: {:.4f}, Valid corr: {:.4f}, Train loss: {:.4f}, Valid loss: {:.4f}"
                      .format(train_corr, valid_corr, train_prd_loss, valid_prd_loss))
                for i in range(len(wsc_flag)):
                    if wsc_flag[i] != 0:
                        print("Layer {:d}: [{:.4f}/{:.4f}]".
                              format(i + 1, hsp_val[i], cur_tg_hsp[i][0]), end=" ")
                # print("\nCurrent learning rate: {:.2e}".format(Decimal(str(lr))))
                print("Train acc: {:.2f}".format(train_acc))

                plot_learning_curves(
                    inner_save_dir, epochs, inner_train_loss, inner_valid_loss,
                    inner_train_corr, inner_valid_corr, 
                    inner_train_acc, inner_valid_acc, 
                    inner_lr, inner_hsp_list, inner_beta_list, cur_tg_hsp
                )
        
        train_prd_loss, train_dsc_loss, train_acc, train_corr, train_mae = valid(
            model, epoch, inner_train_loader, criterion_prd, criterion_dsc, X_train, y_train
        )
        print("\nInner Fold [{}/{}] train corr: {:.4f}, valid corr: {:.4f}"
              .format(n_inner_cv + 1, inner_n_splits, train_corr, valid_corr))
        
        if n_outer_cv == 0:
            torch.save(
                model.state_dict(), inner_save_dir + "/inner_model_fold_" + str(n_inner_cv + 1) + ".pt")
        torch.cuda.empty_cache()
        gc.collect()
        
        inner_tot_time = (time.time() - inner_start_fold_time) / 60
        print("Execution Time for Fold: {:.2f} mins".format(inner_tot_time))
        inner_cv_list.append(
            [train_corr.cpu().numpy(), valid_corr.cpu().numpy(),
             train_mae.cpu().numpy(), valid_mae.cpu().numpy(),
             train_acc, valid_acc]
        )
            
    inner_cv_df = pd.DataFrame(
        np.array(inner_cv_list), 
        columns=["train_corr", "valid_corr", "train_mae", "valid_mae", "train_acc", "valid_acc"]
    )
    
    avg_train_corr = inner_cv_df["train_corr"].mean()
    avg_valid_corr = inner_cv_df["valid_corr"].mean()
    avg_train_mae = inner_cv_df["train_mae"].mean()
    avg_valid_mae = inner_cv_df["valid_mae"].mean()
    
    inner_cv_df.to_csv("{}/inner_cv.csv".format(output_save_dir))

    return avg_train_corr, avg_valid_corr, avg_train_mae, avg_valid_mae

In [None]:
def run_outer_fold(n_outer_cv=0, outer_save_dir=None, sel_tg_hsp=None, sel_lambda=None):
    
    outer_cv_list = []
    
    # Outer fold
    print("\n===================================", end=" ")
    print("Outer Fold [{}/{}]".format(n_outer_cv + 1, outer_n_splits), end=" ")
    print("===================================")
    
    outer_start_fold_time = time.time()
    outer_train_idx = outer_train_folds_idx[n_outer_cv]
    outer_test_idx = outer_test_folds_idx[n_outer_cv]

    X_train, y_train = X[outer_train_idx], y[outer_train_idx]
    X_test, y_test = X[outer_test_idx], y[outer_test_idx]
    
    X_train = stats.zscore(X_train, axis=1)
    X_test = stats.zscore(X_test, axis=1)
        
    outer_train_dataset = TrainDataset(X_train, y_train)
    outer_test_dataset = TestDataset(X_test, y_test)
    
    outer_train_loader = DataLoader(
        outer_train_dataset, batch_size=batch_size, pin_memory=True,
        shuffle=True, num_workers=num_workers, drop_last=True)
    outer_test_loader = DataLoader(
        outer_test_dataset, batch_size=len(y_test), pin_memory=True,
        shuffle=True, num_workers=num_workers, drop_last=True)
        
    # Assign model 
    model = SGNN(
        fe_hidden, dsc_hidden, prd_hidden, 
        dropout_fe, dropout_prd, dropout_dsc, act_func_name, sel_lambda
    ).to(device)
    optimizer = get_optimizer(model, optimizer_name, learning_rate, l2_param)
    lr_factor = lr_alpha * sel_tg_hsp[0][0] + lr_beta
    scheduler = ReduceLROnPlateau(
        optimizer, mode=mode, patience=lr_patience, min_lr=min_lr, factor=lr_factor
    )
    criterion_prd = nn.MSELoss()
    criterion_dsc = nn.CrossEntropyLoss()
              
    # list to save learning parameters
    outer_train_loss = []
    outer_test_loss = []
    outer_train_corr = []
    outer_test_corr = []
    outer_train_acc = []
    outer_test_acc = []
    outer_lr = []
    outer_hsp_list = []
    outer_beta_list = []

    hsp_val, beta_val, hsp_list, beta_list = init_hsp(n_wsc, epochs)
        
    for epoch in range(1, epochs + 1):
        train_prd_loss, train_dsc_loss, train_acc, train_corr, train_mae = train(
            model, epoch, outer_train_loader, 
            optimizer, criterion_prd, criterion_dsc, 
            hsp_val, beta_val, hsp_list, beta_list, sel_tg_hsp, sel_lambda,
            X_train, y_train
        )
        test_prd_loss, test_corr, test_mae, test_acc = test(
            model, epoch, outer_test_loader, criterion_prd, criterion_dsc, X_test, y_test
        )

        scheduler.step(hsp_val[0])
        lr = optimizer.param_groups[0]['lr']
        
        outer_train_loss.append([train_prd_loss, train_dsc_loss])
        outer_train_corr.append(train_corr)
        outer_train_acc.append(train_acc)
        outer_test_loss.append([test_prd_loss, []])
        outer_test_corr.append(test_corr)
        outer_test_acc.append(test_acc)
        outer_lr.append(lr)
        outer_hsp_list.append(list(hsp_val.clone()))
        outer_beta_list.append(list(beta_val.clone()))

        if epoch % print_epoch == 0:
            print("\nEpoch [{:d}/{:d}]".format(epoch, epochs), end=" ")
            print("Train corr: {:.4f}, Test corr: {:.4f}, Train loss: {:.4f}, Test loss: {:.4f}"
                  .format(train_corr, test_corr, train_prd_loss, test_prd_loss))
            for i in range(len(wsc_flag)):
                if wsc_flag[i] != 0:
                    print("Layer {:d}: [{:.4f}/{:.4f}]".
                          format( i + 1, hsp_val[i], sel_tg_hsp[i][0]), end=" ")
            print("Train acc: {:.2f}".format(train_acc))

            plot_learning_curves(
                outer_save_dir, epochs, outer_train_loss, outer_test_loss,  
                outer_train_corr, outer_test_corr, 
                outer_train_acc, outer_test_acc, 
                outer_lr, outer_hsp_list, outer_beta_list, sel_tg_hsp
            )
    
    train_prd_loss, train_corr, train_mae, train_acc = test(
        model, epoch, outer_train_loader, criterion_prd, criterion_dsc, X_train, y_train
    )
    torch.save(model.state_dict(), 
               outer_save_dir + "/model_fold_" + str(n_outer_cv + 1) + ".pt")
    
    torch.cuda.empty_cache()
    gc.collect()
    
    train_corr = train_corr.cpu().numpy()
    test_corr = test_corr.cpu().numpy()
    train_mae = train_mae.cpu().numpy()
    test_mae = test_mae.cpu().numpy()

    outer_cv_list.append([train_corr, test_corr, train_mae, test_mae, train_acc, test_acc])
            
    outer_cv_df = pd.DataFrame(
        np.array(outer_cv_list), 
        columns=["train_corr", "valid_corr", "train_mae", "valid_mae", "train_acc", "valid_acc"]
    )
    outer_cv_df.to_csv("{}/outer_cv.csv".format(outer_save_dir))

    outer_tot_time = time.time() - outer_start_fold_time
    print("\nExecution Time for Fold: {:.2f} mins".format(outer_tot_time / 60))
    
    return train_corr, test_corr, train_mae, test_mae, outer_train_acc, outer_train_corr, outer_hsp_list

# Parameter for SGNN training

In [None]:
act_func_name = "elu"
optimizer_name = "nag"

fe_hidden = 1024
pp_hidden = 1024
scd_hidden = 1024

dropout_fe = 0.9
dropout_pp = 0.9
dropout_scd = 0.9

batch_size = 32
learning_rate = 5e-05
epochs = 150
pretrain_epoch = 40

l2_param = 5e-02

# Hyperparemter for optimizaiton with Nested CV

In [None]:
hsp_fe_cand = [0.99, 0.98, 0.975, 0.9, 0.8, 0.5]
lambda_cand = [0, 0.002, 0.005, 0.01, 0.02]

param_cand = {"lambda_": lambda_cand, "hsp_fe": hsp_fe_cand}

In [None]:
print_epoch = 150

In [None]:
code_start_time = time.time()

# Training SGNN with LOSOCV framework

In [None]:
outer_cv = []

param_grid = list(ParameterGrid(param_cand))

for n_outer_cv in outer_cv_part:
    print("\n===================================", end=" ")
    print("Outer Fold [{}/{}]".format(n_outer_cv + 1, outer_n_splits), end=" ")
    print("===================================")

    outer_save_dir = "{}/Outer_fold_{}".format(output_folder, n_outer_cv + 1)
    model_file = "model_fold_{}.pt".format(n_outer_cv + 1)
    model_path = os.path.join(outer_save_dir, model_file)
    
    os.makedirs(outer_save_dir, exist_ok=True)
    
    inner_cv = []
    temp_inner_cv_corr = []
    temp_inner_cv_mae = []

    # Inner Fold
    for param_idx, cur_param in enumerate(param_grid):
        print("\n===================================", end=" ")
        print("Param Cand [{}/{}]".format(param_idx + 1, len(param_grid)), end=" ")
        print("===================================")

        hsp_cand_1 = [cur_param["1_hsp_fe"]]
        hsp_cand_2 = [cur_param["2_hsp_prd"]]
        hsp_cand_3 = [cur_param["3_hsp_dsc"]]

        indices = [i + 1 for i, x in enumerate(wsc_flag) if x == 1]
        hsp_cand_list = list(itertools.product(hsp_cand_1, hsp_cand_2, hsp_cand_3))
        hsp_cand_list = [list(i) for i in hsp_cand_list]
        hsp_cand = [hsp_cand_1, hsp_cand_2, hsp_cand_3]

        cur_tg_hsp = hsp_cand
        cur_lambda = cur_param["lambda_"]

        print("Param:", end=" ")
        for i, param in enumerate(cur_param):
            if "hsp" in param or "lambda" in param: 
                print("{}: {}".format(param, cur_param[param]), end=" ")
        print("")
        
        cur_param_name = "FE_{}_LMD_{}".format(
            cur_tg_hsp[0][0], cur_lambda) 
        param_save_dir = "{}/{}".format(outer_save_dir, cur_param_name)
        
        final_inner_cv_df_path = os.path.join(param_save_dir, "inner_cv.csv")
        
        if os.path.exists(final_inner_cv_df_path):
            final_inner_cv_df = pd.read_csv(final_inner_cv_df_path, index_col=0)
            temp_inner_cv_corr.append(final_inner_cv_df["valid_corr"].mean())
            temp_inner_cv_mae.append(final_inner_cv_df["valid_mae"].mean())
            print("Training Inner CV Done")
        else:
            os.makedirs(param_save_dir, exist_ok=True)
            inner_train_corr, inner_valid_corr, inner_train_mae, inner_valid_mae = run_inner_fold(
                n_outer_cv, param_save_dir, cur_tg_hsp, cur_lambda
            )
            temp_inner_cv_corr.append([inner_valid_corr])
            temp_inner_cv_mae.append([inner_valid_mae])

            print("\nParam Cand: [{}/{}] train corr: {:.4f}, valid corr: {:.4f}"
                  .format(param_idx + 1, len(param_grid), inner_train_corr, inner_valid_corr))

    sel_idx = np.argmin(temp_inner_cv_mae)
    sel_param = param_grid[sel_idx]
    sel_hsp = []
    sel_lambda = sel_param["lambda_"]
    print("Selected param:", end=" ")
    for x in sel_param:
        if "hsp" in x or "lambda" in x: 
            print("{}: {}".format(x, sel_param[x]), end=" ")
            sel_hsp.append(sel_param[x])
    
    # Outer Fold
    hsp_cand = [sel_param["1_hsp_fe"]]

    indices = [i + 1 for i, x in enumerate(wsc_flag) if x == 1]
    hsp_cand_list = list(itertools.product(hsp_cand))
    hsp_cand_list = [list(i) for i in hsp_cand_list]
    hsp_cand = [hsp_cand_1, hsp_cand_2, hsp_cand_3]
    sel_tg_hsp = hsp_cand

    (outer_train_corr, outer_test_corr, outer_train_mae, 
    outer_test_mae, outer_train_acc, outer_test_acc, outer_hsp_list) = run_outer_fold(
        n_outer_cv, outer_save_dir, sel_tg_hsp, sel_lambda
    )
    outer_cv.append([sel_hsp, outer_train_corr, outer_test_corr])
    
    print("\nOuter Fold [{}/{}]: train corr: {:.4f}, valid corr: {:.4f}"
          .format(n_outer_cv + 1, outer_n_splits, outer_train_corr, outer_test_corr))

In [None]:
code_tot_time = time.time() - code_start_time 
print("Execution Time for the training: {:.2f} hours".format(code_tot_time / 60 / 60))