In [1]:
GPU = 7
num_workers = 4
seed = 1000

In [3]:
hsp_extr_cand = [0.9]
hsp_pred_cand = [0.7, 0.5, 0.3]
hsp_disc_cand = [0.7, 0.5, 0.3]

wsc_flag = [1, 1, 1]
beta_lr = [1e-04, 1e-03, 1e-03]
max_beta = [1e-02, 5e-02, 5e-02]
n_wsc = wsc_flag.count(1)

hsp_cand_1 = [cur_param["1_hsp_extr"]]
hsp_cand_2 = [cur_param["2_hsp_pred"]]
hsp_cand_3 = [cur_param["3_hsp_disc"]]

In [4]:
import os
import gc
import time
import pickle
import random
import itertools
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
from decimal import Decimal
from datetime import datetime as dt
from pytz import timezone

In [5]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable, Function
import torch.optim as optim
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import ReduceLROnPlateau 
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

In [6]:
import warnings
warnings.filterwarnings('ignore')

In [7]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU)

In [8]:
def seed_everything(seed=seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed)

In [9]:
nowtime = dt.now(timezone("Asia/Seoul")); year = str(nowtime.year)
month = '0{}'.format(nowtime.month) if nowtime.month < 10 else str(nowtime.month)
day = '0{}'.format(nowtime.day) if nowtime.day < 10 else str(nowtime.day)
hour = '0{}'.format(nowtime.hour) if nowtime.hour < 10 else str(nowtime.hour)
minute = '0{}'.format(nowtime.minute) if nowtime.minute < 10 else str(nowtime.minute)
sec = '0{}'.format(nowtime.second) if nowtime.second < 10 else str(nowtime.second)
msec = str(nowtime.microsecond)[:2]

In [10]:
save_path = "/users/hjw/data/sann/optuna"
output_folder = "{}/{}{}{}_{}{}{}{}/".format(
    save_path, year, month, day, hour, minute, sec, msec
)
if not os.path.exists(output_folder):
    os.makedirs(output_folder)
print(output_folder)

/users/hjw/data/sann/optuna/20210702_18290451/


In [11]:
data = np.load("/users/hjw/data/ABCD/npz_files/rsfc_p_site_scanner_si_ge.npz", allow_pickle=True)
X = stats.zscore(data["X"], axis=1)
y = data["y"]
print(X.shape, y.shape)

(6905, 61776) (6905, 3)


In [12]:
p_factor_idx = 0
site_idx = 1
scanner_idx = 2

In [13]:
y = np.array(y, dtype=np.float)
y[:, site_idx] = y[:, site_idx].astype(np.int)
y[:, scanner_idx] = y[:, scanner_idx].astype(np.int)

In [14]:
# Spliting subject indices for leave-one-site-out with validation set
y_df = pd.DataFrame(y, columns=["p-factor", "site", "scanner"])
y_arr = np.array(y_df)
site_unq = np.unique(y_arr[:, site_idx])
data_idx = np.arange(y_arr.shape[0])

train_folds_idx = []
valid_folds_idx = []
test_folds_idx = []

split_seed = 0
valid_ratio = 0.10

for i, site in enumerate(site_unq):
    temp_train_idx = np.where(y_arr[:, site_idx] != site)[0]
    temp_test_idx = np.where(y_arr[:, site_idx] == site)[0]
    
    temp_test_df = y_df.iloc[temp_test_idx]
    test_scnr_label = np.unique(temp_test_df["scanner"])
    
    temp_train_df = y_df.iloc[temp_train_idx]
    temp_n_valid = np.int(len(temp_train_df) * valid_ratio)
    temp_valid_df = temp_train_df[temp_train_df["scanner"] == test_scnr_label[0]]
    temp_valid_df = temp_valid_df.sample(n=temp_n_valid, replace=False, random_state=split_seed)

    temp_train_idx = np.setdiff1d(temp_train_df.index.values, temp_valid_df.index.values)
    temp_valid_idx = temp_valid_df.index.values
    train_folds_idx.append(temp_train_idx)
    valid_folds_idx.append(temp_valid_idx)
    test_folds_idx.append(temp_test_idx)
    
    train_scnr_label = np.unique(temp_train_df["scanner"])
    valid_scnr_label = np.unique(temp_valid_df["scanner"])

In [15]:
mode = "max"
lr_patience = 5
min_lr = 1e-12
lr_factor = 0.25

swa_lr = 5e-03

momentum = 0.90

l1_param = 0
l2_param = 1e-03

early_stopping_patience = 150

input_dim = 61776
n_classes = len(np.unique(y[:, scanner_idx]))
output_reg_dim = 1
output_clf_dim = n_classes

outer_n_splits = len(train_folds_idx)
inner_n_splits = outer_n_splits - 1

In [16]:
indices = [i + 1 for i, x in enumerate(wsc_flag) if x == 1]
hsp_cand_list = list(itertools.product(hsp_cand_1, hsp_cand_2, hsp_cand_2))
hsp_cand_list = [list(i) for i in hsp_cand_list]
hsp_cand = [hsp_cand_1, hsp_cand_2, hsp_cand_3]
tg_hsp = hsp_cand

[[0.9], [0.3], [0.5]]


In [17]:
# Training dataset
class train_dataset(Dataset): 
    def __init__(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train
        
    def __len__(self):
        return len(self.X_train)
    
    def __getitem__(self, idx): 
        X_train = torch.from_numpy(self.X_train[idx]).type(torch.FloatTensor)
        y_train = torch.from_numpy(self.y_train[idx]).type(torch.FloatTensor)

        return X_train, y_train

In [18]:
# Test dataset
class valid_dataset(Dataset): 
    def __init__(self, X_valid, y_valid):
        self.X_valid = X_valid
        self.y_valid = y_valid
        
    def __len__(self):
        return len(self.X_valid)
    
    def __getitem__(self, idx): 
        X_valid = torch.from_numpy(self.X_valid[idx]).type(torch.FloatTensor)
        y_valid = torch.from_numpy(self.y_valid[idx]).type(torch.FloatTensor)
        
        return X_valid, y_valid

In [19]:
# Test dataset
class test_dataset(Dataset): 
    def __init__(self, X_test, y_test):
        self.X_test = X_test
        self.y_test = y_test
        
    def __len__(self):
        return len(self.X_test)
    
    def __getitem__(self, idx): 
        X_test = torch.from_numpy(self.X_test[idx]).type(torch.FloatTensor)
        y_test = torch.from_numpy(self.y_test[idx]).type(torch.FloatTensor)
        
        return X_test, y_test

In [20]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")

In [21]:
class GradientReversalFunction(Function):
    """
    Gradient Reversal Layer from:
    Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
    Forward pass is the identity function. In the backward pass,
    the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
    """

    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()

    @staticmethod
    def backward(ctx, grads):
        lambda_ = ctx.lambda_
        lambda_ = grads.new_tensor(lambda_)
        dx = -lambda_ * grads
        return dx, None

In [22]:
class GradientReversal(torch.nn.Module):
    def __init__(self, lambda_=0.0):
        super(GradientReversal, self).__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_)

In [23]:
class DNN(nn.Module):
    def __init__(self, extr_hidden, disc_hidden, pred_hidden, 
                 dropout_rate, dropout_reg, lambda_, act_func_name):
        super(DNN, self).__init__()
        self.ext_1 = nn.Linear(input_dim, extr_hidden)
        self.ext_bn_1 = nn.BatchNorm1d(extr_hidden)
        
        self.reg_1 = nn.Linear(extr_hidden, pred_hidden)
        self.reg_bn_1 = nn.BatchNorm1d(pred_hidden)
        self.reg_2 = nn.Linear(pred_hidden, output_reg_dim)
        
        self.clf_1 = nn.Linear(extr_hidden, disc_hidden)
        self.clf_bn_1 = nn.BatchNorm1d(disc_hidden)
        self.clf_2 = nn.Linear(disc_hidden, output_clf_dim)

        self.GradientReversal = GradientReversal(lambda_)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.dropout_reg = nn.Dropout(p=dropout_reg)
        self.act_func = get_activation_function(act_func_name)
        self.weights_init()
    
    def forward(self, x):
        feature = self.ext_1(x)
        feature = self.ext_bn_1(feature)
        feature = self.act_func(feature)
        feature = self.dropout(feature)
        
        x_reg = self.reg_1(feature)
        x_reg = self.reg_bn_1(x_reg)
        x_reg = self.act_func(x_reg)
        x_reg = self.dropout_reg(x_reg)
        x_reg = self.reg_2(x_reg)
        
        x_clf = self.GradientReversal(feature)
        x_clf = self.clf_1(x_clf)
        x_clf = self.clf_bn_1(x_clf)
        x_clf = self.act_func(x_clf)
        # x_clf = self.dropout(x_clf)
        x_clf = self.clf_2(x_clf)
        
        return x_reg, x_clf
    
    def weights_init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
                nn.init.normal_(m.bias, std=0.01)

In [24]:
def get_optimizer(model, opt_name, learning_rate=None, l2_param=None):
    lower_opt_name = opt_name.lower()
    if lower_opt_name == 'momentum':
        return optim.SGD(model.parameters(), lr=learning_rate, 
                         momentum=momentum, weight_decay=l2_param)
    elif lower_opt_name == 'nag':
        return optim.SGD(model.parameters(), lr=learning_rate, 
                         momentum=momentum, weight_decay=l2_param, nesterov=True)
    elif lower_opt_name == 'adam':
        return optim.Adam(model.parameters(), lr=learning_rate, 
                          weight_decay=l2_param)
    else:
        sys.exit("Illegal arguement for optimizer type")

In [25]:
def get_activation_function(act_func_name):
    act_func_name = act_func_name.lower()
    if act_func_name == 'relu':
        return nn.ReLU()
    elif act_func_name == 'prelu':
        return nn.PReLU()
    elif act_func_name == 'elu':
        return nn.ELU()
    elif act_func_name == 'silu':
        return nn.SiLU()
    elif act_func_name == 'leakyrelu':
        return nn.LeakyReLU()
    elif act_func_name == 'tanh':
        return nn.Tanh()
    else:
        sys.exit("Illegal arguement for activation function type")

In [26]:
def init_hsp(n_wsc, epochs):
    hsp_val = np.zeros(n_wsc)
    beta_val = hsp_val.copy()
    hsp_list = np.zeros((n_wsc, epochs))
    beta_list = np.zeros((n_wsc, epochs))
    
    return hsp_val, beta_val, hsp_list, beta_list

In [27]:
# Weight sparsity control with Hoyer's sparsness (Layer wise)
def calc_hsp(w, beta, max_beta, beta_lr, tg_hsp):
    
    # Get value of weight
    [dim, n_nodes] = w.shape
    num_elements = dim * n_nodes
    norm_ratio = torch.norm(w, 1) / torch.norm(w, 2)

    # Calculate hoyer's sparsity level
    num = np.sqrt(num_elements) - norm_ratio.item()
    den = np.sqrt(num_elements) - 1
    hsp = num / den

    # Update beta
    beta = beta + beta_lr * np.sign(tg_hsp - hsp)
    
    # Trim value
    beta = 0.0 if beta < 0.0 else beta
    beta = max_beta if beta > max_beta else beta

    return [hsp, beta]

In [28]:
def l1_penalty(model, epoch, hsp_val, beta_val, hsp_list, beta_list, tg_hsp):
    l1_reg = None
    layer_idx = 0
    wsc_idx = 0

    for name, param in model.named_parameters():
        if "weight" in name and "bn" not in name:
            if "ext" in name or "reg_1" in name or "clf_1" in name:
                temp_w = param
                
                if wsc_flag[layer_idx] != 0:
                    hsp_val[wsc_idx], beta_val[wsc_idx] = calc_hsp(
                        temp_w, beta_val[wsc_idx], max_beta[wsc_idx], 
                        beta_lr[wsc_idx], tg_hsp[wsc_idx]
                    )
                    hsp_list[wsc_idx, epoch - 1] = hsp_val[wsc_idx]
                    beta_list[wsc_idx, epoch - 1] = beta_val[wsc_idx]
                    layer_reg = torch.norm(temp_w, 1) * beta_val[wsc_idx]
                    wsc_idx += 1
                else:
                    layer_reg = torch.norm(temp_w, 1).item() * l1_param

                if l1_reg is None:
                    l1_reg = layer_reg
                else:
                    l1_reg = l1_reg + layer_reg
                layer_idx += 1
        
    return l1_reg

In [29]:
def pearsonr(x, y):
    x_mean = torch.mean(x)
    y_mean = torch.mean(y)
    xx = x.sub(x_mean)
    yy = y.sub(y_mean)
    num = xx.dot(yy)
    den = torch.norm(xx, 2) * torch.norm(yy, 2)
    corr = num / den
    return corr

In [30]:
def train(model, epoch, train_loader, optimizer, criterion_clf, criterion_reg, 
          hsp_val, beta_val, hsp_list, beta_list, tg_hsp):
    model.train()
    reg_loss = 0
    clf_loss = 0
    clf_acc = 0
    total = 0
    correct = 0
    y_train_true = []
    y_train_pred = []
    
    for batch_idx, (input, target) in enumerate(train_loader):
        optimizer.zero_grad()
        input, target = input.to(DEVICE), target.to(DEVICE)
        output_reg, output_clf = model(input)
        target_clf = target[:, scanner_idx].long().view(-1)
        target_reg = target[:, p_factor_idx].view(-1, 1)
        running_clf_loss = criterion_clf(output_clf, target_clf)
        running_reg_loss = criterion_reg(output_reg, target_reg)
        l1_term = l1_penalty(model, epoch, hsp_val, beta_val, hsp_list, beta_list, tg_hsp)
        running_loss = running_clf_loss + running_reg_loss + l1_term
        cost = running_loss
        cost.backward()
        optimizer.step()
        clf_loss += running_clf_loss.item()
        reg_loss += running_reg_loss.item()
        total += output_reg.size(0)
        _, pred = torch.max(output_clf.data, 1)
        correct += (pred.view(-1, 1) == target).sum().item()
        true_batch = torch.flatten(target_reg.detach())
        pred_batch = torch.flatten(output_reg.detach())
        y_train_true.append(true_batch)
        y_train_pred.append(pred_batch)
        
    reg_loss /= total
    clf_loss /= total
    clf_acc = 100 * correct / total
    y_train_true = torch.flatten(torch.stack(y_train_true))
    y_train_pred = torch.flatten(torch.stack(y_train_pred))
    train_corr = pearsonr(y_train_true, y_train_pred)
    return clf_loss, reg_loss, clf_acc, train_corr

In [31]:
def valid(model, epoch, valid_loader, criterion_clf, criterion_reg):
    model.eval()
    reg_loss = 0
    clf_loss = 0
    clf_acc = 0
    correct = 0
    total = 0
    y_valid_true = []
    y_valid_pred = []
    
    with torch.no_grad():
        for input, target in valid_loader:
            input, target = input.to(DEVICE), target.to(DEVICE)
            output_reg, output_clf = model(input)
            target_clf = target[:, scanner_idx].long().view(-1)
            target_reg = target[:, p_factor_idx].view(-1, 1)
            running_clf_loss = criterion_clf(output_clf, target_clf)
            running_reg_loss = criterion_reg(output_reg, target_reg)
            clf_loss += running_clf_loss.item()
            reg_loss += running_reg_loss.item()
            total += output_reg.size(0)
            _, pred = torch.max(output_clf.data, 1)
            correct += (pred.view(-1, 1) == target).sum().item()
            true_batch = torch.flatten(target_reg.detach())
            pred_batch = torch.flatten(output_reg.detach())
            y_valid_true.append(true_batch)
            y_valid_pred.append(pred_batch)

    clf_acc = 100 * correct / total
    y_valid_true = torch.flatten(torch.stack(y_valid_true))
    y_valid_pred = torch.flatten(torch.stack(y_valid_pred))
    valid_corr = pearsonr(y_valid_true, y_valid_pred)
    return clf_loss, reg_loss, clf_acc, valid_corr

In [32]:
def test(model, epoch, test_loader, criterion_clf, criterion_reg):
    model.eval()
    reg_loss = 0
    total = 0
    y_test_true = []
    y_test_pred = []
    
    with torch.no_grad():
        for input, target in test_loader:
            input, target = input.to(DEVICE), target.to(DEVICE)
            output_reg, output_clf = model(input)
            target_reg = target[:, p_factor_idx].view(-1, 1)
            running_reg_loss = criterion_reg(output_reg, target_reg)
            reg_loss += running_reg_loss.item()
            total += output_reg.size(0)
            true_batch = torch.flatten(target_reg.detach())
            pred_batch = torch.flatten(output_reg.detach())
            y_test_true.append(true_batch)
            y_test_pred.append(pred_batch)

    y_test_true = torch.flatten(torch.stack(y_test_true))
    y_test_pred = torch.flatten(torch.stack(y_test_pred))
    test_corr = pearsonr(y_test_true, y_test_pred)
    return reg_loss, test_corr

In [33]:
class early_stopping_func:
    def __init__(self, patience=5, verbose=False, delta=0, path=None):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_epoch = 0
        self.best_corr = 0
        self.early_stop = False
        self.valid_corr_max = -np.Inf
        self.delta = delta
        self.path = path
    
    def __call__(self, valid_loss, model, epoch, train_corr, valid_corr, test_corr):
        if self.best_corr is None:
            self.best_corr = valid_corr
            self.best_corr_list = [train_corr, valid_corr, test_corr]
            self.save_checkpoint(valid_loss, model, epoch)
        elif valid_corr < self.best_corr + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                print("Early Stopping! Best Model at Epoch {}"
                      .format(self.best_epoch), end=", ")
                print("valid corr: {:.4f}, test corr: {:.4f}"
                      .format(self.best_corr_list[1], self.best_corr_list[2]))
        else:
            self.best_corr = valid_corr
            self.best_corr_list = [train_corr, valid_corr, test_corr]
            self.save_checkpoint(self.best_corr, model, epoch)
            self.counter = 0

    def save_checkpoint(self, best_corr, model, epoch):
        if self.verbose:
            print("Validation Corr Increased! ({:.4f} --> {:.4f}), Saving the Model!"
                  .format(self.valid_corr_max, best_corr))
        torch.save(model.state_dict(), self.path + "/early_stopped_model.pt")
        self.valid_corr_max = best_corr
        self.best_epoch = epoch

In [34]:
def plot_learning_curves(
    save_dir, epochs, train_loss, valid_loss, test_loss, 
    train_corr, valid_corr, test_corr, train_acc, valid_acc, lr,
    plot_hsp_list, plot_beta_list, tg_hsp):
    
    sns.set(style="dark", font_scale=2)
    fig, ax = plt.subplots(2, 3, figsize=(28, 10))
    ax = ax.flat
    lw = 3.5
    last_epoch = epochs
    
    train_loss, valid_loss, test_loss = np.array(train_loss), np.array(valid_loss), np.array(test_loss)
    
    ax[0].plot(train_loss[:last_epoch, 0], label='train disc loss', lw=lw, color="r")
    ax[0].legend()
    ax[0].set_title("Discriminator Loss Plot")

    # ax[1].plot(train_loss[:last_epoch, 1], label='train pred loss', lw=lw, color="r")
    ax[1].plot(valid_loss[:last_epoch, 1], label='valid pred loss', lw=lw, color="g")
    ax[1].plot(test_loss[:last_epoch, 1], label='test pred loss', lw=lw, color="b")
    ax[1].legend()
    ax[1].set_title("Predictor Loss Plot")

    ax[2].plot(lr[:last_epoch], label='learning rate', lw=lw, color="k")
    ax[2].legend()
    ax[2].set_title("Learning Rate Plot")

    ax[3].plot(train_corr[:last_epoch], label='train corr', lw=lw, color="r")
    ax[3].plot(valid_corr[:last_epoch], label='valid corr', lw=lw, color="g")
    ax[3].plot(test_corr[:last_epoch], label='test corr', lw=lw, color="b")
    ax[3].legend()
    ax[3].set_title("Correlation Plot ($r$={:.4f})".format(test_corr[-1]))

    plot_hsp_list, plot_beta_list = np.array(plot_hsp_list).T, np.array(plot_beta_list).T
    
    for idx, n_layer in enumerate(indices):
        ax[4].plot(plot_hsp_list[idx], label='layer{}'.format(n_layer), lw=lw)
        ax[5].plot(plot_beta_list[idx], 
                   label='layer{}'.format(n_layer), lw=lw)
        ax[4].legend(); ax[5].legend()
        ax[4].set_title("HSP plot [{:.3f}/{:.3f}]".format(plot_hsp_list[0, -1], tg_hsp[0][0]))
        ax[5].set_title("Beta plot")
    
    fig.tight_layout()
    fig.savefig("{}/Learning_curves.png".format(save_dir))
    
    plt.close(fig)

In [35]:
def run_fold(
    n_outer_cv=0, output_folder=output_folder, params=None):

    start_fold_time = time.time()
    act_func_name = params["act_func"]
    optimizer_name = params["optimizer"]

    extr_hidden = params["extr_hidden"]
    pred_hidden = params["pred_hidden"]
    disc_hidden = params["disc_hidden"]

    learning_rate = params["learning_rate"]
    batch_size = params["batch_size"]
    dropout_rate = params["dropout_rate"]
    dropout_reg = params["dropout_reg"]
    lambda_ = params["lambda_"]
    
    epochs = params["epochs"]
    swa_start = params["swa_start"]
    
    n_outer_cv = n_outer_cv
    output_folder = output_folder

    print("\n=================================== Outer Fold [{}/{}] ==================================="
          .format(n_outer_cv + 1, outer_n_splits))
    outer_save_dir = "{}/Outer_fold_{}".format(output_folder, n_outer_cv + 1)
    os.makedirs(outer_save_dir, exist_ok=True)
    outer_train_idx = train_folds_idx[n_outer_cv]
    outer_valid_idx = valid_folds_idx[n_outer_cv]
    outer_test_idx = test_folds_idx[n_outer_cv]

    X_train, y_train = X[outer_train_idx], y[outer_train_idx]
    X_valid, y_valid = X[outer_valid_idx], y[outer_valid_idx]
    X_test, y_test = X[outer_test_idx], y[outer_test_idx]
    
    outer_train_dataset = train_dataset(X_train, y_train)
    outer_valid_dataset = valid_dataset(X_valid, y_valid)
    outer_test_dataset = test_dataset(X_test, y_test)
    
    outer_train_loader = DataLoader(
        outer_train_dataset, batch_size=batch_size, pin_memory=True,
        shuffle=True, num_workers=num_workers, drop_last=True
    )
    outer_valid_loader = DataLoader(
        outer_valid_dataset, batch_size=len(y_valid), pin_memory=True,
        shuffle=True, num_workers=num_workers, drop_last=True
    )
    outer_test_loader = DataLoader(
        outer_test_dataset, batch_size=len(y_test), pin_memory=True,
        shuffle=True, num_workers=num_workers, drop_last=True
    )
        
    # Assign model 
    model = DNN(
        extr_hidden, disc_hidden, pred_hidden, dropout_rate, dropout_reg, lambda_, act_func_name
    ).to(DEVICE)
    optimizer = get_optimizer(model, optimizer_name, learning_rate, l2_param)
    scheduler = ReduceLROnPlateau(
        optimizer, mode=mode, patience=lr_patience, min_lr=min_lr, factor=lr_factor
    )
    swa_model_50 = AveragedModel(model).to(DEVICE)
    swa_model_75 = AveragedModel(model).to(DEVICE)
    swa_model_100 = AveragedModel(model).to(DEVICE)
    swa_scheduler = SWALR(optimizer, swa_lr=swa_lr)

    criterion_clf = nn.CrossEntropyLoss()
    criterion_reg = nn.MSELoss(reduction="mean")
    
    early_stopping = early_stopping_func(patience=early_stopping_patience, path=outer_save_dir)

    # list to save learning parameters
    outer_train_loss = []
    outer_valid_loss = []
    outer_test_loss = []
    outer_train_corr = []
    outer_valid_corr = []
    outer_test_corr = []
    outer_train_acc = []
    outer_valid_acc = []
    outer_lr = []
    outer_hsp_list = []
    outer_beta_list = []

    hsp_val, beta_val, hsp_list, beta_list = init_hsp(n_wsc, epochs)
        
    for epoch in range(1, epochs + 1):
        train_clf_loss, train_reg_loss, train_acc, train_corr = train(
            model, epoch, outer_train_loader, 
            optimizer, criterion_clf, criterion_reg, 
            hsp_val, beta_val, hsp_list, beta_list, tg_hsp
        )
        valid_clf_loss, valid_reg_loss, valid_acc, valid_corr = valid(
            model, epoch, outer_valid_loader, criterion_clf, criterion_reg
        )
        test_reg_loss, test_corr = test(
            model, epoch, outer_test_loader, criterion_clf, criterion_reg
        )
        
        if epoch > swa_start:
            swa_model_50.update_parameters(model)
            swa_model_75.update_parameters(model)
            swa_model_100.update_parameters(model)
            # swa_scheduler.step()
        
        scheduler.step(hsp_val[0])
        lr = optimizer.param_groups[0]['lr']
        
        outer_train_loss.append([train_clf_loss, train_reg_loss])
        outer_train_corr.append(train_corr)
        outer_train_acc.append(train_acc)
        outer_valid_loss.append([valid_clf_loss, valid_reg_loss])
        outer_valid_corr.append(valid_corr)
        outer_valid_acc.append(valid_acc)
        outer_test_loss.append([[], test_reg_loss])
        outer_test_corr.append(test_corr)
        outer_lr.append(lr)
        outer_hsp_list.append(list(hsp_val))
        outer_beta_list.append(list(beta_val))

        if epoch % 5 == 0:
            print("Epoch [{:d}/{:d}]".format(epoch, epochs), end=" ")
            print("Train corr: {:.4f}, Valid corr: {:.4f}, Test corr: {:.4f}"
                  .format(train_corr, valid_corr, test_corr))
            for i in range(len(wsc_flag)):
                if wsc_flag[i] != 0:
                    print("Layer {:d}: [{:.4f}/{:.4f}]".
                          format( i + 1, hsp_val[i], tg_hsp[i][0]), end=" ")
            print("\nCurrent learning rate: {:.2e}".format(Decimal(str(lr))))

        early_stopping(valid_reg_loss, model, epoch, train_corr, valid_corr, test_corr)
        
        if early_stopping.early_stop:
            train_corr = early_stopping.best_corr_list[0]
            valid_corr = early_stopping.best_corr_list[1]
            test_corr = early_stopping.best_corr_list[2]
            break

        plot_learning_curves(
            outer_save_dir, epochs, outer_train_loss, outer_valid_loss, outer_test_loss,  
            outer_train_corr, outer_valid_corr, outer_test_corr, 
            outer_train_acc, outer_valid_acc, outer_lr, outer_hsp_list, outer_beta_list, tg_hsp
        )
    
    if early_stopping.early_stop == False:
        print("Outer Fold [{}/{}] train corr: {:.4f}, valid corr: {:.4f}, test corr: {:.4f}"
              .format(n_outer_cv + 1, outer_n_splits, train_corr, valid_corr, test_corr))
    
    swa_model_50 = swa_model_50.cpu()
    swa_model_75 = swa_model_75.cpu()
    swa_model_100 = swa_model_100.cpu()

    torch.optim.swa_utils.update_bn(outer_train_loader, swa_model_50)
    torch.optim.swa_utils.update_bn(outer_train_loader, swa_model_75)
    torch.optim.swa_utils.update_bn(outer_train_loader, swa_model_100)
    
    torch.save(swa_model_50.state_dict(), 
               outer_save_dir + "/swa_model_50_fold_" + str(n_outer_cv + 1) + ".pt")
    torch.save(swa_model_75.state_dict(), 
               outer_save_dir + "/swa_model_75_fold_" + str(n_outer_cv + 1) + ".pt")
    torch.save(swa_model_100.state_dict(), 
               outer_save_dir + "/swa_model_100_fold_" + str(n_outer_cv + 1) + ".pt")

    torch.save(model.state_dict(), 
               outer_save_dir + "/model_fold_" + str(n_outer_cv + 1) + ".pt")
    
    torch.cuda.empty_cache()
    gc.collect()
    
    tot_time = time.time() - start_fold_time
    print("Execution Time for Fold: {:.2f} mins".format(tot_time / 60))
    return train_corr, valid_corr, test_corr

In [36]:
params = {
    "act_func": "elu",
    "optimizer": "nag",
    "extr_hidden": 1024, 
    "pred_hidden": 1024,
    "disc_hidden": 1024,
    "dropout_rate": 0.9,
    "dropout_reg": 0.95,
    "learning_rate": 5e-05,
    "batch_size": 32,
    "lambda_": 1.25,
    "epochs": 150,
    "swa_start": 75
}

temp_param = "arc_{}_{}_{}".format(
    params["extr_hidden"], params["disc_hidden"], params["pred_hidden"],
    params["learning_rate"], params["batch_size"], 
    params["dropout_rate"], params["dropout_reg"],
    params["lambda_"], params["epochs"]
)

temp_output_folder = os.path.join(output_folder, temp_param)

code_start_time = time.time()
corr_list = []

for n_outer_cv in range(outer_n_splits):
    train_corr, valid_corr, test_corr = run_fold(
        n_outer_cv=n_outer_cv, output_folder=temp_output_folder, params=params
    )
    corr_list.append([train_corr, valid_corr, test_corr])
    corr_df = pd.DataFrame(corr_list, columns=["train", "valid", "test"])
    corr_df.to_csv("{}/corr_df.csv".format(temp_output_folder))


/users/hjw/data/sann/optuna/20210702_18290451/arc_1024_1024_1024

Epoch [5/150] Train corr: 0.0092, Valid corr: 0.1171, Test corr: 0.0477
Layer 1: [0.4864/0.9000] Layer 2: [0.3069/0.3000] Layer 3: [0.3923/0.5000] 
Current learning rate: 5.00e-5
Epoch [10/150] Train corr: 0.0223, Valid corr: 0.1467, Test corr: 0.0654
Layer 1: [0.7407/0.9000] Layer 2: [0.3066/0.3000] Layer 3: [0.5026/0.5000] 
Current learning rate: 5.00e-5
Epoch [15/150] Train corr: 0.0370, Valid corr: 0.0878, Test corr: 0.1430
Layer 1: [0.8072/0.9000] Layer 2: [0.3065/0.3000] Layer 3: [0.5027/0.5000] 
Current learning rate: 5.00e-5
Epoch [20/150] Train corr: 0.0473, Valid corr: 0.0623, Test corr: 0.0725
Layer 1: [0.5978/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5026/0.5000] 
Current learning rate: 1.25e-5
Epoch [25/150] Train corr: 0.0394, Valid corr: 0.1253, Test corr: 0.1807
Layer 1: [0.7128/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5023/0.5000] 
Current learning rate: 1.25e-5
Epoch [30/150] Train corr: 0.0615,

Epoch [75/150] Train corr: 0.1889, Valid corr: 0.1938, Test corr: 0.1786
Layer 1: [0.9001/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5020/0.5000] 
Current learning rate: 1.22e-8
Epoch [80/150] Train corr: 0.1595, Valid corr: 0.1937, Test corr: 0.1776
Layer 1: [0.9001/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5020/0.5000] 
Current learning rate: 1.22e-8
Epoch [85/150] Train corr: 0.1602, Valid corr: 0.1924, Test corr: 0.1786
Layer 1: [0.8999/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5020/0.5000] 
Current learning rate: 1.22e-8
Epoch [90/150] Train corr: 0.1928, Valid corr: 0.1926, Test corr: 0.1776
Layer 1: [0.9001/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5020/0.5000] 
Current learning rate: 1.22e-8
Epoch [95/150] Train corr: 0.2014, Valid corr: 0.1916, Test corr: 0.1767
Layer 1: [0.9000/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5020/0.5000] 
Current learning rate: 1.22e-8
Epoch [100/150] Train corr: 0.2040, Valid corr: 0.1900, Test corr: 0.1761
Layer 1: [0.9000/0.9000] L

Epoch [150/150] Train corr: 0.3253, Valid corr: 0.1304, Test corr: 0.1451
Layer 1: [0.8999/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 1.22e-8
Outer Fold [3/18] train corr: 0.3253, valid corr: 0.1304, test corr: 0.1451
Execution Time for Fold: 35.78 mins

Epoch [5/150] Train corr: 0.0312, Valid corr: 0.0980, Test corr: -0.0493
Layer 1: [0.4847/0.9000] Layer 2: [0.3069/0.3000] Layer 3: [0.3910/0.5000] 
Current learning rate: 5.00e-5
Epoch [10/150] Train corr: 0.0307, Valid corr: 0.1298, Test corr: -0.0442
Layer 1: [0.7386/0.9000] Layer 2: [0.3066/0.3000] Layer 3: [0.5024/0.5000] 
Current learning rate: 5.00e-5
Epoch [15/150] Train corr: 0.0387, Valid corr: 0.0860, Test corr: -0.0039
Layer 1: [0.8141/0.9000] Layer 2: [0.3065/0.3000] Layer 3: [0.5025/0.5000] 
Current learning rate: 5.00e-5
Epoch [20/150] Train corr: 0.0231, Valid corr: 0.1152, Test corr: -0.0902
Layer 1: [0.6083/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5024/0.5000] 
Current learni

Epoch [70/150] Train corr: 0.1675, Valid corr: 0.1606, Test corr: 0.1208
Layer 1: [0.8999/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5018/0.5000] 
Current learning rate: 1.22e-8
Epoch [75/150] Train corr: 0.1914, Valid corr: 0.1568, Test corr: 0.1204
Layer 1: [0.8999/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5018/0.5000] 
Current learning rate: 1.22e-8
Epoch [80/150] Train corr: 0.1465, Valid corr: 0.1560, Test corr: 0.1212
Layer 1: [0.8999/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5018/0.5000] 
Current learning rate: 1.22e-8
Epoch [85/150] Train corr: 0.1963, Valid corr: 0.1574, Test corr: 0.1229
Layer 1: [0.8999/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5018/0.5000] 
Current learning rate: 1.22e-8
Epoch [90/150] Train corr: 0.2228, Valid corr: 0.1561, Test corr: 0.1212
Layer 1: [0.9000/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5018/0.5000] 
Current learning rate: 1.22e-8
Epoch [95/150] Train corr: 0.2022, Valid corr: 0.1545, Test corr: 0.1221
Layer 1: [0.9000/0.9000] La

Epoch [145/150] Train corr: 0.3006, Valid corr: 0.1426, Test corr: 0.1070
Layer 1: [0.9002/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5019/0.5000] 
Current learning rate: 1.22e-8
Epoch [150/150] Train corr: 0.3114, Valid corr: 0.1439, Test corr: 0.1075
Layer 1: [0.8999/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5019/0.5000] 
Current learning rate: 1.22e-8
Outer Fold [6/18] train corr: 0.3114, valid corr: 0.1439, test corr: 0.1075
Execution Time for Fold: 54.02 mins

Epoch [5/150] Train corr: 0.0263, Valid corr: 0.0745, Test corr: 0.0337
Layer 1: [0.4881/0.9000] Layer 2: [0.3069/0.3000] Layer 3: [0.3932/0.5000] 
Current learning rate: 5.00e-5
Epoch [10/150] Train corr: 0.0221, Valid corr: 0.1324, Test corr: 0.0417
Layer 1: [0.7426/0.9000] Layer 2: [0.3066/0.3000] Layer 3: [0.5026/0.5000] 
Current learning rate: 5.00e-5
Epoch [15/150] Train corr: 0.0259, Valid corr: 0.1205, Test corr: 0.0438
Layer 1: [0.8088/0.9000] Layer 2: [0.3065/0.3000] Layer 3: [0.5023/0.5000] 
Current learning 

Epoch [65/150] Train corr: 0.1699, Valid corr: 0.2123, Test corr: 0.2389
Layer 1: [0.9002/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5020/0.5000] 
Current learning rate: 4.88e-8
Epoch [70/150] Train corr: 0.1788, Valid corr: 0.2058, Test corr: 0.2377
Layer 1: [0.9000/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5020/0.5000] 
Current learning rate: 1.22e-8
Epoch [75/150] Train corr: 0.1826, Valid corr: 0.2028, Test corr: 0.2387
Layer 1: [0.8999/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5020/0.5000] 
Current learning rate: 1.22e-8
Epoch [80/150] Train corr: 0.1753, Valid corr: 0.2036, Test corr: 0.2390
Layer 1: [0.8999/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5020/0.5000] 
Current learning rate: 1.22e-8
Epoch [85/150] Train corr: 0.1846, Valid corr: 0.2024, Test corr: 0.2386
Layer 1: [0.9000/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5020/0.5000] 
Current learning rate: 1.22e-8
Epoch [90/150] Train corr: 0.1931, Valid corr: 0.2027, Test corr: 0.2391
Layer 1: [0.9002/0.9000] La

Epoch [140/150] Train corr: 0.3358, Valid corr: 0.0939, Test corr: 0.1405
Layer 1: [0.8999/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 1.22e-8
Epoch [145/150] Train corr: 0.3309, Valid corr: 0.0932, Test corr: 0.1411
Layer 1: [0.8999/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 1.22e-8
Epoch [150/150] Train corr: 0.3315, Valid corr: 0.0980, Test corr: 0.1381
Layer 1: [0.8999/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 1.22e-8
Outer Fold [9/18] train corr: 0.3315, valid corr: 0.0980, test corr: 0.1381
Execution Time for Fold: 46.16 mins

Epoch [5/150] Train corr: 0.0129, Valid corr: 0.1411, Test corr: 0.1175
Layer 1: [0.4832/0.9000] Layer 2: [0.3068/0.3000] Layer 3: [0.3899/0.5000] 
Current learning rate: 5.00e-5
Epoch [10/150] Train corr: 0.0222, Valid corr: 0.1191, Test corr: 0.1804
Layer 1: [0.7367/0.9000] Layer 2: [0.3065/0.3000] Layer 3: [0.5027/0.5000] 
Current learning

Epoch [60/150] Train corr: 0.1516, Valid corr: 0.1725, Test corr: 0.2201
Layer 1: [0.9007/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 4.88e-8
Epoch [65/150] Train corr: 0.2034, Valid corr: 0.1733, Test corr: 0.2335
Layer 1: [0.9010/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 4.88e-8
Epoch [70/150] Train corr: 0.1672, Valid corr: 0.1687, Test corr: 0.2424
Layer 1: [0.9002/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 1.22e-8
Epoch [75/150] Train corr: 0.1965, Valid corr: 0.1673, Test corr: 0.2452
Layer 1: [0.9002/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 1.22e-8
Epoch [80/150] Train corr: 0.2138, Valid corr: 0.1704, Test corr: 0.2446
Layer 1: [0.9003/0.9000] Layer 2: [0.3064/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 1.22e-8
Epoch [85/150] Train corr: 0.2022, Valid corr: 0.1681, Test corr: 0.2501
Layer 1: [0.9001/0.9000] La

Epoch [135/150] Train corr: 0.3056, Valid corr: 0.1590, Test corr: 0.0870
Layer 1: [0.9002/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5022/0.5000] 
Current learning rate: 1.22e-8
Epoch [140/150] Train corr: 0.2963, Valid corr: 0.1590, Test corr: 0.0870
Layer 1: [0.9000/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5022/0.5000] 
Current learning rate: 1.22e-8
Epoch [145/150] Train corr: 0.2941, Valid corr: 0.1579, Test corr: 0.0858
Layer 1: [0.8999/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5022/0.5000] 
Current learning rate: 1.22e-8
Epoch [150/150] Train corr: 0.3091, Valid corr: 0.1575, Test corr: 0.0852
Layer 1: [0.9001/0.9000] Layer 2: [0.3063/0.3000] Layer 3: [0.5022/0.5000] 
Current learning rate: 1.22e-8
Outer Fold [12/18] train corr: 0.3091, valid corr: 0.1575, test corr: 0.0852
Execution Time for Fold: 49.22 mins

Epoch [5/150] Train corr: 0.0347, Valid corr: 0.1399, Test corr: 0.0204
Layer 1: [0.4832/0.9000] Layer 2: [0.3069/0.3000] Layer 3: [0.3896/0.5000] 
Current learni

Epoch [55/150] Train corr: 0.1576, Valid corr: 0.1779, Test corr: 0.1478
Layer 1: [0.9006/0.9000] Layer 2: [0.3062/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 4.88e-8
Epoch [60/150] Train corr: 0.2017, Valid corr: 0.1823, Test corr: 0.1469
Layer 1: [0.9002/0.9000] Layer 2: [0.3062/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 1.22e-8
Epoch [65/150] Train corr: 0.1905, Valid corr: 0.1835, Test corr: 0.1474
Layer 1: [0.9003/0.9000] Layer 2: [0.3062/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 1.22e-8
Epoch [70/150] Train corr: 0.2123, Valid corr: 0.1813, Test corr: 0.1427
Layer 1: [0.8999/0.9000] Layer 2: [0.3062/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 1.22e-8
Epoch [75/150] Train corr: 0.2234, Valid corr: 0.1781, Test corr: 0.1297
Layer 1: [0.9004/0.9000] Layer 2: [0.3062/0.3000] Layer 3: [0.5021/0.5000] 
Current learning rate: 1.22e-8
Epoch [80/150] Train corr: 0.2341, Valid corr: 0.1770, Test corr: 0.1285
Layer 1: [0.8998/0.9000] La

OSError: [Errno 12] Cannot allocate memory

In [37]:
code_tot_time = time.time() - code_start_time 
print("Execution Time for the training: {:.2f} hours".format(code_tot_time / 60 / 60))

Execution Time for the training: 53.69 hours


In [39]:
new_corr_list = []
for i in range(len(corr_df)):
    new_corr_list.append([
        corr_df.iloc[i, 0].detach().cpu().numpy(),
        corr_df.iloc[i, 1].detach().cpu().numpy(),
        corr_df.iloc[i, 2].detach().cpu().numpy(),
    ])
new_corr_df = pd.DataFrame(new_corr_list, columns=["train", "valid", "test"])
print("LOSOCV results --> valid corr: {:.4f}, test corr: {:.4f}"
      .format(new_corr_df.valid.values.mean(), new_corr_df.test.values.mean()))

LOSOCV results --> valid corr: 0.1588, test corr: 0.1378
