# Vertical GNN Model for DICT Dataset 

In [5]:
from utils import seed_everything, LoadDICTDataset
from config import SEED_NO, NUM_FEATURES, NUM_GRAPHS_PER_BATCH, NUM_TARGET, EDGE_DIM, DEVICE, PATIENCE, EPOCHS, N_SPLITS, params_vertical_gnn
from engine import EngineDICT
from model import VerticalGNN

import torch
import numpy as np
import optuna
from sklearn.model_selection import KFold
from torch_geometric.loader import DataLoader
import os 

## 2. Tuning the Model

In [2]:
def run_tuning(train_loader, valid_loader, params):
    model = VerticalGNN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_gin_layers=params['num_gin_layers'], num_graph_trans_layers=params['num_graph_trans_layers'], 
                            hidden_size=params['hidden_size'], n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineDICT(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0 

    for epoch in range(EPOCHS):
        train_loss = eng.train(train_loader)
        valid_loss_tuple = eng.validate(valid_loader)
        valid_loss = valid_loss_tuple[0]
        print(f'Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}')
        if valid_loss < best_loss:
            best_loss = valid_loss 
            early_stopping_counter=0

        else:
            early_stopping_counter +=1

        if early_stopping_counter > early_stopping_iter:
            print('Early stopping...')
            break
        print(f'Early stop counter: {early_stopping_counter}')
    
    return best_loss


In [None]:
def objective(trial):
    params = {
        'num_gin_layers' : trial.suggest_categorical('num_gin_layers', [1, 2, 3]),
        'num_graph_trans_layers' : trial.suggest_categorical('num_graph_trans_layers', [1, 2, 3]),
        'hidden_size' : trial.suggest_categorical('hidden_size', [64, 128, 256]),
        'n_heads' : trial.suggest_categorical('n_heads', [1, 2, 3]),
        'dropout': trial.suggest_categorical('dropout', [0.1, 0.2, 0.3, 0.4]),
        'learning_rate' : trial.suggest_categorical('learning_rate', [1e-3, 3e-3, 5e-3, 7e-3, 9e-3])
    }
    
    
    #load dataset 
    dataset_for_cv = LoadDICTDataset(root='./data/graph_data/data_DICT_train/', raw_filename='data_DICT_train.csv')
    kf = KFold(n_splits=N_SPLITS)
    fold_loss = 0

    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        print(f'Fold {fold_no}')
        train_dataset= []
        valid_dataset = []
        for t_idx in train_idx:
            train_dataset.append(torch.load(f'./data/graph_data/data_DICT_train/processed/molecule_{t_idx}.pt'))
        for v_idx in valid_idx:
            valid_dataset.append(torch.load(f'./data/graph_data/data_DICT_train/processed/molecule_{v_idx}.pt'))

        train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)

        loss = run_tuning(train_loader, valid_loader, params)
        fold_loss += loss

    return fold_loss/10
if __name__ == '__main__':
    study = optuna.create_study(direction = 'minimize')
    study.optimize(objective, n_trials=30)
    print(f'best trial:')
    trial_ = study.best_trial
    print(trial_.values)
    print(f'Best parameters: {trial_.params}')

## 3. Train/validate/test model

In [6]:
def run_training(train_loader, valid_loader, params, trained_model_path):
    model = VerticalGNN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_gin_layers=params['num_gin_layers'], num_graph_trans_layers=params['num_graph_trans_layers'], 
                            hidden_size=params['hidden_size'], n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineDICT(model, optimizer, device=DEVICE)

    best_loss = np.inf
    early_stopping_iter = PATIENCE
    early_stopping_counter = 0 

    for epoch in range(EPOCHS):
        train_loss = eng.train(train_loader)
        valid_loss_tuple = eng.validate(valid_loader)
        valid_loss = valid_loss_tuple[0]
        print(f'Epoch: {epoch+1}/{EPOCHS}, train loss : {train_loss}, validation loss : {valid_loss}')
        if valid_loss < best_loss:
            best_loss = valid_loss 
            early_stopping_counter=0 #reset counter
            print('Saving model...')
            
            os.makedirs(os.path.dirname(trained_model_path), exist_ok=True)

            torch.save(model.state_dict(), trained_model_path)
        else:
            early_stopping_counter +=1

        if early_stopping_counter > early_stopping_iter:
            print('Early stopping...')
            break
        print(f'Early stop counter: {early_stopping_counter}')
    
    return best_loss

def run_validation(valid_loader, params, trained_model_path):
    model = VerticalGNN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_gin_layers=params['num_gin_layers'], num_graph_trans_layers=params['num_graph_trans_layers'], 
                            hidden_size=params['hidden_size'], n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.load_state_dict(torch.load(trained_model_path))
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineDICT(model, optimizer, device=DEVICE)
    bce, acc, f1, roc_auc = eng.validate(valid_loader)
    print(f"bce:{bce}, acc :{acc}, f1: {f1}, roc_auc: {roc_auc}")
    return bce, acc, f1, roc_auc


def run_testing(test_loader, params, trained_model_path):
    model = VerticalGNN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_gin_layers=params['num_gin_layers'], num_graph_trans_layers=params['num_graph_trans_layers'], 
                            hidden_size=params['hidden_size'], n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)
    model.load_state_dict(torch.load(trained_model_path))
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EngineDICT(model, optimizer, device=DEVICE)

    #print('Begin testing...')
    bce, acc, f1, roc_auc, all_labels, all_predictions = eng.test(test_loader)
    #print('Test completed!')
    print(f'bce:{bce}, acc:{acc}, f1: {f1}, roc_auc: {roc_auc}')
    
    # Save the predicted labels to a CSV file
    predictions_file = f"{trained_model_path[:-3]}_pre_test.csv"  # Remove .pt and add suffix
    save_predictions_to_csv(all_labels, all_predictions, predictions_file)
    
    return bce, acc, f1, roc_auc

def save_predictions_to_csv(labels, predictions, file_path):
    """Save the test_labels and predictions to a CSV."""
    import pandas as pd
    df = pd.DataFrame({
        "True_Labels": labels,
        "Predictions": predictions
    })
    df.to_csv(file_path, index=False)
    print(f"Prediction results saved to {file_path}")

In [None]:
params = params_vertical_gnn
n_repetitions = 1
train_data_root_path = './data/graph_data/data_DICT_train/'
train_data_raw_filename = 'data_DICT_train.csv'
test_data_root_path = './data/graph_data/data_DICT_test/'
test_data_raw_filename = 'data_DICT_test.csv'
path_to_save_trained_model = './GNN_models/vertical/'

val_bce_list = []
val_acc_list = []
val_f1_list = []
val_roc_auc_list = []

bce_list = []
acc_list = []
f1_list = []
roc_auc_list = []

dataset_for_cv = LoadDICTDataset(root=train_data_root_path, raw_filename=train_data_raw_filename)
test_dataset = LoadDICTDataset(root=test_data_root_path, raw_filename=test_data_raw_filename)
kf = KFold(n_splits=N_SPLITS)

for repeat in range(n_repetitions):
    repeat_val_bce_list = []
    repeat_val_acc_list = []
    repeat_val_f1_list = []
    repeat_val_roc_auc_list = []
    
    repeat_bce_list = []
    repeat_acc_list = []
    repeat_f1_list = []
    repeat_roc_auc_list = []

    for fold_no, (train_idx, valid_idx) in enumerate(kf.split(dataset_for_cv)):
        seed_everything(SEED_NO)
        train_dataset= []
        valid_dataset = []
        for t_idx in train_idx:
            train_dataset.append(torch.load(f'./data/graph_data/data_DICT_train/processed/molecule_{t_idx}.pt'))
        for v_idx in valid_idx:
            valid_dataset.append(torch.load(f'./data/graph_data/data_DICT_train/processed/molecule_{v_idx}.pt'))

        train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)

        run_training(train_loader, valid_loader, params, os.path.join(path_to_save_trained_model, f'vertical_repeat_{repeat}_fold_{fold_no}.pt'))
        val_bce, val_acc, val_f1, val_roc_auc = run_validation(valid_loader, params, os.path.join(path_to_save_trained_model, f'vertical_repeat_{repeat}_fold_{fold_no}.pt'))
        bce, acc, f1, roc_auc = run_testing(test_loader, params, os.path.join(path_to_save_trained_model, f'vertical_repeat_{repeat}_fold_{fold_no}.pt'))
        
        repeat_val_bce_list.append(val_bce)
        repeat_val_acc_list.append(val_acc)
        repeat_val_f1_list.append(val_f1)
        repeat_val_roc_auc_list.append(val_roc_auc)
        
        repeat_bce_list.append(bce)
        repeat_acc_list.append(acc)
        repeat_f1_list.append(f1)
        repeat_roc_auc_list.append(roc_auc)
        
        val_bce_list.append(val_bce)
        val_acc_list.append(val_acc)
        val_f1_list.append(val_f1)
        val_roc_auc_list.append(val_roc_auc)
        
        bce_list.append(bce)
        acc_list.append(acc)
        f1_list.append(f1)
        roc_auc_list.append(roc_auc)

    # Output statistics for validation and CV results for the repeat
    print(f'Statistics for repeat {repeat}:')
    print(f'Validation - BCE: {np.mean(repeat_val_bce_list):.3f}±{np.std(repeat_val_bce_list):.3f}')
    print(f'Validation - ACC: {np.mean(repeat_val_acc_list):.3f}±{np.std(repeat_val_acc_list):.3f}')
    print(f'Validation - F1: {np.mean(repeat_val_f1_list):.3f}±{np.std(repeat_val_f1_list):.3f}')
    print(f'Validation - ROC_AUC: {np.mean(repeat_val_roc_auc_list):.3f}±{np.std(repeat_val_roc_auc_list):.3f}')

    print(f'test - BCE: {np.mean(repeat_bce_list):.3f}±{np.std(repeat_bce_list):.3f}')
    print(f'test - ACC: {np.mean(repeat_acc_list):.3f}±{np.std(repeat_acc_list):.3f}')
    print(f'test - F1: {np.mean(repeat_f1_list):.3f}±{np.std(repeat_f1_list):.3f}')
    print(f'test - ROC_AUC: {np.mean(repeat_roc_auc_list):.3f}±{np.std(repeat_roc_auc_list):.3f}')

val_bce_arr = np.array(val_bce_list)
val_mean_bce = np.mean(val_bce_arr)
val_sd_bce = np.std(val_bce_arr)
print(f'validation bce:{val_mean_bce:.3f}±{val_sd_bce:.3f}')

val_acc_arr = np.array(val_acc_list)
val_acc_mean= np.mean(val_acc_arr)
val_acc_sd = np.std(val_acc_arr)
print(f'validation acc:{val_acc_mean:.3f}±{val_acc_sd:.3f}')

val_f1_arr = np.array(val_f1_list)
val_f1_mean= np.mean(val_f1_arr)
val_f1_sd = np.std(val_f1_arr)
print(f'validation f1: {val_f1_mean:.3f}±{val_f1_sd:.3f}')

val_roc_auc_arr = np.array(val_roc_auc_list)
val_roc_auc_mean= np.mean(val_roc_auc_arr)
val_roc_auc_sd = np.std(val_roc_auc_arr)
print(f'validation roc_auc: {val_roc_auc_mean:.3f}±{val_roc_auc_sd:.3f}')

bce_arr = np.array(bce_list)
mean_bce = np.mean(bce_arr)
sd_bce = np.std(bce_arr)
print(f'bce:{mean_bce:.3f}±{sd_bce:.3f}')

acc_arr = np.array(acc_list)
acc_mean= np.mean(acc_arr)
acc_sd = np.std(acc_arr)
print(f'acc:{acc_mean:.3f}±{acc_sd:.3f}')

f1_arr = np.array(f1_list)
f1_mean= np.mean(f1_arr)
f1_sd = np.std(f1_arr)
print(f'f1: {f1_mean:.3f}±{f1_sd:.3f}')

roc_auc_arr = np.array(roc_auc_list)
roc_auc_mean= np.mean(roc_auc_arr)
roc_auc_sd = np.std(roc_auc_arr)
print(f'roc_auc: {roc_auc_mean:.3f}±{roc_auc_sd:.3f}')