In [None]:
#neccessary imports
import pandas as pd
import pyarrow.dataset as ds
from loguru import logger
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
import tensorflow as tf
import dgl
from dgl import batch as dgl_batch
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
from torch.utils.data import Dataset, DataLoader
from dgl.nn import GraphConv, GlobalAttentionPooling

wandb.login()

import torch_geometric
from torch_geometric.data import Data
print("PyTorch version:", torch.__version__)
print("DGL version:", dgl.__version__)
print("CUDA available in PyTorch:", torch.cuda.is_available())

# Check CUDA support in DGL
try:
    # Create a simple graph
    g = dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2])))

    # Try to move the graph to GPU
    g = g.to('cuda')
    print("CUDA available in DGL: True")
except:
    print("CUDA available in DGL: False")
print(dgl.__version__)

In [None]:
data_combined = ds.dataset(source="../../../data/train_combined.parquet", format="parquet")
# Convert the PyArrow table to a Pandas dataframe
df = data_combined.to_table().to_pandas()
from sklearn.model_selection import train_test_split

# Define a function to split the data for training and testing 
# keeping the distribution of the target variable the same in both sets and the bind vs no bind ratio the same
def split_data(df, target_col, test_size=0.2, random_state=42):
    df_train, df_test = train_test_split(df, test_size=test_size, random_state=random_state, stratify=df[target_col])
    return df_train, df_test


train_brd4,test_brd4=split_data(df[df['protein_name']=='BRD4'], "binds")
train_hsa,test_hsa=split_data(df[df['protein_name']=='HSA'], "binds")
train_seh,test_seh=split_data(df[df['protein_name']=='sEH'], "binds")

logger.info(f"Train set BRD4 {len(train_brd4)}")
logger.info(f"Test set BRD4 {len(test_brd4)}")
logger.info(f"Train set HSA {len(train_hsa)}")
logger.info(f"Test set HSA {len(test_hsa)}")
logger.info(f"Train set sEH {len(train_seh)}")
logger.info(f"Test set sEH {len(test_seh)}")

logger.info(f"Train set bind ratio BRD4 {len(train_brd4[train_brd4['binds']==1])/len(train_brd4)}")
logger.info(f"Test set bind ratio BRD4 {len(test_brd4[test_brd4['binds']==1])/len(test_brd4)}")
logger.info(f"Train set bind ratio HSA {len(train_hsa[train_hsa['binds']==1])/len(train_hsa)}")
logger.info(f"Test set bind ratio HSA {len(test_hsa[test_hsa['binds']==1])/len(test_hsa)}")
logger.info(f"Train set bind ratio sEH {len(train_seh[train_seh['binds']==1])/len(train_seh)}")
logger.info(f"Test set bind ratio sEH {len(test_seh[test_seh['binds']==1])/len(test_seh)}")

train_combined = pd.concat([train_brd4, train_hsa, train_seh])
test_combined = pd.concat([test_brd4, test_hsa, test_seh])


In [None]:
def clean_smi(smi: str | list):
    r""" Clean a SMILES string by removing salts and fragments.
    Parameters
    ----------
    smi : str | list
        The SMILES string for a molecule. or a list of SMILES strings
    Returns
    -------
    str | list
        The cleaned SMILES string.
    """
    if isinstance(smi, list):
        return [clean_smi(s) for s in smi]
    # Remove [Dy] from smiles
    smi = smi.replace("[Dy]", "")

    # Convert SMILES to a RDKit molecule object
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        raise ValueError("Invalid SMILES string")
    
    # Remove any salts or fragments
    mol = Chem.RemoveHs(mol)  # Remove explicit hydrogens
    fragments = Chem.GetMolFrags(mol, asMols=True)
    
    # Keep the largest fragment
    largest_fragment = max(fragments, default=mol, key=lambda m: m.GetNumAtoms())
    
    # Standardize the molecule
    AllChem.Compute2DCoords(largest_fragment)  # Compute 2D coordinates
    
    # Convert the molecule back to a canonical SMILES string
    cleaned_smiles = Chem.MolToSmiles(largest_fragment, canonical=True)
    return cleaned_smiles


def smiles_to_dgl_graph(smiles: str |list):
    r""" Convert a SMILES string to a DGLGraph.
    Parameters
    ----------
    smiles : str | list
        The SMILES string for a molecule. or a list of SMILES strings
    Returns
    -------
    DGLGraph
        A DGLGraph object for the molecule.
    """
    if isinstance(smiles, list):
        return [smiles_to_dgl_graph(s) for s in smiles]
    clean_smiles=clean_smi(smiles)
    mol = Chem.MolFromSmiles(clean_smiles)
    if mol is None:
        return None

    # Node features
    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append([
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            atom.GetHybridization(),
            atom.GetIsAromatic(),
            atom.GetTotalNumHs()
        ])
    
    # Edge features and adjacency list
    src, dst = [], []
    bond_features = []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        src.append(start)
        dst.append(end)
        bond_features.append([
            int(bond.GetBondTypeAsDouble()), 
            bond.GetBondType(),
            bond.GetIsConjugated(),
            bond.IsInRing()
        ])
    """
    g = dgl.graph((src, dst))
    g.ndata['h'] = torch.tensor(atom_features, dtype=torch.float)
    g.edata['h'] = torch.tensor(bond_features, dtype=torch.float)
    g= dgl.add_self_loop(g)
    return g
    """
    edge_index = torch.tensor([src, dst], dtype=torch.long)
    node_features = torch.tensor(atom_features, dtype=torch.float)
    edge_features = torch.tensor(bond_features, dtype=torch.float)

    data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)
    return data

#create a dictionary with the protein names as keys and the protein sequences as values
protein_sequences = {
    "BRD4": "NPPPPETSNPNKPKRQTNQLQYLLRVVLKTLWKHQFAWPFQQPVDAVKLNLPDYYKIIKTPMDMGTIKKRLENNYYWNAQECIQDFNTMFTNCYIYNKPGDDIVLMAEALEKLFLQKINELPTEETEIMIVQAKGRGRGRKETGTAKPGVSTVPNTTQASTPPQTQTPQPNPPPVQATPHPFPAVTPDLIVQTPVMTVVPPQPLQTPPPVPPQPQPPPAPAPQPVQSHPPIIAATPQPVKTKKGVKRKADTTTPTTIDPIHEPPSLPPEPKTTKLGQRRESSRPVKPPKKDVPDSQQHPAPEKSSKVSEQLKCCSGILKEMFAKKHAAYAWPFYKPVDVEALGLHDYCDIIKHPMDMSTIKSKLEAREYRDAQEFGADVRLMFSNCYKYNPPDHEVVAMARKLQDVFEMRFAKMPDE",
    "sEH": "TLRAAVFDLDGVLALPAVFGVLGRTEEALALPRGLLNDAFQKGGPEGATTRLMKGEITLSQWIPLMEENCRKCSETAKVCLPKNFSIKEIFDKAISARKINRPMLQAALMLRKKGFTTAILTNTWLDDRAERDGLAQLMCELKMHFDFLIESCQVGMVKPEPQIYKFLLDTLKASPSEVVFLDDIGANLKPARDLGMVTILVQDTDTALKELEKVTGIQLLNTPAPLPTSCNPSDMSHGYVTVKPRVRLHFVELGSGPAVCLCHGFPESWYSWRYQIPALAQAGYRVLAMDMKGYGESSAPPEIEEYCMEVLCKEMVTFLDKLGLSQAVFIGHDWGGMLVWYMALFYPERVRAVASLNTPFIPANPNMSPLESIKANPVFDYQLYFQEPGVAEAELEQNLSRTFKSLFRASDESVLSMHKVCEAGGLFVNSPEEPSLSRMVTEEEIQFYVQQFKKSGFRGPLNWYRNMERNWKWACKSLGRKILIPALMVTAEKDFVLVPQMSQHMEDWIPHLKRGHIEDCGHWTQMDKPTEVNQILIKWLDSDARNPPVVSKM",
    "HSA": "DAHKSEVAHRFKDLGEENFKALVLIAFAQYLQQCPFEDHVKLVNEVTEFAKTCVADESAENCDKSLHTLFGDKLCTVATLRETYGEMADCCAKQEPERNECFLQHKDDNPNLPRLVRPEVDVMCTAFHDNEETFLKKYLYEIARRHPYFYAPELLFFAKRYKAAFTECCQAADKAACLLPKLDELRDEGKASSAKQRLKCASLQKFGERAFKAWAVARLSQRFPKAEFAEVSKLVTDLTKVHTECCHGDLLECADDRADLAKYICENQDSISSKLKECCEKPLLEKSHCIAEVENDEMPADLPSLAADFVESKDVCKNYAEAKDVFLGMFLYEYARRHPDYSVVLLLRLAKTYETTLEKCCAAADPHECYAKVFDEFKPLVEEPQNLIKQNCELFEQLGEYKFQNALLVRYTKKVPQVSTPTLVEVSRNLGKVGSKCCKHPEAKRMPCAEDYLSVVLNQLCVLHEKTPVSDRVTKCCTESLVNRRPCFSALEVDETYVPKEFNAETFTFHADICTLSEKERQIKKQTALVELVKHKPKATKEQLKAVMDDFAAFVEKCCKADDKETCFAEEGKKLVAASQAALGL",
}

amino_acids = "ACDEFGHIKLMNPQRSTVWY"
aa_to_index = {aa: idx for idx, aa in enumerate(amino_acids)}


def one_hot_encode_sequence(sequence, aa_to_index=aa_to_index):
    one_hot_encoded = np.zeros((len(sequence), len(amino_acids)), dtype=np.float32)
    for i, aa in enumerate(sequence):
        if aa in aa_to_index:
            one_hot_encoded[i, aa_to_index[aa]] = 1.0
    return one_hot_encoded

# Function to generate one-hot encoded features for proteins
def generate_protein_one_hot_features(protein_sequences):
    features = {}
    for protein_name, sequence in protein_sequences.items():
        one_hot_features = one_hot_encode_sequence(sequence)
        features[protein_name] = one_hot_features
    return features

def protein_to_features(protein_seq: str):
    #one hot encode the amino acid composition
    one_hot_encoded = one_hot_encode_sequence(protein_seq)
    amino_acid_composition = torch.tensor(one_hot_encoded, dtype=torch.float)
    # Hydrophobicity
    hydrophobicity = [0] * len(protein_seq)
    for i, amino_acid in enumerate(protein_seq):
        if amino_acid in ['A', 'I', 'L', 'M', 'F', 'W', 'V']:
            hydrophobicity[i] = 1
    hydrophobicity = torch.tensor(hydrophobicity, dtype=torch.float).unsqueeze(1)
    # Charge
    charge = [0] * len(protein_seq)
    for i, amino_acid in enumerate(protein_seq):
        if amino_acid in ['K', 'R']:
            charge[i] = 1
        elif amino_acid in ['D', 'E']:
            charge[i] = -1
    charge = torch.tensor(charge, dtype=torch.float).unsqueeze(1)
    # Concatenate features
    protein_features = torch.cat((amino_acid_composition, hydrophobicity, charge), dim=1)
    return protein_features

def protein_to_graph(protein_seq, protein_features, neighbor_distance=3):
    node_features = protein_features
    edge_index = []
    for i in range(len(protein_seq)):
        for j in range(i + 1, min(i + neighbor_distance + 1, len(protein_seq))):
            edge_index.append([i, j])
            edge_index.append([j, i])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t()

    data = Data(x=node_features, edge_index=edge_index)
    return data



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_add_pool, global_mean_pool, global_max_pool

class ProteinLigandGCN(torch.nn.Module):
    def __init__(self, protein_feature_dim, ligand_feature_dim, hidden_dim, output_dim,dropout=0.2):
        super(ProteinLigandGCN, self).__init__()
        # Initialize GCN layers for protein
        self.protein_conv1 = GCNConv(protein_feature_dim, hidden_dim)
        self.protein_conv2 = GCNConv(hidden_dim, hidden_dim)

        # Initialize GCN layers for ligand
        self.ligand_conv1 = GCNConv(ligand_feature_dim, hidden_dim)
        self.ligand_conv2 = GCNConv(hidden_dim, hidden_dim)

        # Dropout layer
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
        # Fully connected layers for each graph after GCN layers
        self.protein_fc = nn.Linear(hidden_dim, hidden_dim)
        self.ligand_fc = nn.Linear(hidden_dim, hidden_dim)

        # Final fully connected layer after combining
        self.final_fc = nn.Linear(2 * hidden_dim, output_dim)

    def forward(self, protein_data, ligand_data):
        # Protein graph processing
        protein_graph, protein_features = protein_data
        ligand_graph, ligand_features = ligand_data
        
        # Apply GCN layers to protein graph
        protein_x = F.relu(self.protein_conv1(protein_features, protein_graph.edge_index))
        protein_x = self.dropout1(protein_x)
        protein_x = F.relu(self.protein_conv2(protein_x, protein_graph.edge_index))
        protein_x = self.dropout2(protein_x)
        
        # Pooling (choose one of the following or implement your own)
        #protein_output = global_add_pool(protein_x, protein_graph.batch)  # Global sum pooling
        protein_output = global_mean_pool(protein_x, protein_graph.batch)  # Global mean pooling
        # protein_output = global_max_pool(protein_x, protein_graph.batch)  # Global max pooling
        
        protein_output = F.relu(self.protein_fc(protein_output))

        # Apply GCN layers to ligand graph
        ligand_x = F.relu(self.ligand_conv1(ligand_features, ligand_graph.edge_index))
        ligand_x = self.dropout1(ligand_x)
        ligand_x = F.relu(self.ligand_conv2(ligand_x, ligand_graph.edge_index))
        ligand_x = self.dropout2(ligand_x)
        
        # Pooling (choose one of the following or implement your own)
        #ligand_output = global_add_pool(ligand_x, ligand_graph.batch)  # Global sum pooling
        ligand_output = global_mean_pool(ligand_x, ligand_graph.batch)  # Global mean pooling
        # ligand_output = global_max_pool(ligand_x, ligand_graph.batch)  # Global max pooling
        
        ligand_output = F.relu(self.ligand_fc(ligand_output))

        # Combine protein and ligand outputs
        combined_output = torch.cat((protein_output, ligand_output), dim=1)
        final_output = self.final_fc(combined_output)

        return final_output


In [None]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Batch
class ProteinLigandDataset(Dataset):
    def __init__(self, data, protein_sequences, transform=None):
        self.data = data
        self.protein_sequences = protein_sequences
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        protein_name = self.data.iloc[idx]["protein_name"]
        molecule_smiles = self.data.iloc[idx]["molecule_smiles"]
        #check if there is a label in the data
        if "binds" in self.data.columns:
            label = self.data.iloc[idx]["binds"]
        else:
            label = None
        #label = self.data.iloc[idx]["binds"]
        
        protein_graph = protein_to_graph(self.protein_sequences[protein_name], protein_to_features(self.protein_sequences[protein_name]))
        smiles_graph = smiles_to_dgl_graph(molecule_smiles)
        """
        protein_features = protein_graph.ndata['h']
        smiles_features = smiles_graph.ndata['h']
                    'protein_features': protein_features,
                                'smiles_features': smiles_features,
        """
        sample = {
            'protein_graph': protein_graph,
            'smiles_graph': smiles_graph,
            'label': torch.tensor(label, dtype=torch.float32) if label is not None else None
        }
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample


def custom_collate_fn(batch):
    protein_graphs = [item['protein_graph'] for item in batch]
    #protein_features = [item['protein_features'] for item in batch]
    smiles_graphs = [item['smiles_graph'] for item in batch]
    #smiles_features = [item['smiles_features'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.float32)
    """
    batched_protein_graph = dgl_batch(protein_graphs)
    batched_smiles_graph = dgl_batch(smiles_graphs)
    
    batched_protein_features = torch.cat(protein_features)
    batched_smiles_features = torch.cat(smiles_features)
    """
    batched_protein_graph = Batch.from_data_list(protein_graphs)
    batched_smiles_graph = Batch.from_data_list(smiles_graphs)
    return {
        'protein_graph': batched_protein_graph,
        'smiles_graph': batched_smiles_graph,
        'label': labels if labels[0] is not None else None
    }


In [None]:
# Initialize the datasets

batch_size = 128
train_dataset = ProteinLigandDataset(train_combined, protein_sequences)
val_dataset = ProteinLigandDataset(test_combined, protein_sequences)
# Initialize the dataloaders with custom collate function
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)


In [None]:
#get 100,000 samples from the train_combined for each protein where binds=1
train_combined_pos=train_combined[train_combined['binds']==1]
train_combined_neg=train_combined[train_combined['binds']==0]

#get 100,000 samples from the train_combined for brd4 where binds=1 and 100,000 where binds=0
train_brd4_pos=train_combined_pos[train_combined_pos['protein_name']=='BRD4'].sample(n=100000)
train_brd4_neg=train_combined_neg[train_combined_neg['protein_name']=='BRD4'].sample(n=100000)

#get 100,000 samples from the train_combined for hsa where binds=1 and 100,000 where binds=0
train_hsa_pos=train_combined_pos[train_combined_pos['protein_name']=='HSA'].sample(n=100000)
train_hsa_neg=train_combined_neg[train_combined_neg['protein_name']=='HSA'].sample(n=100000)


#get 100,000 samples from the train_combined for seh where binds=1 and 100,000 where binds=0
train_seh_pos=train_combined_pos[train_combined_pos['protein_name']=='sEH'].sample(n=100000)
train_seh_neg=train_combined_neg[train_combined_neg['protein_name']=='sEH'].sample(n=100000)

train_combined_sampled = pd.concat([train_brd4_pos, train_brd4_neg, train_hsa_pos, train_hsa_neg, train_seh_pos, train_seh_neg])

#do the same for the test_combined
test_combined_pos=test_combined[test_combined['binds']==1]
test_combined_neg=test_combined[test_combined['binds']==0]

#get 100,000 samples from the test_combined for brd4 where binds=1 and 100,000 where binds=0
test_brd4_pos=test_combined_pos[test_combined_pos['protein_name']=='BRD4'].sample(n=10000)
test_brd4_neg=test_combined_neg[test_combined_neg['protein_name']=='BRD4'].sample(n=10000)

#get 100,000 samples from the test_combined for hsa where binds=1 and 100,000 where binds=0
test_hsa_pos=test_combined_pos[test_combined_pos['protein_name']=='HSA'].sample(n=10000)
test_hsa_neg=test_combined_neg[test_combined_neg['protein_name']=='HSA'].sample(n=10000)

test_seh_pos=test_combined_pos[test_combined_pos['protein_name']=='sEH'].sample(n=10000)
test_seh_neg=test_combined_neg[test_combined_neg['protein_name']=='sEH'].sample(n=10000)

test_combined_sampled = pd.concat([test_brd4_pos, test_brd4_neg, test_hsa_pos, test_hsa_neg, test_seh_pos, test_seh_neg])

# Initialize the datasets
sampled_train_dataset = ProteinLigandDataset(train_combined_sampled, protein_sequences)
sampled_val_dataset = ProteinLigandDataset(test_combined_sampled, protein_sequences)


# Initialize the dataloaders with custom collate function
sampled_train_loader = DataLoader(sampled_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
sampled_val_loader = DataLoader(sampled_val_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)

#print(test_combined_sampled.tail())
#print(train_combined_sampled.head())

In [None]:

global_step = 0
batch_size = 128
learning_rate = 0.1
wandb.init(project="BELKA_NeruIPS", entity="mayarahmed",config={"learning_rate":learning_rate,"architecture":"GCN","dataset":"BELKA","epochs":100,"batch_size":batch_size})
logger.info("Initialized WandB run")
logger.info(f'batch size: {batch_size}')
protein_node_feats = 22
smiles_node_feats = 6
edge_feats = 4
hidden_dim = 128
output_dim = 1


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using {device} device")
# Initialize the model
model = ProteinLigandGCN(protein_node_feats, smiles_node_feats, hidden_dim, output_dim).to(device)


checkpoint = torch.load('best_model_30.pt')

wandb.watch(model, log_freq=2) 
logger.info(checkpoint.keys())
model.load_state_dict(checkpoint['odel_state_dict'])
logger.info("Loaded model state from checkpoint")
penalty=torch.tensor([len(train_combined[train_combined['binds']==0])/len(train_combined[train_combined['binds']==1])] ,dtype=torch.float32).to(device)

# Define the loss function and optimizer
criterion = nn.BCEWithLogitsLoss(pos_weight=penalty)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,eta_min=0.0001,T_0=10,T_mult=2)
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.25, patience=5)


if 'optimizer_state_dict' in checkpoint:
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    logger.info("Loaded optimizer state from checkpoint")
model.train()

# Train the model
best_loss = float('inf')
patience = 5  # number of epochs to wait before stopping training
patience_counter = 0
min_delta = 0.001  # minimum difference between train and val loss to consider convergence

for epoch in range(1,100):
    model.train()
    train_loss = 0
    optimizer.zero_grad()
    for batch_ix, batch in enumerate(train_loader):
        protein_graph = batch['protein_graph'].to(device)
        #get the protein features from the protein graph
        protein_features = protein_graph.x.to(device)
        #protein_features = batch['protein_features'].to(device)
        smiles_graph = batch['smiles_graph'].to(device)
        #smiles_features = batch['smiles_features'].to(device)
        smiles_features = smiles_graph.x.to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        # Forward pass
        output = model((protein_graph, protein_features), (smiles_graph, smiles_features))
        # Calculate the loss
        loss = criterion(output.squeeze(1), labels)
        # Backward pass
        loss.backward()
        optimizer.step()

        # Log the loss and accumulate train loss
        wandb.log({"epoch": epoch, "train_loss": loss.item(), "trainer/global_step": global_step,"batch":batch_ix,"learning_rate":optimizer.param_groups[0]['lr']})
        train_loss += loss.item()
        global_step += 1
    
    # Calculate average train loss
    train_loss /= len(train_loader)

    # Evaluate on validation set
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            protein_graph = batch['protein_graph'].to(device)
            #protein_features = batch['protein_features'].to(device)
            protein_features=protein_graph.x.to(device)
            smiles_graph = batch['smiles_graph'].to(device)
            #smiles_features = batch['smiles_features'].to(device)
            smiles_features=smiles_graph.x.to(device)
            labels = batch['label'].to(device)        
            output = model((protein_graph, protein_features), (smiles_graph, smiles_features))
            val_loss_tmp=criterion(output.squeeze(1), labels).item()
            val_loss += val_loss_tmp
            wandb.log({"val_loss":val_loss_tmp, "trainer/global_step": global_step,'epoch':epoch})

        val_loss /= len(val_loader)
        scheduler.step(val_loss)
        
    # Check if the model is converging
    if abs(train_loss - val_loss) < min_delta and val_loss < 0.5:
        patience_counter += 1
        if patience_counter >= patience:
            logger.info("Model is converging. Stopping training.")
            break
    else:
        patience_counter = 0

    # Log the loss to WandB
    wandb.log({"average_train_loss": train_loss, "average_val_loss": val_loss, "epoch": epoch})

    # Save the model with the lowest loss
    if val_loss < best_loss:
        best_loss = val_loss
        # Save the model and optimizer state
        torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'average_train_loss': train_loss,
        'average_val_loss': val_loss,
        }, f'best_model_{str(epoch)}.pt')

        wandb.save(f"best_model_{str(epoch)}.pt")

    # Print the loss
    logger.info(f'Epoch {epoch}, Loss: {train_loss}, Val Loss: {val_loss}')

# Finish the WandB run
wandb.finish()
