# <center>**Build GNN Model**</center>  
**Author**: Shirshak Aryal  
**Last Updated**: 18 July 2025

---
**Purpose:** This notebook is dedicated to the development, training, and evaluation of a Graph Neural Network (GNN) model for `pGI50` prediction. It covers creating the graph data objects (and saving them), defining the GNN architecture, optimizing hyperparameters, training the final model with optimal parameters, and comprehensively evaluating its performance on unseen test data.

---

## 1. Setup Notebook
This section initializes the notebook environment by importing all necessary libraries, configuring system and PyTorch-specific settings for performance, defining the project path for module imports, and establishing global parameters and file paths.

### 1.1. Configure Environment
This sub-section configures environment variables for CPU usage optimization and sets up PyTorch-specific thread management. It also ensures the project's root directory is added to the system path for proper module imports.

In [1]:
import os
import sys
import time
from pathlib import Path

# General CPU Usage Optimization
os.environ["OMP_NUM_THREADS"] = "16"
os.environ["MKL_NUM_THREADS"] = "16"
os.environ["OPENBLAS_NUM_THREADS"] = "16"
os.environ["NUMEXPR_NUM_THREADS"] = "16"

# PyTorch-specific CPU Usage Optimization (if not using GPU exclusively)
import torch

try:
    torch.set_num_threads(16)
except RuntimeError as e:
    print(f"Warning: Could not set torch.set_num_threads.\n{e}")

try:
    torch.set_num_interop_threads(16)
except RuntimeError as e:
    print(f"Warning: Could not set torch.set_num_interop_threads.\n{e}")

print(f"PyTorch threads: {torch.get_num_threads()}")
print(f"PyTorch interop threads: {torch.get_num_interop_threads()}")


# Configure Project Path
from pathlib import Path

# Get the current working directory
current_dir = os.getcwd()

# Navigate up to the project root directory
project_root = Path(current_dir).parent.resolve()

# Add the project root to sys.path if it's not already there
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

print(f"Project root added to sys.path: {project_root}")

PyTorch threads: 16
PyTorch interop threads: 16
Project root added to sys.path: C:\Users\Acer\Desktop\Projects for Data Science\Drug Gi50 Value Prediction


### 1.2. Import Libraries
All required Python libraries for data manipulation, molecular handling, PyTorch core functionalities, PyTorch Geometric, machine learning utilities, hyperparameter optimization, and general utilities are imported here.

In [2]:
# Standard Library Imports
from datetime import datetime
import subprocess  # For getting Git commit ID

# Core Data Science Libraries
import numpy as np
import pandas as pd

# Molecular Handling (RDKit) Libraries
from rdkit import Chem  # For basic molecule handling
from rdkit.Chem import (
    AllChem,
)  # For atom features like Gasteiger charges, and other utilities

# PyTorch Core Libraries
import torch.nn as nn  # Neural network modules like Linear, ReLU, MSELoss
import torch.nn.functional as F  # Functional interface for activations, e.g. F.ReLU
import torch.optim as optim  # Optimization functions like Adam, AdamW, etc.
from torch.optim import lr_scheduler  # Learning rate scheduling

# PyTorch Geometric (PyG) Libraries
from torch_geometric.data import Data  # The graph data object in PyG
from torch_geometric.loader import (
    DataLoader as PyGDataLoader,
)  # PyG DataLoader for graphs
import torch_geometric.nn as pyg_nn  # Common GNN layers (e.g., GCNConv, GraphSAGEConv)
import torch_geometric.utils as pyg_utils  # Utility functions for graph manipulation

# Machine Learning Utilities
from sklearn.metrics import mean_squared_error, r2_score  # For model evaluation metrics
from sklearn.preprocessing import (
    StandardScaler,
)  # For feature scaling

# Hyperparameter Optimization Libraries
import optuna

# Conditional import for progress bars (tqdm)
tqdm_notebook_available = False  # Initialize flag
try:
    from tqdm.notebook import tqdm

    tqdm.pandas()  # Enable tqdm for pandas apply
    tqdm_notebook_available = True
    print("tqdm.notebook found and enabled for pandas.")
except ImportError:
    print("tqdm.notebook not found. Install with 'pip install tqdm'.")

# Local Project Imports
from src.models.gnn_models import GNN  # Import your custom GNN model class

tqdm.notebook found and enabled for pandas.


### 1.3. Define Device (GPU/CPU)
This sub-section defines the computational device (GPU if available, otherwise CPU) for PyTorch operations.

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


### 1.4. Set Final Model Save Location

In [4]:
gnn_models_base_dir = Path("../models/gnn")
gnn_models_base_dir.mkdir(parents=True, exist_ok=True)
print(f"The best final GNN model will be saved in: {gnn_models_base_dir}")

The best final GNN model will be saved in: ..\models\gnn


## 2. Load Data Splits
This section loads the pre-engineered and split datasets (training, validation, and test sets for both features and target variable) that were prepared in the previous notebook.

In [6]:
splits_dir = Path("../data/splits")
print(f"\nLoading data splits from {splits_dir}...")

try:
    X_train = pd.read_parquet(splits_dir / "X_train.parquet")
    X_val = pd.read_parquet(splits_dir / "X_val.parquet")
    X_test = pd.read_parquet(splits_dir / "X_test.parquet")
    
    y_train = pd.read_parquet(splits_dir / "y_train.parquet")
    y_val = pd.read_parquet(splits_dir / "y_val.parquet")
    y_test = pd.read_parquet(splits_dir / "y_test.parquet")
    print("Data splits loaded successfully.")
except FileNotFoundError:
    print(f"Error: One or more split files not found in '{splits_dir}'.")
    print("Please ensure you have run '02_Split_Features.ipynb' to generate and save the splits.")

print(f"X_train shape: {X_train.shape}")
print(f"X_val shape: {X_val.shape}")
print(f"X_test shape: {X_test.shape}")

print(f"y_train shape: {y_train.shape}")
print(f"y_val shape: {y_val.shape}")
print(f"y_test shape: {y_test.shape}")

# Display first few rows to verify data
print("\nFirst 5 rows of X_train:")
display(X_train.head())

print("\nFirst 5 rows of y_train:")
display(y_train.head())


Loading data splits from ..\data\splits...
Data splits loaded successfully.
X_train shape: (13119, 2268)
X_val shape: (2812, 2268)
X_test shape: (2812, 2268)
y_train shape: (13119, 1)
y_val shape: (2812, 1)
y_test shape: (2812, 1)

First 5 rows of X_train:


Unnamed: 0,molregno,canonical_smiles,num_activities,MaxAbsEStateIndex,MaxEStateIndex,MinAbsEStateIndex,MinEStateIndex,qed,SPS,MolWt,...,morgan_fp_2038,morgan_fp_2039,morgan_fp_2040,morgan_fp_2041,morgan_fp_2042,morgan_fp_2043,morgan_fp_2044,morgan_fp_2045,morgan_fp_2046,morgan_fp_2047
0,2307646,COc1cccc2c1OCc1c-2nc2cnc3ccccc3c2c1C,6,6.033142,6.033142,0.494176,0.494176,0.476742,12.56,328.371,...,0,0,0,0,0,0,0,0,0,0
1,2081122,COc1cc(/C(C#N)=C/c2ccc3c(c2)OCCO3)cc(OC)c1OC,9,9.645791,9.645791,0.459195,0.459195,0.604738,12.923077,353.374,...,0,0,0,0,0,0,0,0,0,0
2,2199496,COC(=O)[C@@H]1CCCN1Cc1ccc(-c2ncc(-c3ccc(OCC=C(...,6,11.953178,11.953178,0.169552,-0.173158,0.359463,15.909091,447.535,...,0,0,0,0,0,0,0,0,0,0
3,2221960,O=C(/C=C/c1cccn(C/C=C/c2ccccc2Br)c1=O)NO,4,12.253458,12.253458,0.216419,-0.686457,0.479732,11.217391,375.222,...,0,0,0,0,0,0,0,0,0,0
4,2879093,Cc1cc(C2c3c(-c4cccc5[nH]c(=O)oc45)n[nH]c3C(=O)...,2,14.128489,14.128489,0.124437,-3.116139,0.437556,16.121212,472.879,...,0,0,0,0,0,0,0,0,0,0



First 5 rows of y_train:


Unnamed: 0,pGI50
14387,5.734742
12543,7.164746
12810,4.928428
13172,6.882724
18712,6.094208


## 3. Prepare Data for GNN
This section performs specific data preparation steps required to transform the molecular data into PyTorch Geometric graph objects, suitable for GNN input.

### 3.1. Extract Global Features of Each Molecule
This sub-section extracts global molecular features (i.e., all the RDKit descriptors and Morgan fingerprints) from the dataset, which will be incorporated into the graph objects.

In [7]:
print("\n--- Excracting global features of each molecule ---")

# Identify columns for global features
# These are all columns in X_train EXCEPT 'molregno' and 'canonical_smiles'
global_feature_columns = X_train.drop(columns=['molregno', 'canonical_smiles'], errors='ignore').columns.tolist()

print(f"Identified {len(global_feature_columns)} global feature columns for GNN.")
print(f"Global feature columns: {global_feature_columns}")

# Extract global features into new DataFrames
# These DataFrames will be the source for data.global_features in GNN Data objects
X_train_global_features = X_train[global_feature_columns]
X_val_global_features = X_val[global_feature_columns]
X_test_global_features = X_test[global_feature_columns]

print("\nExtracted global features for each split:")
print(f"X_train_global_features shape: {X_train_global_features.shape}")
print(f"X_val_global_features shape: {X_val_global_features.shape}")
print(f"X_test_global_features shape: {X_test_global_features.shape}")

print("\nFirst 5 rows of X_train_global_features:")
display(X_train_global_features.head())


# Confirm original X DataFrames (with molregno and canonical_smiles) are still available
# These will be used for iterating and building individual graph objects.
print("\nOriginal X DataFrames (with molregno and canonical_smiles) are retained and ready for graph construction:")
print(f"X_train shape: {X_train.shape}")
print(f"X_val shape: {X_val.shape}")
print(f"X_test shape: {X_test.shape}")
print(f"y_train shape: {y_train.shape}, y_val shape: {y_val.shape}, y_test shape: {y_test.shape}")

display(X_train.head()) # Show that original X_train still has molregno and canonical_smiles
display(y_train.head()) # Show the pGI50 target values


--- Excracting global features of each molecule ---
Identified 2266 global feature columns for GNN.
Global feature columns: ['num_activities', 'MaxAbsEStateIndex', 'MaxEStateIndex', 'MinAbsEStateIndex', 'MinEStateIndex', 'qed', 'SPS', 'MolWt', 'HeavyAtomMolWt', 'ExactMolWt', 'NumValenceElectrons', 'NumRadicalElectrons', 'MaxPartialCharge', 'MinPartialCharge', 'MaxAbsPartialCharge', 'MinAbsPartialCharge', 'FpDensityMorgan1', 'FpDensityMorgan2', 'FpDensityMorgan3', 'BCUT2D_MWHI', 'BCUT2D_MWLOW', 'BCUT2D_CHGHI', 'BCUT2D_CHGLO', 'BCUT2D_LOGPHI', 'BCUT2D_LOGPLOW', 'BCUT2D_MRHI', 'BCUT2D_MRLOW', 'AvgIpc', 'BalabanJ', 'BertzCT', 'Chi0', 'Chi0n', 'Chi0v', 'Chi1', 'Chi1n', 'Chi1v', 'Chi2n', 'Chi2v', 'Chi3n', 'Chi3v', 'Chi4n', 'Chi4v', 'HallKierAlpha', 'Ipc', 'Kappa1', 'Kappa2', 'Kappa3', 'LabuteASA', 'PEOE_VSA1', 'PEOE_VSA10', 'PEOE_VSA11', 'PEOE_VSA12', 'PEOE_VSA13', 'PEOE_VSA14', 'PEOE_VSA2', 'PEOE_VSA3', 'PEOE_VSA4', 'PEOE_VSA5', 'PEOE_VSA6', 'PEOE_VSA7', 'PEOE_VSA8', 'PEOE_VSA9', 'SMR_VSA1

Unnamed: 0,num_activities,MaxAbsEStateIndex,MaxEStateIndex,MinAbsEStateIndex,MinEStateIndex,qed,SPS,MolWt,HeavyAtomMolWt,ExactMolWt,...,morgan_fp_2038,morgan_fp_2039,morgan_fp_2040,morgan_fp_2041,morgan_fp_2042,morgan_fp_2043,morgan_fp_2044,morgan_fp_2045,morgan_fp_2046,morgan_fp_2047
0,6,6.033142,6.033142,0.494176,0.494176,0.476742,12.56,328.371,312.243,328.121178,...,0,0,0,0,0,0,0,0,0,0
1,9,9.645791,9.645791,0.459195,0.459195,0.604738,12.923077,353.374,334.222,353.126323,...,0,0,0,0,0,0,0,0,0,0
2,6,11.953178,11.953178,0.169552,-0.173158,0.359463,15.909091,447.535,418.303,447.215806,...,0,0,0,0,0,0,0,0,0,0
3,4,12.253458,12.253458,0.216419,-0.686457,0.479732,11.217391,375.222,360.102,374.026604,...,0,0,0,0,0,0,0,0,0,0
4,2,14.128489,14.128489,0.124437,-3.116139,0.437556,16.121212,472.879,453.727,472.111375,...,0,0,0,0,0,0,0,0,0,0



Original X DataFrames (with molregno and canonical_smiles) are retained and ready for graph construction:
X_train shape: (13119, 2268)
X_val shape: (2812, 2268)
X_test shape: (2812, 2268)
y_train shape: (13119, 1), y_val shape: (2812, 1), y_test shape: (2812, 1)


Unnamed: 0,molregno,canonical_smiles,num_activities,MaxAbsEStateIndex,MaxEStateIndex,MinAbsEStateIndex,MinEStateIndex,qed,SPS,MolWt,...,morgan_fp_2038,morgan_fp_2039,morgan_fp_2040,morgan_fp_2041,morgan_fp_2042,morgan_fp_2043,morgan_fp_2044,morgan_fp_2045,morgan_fp_2046,morgan_fp_2047
0,2307646,COc1cccc2c1OCc1c-2nc2cnc3ccccc3c2c1C,6,6.033142,6.033142,0.494176,0.494176,0.476742,12.56,328.371,...,0,0,0,0,0,0,0,0,0,0
1,2081122,COc1cc(/C(C#N)=C/c2ccc3c(c2)OCCO3)cc(OC)c1OC,9,9.645791,9.645791,0.459195,0.459195,0.604738,12.923077,353.374,...,0,0,0,0,0,0,0,0,0,0
2,2199496,COC(=O)[C@@H]1CCCN1Cc1ccc(-c2ncc(-c3ccc(OCC=C(...,6,11.953178,11.953178,0.169552,-0.173158,0.359463,15.909091,447.535,...,0,0,0,0,0,0,0,0,0,0
3,2221960,O=C(/C=C/c1cccn(C/C=C/c2ccccc2Br)c1=O)NO,4,12.253458,12.253458,0.216419,-0.686457,0.479732,11.217391,375.222,...,0,0,0,0,0,0,0,0,0,0
4,2879093,Cc1cc(C2c3c(-c4cccc5[nH]c(=O)oc45)n[nH]c3C(=O)...,2,14.128489,14.128489,0.124437,-3.116139,0.437556,16.121212,472.879,...,0,0,0,0,0,0,0,0,0,0


Unnamed: 0,pGI50
14387,5.734742
12543,7.164746
12810,4.928428
13172,6.882724
18712,6.094208


### 3.2. Create Graph Objects of Each Molecule
This step transforms each molecule's SMILES string into a `torch_geometric.data.Data` object, incorporating atom features, bond features, and the extracted global features.

#### 3.2.1. Define Helper Function to Create Graph Object
A helper function is defined here to encapsulate the logic for converting a single molecule's SMILES and its corresponding features into a PyTorch Geometric `Data` object.

In [None]:
def mol_to_pyg_data(mol, pgi50_value, global_features_vector, molregno, smiles_string):
    if mol is None:
        return None  # Handle cases where SMILES parsing fails

    # Compute Gasteiger charges (how electron-dense the area occupied by this atom is, crucial for interactions)
    try:
        AllChem.ComputeGasteigerCharges(mol)
    except Exception as e:
        print(f"Warning: Could not compute Gasteiger charges for molregno {molregno}: {e}")
        # If computation fails, atoms will default to 0.0 for this property
        pass

    # Node Features (x): Atom Properties
    atom_features = []
    for atom in mol.GetAtoms():
        # Initialize a list for this atom's features
        features = []

        # Atomic Number (int, not one-hot coded)
        features.append(atom.GetAtomicNum())

        # Basic Connectivity
        features.append(atom.GetDegree())  # Num of directly-bonded heavy (non-Hydrogen) atoms
        features.append(atom.GetTotalDegree())  # Total numb of neighbors (including all Hydrogens)

        # Charge and Valence
        features.append(atom.GetFormalCharge())  # Formal charge (integer charge based on bonding rules)
        features.append(atom.GetNumExplicitHs())  # Number of explicitly defined hydrogens attached
        features.append(atom.GetNumImplicitHs())  # Number of hydrogens implicitly defined by valence
        features.append(atom.GetTotalNumHs())  # Total number of hydrogens attached (explicit + implicit)
        features.append(atom.GetValence(Chem.ValenceType.IMPLICIT))  # Implicit Valence: Number of bonds formed by implicit hydrogens
        features.append(atom.GetValence(Chem.ValenceType.EXPLICIT))  # Explicit Valence: Sum of bond orders (1 for single, 2 for double, etc.) to explicitly defined atoms
        features.append(atom.GetTotalValence())  # Total Valence: Total number of bonds (sum of explicit & implicit valence)

        # Hybridization (convert enum to int) (e.g., sp3, sp2)
        features.append(int(atom.GetHybridization()))

        # Aromaticity and Ring Information (boolean converted to int)
        features.append(int(atom.GetIsAromatic()))        # Whether the atom is part of an aromatic system
        features.append(int(atom.IsInRing()))             # Whether the atom is in ANY ring structure
        features.append(int(atom.IsInRingSize(3)))        # Whether the atom is in a 3-membered ring
        features.append(int(atom.IsInRingSize(4)))        # Whether the atom is in a 4-membered ring
        features.append(int(atom.IsInRingSize(5)))        # Whether the atom is in a 5-membered ring
        features.append(int(atom.IsInRingSize(6)))        # Whether the atom is in a 6-membered ring
        features.append(int(atom.IsInRingSize(7)))        # Whether the atom is in a 7-membered ring
        features.append(int(atom.IsInRingSize(8)))        # Whether the atom is in an 8-membered ring

        # Chirality (convert enum to int)(stereochemical information, crucial for biological activity)
        features.append(int(atom.GetChiralTag()))

         # Partial Charges (from Gasteiger calculation)
        gasteiger_charge = 0.0
        if atom.HasProp('_GasteigerCharge'):
            try:
                gasteiger_charge = float(atom.GetProp('_GasteigerCharge'))
            except ValueError:
                pass # Handle potential 'nan' or non-float values gracefully
        features.append(gasteiger_charge)

        # Add to the list of all atom features for this molecule
        atom_features.append(features)
        
    # Convert the list of lists to a PyTorch tensor
    x = torch.tensor(atom_features, dtype=torch.float)

    # Edge Index (edge_index): Bond connectivity
    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices.append([i, j])
        edge_indices.append([j, i]) # Add reverse edge for undirected graph
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()

    # Handle molecules with no bonds (single atom, e.g., for [Ne])
    if edge_index.numel() == 0:
        edge_index = torch.empty((2, 0), dtype=torch.long) # Create an empty edge_index tensor

    # Graph-level Target (y): pGI50
    y = torch.tensor([pgi50_value], dtype=torch.float)

    # Global Features (global_features)
    try:
        global_features_vector = global_features_vector.astype(float)
    except ValueError as e:
        print(f"Error converting global_features_vector to float for molregno {molregno}: {e}")
        
    global_features_tensor = torch.tensor(global_features_vector, dtype=torch.float).unsqueeze(0)

    # Create the PyTorch Geometric Data object
    data = Data(x=x,
                edge_index=edge_index,
                y=y,
                global_features=global_features_tensor,
                molregno=molregno,
                smiles=smiles_string)

    return data

#### 3.2.2. Apply Helper Function to Create Graph Objects
The defined helper function is applied across the entire dataset to generate a list of PyTorch Geometric `Data` objects for each molecule.

In [9]:
train_data_list = []
val_data_list = []
test_data_list = []

# Process Training Data
print("\n--- Creating PyG Data objects for Training Set ---")

print(f"Type of X_train: {type(X_train)}")
print(f"Type of y_train: {type(y_train)}")
print(f"Type of X_train_global_features: {type(X_train_global_features)}")

# Ensure X_train, y_train, X_train_global_features have the same index for alignment
train_df = pd.concat([X_train.reset_index(drop=True),
                      y_train.reset_index(drop=True),
                      X_train_global_features.reset_index(drop=True)],
                     axis=1)
print(f"Length of train_df after concatenation: {len(train_df)}")

successful_train_graphs = 0
for index, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Processing Train Molecules"):
    smiles = row['canonical_smiles']
    molregno = row['molregno']
    pgi50 = row['pGI50']
    
    # Extract global features based on the column names extracted after loading data splits
    global_features_vector = row[global_feature_columns].values

    # Convert SMILES to RDKit Mol object
    mol = Chem.MolFromSmiles(smiles)

    # Create PyG Data object
    pyg_data = mol_to_pyg_data(mol, pgi50, global_features_vector, molregno, smiles)

    if pyg_data is not None and pyg_data.x.numel() > 0: # Ensure valid mol and has nodes
        train_data_list.append(pyg_data)
        successful_train_graphs += 1
    else:
        print(f"Warning: Could not process SMILES: {smiles} (Molregno: {molregno})")

print(f"Successfully created {successful_train_graphs} / {len(train_df)} graph objects for the training set.")
print(f"Total training graphs: {len(train_data_list)}")


# Process Validation Data
print("\n--- Creating PyG Data objects for Validation Set ---")
val_df = pd.concat([X_val.reset_index(drop=True),
                    y_val.reset_index(drop=True),
                    X_val_global_features.reset_index(drop=True)],
                   axis=1)
print(f"Length of val_df after concatenation: {len(val_df)}")

successful_val_graphs = 0
for index, row in tqdm(val_df.iterrows(), total=len(val_df), desc="Processing Validation Molecules"):
    smiles = row['canonical_smiles']
    molregno = row['molregno']
    pgi50 = row['pGI50']
    global_features_vector = row[global_feature_columns].values

    mol = Chem.MolFromSmiles(smiles)
    pyg_data = mol_to_pyg_data(mol, pgi50, global_features_vector, molregno, smiles)

    if pyg_data is not None and pyg_data.x.numel() > 0:
        val_data_list.append(pyg_data)
        successful_val_graphs += 1

print(f"Successfully created {successful_val_graphs} / {len(val_df)} graph objects for the validation set.")
print(f"Total validation graphs: {len(val_data_list)}")


# Process Test Data
print("\n--- Creating PyG Data objects for Test Set ---")
test_df = pd.concat([X_test.reset_index(drop=True),
                     y_test.reset_index(drop=True),
                     X_test_global_features.reset_index(drop=True)],
                    axis=1)
print(f"Length of test_df after concatenation: {len(test_df)}")

successful_test_graphs = 0
for index, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Processing Test Molecules"):
    smiles = row['canonical_smiles']
    molregno = row['molregno']
    pgi50 = row['pGI50']
    global_features_vector = row[global_feature_columns].values

    mol = Chem.MolFromSmiles(smiles)
    pyg_data = mol_to_pyg_data(mol, pgi50, global_features_vector, molregno, smiles)

    if pyg_data is not None and pyg_data.x.numel() > 0:
        test_data_list.append(pyg_data)
        successful_test_graphs += 1

print(f"Successfully created {successful_test_graphs} / {len(test_df)} graph objects for the test set.")
print(f"Total test graphs: {len(test_data_list)}")


--- Creating PyG Data objects for Training Set ---
Type of X_train: <class 'pandas.core.frame.DataFrame'>
Type of y_train: <class 'pandas.core.frame.DataFrame'>
Type of X_train_global_features: <class 'pandas.core.frame.DataFrame'>
Length of train_df after concatenation: 13119


Processing Train Molecules:   0%|          | 0/13119 [00:00<?, ?it/s]

Successfully created 13119 / 13119 graph objects for the training set.
Total training graphs: 13119

--- Creating PyG Data objects for Validation Set ---
Length of val_df after concatenation: 2812


Processing Validation Molecules:   0%|          | 0/2812 [00:00<?, ?it/s]

Successfully created 2812 / 2812 graph objects for the validation set.
Total validation graphs: 2812

--- Creating PyG Data objects for Test Set ---
Length of test_df after concatenation: 2812


Processing Test Molecules:   0%|          | 0/2812 [00:00<?, ?it/s]

Successfully created 2812 / 2812 graph objects for the test set.
Total test graphs: 2812


##### 3.2.2.1. Verify Creation of Graph Objects
Basic checks are performed to verify that the graph objects have been correctly created (by checking the number of nodes, edges, and features for a sample object).

In [10]:
print("\n--- Sample PyTorch Geometric Data object (from train_data_list[0]) ---")
if len(train_data_list) > 0:
    sample_data = train_data_list[0]
    print(sample_data)
    print(f"  Number of nodes (atoms): {sample_data.num_nodes}")
    print(f"  Number of edges (bonds): {sample_data.num_edges}")
    
    # Node features (data.x) details
    print(f"\n  Node features (data.x) shape: {sample_data.x.shape}")
    if sample_data.x.numel() > 0:
        print(f"    Node features (data.x) sample (first 5 values): {sample_data.x.flatten()[:5].tolist()}")
        print(f"    Node features (data.x) min: {sample_data.x.min().item():.4f}")
        print(f"    Node features (data.x) max: {sample_data.x.max().item():.4f}")
        print(f"    Node features (data.x) mean: {sample_data.x.float().mean().item():.4f}")
        print(f"    Node features (data.x) std: {sample_data.x.float().std().item():.4f}")
        print(f"    Contains NaN in data.x: {torch.isnan(sample_data.x).any().item()}")
        print(f"    Contains Inf in data.x: {torch.isinf(sample_data.x).any().item()}")
    else:
        print("    Node features (data.x) is an empty tensor.")

    print(f"  Edge index (data.edge_index) shape: {sample_data.edge_index.shape}")
    print(f"  Target (data.y): {sample_data.y.item():.4f}") # Display target with 4 decimal places
    
    # Global features (data.global_features) details (VERIFYING SCALING HERE)
    print(f"\n  Global features (data.global_features) shape: {sample_data.global_features.shape}")
    if sample_data.global_features.numel() > 0:
        print(f"    Global features (data.global_features) sample (first 5 values): {sample_data.global_features.flatten()[:5].tolist()}")
        print(f"    Global features (data.global_features) min: {sample_data.global_features.min().item():.4f}")
        print(f"    Global features (data.global_features) max: {sample_data.global_features.max().item():.4f}")
        print(f"    Global features (data.global_features) mean: {sample_data.global_features.float().mean().item():.4f}")
        print(f"    Global features (data.global_features) std: {sample_data.global_features.float().std().item():.4f}")
        print(f"    Contains NaN in global_features: {torch.isnan(sample_data.global_features).any().item()}")
        print(f"    Contains Inf in global_features: {torch.isinf(sample_data.global_features).any().item()}")
    else:
        print("    Global features (data.global_features) is an empty tensor.")
        
    print(f"\n  SMILES: {sample_data.smiles}")
    print(f"  Molregno: {sample_data.molregno}")
else:
    print("No training data objects created to display sample.")


--- Sample PyTorch Geometric Data object (from train_data_list[0]) ---
Data(x=[25, 21], edge_index=[2, 58], y=[1], global_features=[1, 4532], molregno=2307646, smiles='COc1cccc2c1OCc1c-2nc2cnc3ccccc3c2c1C')
  Number of nodes (atoms): 25
  Number of edges (bonds): 58

  Node features (data.x) shape: torch.Size([25, 21])
    Node features (data.x) sample (first 5 values): [6.0, 1.0, 4.0, 0.0, 0.0]
    Node features (data.x) min: -0.4928
    Node features (data.x) max: 8.0000
    Node features (data.x) mean: 1.2363
    Node features (data.x) std: 1.7344
    Contains NaN in data.x: False
    Contains Inf in data.x: False
  Edge index (data.edge_index) shape: torch.Size([2, 58])
  Target (data.y): 5.7347

  Global features (data.global_features) shape: torch.Size([1, 4532])
    Global features (data.global_features) sample (first 5 values): [6.0, 6.0, 6.033141613006592, 6.033141613006592, 6.033141613006592]
    Global features (data.global_features) min: -3.1400
    Global features (data.g

#### 3.2.3. Standardize Global Features of Each Molecule
The global features (RDKit descriptors, Morgan fingerprints) within each graph object are standardized (using `StandardScaler`) to normalize their scale.

In [11]:
# Collect all global features to fit the scaler
# Concatenate all global_features tensors. Each data.global_features is already (1, feature_dim),
# so torch.cat(..., dim=0) will result in (num_total_graphs, feature_dim).
list_of_global_features_tensors = [data.global_features for data in train_data_list + val_data_list]
all_global_features_combined = torch.cat(list_of_global_features_tensors, dim=0).cpu().numpy()

# Initialize and fit the scaler on the combined global features from training and validation sets
global_feature_scaler = StandardScaler()
global_feature_scaler.fit(all_global_features_combined)

# Apply scaling to the 'global_features' in Data objects for all splits
for data_list in [train_data_list, val_data_list, test_data_list]:
    for data in data_list:
        # Ensure it's numpy before scaling, then back to torch
        original_global_features_np = data.global_features.cpu().numpy()
        scaled_global_features_np = global_feature_scaler.transform(original_global_features_np)
        # Put it back on the correct device
        data.global_features = torch.tensor(scaled_global_features_np, dtype=torch.float32).to(data.global_features.device)

print("\nGlobal features in torch_geometric.data.Data objects have been scaled!")


Global features in torch_geometric.data.Data objects have been scaled!


##### 3.2.3.1. Verify Scaling of Global Features
A quick check is performed to confirm that the global features within the graph objects have been successfully scaled.

In [12]:
print("\n--- Sample PyTorch Geometric Data object (from train_data_list[0]) ---")
if len(train_data_list) > 0:
    sample_data = train_data_list[0]
    print(sample_data)
    print(f"  Number of nodes (atoms): {sample_data.num_nodes}")
    print(f"  Number of edges (bonds): {sample_data.num_edges}")
    
    # Node features (data.x) details
    print(f"\n  Node features (data.x) shape: {sample_data.x.shape}")
    if sample_data.x.numel() > 0:
        print(f"    Node features (data.x) sample (first 5 values): {sample_data.x.flatten()[:5].tolist()}")
        print(f"    Node features (data.x) min: {sample_data.x.min().item():.4f}")
        print(f"    Node features (data.x) max: {sample_data.x.max().item():.4f}")
        print(f"    Node features (data.x) mean: {sample_data.x.float().mean().item():.4f}")
        print(f"    Node features (data.x) std: {sample_data.x.float().std().item():.4f}")
        print(f"    Contains NaN in data.x: {torch.isnan(sample_data.x).any().item()}")
        print(f"    Contains Inf in data.x: {torch.isinf(sample_data.x).any().item()}")
    else:
        print("    Node features (data.x) is an empty tensor.")

    print(f"  Edge index (data.edge_index) shape: {sample_data.edge_index.shape}")
    print(f"  Target (data.y): {sample_data.y.item():.4f}") # Display target with 4 decimal places
    
    # Global features (data.global_features) details (VERIFYING SCALING HERE)
    print(f"\n  Global features (data.global_features) shape: {sample_data.global_features.shape}")
    if sample_data.global_features.numel() > 0:
        print(f"    Global features (data.global_features) sample (first 5 values): {sample_data.global_features.flatten()[:5].tolist()}")
        print(f"    Global features (data.global_features) min: {sample_data.global_features.min().item():.4f}")
        print(f"    Global features (data.global_features) max: {sample_data.global_features.max().item():.4f}")
        print(f"    Global features (data.global_features) mean: {sample_data.global_features.float().mean().item():.4f}")
        print(f"    Global features (data.global_features) std: {sample_data.global_features.float().std().item():.4f}")
        print(f"    Contains NaN in global_features: {torch.isnan(sample_data.global_features).any().item()}")
        print(f"    Contains Inf in global_features: {torch.isinf(sample_data.global_features).any().item()}")
    else:
        print("    Global features (data.global_features) is an empty tensor.")
        
    print(f"\n  SMILES: {sample_data.smiles}")
    print(f"  Molregno: {sample_data.molregno}")
else:
    print("No training data objects created to display sample.")


--- Sample PyTorch Geometric Data object (from train_data_list[0]) ---
Data(x=[25, 21], edge_index=[2, 58], y=[1], global_features=[1, 4532], molregno=2307646, smiles='COc1cccc2c1OCc1c-2nc2cnc3ccccc3c2c1C')
  Number of nodes (atoms): 25
  Number of edges (bonds): 58

  Node features (data.x) shape: torch.Size([25, 21])
    Node features (data.x) sample (first 5 values): [6.0, 1.0, 4.0, 0.0, 0.0]
    Node features (data.x) min: -0.4928
    Node features (data.x) max: 8.0000
    Node features (data.x) mean: 1.2363
    Node features (data.x) std: 1.7344
    Contains NaN in data.x: False
    Contains Inf in data.x: False
  Edge index (data.edge_index) shape: torch.Size([2, 58])
  Target (data.y): 5.7347

  Global features (data.global_features) shape: torch.Size([1, 4532])
    Global features (data.global_features) sample (first 5 values): [0.038460150361061096, 0.038460150361061096, -2.2056941986083984, -2.2056941986083984, -2.2056941986083984]
    Global features (data.global_features) 

#### 3.2.4. Save Graph Objects
The list of generated PyTorch Geometric `Data` objects with the scaled global features is saved locally to avoid regenerating them in subsequent notebooks.

In [None]:
# Directory for saving the processed graph data
save_dir = Path('../data/pyg_data_graphs')
save_dir.mkdir(parents=True, exist_ok=True)

# Define the full file paths
train_data_path = save_dir / 'train_data_list.pt'
val_data_path = save_dir / 'val_data_list.pt'
test_data_path = save_dir / 'test_data_list.pt'

# Save the lists of Data objects
torch.save(train_data_list, train_data_path)
torch.save(val_data_list, val_data_path)
torch.save(test_data_list, test_data_path)

print(f"Processed graph data saved to: {save_dir}")
print(f"Train data list size: {len(train_data_list)}")
print(f"Validation data list size: {len(val_data_list)}")
print(f"Test data list size: {len(test_data_list)}")

Processed graph data saved to: ..\data\splits\pyg_data_graphs
Train data list size: 13119
Validation data list size: 2812
Test data list size: 2812


## 4. Optimize Hyperparameters
This section utilizes Optuna to systematically search for the optimal set of hyperparameters for the GNN model, aiming to minimize prediction error on the validation set.


### 4.1. Load Graph Objects
The previously saved PyTorch Geometric `Data` graph objects are loaded, serving as input for the hyperparameter optimization process.

In [5]:
# Directory where the graph objects are saved
load_dir = Path('../data/pyg_data_graphs')

# Define the full file paths
train_data_path = load_dir / 'train_data_list.pt'
val_data_path = load_dir / 'val_data_list.pt'
test_data_path = load_dir / 'test_data_list.pt'

# Load the lists of Data objects
try:
    train_data_list = torch.load(train_data_path, weights_only=False)
    val_data_list = torch.load(val_data_path, weights_only=False)
    test_data_list = torch.load(test_data_path, weights_only=False)

    print(f"Loaded {len(train_data_list)} training graphs.")
    print(f"Loaded {len(val_data_list)} validation graphs.")
    print(f"Loaded {len(test_data_list)} test graphs.")

except FileNotFoundError:
    print(f"Error: Processed data not found in {load_dir}. Please run the data processing and saving step first.")
except Exception as e:
    print(f"An error occurred during loading: {e}")

Loaded 13119 training graphs.
Loaded 2812 validation graphs.
Loaded 2812 test graphs.


#### 4.1.1. Verify Loading of Graph Objects
Basic checks are performed to ensure the graph objects have been loaded correctly.

In [6]:
print("\n--- Sample PyTorch Geometric Data object (from train_data_list[0]) ---")
if len(train_data_list) > 0:
    sample_data = train_data_list[0]
    print(sample_data)
    print(f"  Number of nodes (atoms): {sample_data.num_nodes}")
    print(f"  Number of edges (bonds): {sample_data.num_edges}")
    
    # Node features (data.x) details
    print(f"\n  Node features (data.x) shape: {sample_data.x.shape}")
    if sample_data.x.numel() > 0:
        print(f"    Node features (data.x) sample (first 5 values): {sample_data.x.flatten()[:5].tolist()}")
        print(f"    Node features (data.x) min: {sample_data.x.min().item():.4f}")
        print(f"    Node features (data.x) max: {sample_data.x.max().item():.4f}")
        print(f"    Node features (data.x) mean: {sample_data.x.float().mean().item():.4f}")
        print(f"    Node features (data.x) std: {sample_data.x.float().std().item():.4f}")
        print(f"    Contains NaN in data.x: {torch.isnan(sample_data.x).any().item()}")
        print(f"    Contains Inf in data.x: {torch.isinf(sample_data.x).any().item()}")
    else:
        print("    Node features (data.x) is an empty tensor.")

    print(f"  Edge index (data.edge_index) shape: {sample_data.edge_index.shape}")
    print(f"  Target (data.y): {sample_data.y.item():.4f}") # Display target with 4 decimal places
    
    # Global features (data.global_features) details (VERIFYING SCALING HERE)
    print(f"\n  Global features (data.global_features) shape: {sample_data.global_features.shape}")
    if sample_data.global_features.numel() > 0:
        print(f"    Global features (data.global_features) sample (first 5 values): {sample_data.global_features.flatten()[:5].tolist()}")
        print(f"    Global features (data.global_features) min: {sample_data.global_features.min().item():.4f}")
        print(f"    Global features (data.global_features) max: {sample_data.global_features.max().item():.4f}")
        print(f"    Global features (data.global_features) mean: {sample_data.global_features.float().mean().item():.4f}")
        print(f"    Global features (data.global_features) std: {sample_data.global_features.float().std().item():.4f}")
        print(f"    Contains NaN in global_features: {torch.isnan(sample_data.global_features).any().item()}")
        print(f"    Contains Inf in global_features: {torch.isinf(sample_data.global_features).any().item()}")
    else:
        print("    Global features (data.global_features) is an empty tensor.")
        
    print(f"\n  SMILES: {sample_data.smiles}")
    print(f"  Molregno: {sample_data.molregno}")
else:
    print("No training data objects created to display sample.")


--- Sample PyTorch Geometric Data object (from train_data_list[0]) ---
Data(x=[25, 21], edge_index=[2, 58], y=[1], global_features=[1, 4532], molregno=2307646, smiles='COc1cccc2c1OCc1c-2nc2cnc3ccccc3c2c1C')
  Number of nodes (atoms): 25
  Number of edges (bonds): 58

  Node features (data.x) shape: torch.Size([25, 21])
    Node features (data.x) sample (first 5 values): [6.0, 1.0, 4.0, 0.0, 0.0]
    Node features (data.x) min: -0.4928
    Node features (data.x) max: 8.0000
    Node features (data.x) mean: 1.2363
    Node features (data.x) std: 1.7344
    Contains NaN in data.x: False
    Contains Inf in data.x: False
  Edge index (data.edge_index) shape: torch.Size([2, 58])
  Target (data.y): 5.7347

  Global features (data.global_features) shape: torch.Size([1, 4532])
    Global features (data.global_features) sample (first 5 values): [0.038460150361061096, 0.038460150361061096, -2.2056941986083984, -2.2056941986083984, -2.2056941986083984]
    Global features (data.global_features) 

### 4.2. Define Optuna Objective Function
The Optuna objective function is defined here. This function instantiates and trains a GNN model with a given set of hyperparameters, returning its performance (i.e., RMSE) on the validation set, which Optuna aims to minimize.

In [1]:
def objective(trial):
    # Hyperparameters to tune
    hidden_channels = trial.suggest_int("hidden_channels", 128, 1024, log=True) # Number of neurons in hidden layer
    learning_rate = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256]) # Batch size for DataLoaders
    n_epochs = trial.suggest_int("n_epochs", 150, 600)  # Number of training epochs
    num_layers = trial.suggest_int("num_layers", 1, 4) # Number of GNN layers
    dropout_rate = trial.suggest_float("dropout_rate", 0.0, 0.5) # Dropout rate
    weight_decay = trial.suggest_float("weight_decay", 1e-8, 1e-3, log=True)

    # Determine feature dimensions dynamically from loaded/created graph objects
    # Ensure train_data_list is not empty before accessing its first element
    if not train_data_list:
        raise ValueError("train_data_list is empty. Cannot determine feature dimensions.")

    # node_feature_dim: Number of features per atom
    # global_feature_dim: Number of global features per molecule
    node_feature_dim = train_data_list[0].x.shape[1]
    global_feature_dim = train_data_list[0].global_features.shape[1]

    # Initialize model
    model = GNN(
        node_feature_dim=node_feature_dim,
        global_feature_dim=global_feature_dim,
        hidden_channels=hidden_channels,  # From Optuna trial
        num_layers=num_layers,  # From Optuna trial
        dropout_rate=dropout_rate
    ).to(device)

    # Loss function and Optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # PyTorch Geometric DataLoaders
    num_workers = 0
    train_loader = PyGDataLoader(train_data_list, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = PyGDataLoader(val_data_list, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    # Early Stopping Logic
    best_val_rmse = float('inf')
    patience_counter = 0
    patience = 50 # Number of epochs to wait for improvement before stopping

    # Training loop
    for epoch in range(n_epochs):
        # Training
        model.train()  # Set model to training mode
        total_loss = 0
        start_epoch_time = time.time()
        for batch_idx, data_batch in enumerate(train_loader):
            data_batch = data_batch.to(device)

            optimizer.zero_grad()
            outputs = model(data_batch)
                
            # Ensure outputs and target are same shape for loss calculation
            loss = criterion(outputs.view(-1), data_batch.y.view(-1)) # .view(-1) flattens to ensure shape compatibility

            if torch.isnan(outputs).any() or torch.isinf(outputs).any():
                print(f"!!! WARNING: NaN/Inf in model outputs at epoch {epoch+1}, batch {batch_idx+1}")
            if torch.isnan(loss).any() or torch.isinf(loss).any():
                print(f"!!! WARNING: NaN/Inf in loss at epoch {epoch+1}, batch {batch_idx+1}")

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            for name, param in model.named_parameters():
                if param.grad is not None:
                    if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                        print(f"!!! CRITICAL: NaN/Inf in gradient of {name} - Epoch {epoch+1}, Batch {batch_idx+1}")
                        # Add a break here for deeper inspection if this happens
                        # import sys; sys.exit("Gradient instability detected.")
            
            optimizer.step()
            total_loss += loss.item()

        # Validation
        model.eval()  # Set model to evaluation mode
        val_predictions = []
        val_targets = []
        with torch.no_grad(): # Disable gradient calculations for validation
            for data_batch in val_loader:
                data_batch = data_batch.to(device)
                val_outputs = model(data_batch)
                val_predictions.extend(val_outputs.cpu().numpy().flatten())
                val_targets.extend(data_batch.y.cpu().numpy().flatten()) # Extract y from PyG Data object

        val_rmse = np.sqrt(mean_squared_error(val_targets, val_predictions))

        if device.type == 'cuda': # Ensure GPU operations are finished before timing an epoch
            torch.cuda.synchronize()
        end_epoch_time = time.time()

        print(f"Trial {trial.number}, Epoch {epoch+1}/{n_epochs}, Val RMSE: {val_rmse:.4f}, Time: {end_epoch_time - start_epoch_time:.2f}s")

        # Optuna Pruning: Report current validation RMSE to Optuna
        trial.report(val_rmse, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

        # Manual Early Stopping Check
        if val_rmse < best_val_rmse:
            best_val_rmse = val_rmse
            patience_counter = 0 # Reset patience if improvement is found
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1} for trial {trial.number}")
                break # Exit training loop for current trial

    # Final evaluation on validation set after training (or early stopping)
    model.eval()
    final_val_predictions = []
    final_val_targets = []
    with torch.no_grad():
        for data_batch in val_loader:
            data_batch = data_batch.to(device)
            val_outputs = model(data_batch)
            final_val_predictions.extend(val_outputs.cpu().numpy().flatten())
            final_val_targets.extend(data_batch.y.cpu().numpy().flatten())

    final_rmse = np.sqrt(mean_squared_error(final_val_targets, final_val_predictions))
    final_r2 = r2_score(final_val_targets, final_val_predictions)

    # Store R2 score as well in the study for later analysis
    trial.set_user_attr("final_r2_score", float(final_r2))

    return final_rmse # Optuna minimizes this value

### 4.3. Run Optuna Study
An Optuna study is created and executed to perform the hyperparameter optimization, iterating through various trials to find the best combination of GNN model parameters.

In [7]:
study_dir = Path("../studies/gnn_study")
study_dir.mkdir(parents=True, exist_ok=True)

study_db_path = f"sqlite:///{study_dir / 'gnn_optuna_study.db'}"
study_name = "gnn_regression_pGI50"
print(f"Optuna study for GNN will be stored at: {study_db_path}")

pruner = optuna.pruners.MedianPruner(
    n_startup_trials=10,  # Run at least these many trials completely before starting to prune
    n_warmup_steps=20,    # Don't prune trials until they've completed these many epochs
    interval_steps=10     # Check for pruning every these many epochs
)
# pruner = None

# Check if a study with the same name already exists in the database
# If it does, load it to resume the optimization.
try:
    study = optuna.load_study(study_name=study_name, storage=study_db_path)
    print(f"Loaded existing study '{study_name}' from {study_db_path}. Resuming optimization.")
except KeyError:
    # If the study does not exist, create a new one
    print(f"Creating new study '{study_name}' at {study_db_path}.")
    study = optuna.create_study(
        study_name=study_name,
        direction="minimize",
        storage=study_db_path,
        pruner=pruner
    )

print("\nStarting Optuna optimization for GNN...")
# Run for 'n_trial' trials or 'timeout' seconds, whichever completes first
study.optimize(objective, n_trials=None, timeout=14400, show_progress_bar=True)
print("\nOptuna optimization finished for GNN.")

# Print best trial results
print("\n--- Best Trial Results for GNN ---")
print(f"Best trial number: {study.best_trial.number}")
print(f"Best RMSE (Validation): {study.best_value:.4f}")
print("Best hyperparameters:")
for key, value in study.best_params.items():
    print(f"  {key}: {value}")

if "final_r2_score" in study.best_trial.user_attrs:
    print(f"Best R2 Score (Validation): {study.best_trial.user_attrs['final_r2_score']:.4f}")

Optuna study for GNN will be stored at: sqlite:///..\studies\gnn_study\gnn_optuna_study.db
Loaded existing study 'gnn_regression_pGI50' from sqlite:///..\studies\gnn_study\gnn_optuna_study.db. Resuming optimization.

Starting Optuna optimization for GNN...


   0%|          | 00:00/4:00:00

[W 2025-07-18 20:01:54,236] Trial 97 failed with parameters: {'hidden_channels': 959, 'lr': 0.000335866088243583, 'batch_size': 32, 'n_epochs': 295, 'num_layers': 3, 'dropout_rate': 0.2061426548243178, 'weight_decay': 1.092019611038451e-08} because of the following error: NameError("name 'train_data_list' is not defined").
Traceback (most recent call last):
  File "c:\Users\Acer\Desktop\Projects for Data Science\Drug Gi50 Value Prediction\venv\Lib\site-packages\optuna\study\_optimize.py", line 201, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "C:\Users\Acer\AppData\Local\Temp\ipykernel_28584\1816650752.py", line 13, in objective
    if not train_data_list:
           ^^^^^^^^^^^^^^^
NameError: name 'train_data_list' is not defined
[W 2025-07-18 20:01:54,261] Trial 97 failed with value None.



KeyboardInterrupt



## 5. Train Final Model
This section trains the final GNN model using the best hyperparameters identified by Optuna and saves it for future use.

### 5.1. Load Graph Objects
The necessary PyTorch Geometric graph objects are reloaded to ensure a fresh start for final model training.

In [7]:
# Directory where the graph objects are saved
load_dir = Path("../data/pyg_data_graphs")

# Define the full file paths
train_data_path = load_dir / "train_data_list.pt"
val_data_path = load_dir / "val_data_list.pt"
test_data_path = load_dir / "test_data_list.pt"

# Load the lists of Data objects
try:
    train_data_list = torch.load(train_data_path, weights_only=False)
    val_data_list = torch.load(val_data_path, weights_only=False)
    test_data_list = torch.load(test_data_path, weights_only=False)

    print(f"Loaded {len(train_data_list)} training graphs.")
    print(f"Loaded {len(val_data_list)} validation graphs.")
    print(f"Loaded {len(test_data_list)} test graphs.")

except FileNotFoundError:
    print(
        f"Error: Processed data not found in {load_dir}. Please run the data processing and saving step first."
    )
except Exception as e:
    print(f"An error occurred during loading: {e}")

Loaded 13119 training graphs.
Loaded 2812 validation graphs.
Loaded 2812 test graphs.


#### 5.1.1. Verify Loading of Graph Objects
Basic checks are performed to ensure the graph objects have been loaded correctly.

In [8]:
print("\n--- Sample PyTorch Geometric Data object (from train_data_list[0]) ---")
if len(train_data_list) > 0:
    sample_data = train_data_list[0]
    print(sample_data)
    print(f"  Number of nodes (atoms): {sample_data.num_nodes}")
    print(f"  Number of edges (bonds): {sample_data.num_edges}")
    
    # Node features (data.x) details
    print(f"\n  Node features (data.x) shape: {sample_data.x.shape}")
    if sample_data.x.numel() > 0:
        print(f"    Node features (data.x) sample (first 5 values): {sample_data.x.flatten()[:5].tolist()}")
        print(f"    Node features (data.x) min: {sample_data.x.min().item():.4f}")
        print(f"    Node features (data.x) max: {sample_data.x.max().item():.4f}")
        print(f"    Node features (data.x) mean: {sample_data.x.float().mean().item():.4f}")
        print(f"    Node features (data.x) std: {sample_data.x.float().std().item():.4f}")
        print(f"    Contains NaN in data.x: {torch.isnan(sample_data.x).any().item()}")
        print(f"    Contains Inf in data.x: {torch.isinf(sample_data.x).any().item()}")
    else:
        print("    Node features (data.x) is an empty tensor.")

    print(f"  Edge index (data.edge_index) shape: {sample_data.edge_index.shape}")
    print(f"  Target (data.y): {sample_data.y.item():.4f}") # Display target with 4 decimal places
    
    # Global features (data.global_features) details (VERIFYING SCALING HERE)
    print(f"\n  Global features (data.global_features) shape: {sample_data.global_features.shape}")
    if sample_data.global_features.numel() > 0:
        print(f"    Global features (data.global_features) sample (first 5 values): {sample_data.global_features.flatten()[:5].tolist()}")
        print(f"    Global features (data.global_features) min: {sample_data.global_features.min().item():.4f}")
        print(f"    Global features (data.global_features) max: {sample_data.global_features.max().item():.4f}")
        print(f"    Global features (data.global_features) mean: {sample_data.global_features.float().mean().item():.4f}")
        print(f"    Global features (data.global_features) std: {sample_data.global_features.float().std().item():.4f}")
        print(f"    Contains NaN in global_features: {torch.isnan(sample_data.global_features).any().item()}")
        print(f"    Contains Inf in global_features: {torch.isinf(sample_data.global_features).any().item()}")
    else:
        print("    Global features (data.global_features) is an empty tensor.")
        
    print(f"\n  SMILES: {sample_data.smiles}")
    print(f"  Molregno: {sample_data.molregno}")
else:
    print("No training data objects created to display sample.")


--- Sample PyTorch Geometric Data object (from train_data_list[0]) ---
Data(x=[25, 21], edge_index=[2, 58], y=[1], global_features=[1, 4532], molregno=2307646, smiles='COc1cccc2c1OCc1c-2nc2cnc3ccccc3c2c1C')
  Number of nodes (atoms): 25
  Number of edges (bonds): 58

  Node features (data.x) shape: torch.Size([25, 21])
    Node features (data.x) sample (first 5 values): [6.0, 1.0, 4.0, 0.0, 0.0]
    Node features (data.x) min: -0.4928
    Node features (data.x) max: 8.0000
    Node features (data.x) mean: 1.2363
    Node features (data.x) std: 1.7344
    Contains NaN in data.x: False
    Contains Inf in data.x: False
  Edge index (data.edge_index) shape: torch.Size([2, 58])
  Target (data.y): 5.7347

  Global features (data.global_features) shape: torch.Size([1, 4532])
    Global features (data.global_features) sample (first 5 values): [0.038460150361061096, 0.038460150361061096, -2.2056941986083984, -2.2056941986083984, -2.2056941986083984]
    Global features (data.global_features) 

### 5.2. Reinitialize Everything with Best Hyperparameters
The GNN model is reinitialized with the optimal hyperparameters found during the Optuna study, along with the final training (training and validation data **combined**) and testing data, and their respective DataLoaders. 

In [13]:
# Re-load the study to ensure the latest best parameters
study_dir = Path("../studies/gnn_study")
study_db_path = f"sqlite:///{study_dir / 'gnn_optuna_study.db'}"
study_name = "gnn_regression_pGI50"

try:
    study = optuna.load_study(study_name=study_name, storage=study_db_path)
    print("Best trial parameters (GNN):", study.best_trial.params)
    best_params = study.best_trial.params
except KeyError:
    print(f"Study '{study_name}' does not exist at {study_db_path}. Please make sure the GNN Optuna study cell has been run.")

best_hidden_channels = best_params["hidden_channels"]
best_learning_rate = best_params["lr"]
best_batch_size = best_params["batch_size"]
best_n_epochs = best_params["n_epochs"]
best_num_layers = best_params["num_layers"]
best_dropout_rate = best_params["dropout_rate"]
best_weight_decay = best_params["weight_decay"]

print(f"Best hyperparameters from Optuna: {best_params}")

# Re-initialize the model with best hyperparameters
if not train_data_list:
    raise ValueError("train_data_list is empty. Cannot determine feature dimensions for GNN.")

node_feature_dim = train_data_list[0].x.shape[1]
global_feature_dim = train_data_list[0].global_features.shape[1]

final_gnn_model = GNN(
    node_feature_dim=node_feature_dim,
    global_feature_dim=global_feature_dim,
    hidden_channels=best_hidden_channels,
    num_layers=best_num_layers,
    dropout_rate=best_dropout_rate
).to(device)

# Re-initialize criterion and optimizer
final_criterion = nn.MSELoss()
final_optimizer = optim.Adam(final_gnn_model.parameters(), lr=best_learning_rate, weight_decay=best_weight_decay)

# Re-create DataLoaders with the best batch size (Training + Validation data COMBINED)
final_train_val_data_list = train_data_list + val_data_list

# Create DataLoaders with the best batch size
num_workers = 0
final_train_val_loader = PyGDataLoader(final_train_val_data_list, batch_size=best_batch_size, shuffle=True, num_workers=num_workers)

# Create the FINAL TEST DataLoader
final_test_loader = PyGDataLoader(test_data_list, batch_size=best_batch_size, shuffle=False, num_workers=num_workers)

print(f"Final GNN model, criterion, optimizer, and DataLoaders initialized with best parameters.")
print(f"Training on combined {len(final_train_val_data_list)} samples, testing on {len(test_data_list)} samples.")

Best trial parameters (GNN): {'hidden_channels': 999, 'lr': 0.0002654886343578734, 'batch_size': 128, 'n_epochs': 190, 'num_layers': 1, 'dropout_rate': 0.12998376396007172, 'weight_decay': 4.855663649252953e-08}
Best hyperparameters from Optuna: {'hidden_channels': 999, 'lr': 0.0002654886343578734, 'batch_size': 128, 'n_epochs': 190, 'num_layers': 1, 'dropout_rate': 0.12998376396007172, 'weight_decay': 4.855663649252953e-08}
Final GNN model, criterion, optimizer, and DataLoaders initialized with best parameters.
Training on combined 15931 samples, testing on 2812 samples.


### 5.3. Get Current Commit ID
The current Git commit ID (hash) is programmatically retrieved. This commit ID will be incorporated into the final model's filename to ensure direct traceability and reproducibility.

In [13]:
def get_git_commit_hash():
    try:
        # Get the short commit hash
        commit_hash = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip().decode('ascii')
        return commit_hash
    except (subprocess.CalledProcessError, FileNotFoundError):
        return "unknown_commit"

In [14]:
# Optionally, see the current commit ID
current_commit = get_git_commit_hash()
print(f"Current Git Commit ID: {current_commit}")

Current Git Commit ID: 271d2f4


### 5.4. Train and Save Model
The final GNN model is trained on the combined training and validation graph datasets and then saved locally with a filename that includes the Git commit ID.

In [16]:
best_final_val_rmse = float('inf')
patience_counter_final = 0
final_patience = 50

current_commit_hash = get_git_commit_hash()
model_filename = f"final_best_gnn_model_{current_commit_hash}.pt" # Pre-define filename

print(f"Retraining final GNN model for {best_n_epochs} epochs with best parameters...")
print(f"Associated Git Commit ID for saved model: {current_commit_hash}")

for epoch in range(best_n_epochs):
    start_epoch_time = time.time()
    
    # Training
    final_gnn_model.train()
    total_train_loss = 0
    num_train_batches = 0
    for data_batch in final_train_val_loader:
        # Move data to device
        data_batch = data_batch.to(device)
        
        final_optimizer.zero_grad()
        outputs = final_gnn_model(data_batch)
        
        # Ensure outputs and target are of same shape for loss calculation
        loss = final_criterion(outputs.view(-1), data_batch.y.view(-1))
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(final_gnn_model.parameters(), max_norm=1.0)
        final_optimizer.step()

        total_train_loss += loss.item()
        num_train_batches += 1

    avg_train_loss = total_train_loss / num_train_batches

    # Evaluation on combined data
    final_gnn_model.eval()
    val_predictions = []
    val_targets = []
    with torch.no_grad():
        for data_batch_eval in final_train_val_loader:
            data_batch_eval = data_batch_eval.to(device)
            
            val_outputs = final_gnn_model(data_batch_eval)
            val_predictions.extend(val_outputs.cpu().numpy().flatten())
            val_targets.extend(data_batch_eval.y.cpu().numpy().flatten())

    current_val_rmse = np.sqrt(mean_squared_error(val_targets, val_predictions))

    if device.type == 'cuda': # Ensure GPU operations are finished before timing an epoch
        torch.cuda.synchronize()
    end_epoch_time = time.time()
    
    print(f"Epoch {epoch+1}/{best_n_epochs}, Train Loss: {avg_train_loss:.4f}, Eval RMSE on combined data: {current_val_rmse:.4f}, Time: {end_epoch_time - start_epoch_time:.2f}s")

    # Dynamic Best Model Saving & Early Stopping
    if current_val_rmse < best_final_val_rmse:
        best_final_val_rmse = current_val_rmse
        # Save the GNN model state dict
        torch.save(final_gnn_model.state_dict(), gnn_models_base_dir / model_filename)
        patience_counter_final = 0 # Reset patience counter if performance improved
        print(f"--- New best final GNN model saved at epoch {epoch+1} with RMSE: {current_val_rmse:.4f} ---")
    else:
        patience_counter_final += 1 # Increment patience counter if no improvement
        print(f"No improvement for {patience_counter_final} epochs. Best RMSE so far: {best_final_val_rmse:.4f}")

    if patience_counter_final >= final_patience:
        print(f"Early stopping triggered at epoch {epoch+1}.")
        break

print("Final model training complete.")

Retraining final GNN model for 190 epochs with best parameters...
Associated Git Commit ID for saved model: 271d2f4
Epoch 1/190, Train Loss: 2.0786, Eval RMSE on combined data: 0.7288, Time: 5.49s
--- New best final GNN model saved at epoch 1 with RMSE: 0.7288 ---
Epoch 2/190, Train Loss: 0.5506, Eval RMSE on combined data: 0.6565, Time: 5.28s
--- New best final GNN model saved at epoch 2 with RMSE: 0.6565 ---
Epoch 3/190, Train Loss: 0.4511, Eval RMSE on combined data: 0.5760, Time: 5.31s
--- New best final GNN model saved at epoch 3 with RMSE: 0.5760 ---
Epoch 4/190, Train Loss: 0.3592, Eval RMSE on combined data: 0.5163, Time: 5.25s
--- New best final GNN model saved at epoch 4 with RMSE: 0.5163 ---
Epoch 5/190, Train Loss: 0.2898, Eval RMSE on combined data: 0.4665, Time: 5.26s
--- New best final GNN model saved at epoch 5 with RMSE: 0.4665 ---
Epoch 6/190, Train Loss: 0.2443, Eval RMSE on combined data: 0.4380, Time: 5.23s
--- New best final GNN model saved at epoch 6 with RMSE: 0

### 5.5. Evaluate Model
This section performs a final, unbiased evaluation of the trained GNN model's performance on the previously unseen test dataset.

In [17]:
# Load the best state dict model
print(f"Loading best saved GNN model from '{gnn_models_base_dir / model_filename}' for final test evaluation...")
path_to_saved_model = gnn_models_base_dir / model_filename
loaded_model_state_dict = torch.load(path_to_saved_model)
final_gnn_model.load_state_dict(loaded_model_state_dict)
final_gnn_model.eval()

print("\nStarting final evaluation on test set for GNN...")
test_predictions = []
test_targets = []
with torch.no_grad():
    for data_batch_test in final_test_loader:
        data_batch_test = data_batch_test.to(device)

        test_outputs = final_gnn_model(data_batch_test)

        # Collect predictions and targets
        test_predictions.extend(test_outputs.cpu().numpy().flatten())
        test_targets.extend(data_batch_test.y.cpu().numpy().flatten())

final_test_rmse = np.sqrt(mean_squared_error(test_targets, test_predictions))
final_test_r2 = r2_score(test_targets, test_predictions)

print(f"Final GNN Model Test RMSE: {final_test_rmse:.4f}")
print(f"Final GNN Model Test R2: {final_test_r2:.4f}")

Loading best saved GNN model from '..\models\gnn\final_best_gnn_model_271d2f4.pt' for final test evaluation...

Starting final evaluation on test set for GNN...
Final GNN Model Test RMSE: 0.6114
Final GNN Model Test R2: 0.6100
