# Pretraining VerticalGNN Models with hERG dataset

In [3]:
import torch
import numpy as np
import os

from torch_geometric.loader import DataLoader
from model import VerticalGNN
from config import NUM_FEATURES, NUM_TARGET, EDGE_DIM, DEVICE, SEED_NO, PATIENCE, EPOCHS, NUM_GRAPHS_PER_BATCH, N_SPLITS, best_params_vertical
from engine import EnginehERG
from utils import seed_everything, LoadhERGDataset


# Pre-training models with different epochs values 

In [4]:
def run_training_trf_learning_model(train_loader, params, pretrained_model_path, epochs):

    '''
    Define function to pretrain model with solubililty dataset 

    Args:
    train_loader: DataLoader class from pytorch geometric containing train dataset
    params (dict): Dictionary containing hyperparameters
    pretrained_model_path (str): path to save the pretrained model
    epochs (int): Number of epochs to pretrain the model

    Return:
    loss: final train loss  
    '''

    model = VerticalGNN(num_features=NUM_FEATURES, num_targets=NUM_TARGET, num_gin_layers=params['num_gin_layers'], num_graph_trans_layers=params['num_graph_trans_layers'], hidden_size=params['hidden_size'], 
                        n_heads=params['n_heads'], dropout=params['dropout'], edge_dim=EDGE_DIM)         
    model.to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(),lr = params['learning_rate'])
    eng = EnginehERG(model, optimizer, device=DEVICE)

    for epoch in range(epochs):
        train_loss = eng.train(train_loader)
        print(f'Epoch: {epoch+1}/{epochs}, train loss : {train_loss}')
        #print('Saving model...')
        # Check and create directory before saving
        if not os.path.exists(os.path.dirname(pretrained_model_path)):
            os.makedirs(os.path.dirname(pretrained_model_path), exist_ok=True)
        print('Saving model...')
        torch.save(model.state_dict(), pretrained_model_path)

        #torch.save(model.state_dict(), pretrained_model_path)

    return train_loss

In [None]:
epochs = 10 #Setup on demand, such as 10-100
params = best_params_vertical

seed_everything(SEED_NO)
train_dataset_mid = LoadhERGDataset(root='./data/graph_data/trf_learning_hERG/', raw_filename='hERG_data_for_pretrained_model.csv')
train_loader_mid = DataLoader(train_dataset_mid, batch_size=64, shuffle=True)
pretrained_model_path = f'./trf_learning_models/pretrained_models/vertical/pretrained_vertical_model_{epochs}_epoch.pt'

train_loss = run_training_trf_learning_model(train_loader_mid, params, pretrained_model_path, epochs)
print(f'train loss: {train_loss}')