# Package Installation and Environment Initialization

In [1]:
!pip install torch-geometric --quiet
!pip install torch --quiet
!pip install pandas --quiet
!pip install rdkit --quiet
!pip install numpy --quiet
!pip install matplotlib --quiet
!pip install seaborn --quiet

In [2]:
import torch
import torch_geometric
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.data import Dataset
from torch_geometric.data import DataLoader
from torch_geometric.nn import global_mean_pool
import torch_geometric.nn as pyg_nn
from rdkit import Chem
from rdkit.Chem import Descriptors, Draw
from rdkit.Chem import rdPartialCharges
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch
from torch_geometric.loader import DenseDataLoader
from sklearn.preprocessing import LabelEncoder

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [4]:
pd.set_option("display.max_colwidth", None)

# Data Loading & Initial Exploration

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
# Path to the train dataset
file_path = '/content/drive/My Drive/GNN Project Materials/train_dataset.csv'

# Load the CSV file into a DataFrame
train_df = pd.read_csv(file_path)

# Display the first few rows
print(train_df.head())

   id                            buildingblock1_smiles buildingblock2_smiles  \
0   0  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
1   1  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
2   2  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
3   3  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
4   4  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   

     buildingblock3_smiles  \
0  Br.Br.NCC1CCCN1c1cccnn1   
1  Br.Br.NCC1CCCN1c1cccnn1   
2  Br.Br.NCC1CCCN1c1cccnn1   
3        Br.NCc1cccc(Br)n1   
4        Br.NCc1cccc(Br)n1   

                                                          molecule_smiles  \
0  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1   
1  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1   
2  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1   
3     C#CCOc1ccc(CNc2nc(NCc3cc

In [7]:
print(train_df.count())

id                       5246830
buildingblock1_smiles    5246830
buildingblock2_smiles    5246830
buildingblock3_smiles    5246830
molecule_smiles          5246830
protein_name             5246830
binds                    5246830
mol_wt                   5246830
logP                     5246830
rotamers                 5246830
dtype: int64


In [8]:
# Drop building block SMILES columns
train_df.drop(columns=['buildingblock1_smiles', 'buildingblock2_smiles', 'buildingblock3_smiles'], inplace=True)

# Addressing Target Imbalance

In [9]:
# Check unique values in the 'binds' column
print(train_df['binds'].value_counts())

binds
0    5236321
1      10509
Name: count, dtype: int64


In [10]:
# Step 1: Separate the data into two parts: one for binds = 0 and one for binds = 1
binds_0_df = train_df[train_df['binds'] == 0]
binds_1_df = train_df[train_df['binds'] == 1]

# Step 2: Downsample binds = 0 to match the number of binds = 1
binds_0_downsampled_df = binds_0_df.sample(n=len(binds_1_df), random_state=42)

# Step 3: Combine the downsampled binds = 0 data with binds = 1 data
balanced_df = pd.concat([binds_0_downsampled_df, binds_1_df])

# Step 4: Shuffle the final dataset to mix the rows
balanced_df = balanced_df.sample(frac=1, random_state=42).reset_index(drop=True)

# Step 5: Verify the shape and class distribution
print(f"Balanced dataset shape: {balanced_df.shape}")
print(f"Class distribution in balanced dataset: {balanced_df['binds'].value_counts()}")

# Step 6: Show the first few rows of the balanced dataset
balanced_df.head()

Balanced dataset shape: (21018, 7)
Class distribution in balanced dataset: binds
0    10509
1    10509
Name: count, dtype: int64


Unnamed: 0,id,molecule_smiles,protein_name,binds,mol_wt,logP,rotamers
0,3566327,C#CC[C@H](CC(=O)N[Dy])Nc1nc(NCC2Cc3ccccc3NC2=O)nc(Nc2ccc(Br)nc2OC)n1,sEH,0,742.055548,2.7786,2048
1,4617174,C#CC[C@H](Nc1nc(NCc2ccc(OC)c(OC)c2C)nc(Nc2ccc(O)cc2Cl)n1)C(=O)N[Dy],BRD4,1,674.09483,3.29622,2048
2,1425402,C#CC[C@@H](Nc1nc(NCc2ccccc2-c2cnn(C)c2)nc(Nc2cccc(Br)c2C)n1)C(=O)N[Dy],BRD4,0,722.065718,4.08302,1024
3,959535,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCc2ccc(C)cc2OC2CCOC2)nc(Nc2ccc3c(c2)CNC3=O)n1,BRD4,1,719.176001,2.72122,4096
4,617223,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCNC(=O)c2ccno2)nc(Nc2nc3c(C)cccc3s2)n1,BRD4,0,683.096705,2.13762,4096


In [11]:
# Save the sampled dataset to Drive
sampled_file_path = '/content/drive/My Drive/GNN Project Materials/sampled_train_dataset.csv'
balanced_df.to_csv(sampled_file_path, index=False)

# Confirm file has been saved
print(f"Sampled dataset saved to: {sampled_file_path}")

Sampled dataset saved to: /content/drive/My Drive/GNN Project Materials/sampled_train_dataset.csv


In [12]:
print(balanced_df.count())

id                 21018
molecule_smiles    21018
protein_name       21018
binds              21018
mol_wt             21018
logP               21018
rotamers           21018
dtype: int64


In [13]:
# Ensure all SMILES are strings and not NaN
balanced_df['molecule_smiles'] = balanced_df['molecule_smiles'].fillna("").astype(str)

# Check that every item is indeed a string
print(balanced_df['molecule_smiles'].apply(type).unique())  # Should print <class 'str'>

[<class 'str'>]


# Retrieving Protein Sequences of `protein_name`



In [17]:
print(balanced_df['protein_name'].value_counts())

protein_name
HSA     7753
BRD4    6746
sEH     6519
Name: count, dtype: int64


In [18]:
pip install biopython --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m3.1/3.3 MB[0m [31m93.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m47.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [82]:
from Bio import ExPASy
from Bio import SeqIO
from collections import Counter
import numpy as np

# Define the UniProt IDs for your proteins
protein_ids = {
    "HSA": "P02768",  # Human Serum Albumin
    "BRD4": "O60885",  # Bromodomain-containing protein 4
    "sEH": "P34913"   # Epoxide hydrolase 2
}

# Define the maximum sequence length (you need to determine this based on your data)
MAX_PROTEIN_LEN = 619  # Or find the maximum length in your protein_sequences

# Function to fetch protein sequence
def fetch_sequence(uniprot_id):
    try:
        handle = ExPASy.get_sprot_raw(uniprot_id)  # Fetch the raw data from UniProt
        record = SeqIO.read(handle, "swiss")  # Parse the data in Swiss-Prot format
        handle.close()
        return str(record.seq)  # Return the protein sequence as a string
    except Exception as e:
        print(f"Error fetching sequence for {uniprot_id}: {e}")
        return None

# Function to pad or truncate protein sequences
def pad_protein_sequence(sequence, max_len=MAX_PROTEIN_LEN, pad_char='X'):
    """Pads or truncates a protein sequence to a specific length."""
    if len(sequence) < max_len:
        return sequence + pad_char * (max_len - len(sequence))
    else:
        return sequence[:max_len]

# Fetch and print sequences
protein_sequences = {}
for protein, uniprot_id in protein_ids.items():
    sequence = fetch_sequence(uniprot_id)
    if sequence:
        padded_sequence = pad_protein_sequence(sequence, max_len=MAX_PROTEIN_LEN)
        protein_sequences[protein] = padded_sequence
        print(f"{protein} ({uniprot_id}): {sequence[:50]}...")
    else:
        print(f"Failed to fetch sequence for {protein} ({uniprot_id})")

# One-hot encode all padded sequences
encoded_sequences = {}
for protein, sequence in protein_sequences.items():
    encoded_sequences[protein] = one_hot_encode_sequence(sequence)

HSA (P02768): MKWVTFISLLFLFSSAYSRGVFRRDAHKSEVAHRFKDLGEENFKALVLIA...
BRD4 (O60885): MSAESGPGTRLRNLPVMGDGLETSQMSTTQAQAQPQPANAASTNPPPPET...
sEH (P34913): MTLRAAVFDLDGVLALPAVFGVLGRTEEALALPRGLLNDAFQKGGPEGAT...


In [83]:
# Define the 20 standard amino acids
amino_acids = "ACDEFGHIKLMNPQRSTVWY"  # Alphabetical order
aa_to_index = {aa: i for i, aa in enumerate(amino_acids)}  # Map each amino acid to an index

print("Amino acid to index mapping:")
print(aa_to_index)

Amino acid to index mapping:
{'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19}


In [84]:
import numpy as np

def one_hot_encode_sequence(sequence):
    """
    Convert a protein sequence into a one-hot encoded matrix.
    :param sequence: Protein sequence (string)
    :return: One-hot encoded matrix (numpy array of shape [sequence_length, 20])
    """
    encoding = np.zeros((len(sequence), len(amino_acids)), dtype=np.float32)  # Initialize matrix with float32 type
    for i, aa in enumerate(sequence):
        if aa in aa_to_index:  # Check if the amino acid is in the standard 20
            encoding[i, aa_to_index[aa]] = 1  # Set the corresponding index to 1
    return encoding

# One-hot encode all sequences
encoded_sequences = {}
for protein, sequence in protein_sequences.items():
    encoded_sequences[protein] = one_hot_encode_sequence(sequence)

# Print shapes of encoded sequences
for protein, encoded in encoded_sequences.items():
    print(f"{protein}: {encoded.shape}")

HSA: (619, 20)
BRD4: (619, 20)
sEH: (619, 20)


# Converting SMILES to Molecular Graphs

In [97]:
from rdkit import Chem
from rdkit.Chem import Descriptors
import torch
from torch_geometric.data import Data
from collections import Counter

def smiles_to_graph_with_protein(smiles, protein_name, binds, protein_sequences):
    """
    Convert a SMILES string and protein sequence into a graph representation.

    Args:
        smiles (str): SMILES string of the ligand.
        protein_name (str): Name of the protein (e.g., "HSA", "BRD4").
        binds (int): Binary label indicating binding (0 or 1).
        protein_sequences (dict): Dictionary mapping protein names to one-hot encoded sequences.

    Returns:
        Data: A PyTorch Geometric Data object.
    """
    # Parse the SMILES string to create the molecule object
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        print(f"Invalid SMILES string: {smiles}")
        return None

    # Initialize lists for atom and bond features
    atom_features = []
    bond_index = []
    bond_features = []

    # Extract atom features
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetTotalNumHs(),
            atom.GetFormalCharge(),
            int(atom.GetIsAromatic()),
            int(atom.GetHybridization()),
            atom.GetMass(),
            atom.GetAtomicNum() / 100.0,  # Normalize atomic number
            int(atom.IsInRing()),  # Whether atom is part of a ring
            int(atom.GetChiralTag() != Chem.rdchem.ChiralType.CHI_UNSPECIFIED),  # Chiral center
        ]
        atom_features.append(features)

    # Extract bond features
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_index.append((i, j))
        bond_index.append((j, i))  # Bond is bidirectional
        bond_features.append([bond.GetBondTypeAsDouble(), int(bond.GetIsAromatic()), bond.IsInRing()])
        bond_features.append([bond.GetBondTypeAsDouble(), int(bond.GetIsAromatic()), bond.IsInRing()])

    # Handle empty graphs
    if len(atom_features) == 0 or len(bond_index) == 0:
        print(f"Skipping invalid SMILES string or empty graph: {smiles}")
        return None

    # Convert features to tensors
    atom_features = torch.tensor(atom_features, dtype=torch.float)
    bond_index = torch.tensor(bond_index, dtype=torch.long).t().contiguous()
    bond_features = torch.tensor(bond_features, dtype=torch.float)

    # Scalar molecular features (Descriptors)
    mol_wt = Descriptors.MolWt(mol)
    logP = Descriptors.MolLogP(mol)
    rotamers = Descriptors.NumRotatableBonds(mol)
    tpsa = Descriptors.TPSA(mol)
    qed = Descriptors.qed(mol)
    # Additional molecular descriptors
    heavy_atoms = Descriptors.HeavyAtomCount(mol)
    h_acceptors = Descriptors.NumHAcceptors(mol)
    h_donors = Descriptors.NumHDonors(mol)
    # Adding more molecular descriptors
    scalar_features = torch.tensor([mol_wt, logP, rotamers, tpsa, qed, heavy_atoms, h_acceptors, h_donors], dtype=torch.float)

    # Get one-hot encoded protein sequence
    if protein_name not in protein_sequences:
        print(f"Protein {protein_name} not found in protein_sequences.")
        return None
    protein_encoded = protein_sequences[protein_name]  # One-hot encoded sequence
    protein_encoded_tensor = torch.tensor(protein_encoded, dtype=torch.float)

    # Additional protein features
    protein_len = len(protein_encoded)  # Length of the protein sequence

    # Modify this part to get amino acid counts correctly
    aa_counts = Counter(protein_encoded.argmax(axis=1))  # Count amino acids based on argmax
    aa_features = [aa_counts.get(aa, 0) / protein_len for aa in range(20)]  # Relative frequency of each amino acid

    protein_length_tensor = torch.tensor([protein_len], dtype=torch.float)
    aa_composition_tensor = torch.tensor(aa_features, dtype=torch.float)

    # Create label tensor
    if binds is None:
        print(f"No label for molecule: {smiles}")
        return None
    label_tensor = torch.tensor([binds], dtype=torch.long)

    # Return a Data object
    return Data(
        x=atom_features,  # Atom features
        edge_index=bond_index,  # Edge indices
        edge_attr=bond_features,  # Bond features
        scalar_features=scalar_features.unsqueeze(0),  # Scalar molecular features
        protein_feature=protein_encoded_tensor,  # One-hot encoded protein sequence
        protein_len=protein_length_tensor,  # Protein sequence length
        aa_composition=aa_composition_tensor,  # Amino acid composition features
        y=label_tensor,  # Label tensor
        protein_name=protein_name  # Add protein name here

    )

In [98]:
from sklearn.model_selection import train_test_split
from torch_geometric.data import DataLoader

# Split the dataset into train and test sets (80% train, 20% test)
train_df, test_df = train_test_split(balanced_df, test_size=0.2, random_state=42)

# Function to process rows into graph data
def create_graph_data(df, protein_sequences):
    """
    Convert rows of the DataFrame into graph data.

    Args:
        df (pd.DataFrame): DataFrame containing SMILES, protein names, and binding labels.
        protein_sequences (dict): Dictionary mapping protein names to one-hot encoded sequences.

    Returns:
        list: List of PyTorch Geometric Data objects.
    """
    graph_data_list = []
    for i, row in df.iterrows():
        graph_data = smiles_to_graph_with_protein(
            row["molecule_smiles"], row["protein_name"], row["binds"], protein_sequences
        )
        if graph_data is not None:
            graph_data_list.append(graph_data)
    return graph_data_list

# Create graph data for train and test sets
train_graph_data_list = create_graph_data(train_df, encoded_sequences)
test_graph_data_list = create_graph_data(test_df, encoded_sequences)

# Create DataLoaders for batching
train_loader = DataLoader(train_graph_data_list, batch_size=32, shuffle=True, follow_batch=['protein_name'])
test_loader = DataLoader(test_graph_data_list, batch_size=32, shuffle=False, follow_batch=['protein_name'])



# Building a GCN (Graph Convolutional Network)

In [108]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn.functional as F

class GNNModel(nn.Module):
    def __init__(self, num_node_features, num_classes, protein_feature_dim, dropout_rate):
        super(GNNModel, self).__init__()

        # Graph layers
        self.conv1 = GCNConv(num_node_features, 64)
        self.conv2 = GCNConv(64, 32)
        self.batch_norm1 = nn.BatchNorm1d(64)
        self.batch_norm2 = nn.BatchNorm1d(32)
        self.dropout = nn.Dropout(p=dropout_rate)

        # Protein feature processing layers
        self.protein_fc1 = nn.Linear(protein_feature_dim, 128)  # Adjust dimensions as needed
        self.protein_fc2 = nn.Linear(128, 64)  # Adjust dimensions as needed
        self.protein_batch_norm = nn.BatchNorm1d(64)
        self.protein_dropout = nn.Dropout(p=dropout_rate)

        # Combined feature processing
        self.fc1 = nn.Linear(32 + 64, 128)  # Combine graph and protein features
        self.fc2 = nn.Linear(128, num_classes)
        self.fc_dropout = nn.Dropout(p=dropout_rate)

    def forward(self, data):
        x, edge_index, batch, protein_feature = data.x, data.edge_index, data.batch, data.protein_feature

        # Graph convolution layers
        x = torch.relu(self.conv1(x, edge_index))
        x = self.batch_norm1(x)
        x = self.dropout(x)
        x = torch.relu(self.conv2(x, edge_index))
        x = self.batch_norm2(x)
        x = self.dropout(x)
        x = global_mean_pool(x, batch)  # Global pooling for graph features

        # Protein feature processing
        protein_x = torch.relu(self.protein_fc1(protein_feature))
        protein_x = torch.relu(self.protein_fc2(protein_x))
        protein_x = self.protein_batch_norm(protein_x)
        protein_x = self.protein_dropout(protein_x)

        # Get unique protein names from the batch
        unique_protein_names = list(set(data.protein_name))

        # Create a dictionary to map protein names to indices
        protein_name_to_index = {name: i for i, name in enumerate(unique_protein_names)}

        # Convert protein names to indices
        protein_name_indices = [protein_name_to_index[name] for name in data.protein_name]

        # Create a tensor of protein name indices
        protein_name_tensor = torch.tensor(protein_name_indices, device=protein_feature.device)

        # Get protein embeddings using indices
        batch_protein_embeddings = protein_x[protein_name_tensor]

        # Concatenate graph and protein features
        x = torch.cat([x, batch_protein_embeddings], dim=1)

        # Fully connected layers
        x = torch.relu(self.fc1(x))
        x = self.fc_dropout(x)
        x = self.fc2(x)
        return x

In [109]:
import torch
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
from torch_geometric.data import DataLoader

# 1. Define a function for training the model with early stopping
def train_with_early_stopping(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=50, patience=10):
    # Move model to the specified device (GPU/CPU)
    model.to(device)

    best_val_loss = float('inf')
    epochs_without_improvement = 0
    best_model_state = None  # To store the best model state

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for data in train_loader:
            data = data.to(device)  # Move data to the correct device
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Calculate average training loss for the epoch
        avg_train_loss = train_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}", end="")

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_labels = []
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                out = model(data)
                loss = criterion(out, data.y)
                val_loss += loss.item()
                _, predicted = torch.max(out, dim=1)
                val_preds.extend(predicted.cpu().numpy())
                val_labels.extend(data.y.cpu().numpy())

        # Calculate average validation loss and accuracy for the epoch
        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = accuracy_score(val_labels, val_preds)
        print(f", Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_without_improvement = 0
            best_model_state = model.state_dict()  # Save the model with best validation loss
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print("Early stopping triggered. Training stopped.")
            break

    # Load the best model state
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return model

In [110]:
import torch
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
from torch_geometric.data import DataLoader

# 2. K-Fold Cross Validation
def k_fold_cross_validation(k=5, num_epochs=50, patience=10, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): # Use the device variable defined earlier
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    fold_accuracies = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(train_graph_data_list)):
        print(f"\nFold {fold + 1}/{k}")

        # Create train and validation sets
        train_data = [train_graph_data_list[i] for i in train_idx]
        val_data = [train_graph_data_list[i] for i in val_idx]

        # Create DataLoaders
        train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=16, shuffle=False)

        # Get protein_feature_dim from your data
        protein_feature_dim = train_graph_data_list[0].protein_feature.shape[1]  # Assuming protein_feature is a 2D tensor

        # Initialize the model, criterion, and optimizer
        model = GNNModel(num_node_features=10, num_classes=2, protein_feature_dim=protein_feature_dim, dropout_rate=0.2).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
        criterion = nn.CrossEntropyLoss()

        # Train the model with early stopping
        trained_model = train_with_early_stopping(model, train_loader, val_loader, criterion, optimizer, device, num_epochs, patience)

        # Evaluate the model on the validation set
        trained_model.eval()
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                if data.y is None:
                    continue
                out = trained_model(data)
                _, predicted = torch.max(out, dim=1)
                val_preds.extend(predicted.cpu().numpy())
                val_labels.extend(data.y.cpu().numpy())

        val_accuracy = accuracy_score(val_labels, val_preds)
        print(f"Fold {fold + 1} - Validation Accuracy: {val_accuracy:.4f}")
        fold_accuracies.append(val_accuracy)

    avg_accuracy = sum(fold_accuracies) / k
    print(f"\nAverage Validation Accuracy over {k} folds: {avg_accuracy:.4f}")

# Run K-Fold Cross-Validation, using the determined device
k_fold_cross_validation(k=5, num_epochs=50, patience=5, device=device) # Pass the device variable


Fold 1/5




Epoch 1/50, Train Loss: 0.6589, Val Loss: 0.6366, Val Accuracy: 0.6310
Epoch 2/50, Train Loss: 0.6388, Val Loss: 0.6307, Val Accuracy: 0.6384
Epoch 3/50, Train Loss: 0.6322, Val Loss: 0.6203, Val Accuracy: 0.6203
Epoch 4/50, Train Loss: 0.6115, Val Loss: 0.6472, Val Accuracy: 0.6328
Epoch 5/50, Train Loss: 0.5936, Val Loss: 0.5513, Val Accuracy: 0.7035
Epoch 6/50, Train Loss: 0.5810, Val Loss: 0.6033, Val Accuracy: 0.6747
Epoch 7/50, Train Loss: 0.5802, Val Loss: 0.6441, Val Accuracy: 0.6340
Epoch 8/50, Train Loss: 0.5701, Val Loss: 0.6450, Val Accuracy: 0.6705
Epoch 9/50, Train Loss: 0.5620, Val Loss: 0.6866, Val Accuracy: 0.6221
Epoch 10/50, Train Loss: 0.5590, Val Loss: 0.6328, Val Accuracy: 0.6830
Early stopping triggered. Training stopped.
Fold 1 - Validation Accuracy: 0.6830

Fold 2/5




KeyboardInterrupt: 