# Apply TL-GNN Model to Predict Given Chemicals with SMILES

In [1]:
import torch
import numpy as np
import pandas as pd
from torch_geometric.loader import DataLoader
from sklearn.model_selection import KFold
import os

from model import VerticalGNN
from config import NUM_FEATURES, NUM_TARGET, EDGE_DIM, DEVICE, SEED_NO, PATIENCE, EPOCHS, NUM_GRAPHS_PER_BATCH, N_SPLITS, best_params_vertical
from engine import EnginehERG, EngineDICT
from utils import seed_everything, LoadDICTDataset, LoadhERGDataset, LoadSMILESDataset

## Predicting application

In [None]:
test_data_root_path = './data/graph_data/pre'
test_data_raw_filename = 'test.csv'

path_to_best_model= './trf_learning_models/trained_models/vertical/mid_10x/pretrained_40/trained_vertical_model_fine_tune_2x_repeat_4_fold_1_20_es_trigger.pt'
params = best_params_vertical
method_tf = 'fine_tune_2x'

model = VerticalGNN(
            num_features=NUM_FEATURES,
            num_targets=NUM_TARGET,
            num_gin_layers=params["num_gin_layers"],
            num_graph_trans_layers=params["num_graph_trans_layers"],
            hidden_size=params["hidden_size"],
            n_heads=params["n_heads"],
            dropout=params["dropout"],
            edge_dim=EDGE_DIM,
        )

model.load_state_dict(torch.load(path_to_best_model))
model.to(DEVICE)
model.eval()

test_dataset = LoadSMILESDataset(test_data_root_path, test_data_raw_filename)
test_loader = DataLoader(
            test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False
        )

predictions = []
with torch.no_grad():
    for data in test_loader:
        data = data.to(DEVICE)
        out = model(data.x, data.edge_attr, data.edge_index, data.batch)
        prob = torch.sigmoid(out)
        pred = (prob > 0.5).long().squeeze()
        predictions.extend(pred.cpu().tolist())

print("Predictions:", predictions)

predictions_df = pd.DataFrame(predictions, columns=['Predicted Label'])

predictions_df.to_csv('predictions_test.csv', index=False)

Processing...
100%|██████████| 1883/1883 [00:46<00:00, 40.12it/s]
Done!


Predictions: [1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,