# Train transfer learning model for DICT prediction

In [None]:
# import all required materials

import torch
import numpy as np
from torch_geometric.loader import DataLoader
from sklearn.model_selection import KFold
import os

from model import VerticalGNN
from config import NUM_FEATURES, NUM_TARGET, EDGE_DIM, DEVICE, SEED_NO, PATIENCE, EPOCHS, NUM_GRAPHS_PER_BATCH, N_SPLITS, best_params_vertical
from engine import EnginehERG, EngineDICT
from utils import seed_everything, LoadDICTDataset, LoadhERGDataset

First, we define a train and test function that can aid us in transfer learning. Then, we evaluate how different factors can affect the results of transfer learning. 

In [None]:
def run_training(method_tf, train_loader, valid_loader, params,es_trigger, path_to_pretrained_model, path_to_save_trained_model):
    
    '''
    Define a function to wrap training

    Args:
    method_tf (str): freeze --> freeze parameters of feature extraction block, fine_tune_2x -> fine tune at 2x slower learning rate, 
                    fine_tune_5x -> fine tune at 5x slower learning rate
    train_loader: DataLoader class from pytorch geometric containing train data
    valid_loader: DataLoader class from pytorch geometric containing validation data
    params (dict): dictionary containing the hyperparameters
    es_trigger (int): a number to force train model before triggering early stopping mechanism 
    path_to_pretrained_model (str): path to load the pretrained models
    path_to_save_trained_model: path to save the trained models

    Return:
    best loss: return best validation loss
    '''
    
    model = VerticalGNN(
            num_features=NUM_FEATURES,
            num_targets=NUM_TARGET,
            num_gin_layers=params["num_gin_layers"],
            num_graph_trans_layers=params["num_graph_trans_layers"],
            hidden_size=params["hidden_size"],
            n_heads=params["n_heads"],
            dropout=params["dropout"],
            edge_dim=EDGE_DIM,
        )

    model.load_state_dict(torch.load(path_to_pretrained_model))  
    model.to(DEVICE)
    if method_tf == 'freeze':
        for param in model.gin_model.parameters():
            param.requires_grad=False
        for param in model.graph_trans_model.parameters():
            param.requires_grad=False

        optimizer=torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = params['learning_rate'])
    
    elif method_tf == 'fine_tune_5x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/5},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/5},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])

    elif method_tf == 'fine_tune_10x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/10},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/10},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])

    elif method_tf == 'fine_tune':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])

    elif method_tf == 'fine_tune_2x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/2},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/2},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])
        
    eng = EngineDICT(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0

    for epoch in range(EPOCHS):
        train_loss = eng.train(train_loader)
        valid_loss_tuple = eng.validate(valid_loader)
        valid_loss = valid_loss_tuple[0]
        print(
            f"Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}"
        )
        if epoch+1>es_trigger:
            if valid_loss < best_loss:
                best_loss = valid_loss
                early_stopping_counter = 0  # reset counter
                
                model_save_directory = os.path.dirname(path_to_save_trained_model)
                if not os.path.exists(model_save_directory):
                    os.makedirs(model_save_directory, exist_ok=True)

                print("Saving model...")
                torch.save(model.state_dict(), path_to_save_trained_model)
            else:
                early_stopping_counter += 1

            if early_stopping_counter > early_stopping_iter:
                print("Early stopping...")
                break
            print(f"Early stop counter: {early_stopping_counter}")

    return best_loss



In [None]:
def run_validation(method_tf, valid_loader, params, path_to_trained_model):
    model = VerticalGNN(
            num_features=NUM_FEATURES,
            num_targets=NUM_TARGET,
            num_gin_layers=params["num_gin_layers"],
            num_graph_trans_layers=params["num_graph_trans_layers"],
            hidden_size=params["hidden_size"],
            n_heads=params["n_heads"],
            dropout=params["dropout"],
            edge_dim=EDGE_DIM,
        )

    model.load_state_dict(torch.load(path_to_trained_model))  
    model.to(DEVICE)
    if method_tf == 'freeze':
        for param in model.gin_model.parameters():
            param.requires_grad=False
        for param in model.graph_trans_model.parameters():
            param.requires_grad=False

        optimizer=torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = params['learning_rate'])        
    
    elif method_tf == 'fine_tune_5x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/5},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/5},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])
        

    elif method_tf == 'fine_tune_10x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/10},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/10},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])

    elif method_tf == 'fine_tune':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])

    elif method_tf == 'fine_tune_2x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/2},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/2},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])
        

    eng = EngineDICT(model, optimizer, device=DEVICE)
    bce, acc, f1, roc_auc = eng.validate(valid_loader)
    print(f"bce:{bce}, acc :{acc}, f1: {f1}, roc_auc: {roc_auc}")
    return bce, acc, f1, roc_auc

In [None]:
def run_testing(method_tf, test_loader, params, path_to_trained_model):
    
    model = VerticalGNN(
            num_features=NUM_FEATURES,
            num_targets=NUM_TARGET,
            num_gin_layers=params["num_gin_layers"],
            num_graph_trans_layers=params["num_graph_trans_layers"],
            hidden_size=params["hidden_size"],
            n_heads=params["n_heads"],
            dropout=params["dropout"],
            edge_dim=EDGE_DIM,
        )

    model.load_state_dict(torch.load(path_to_trained_model))  
    model.to(DEVICE)
    if method_tf == 'freeze':
        for param in model.gin_model.parameters():
            param.requires_grad=False
        for param in model.graph_trans_model.parameters():
            param.requires_grad=False

        optimizer=torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = params['learning_rate'])        
    
    elif method_tf == 'fine_tune_5x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/5},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/5},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])
        

    elif method_tf == 'fine_tune_10x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/10},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/10},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])
        
    elif method_tf == 'fine_tune':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])

    elif method_tf == 'fine_tune_2x':
        optimizer=torch.optim.Adam([
            {'params': model.gin_model.parameters(), 'lr': params['learning_rate']/2},
            {'params': model.graph_trans_model.parameters(), 'lr': params['learning_rate']/2},
            {'params': model.ro.parameters()}
        ],lr = params['learning_rate'])
        

    eng = EngineDICT(model, optimizer, device=DEVICE)
    bce, acc, f1, roc_auc = eng.test(test_loader)
    print(f"bce:{bce}, acc :{acc}, f1: {f1}, roc_auc: {roc_auc}")
    return bce, acc, f1, roc_auc

# Effect of number of pre-training epochs to transfer learning model performance
1. Weights are first frozen for the feature extraction block. 
2. Models are allowed to train and only weights for the classifier block is allowed to be updated. 
3. To evaluate how number of pre-training epochs affect the transfer learning prediction performance.

In [None]:
train_data_root_path = './data/graph_data/data_DICT_train/'
train_data_raw_filename = 'data_DICT_train.csv'
test_data_root_path = './data/graph_data/data_DICT_test'
test_data_raw_filename = 'data_DICT_test.csv'
n_repetitions = 1
method_tf = 'freeze'
params = best_params_vertical
es_trigger = 0
path_to_pretrained_model = './trf_learning_models/pretrained_models/vertical/'
path_to_save_trained_model = './trf_learning_models/trained_models/vertical/pretrained_40/'    # Setup on demand
path_to_trained_model = './trf_learning_models/trained_models/vertical/pretrained_40/'   # Setup on demand

val_bce_list = []
val_acc_list = []
val_f1_list = []
val_roc_auc_list = []

bce_list = []
acc_list = []
f1_list = []
roc_auc_list = []

dataset_for_cv = LoadDICTDataset(train_data_root_path, train_data_raw_filename)
kf = KFold(n_splits=N_SPLITS)

for repeat in range(n_repetitions):
    repeat_val_bce_list = []
    repeat_val_acc_list = []
    repeat_val_f1_list = []
    repeat_val_roc_auc_list = []
    
    repeat_bce_list = []
    repeat_acc_list = []
    repeat_f1_list = []
    repeat_roc_auc_list = []
    
    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        seed_everything(SEED_NO)
        train_dataset = []
        valid_dataset = []
        
        for t_idx in train_idx:
            train_dataset.append(
                torch.load(
                    f"./data/graph_data/data_DICT_train/processed/molecule_{t_idx}.pt"
                )
            )
        for v_idx in valid_idx:
            valid_dataset.append(
                torch.load(
                    f"./data/graph_data/data_DICT_train/processed/molecule_{v_idx}.pt"
                )
            )

        train_loader = DataLoader(
            train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True
        )
        valid_loader = DataLoader(
            valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        test_dataset = LoadDICTDataset(test_data_root_path, test_data_raw_filename)
        test_loader = DataLoader(
            test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        print(f'Rep no {repeat}, Fold no {fold_no}')
        
        run_training(method_tf, train_loader, valid_loader, params, es_trigger, os.path.join(path_to_pretrained_model, f'pretrained_vertical_model_40_epoch.pt'),
            os.path.join(
                path_to_save_trained_model, f"trained_vertical_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"
            )
        )
        
        val_bce, val_acc, val_f1, val_roc_auc = run_validation(method_tf, valid_loader, params, 
                        os.path.join(path_to_save_trained_model, f"trained_vertical_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"))
        bce, acc, f1, roc_auc = run_testing(method_tf, test_loader, params, 
                        os.path.join(path_to_save_trained_model, f"trained_vertical_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"))

        repeat_val_bce_list.append(val_bce)
        repeat_val_acc_list.append(val_acc)
        repeat_val_f1_list.append(val_f1)
        repeat_val_roc_auc_list.append(val_roc_auc)
        
        repeat_bce_list.append(bce)
        repeat_acc_list.append(acc)
        repeat_f1_list.append(f1)
        repeat_roc_auc_list.append(roc_auc)
        
        val_bce_list.append(val_bce)
        val_acc_list.append(val_acc)
        val_f1_list.append(val_f1)
        val_roc_auc_list.append(val_roc_auc)
        
        bce_list.append(bce)
        acc_list.append(acc)
        f1_list.append(f1)
        roc_auc_list.append(roc_auc)

    # Output statistics for validation and CV results for the repeat
    print(f'Statistics for repeat {repeat}:')
    print(f'Validation - BCE: {np.mean(repeat_val_bce_list):.3f}±{np.std(repeat_val_bce_list):.3f}')
    print(f'Validation - ACC: {np.mean(repeat_val_acc_list):.3f}±{np.std(repeat_val_acc_list):.3f}')
    print(f'Validation - F1: {np.mean(repeat_val_f1_list):.3f}±{np.std(repeat_val_f1_list):.3f}')
    print(f'Validation - ROC_AUC: {np.mean(repeat_val_roc_auc_list):.3f}±{np.std(repeat_val_roc_auc_list):.3f}')

    print(f'test - BCE: {np.mean(repeat_bce_list):.3f}±{np.std(repeat_bce_list):.3f}')
    print(f'test - ACC: {np.mean(repeat_acc_list):.3f}±{np.std(repeat_acc_list):.3f}')
    print(f'test - F1: {np.mean(repeat_f1_list):.3f}±{np.std(repeat_f1_list):.3f}')
    print(f'test - ROC_AUC: {np.mean(repeat_roc_auc_list):.3f}±{np.std(repeat_roc_auc_list):.3f}')

val_bce_arr = np.array(val_bce_list)
val_mean_bce = np.mean(val_bce_arr)
val_sd_bce = np.std(val_bce_arr)
print(f'validation bce:{val_mean_bce:.3f}±{val_sd_bce:.3f}')

val_acc_arr = np.array(val_acc_list)
val_acc_mean= np.mean(val_acc_arr)
val_acc_sd = np.std(val_acc_arr)
print(f'validation acc:{val_acc_mean:.3f}±{val_acc_sd:.3f}')

val_f1_arr = np.array(val_f1_list)
val_f1_mean= np.mean(val_f1_arr)
val_f1_sd = np.std(val_f1_arr)
print(f'validation f1: {val_f1_mean:.3f}±{val_f1_sd:.3f}')

val_roc_auc_arr = np.array(val_roc_auc_list)
val_roc_auc_mean= np.mean(val_roc_auc_arr)
val_roc_auc_sd = np.std(val_roc_auc_arr)
print(f'validation roc_auc: {val_roc_auc_mean:.3f}±{val_roc_auc_sd:.3f}')

bce_arr = np.array(bce_list)
mean_bce = np.mean(bce_arr)
sd_bce = np.std(bce_arr)
print(f'bce:{mean_bce:.3f}±{sd_bce:.3f}')

acc_arr = np.array(acc_list)
acc_mean= np.mean(acc_arr)
acc_sd = np.std(acc_arr)
print(f'acc:{acc_mean:.3f}±{acc_sd:.3f}')

f1_arr = np.array(f1_list)
f1_mean= np.mean(f1_arr)
f1_sd = np.std(f1_arr)
print(f'f1: {f1_mean:.3f}±{f1_sd:.3f}')

roc_auc_arr = np.array(roc_auc_list)
roc_auc_mean= np.mean(roc_auc_arr)
roc_auc_sd = np.std(roc_auc_arr)
print(f'roc_auc: {roc_auc_mean:.3f}±{roc_auc_sd:.3f}')

print("Training Completed!")


# Effect of reducing learning rates to transfer learning prediction performance
Fine_tune = No change in learning rate  
Fine_tune_2x = 2-fold reduced in learning rate  
Fine_tune_5x = 5-fold reduced in learning rate  

In [None]:
train_data_root_path = './data/graph_data/data_DICT_train/'
train_data_raw_filename = 'data_DICT_train.csv'
test_data_root_path = './data/graph_data/data_DICT_test'
test_data_raw_filename = 'data_DICT_test.csv'
n_repetitions = 1
method_tf = 'fine_tune' # Setup on demand
params = best_params_vertical
es_trigger = 0
path_to_pretrained_model = './trf_learning_models/pretrained_models/vertical/'
path_to_save_trained_model = './trf_learning_models/trained_models/vertical/pretrained_40/'
path_to_trained_model = './trf_learning_models/trained_models/vertical/pretrained_40/'

val_bce_list = []
val_acc_list = []
val_f1_list = []
val_roc_auc_list = []

bce_list = []
acc_list = []
f1_list = []
roc_auc_list = []

dataset_for_cv = LoadDICTDataset(train_data_root_path, train_data_raw_filename)
kf = KFold(n_splits=N_SPLITS)

for repeat in range(n_repetitions):
    repeat_val_bce_list = []
    repeat_val_acc_list = []
    repeat_val_f1_list = []
    repeat_val_roc_auc_list = []
    
    repeat_bce_list = []
    repeat_acc_list = []
    repeat_f1_list = []
    repeat_roc_auc_list = []
    
    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        seed_everything(SEED_NO)
        train_dataset = []
        valid_dataset = []
        
        for t_idx in train_idx:
            train_dataset.append(
                torch.load(
                    f"./data/graph_data/data_DICT_train/processed/molecule_{t_idx}.pt"
                )
            )
        for v_idx in valid_idx:
            valid_dataset.append(
                torch.load(
                    f"./data/graph_data/data_DICT_train/processed/molecule_{v_idx}.pt"
                )
            )

        train_loader = DataLoader(
            train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True
        )
        valid_loader = DataLoader(
            valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        test_dataset = LoadDICTDataset(test_data_root_path, test_data_raw_filename)
        test_loader = DataLoader(
            test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        print(f'Rep no {repeat}, Fold no {fold_no}')
        
        run_training(method_tf, train_loader, valid_loader, params, es_trigger, os.path.join(path_to_pretrained_model, f'pretrained_vertical_model_40_epoch.pt'),
            os.path.join(
                path_to_save_trained_model, f"trained_vertical_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"
            )
        )
        
        val_bce, val_acc, val_f1, val_roc_auc = run_validation(method_tf, valid_loader, params, 
                        os.path.join(path_to_save_trained_model, f"trained_vertical_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"))
        bce, acc, f1, roc_auc = run_testing(method_tf, test_loader, params, 
                        os.path.join(path_to_save_trained_model, f"trained_vertical_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"))

        repeat_val_bce_list.append(val_bce)
        repeat_val_acc_list.append(val_acc)
        repeat_val_f1_list.append(val_f1)
        repeat_val_roc_auc_list.append(val_roc_auc)
        
        repeat_bce_list.append(bce)
        repeat_acc_list.append(acc)
        repeat_f1_list.append(f1)
        repeat_roc_auc_list.append(roc_auc)
        
        val_bce_list.append(val_bce)
        val_acc_list.append(val_acc)
        val_f1_list.append(val_f1)
        val_roc_auc_list.append(val_roc_auc)
        
        bce_list.append(bce)
        acc_list.append(acc)
        f1_list.append(f1)
        roc_auc_list.append(roc_auc)

    # Output statistics for validation and CV results for the repeat
    print(f'Statistics for repeat {repeat}:')
    print(f'Validation - BCE: {np.mean(repeat_val_bce_list):.3f}±{np.std(repeat_val_bce_list):.3f}')
    print(f'Validation - ACC: {np.mean(repeat_val_acc_list):.3f}±{np.std(repeat_val_acc_list):.3f}')
    print(f'Validation - F1: {np.mean(repeat_val_f1_list):.3f}±{np.std(repeat_val_f1_list):.3f}')
    print(f'Validation - ROC_AUC: {np.mean(repeat_val_roc_auc_list):.3f}±{np.std(repeat_val_roc_auc_list):.3f}')

    print(f'test - BCE: {np.mean(repeat_bce_list):.3f}±{np.std(repeat_bce_list):.3f}')
    print(f'test - ACC: {np.mean(repeat_acc_list):.3f}±{np.std(repeat_acc_list):.3f}')
    print(f'test - F1: {np.mean(repeat_f1_list):.3f}±{np.std(repeat_f1_list):.3f}')
    print(f'test - ROC_AUC: {np.mean(repeat_roc_auc_list):.3f}±{np.std(repeat_roc_auc_list):.3f}')

val_bce_arr = np.array(val_bce_list)
val_mean_bce = np.mean(val_bce_arr)
val_sd_bce = np.std(val_bce_arr)
print(f'validation bce:{val_mean_bce:.3f}±{val_sd_bce:.3f}')

val_acc_arr = np.array(val_acc_list)
val_acc_mean= np.mean(val_acc_arr)
val_acc_sd = np.std(val_acc_arr)
print(f'validation acc:{val_acc_mean:.3f}±{val_acc_sd:.3f}')

val_f1_arr = np.array(val_f1_list)
val_f1_mean= np.mean(val_f1_arr)
val_f1_sd = np.std(val_f1_arr)
print(f'validation f1: {val_f1_mean:.3f}±{val_f1_sd:.3f}')

val_roc_auc_arr = np.array(val_roc_auc_list)
val_roc_auc_mean= np.mean(val_roc_auc_arr)
val_roc_auc_sd = np.std(val_roc_auc_arr)
print(f'validation roc_auc: {val_roc_auc_mean:.3f}±{val_roc_auc_sd:.3f}')

bce_arr = np.array(bce_list)
mean_bce = np.mean(bce_arr)
sd_bce = np.std(bce_arr)
print(f'bce:{mean_bce:.3f}±{sd_bce:.3f}')

acc_arr = np.array(acc_list)
acc_mean= np.mean(acc_arr)
acc_sd = np.std(acc_arr)
print(f'acc:{acc_mean:.3f}±{acc_sd:.3f}')

f1_arr = np.array(f1_list)
f1_mean= np.mean(f1_arr)
f1_sd = np.std(f1_arr)
print(f'f1: {f1_mean:.3f}±{f1_sd:.3f}')

roc_auc_arr = np.array(roc_auc_list)
roc_auc_mean= np.mean(roc_auc_arr)
roc_auc_sd = np.std(roc_auc_arr)
print(f'roc_auc: {roc_auc_mean:.3f}±{roc_auc_sd:.3f}')

print("Training Completed!")

# Effect of epoches before triggering the early stopping mechanism  to transfer learning prediction performance 

In [None]:
train_data_root_path = './data/graph_data/data_DICT_train/'
train_data_raw_filename = 'data_DICT_train.csv'
test_data_root_path = './data/graph_data/data_DICT_test'
test_data_raw_filename = 'data_DICT_test.csv'
n_repetitions = 1
method_tf = 'fine_tune_2x'
params = best_params_vertical
es_trigger = 0  # Setup on demand
path_to_pretrained_model = './trf_learning_models/pretrained_models/vertical/'
path_to_save_trained_model = './trf_learning_models/trained_models/vertical/pretrained_40/'
path_to_trained_model = './trf_learning_models/trained_models/vertical/pretrained_40/'

val_bce_list = []
val_acc_list = []
val_f1_list = []
val_roc_auc_list = []

bce_list = []
acc_list = []
f1_list = []
roc_auc_list = []

dataset_for_cv = LoadDICTDataset(train_data_root_path, train_data_raw_filename)
kf = KFold(n_splits=N_SPLITS)

for repeat in range(n_repetitions):
    repeat_val_bce_list = []
    repeat_val_acc_list = []
    repeat_val_f1_list = []
    repeat_val_roc_auc_list = []
    
    repeat_bce_list = []
    repeat_acc_list = []
    repeat_f1_list = []
    repeat_roc_auc_list = []
    
    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        seed_everything(SEED_NO)
        train_dataset = []
        valid_dataset = []
        
        for t_idx in train_idx:
            train_dataset.append(
                torch.load(
                    f"./data/graph_data/data_DICT_train/processed/molecule_{t_idx}.pt"
                )
            )
        for v_idx in valid_idx:
            valid_dataset.append(
                torch.load(
                    f"./data/graph_data/data_DICT_train/processed/molecule_{v_idx}.pt"
                )
            )

        train_loader = DataLoader(
            train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True
        )
        valid_loader = DataLoader(
            valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        test_dataset = LoadDICTDataset(test_data_root_path, test_data_raw_filename)
        test_loader = DataLoader(
            test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )
        print(f'Rep no {repeat}, Fold no {fold_no}')
        
        run_training(method_tf, train_loader, valid_loader, params, es_trigger, os.path.join(path_to_pretrained_model, f'pretrained_vertical_model_40_epoch.pt'),
            os.path.join(
                path_to_save_trained_model, f"trained_vertical_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"
            )
        )
        
        val_bce, val_acc, val_f1, val_roc_auc = run_validation(method_tf, valid_loader, params, 
                        os.path.join(path_to_save_trained_model, f"trained_vertical_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"))
        bce, acc, f1, roc_auc = run_testing(method_tf, test_loader, params, 
                        os.path.join(path_to_save_trained_model, f"trained_vertical_model_{method_tf}_repeat_{repeat}_fold_{fold_no}_{es_trigger}_es_trigger.pt"))

        repeat_val_bce_list.append(val_bce)
        repeat_val_acc_list.append(val_acc)
        repeat_val_f1_list.append(val_f1)
        repeat_val_roc_auc_list.append(val_roc_auc)
        
        repeat_bce_list.append(bce)
        repeat_acc_list.append(acc)
        repeat_f1_list.append(f1)
        repeat_roc_auc_list.append(roc_auc)
        
        val_bce_list.append(val_bce)
        val_acc_list.append(val_acc)
        val_f1_list.append(val_f1)
        val_roc_auc_list.append(val_roc_auc)
        
        bce_list.append(bce)
        acc_list.append(acc)
        f1_list.append(f1)
        roc_auc_list.append(roc_auc)

    # Output statistics for validation and CV results for the repeat
    print(f'Statistics for repeat {repeat}:')
    print(f'Validation - BCE: {np.mean(repeat_val_bce_list):.3f}±{np.std(repeat_val_bce_list):.3f}')
    print(f'Validation - ACC: {np.mean(repeat_val_acc_list):.3f}±{np.std(repeat_val_acc_list):.3f}')
    print(f'Validation - F1: {np.mean(repeat_val_f1_list):.3f}±{np.std(repeat_val_f1_list):.3f}')
    print(f'Validation - ROC_AUC: {np.mean(repeat_val_roc_auc_list):.3f}±{np.std(repeat_val_roc_auc_list):.3f}')

    print(f'test - BCE: {np.mean(repeat_bce_list):.3f}±{np.std(repeat_bce_list):.3f}')
    print(f'test - ACC: {np.mean(repeat_acc_list):.3f}±{np.std(repeat_acc_list):.3f}')
    print(f'test - F1: {np.mean(repeat_f1_list):.3f}±{np.std(repeat_f1_list):.3f}')
    print(f'test - ROC_AUC: {np.mean(repeat_roc_auc_list):.3f}±{np.std(repeat_roc_auc_list):.3f}')

val_bce_arr = np.array(val_bce_list)
val_mean_bce = np.mean(val_bce_arr)
val_sd_bce = np.std(val_bce_arr)
print(f'validation bce:{val_mean_bce:.3f}±{val_sd_bce:.3f}')

val_acc_arr = np.array(val_acc_list)
val_acc_mean= np.mean(val_acc_arr)
val_acc_sd = np.std(val_acc_arr)
print(f'validation acc:{val_acc_mean:.3f}±{val_acc_sd:.3f}')

val_f1_arr = np.array(val_f1_list)
val_f1_mean= np.mean(val_f1_arr)
val_f1_sd = np.std(val_f1_arr)
print(f'validation f1: {val_f1_mean:.3f}±{val_f1_sd:.3f}')

val_roc_auc_arr = np.array(val_roc_auc_list)
val_roc_auc_mean= np.mean(val_roc_auc_arr)
val_roc_auc_sd = np.std(val_roc_auc_arr)
print(f'validation roc_auc: {val_roc_auc_mean:.3f}±{val_roc_auc_sd:.3f}')

bce_arr = np.array(bce_list)
mean_bce = np.mean(bce_arr)
sd_bce = np.std(bce_arr)
print(f'bce:{mean_bce:.3f}±{sd_bce:.3f}')

acc_arr = np.array(acc_list)
acc_mean= np.mean(acc_arr)
acc_sd = np.std(acc_arr)
print(f'acc:{acc_mean:.3f}±{acc_sd:.3f}')

f1_arr = np.array(f1_list)
f1_mean= np.mean(f1_arr)
f1_sd = np.std(f1_arr)
print(f'f1: {f1_mean:.3f}±{f1_sd:.3f}')

roc_auc_arr = np.array(roc_auc_list)
roc_auc_mean= np.mean(roc_auc_arr)
roc_auc_sd = np.std(roc_auc_arr)
print(f'roc_auc: {roc_auc_mean:.3f}±{roc_auc_sd:.3f}')

print("Training Completed!")