In [52]:
import pandas as pd
import torch
from lightning import pytorch as pl
from chemprop import data, models, nn
import json
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import IterableDataset
import numpy as np
import torch
from chemprop import data
import rdkit
from rdkit import Chem
from chemprop import data, featurizers
from chemprop.featurizers.molecule import MorganBinaryFeaturizer
import math
import math
from torch.utils.data import IterableDataset
from chemprop.data.collate import collate_batch
from sklearn.preprocessing import StandardScaler

In [53]:
class Data_Preprocessor:
    '''A class to prepare Chemprop dataset from Pandas dataframe.'''
    
    def __init__(self):
        pass

    
    def is_hbd(self,atom):
        '''Check if an atom is a Hydrogen Bond Donor (HBD). An atom is considered an HBD if it's N or O with at least one hydrogen.
        
        Parameters:
        ----------
        atom: RDKit atom object.
            
        Returns:
        ----------
        bool: True if atom is HBD, False otherwise.
        '''
        
        if atom.GetAtomicNum() not in [7, 8]:  # 7 for N, 8 for O
            return False
        
        n_hydrogens = atom.GetTotalNumHs()
        return n_hydrogens > 0

    
    def is_hba(self,atom):
        '''Check if an atom is a Hydrogen Bond Acceptor (HBA). An atom is considered an HBA if it's N or O with a lone pair electron
        
        Parameters:
        ----------
        atom: RDKit atom object.
            
        Returns:
        ----------
        bool: True if atom is HBD, False otherwise.
        '''
        
        atomic_num = atom.GetAtomicNum()
        if atomic_num not in [7, 8]: 
            return False
        
        valence = atom.GetTotalValence()
        if atomic_num == 7:  
            return valence <= 3 
        else: 
            return valence <= 2  
        

    def get_mol_HBD_HBA(self,mols):
        '''A function to generate HBD_HBA properties for molecules
        
        Parameters:
        ---------
        mols (list): list of RDKit mol objects.
        
        Returns:
        ----------
        mol_HBs (list): list of array that contain HBD-HBA descriptor for molecules, shape of each array is (n_atom, 2)  '''
        mol_HBs = []
        for mol in mols:
            mol_HB = [[],[]]
            for atom in mol.GetAtoms():
                if self.is_hbd(atom):
                    mol_HB[0].append(1)
                else:
                    mol_HB[0].append(0)
                    
                if self.is_hba(atom):
                    mol_HB[1].append(1)
                else:
                    mol_HB[1].append(0)
            mol_HB = np.array(mol_HB).T
            mol_HBs.append(mol_HB)
        return mol_HBs

    

    def dataset_generator(self):
        '''Prepare chemprop dataset without additional HBD/HBA feature.
    
        Returns:
        ----------
        dataset (Chemprop dataset): Chemprop dataset.
        '''
        
        morgan_fp = MorganBinaryFeaturizer()
        def datapoint_generator(df,smiles,y,addH,HB,morgan):
            smis = df.loc[:,smiles].values
            ys = df.loc[:,[y]].values
            mols = [Chem.MolFromSmiles(smi) for smi in smis]

            if HB:
                mol_HBs = self.get_mol_HBD_HBA(mols)
            else:
                mol_HBs = [None]*len(smis)

            if morgan:
                x_ds = [morgan_fp(mol) for mol in mols]
            else:
                x_ds = [None]*len(smis)
            
            datapoints = [data.MoleculeDatapoint.from_smi(smi,y,add_h=addH, V_f = mol_HB, x_d = x_d) for smi, y, mol_HB, x_d in zip(smis,ys,mol_HBs,x_ds)]
            return datapoints

        datapoints = datapoint_generator(df=self.df,smiles=self.smiles_column,y=self.target_column,addH=self.addH,HB=self.HB,morgan=self.morgan)
        dataset = data.MoleculeDataset(datapoints, featurizer=self.featurizer)
        return dataset
    

    

    def generate(self, df, smiles_column = 'smiles', target_column='docking_score', addH=False, HB = False, morgan = False,
                 featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()):
        '''Generate chemprop dataset according to a given configuration

        Parameters:
        ----------
        df (Pandas DataFrame): a data frame that contains SMILES code of compounds.
        smiles_column (str): a string that indicates SMILES column in the data frame.
        target_column (str): a string that indicates the target column (i.e. docking_scores, solubility) in the data frame.
        addH (boolean): to incorporate explicit hydrogen atoms into a molecular graph.
        HB (boolean): to incorporate additional HBD/HBA features for each atom in BatchMolGraph.
        morgan (boolean): to incorporate morgan binaray fingerprint for each molecules
        featurizer (Chemprop Featurizer): a Featurizer from Chemprop to encode features for atoms, bonds, and molecules.
    
        Returns:
        ----------
        dataset (Chemprop dataset): Chemprop dataset
        '''
                     
        self.df = df
        self.smiles_column = smiles_column
        self.target_column = target_column
        self.addH = addH
        self.HB = HB
        self.featurizer = featurizer
        self.morgan = morgan
        

        return self.dataset_generator()
    

In [58]:
class StreamingMolDataset(IterableDataset):
    def __init__(self, df, smiles_column, target_column, data_generator, scaler, batch_size=64, shuffle=True):
        self.df = df
        self.smiles_column = smiles_column
        self.target_column = target_column
        self.data_generator = data_generator
        self.batch_size = batch_size
        self.shuffle= shuffle
        self.scaler = scaler


    def __iter__(self):
        # Shuffle self.df at the start of each epoch
        if self.shuffle:
            df_shuffled = self.df.sample(frac=1).reset_index(drop=True)
        else:
            df_shuffled = self.df.copy()

        # Process data in batches, yielding each batch of processed data
        for i in range(0, len(self.df), self.batch_size):
            df_chunk = self.df.iloc[i:i + self.batch_size]
            # Generate processed data using the data generator
            df_process = self.data_generator.generate(
                df=df_chunk,
                smiles_column=self.smiles_column,
                target_column=self.target_column, addH =False, HB = False, morgan = False
            )
            df_process.normalize_targets(self.scaler)
        
        # Yield all the samples in the current batch
            for mol in df_process:  # Debug what is being yielded
                yield mol
            

In [55]:
data_path = '../../../DRD2_diverse_data.csv'
smiles_column = 'smiles'
target_column = 'docking_score'
epochs = 50
batch_size = 64
model_config = '../../../hyperparam_optim_5/model.json'

with open(model_config, 'r') as file:
    model_config = json.load(file)
    model_train_config = model_config['train_loop_config']

# Prepare data
df = pd.read_csv(data_path)
df_train = df[df['split_random_1']!='test']
df_test = df[df['split_random_1']=='test']
num_compounds = df_train.shape[0]

# Establish model
mp = nn.BondMessagePassing(d_h = model_train_config['message_hidden_dim'],
                           dropout=model_train_config['dropout'],
                           depth=model_train_config['depth'])

agg = nn.SumAggregation()

ffn = nn.RegressionFFN(n_layers=model_train_config['ffn_num_layers'],
                       dropout=model_train_config['dropout'],
                       input_dim=model_train_config['message_hidden_dim'],
                       hidden_dim=model_train_config['ffn_hidden_dim'])
metric_list = [nn.metrics.RMSE(), nn.metrics.MAE(), nn.metrics.R2Score()]

mpnn = models.MPNN(message_passing=mp, 
                   agg = agg, 
                   predictor=ffn, 
                   batch_norm=False, 
                   metrics=metric_list,
                   warmup_epochs=model_train_config['warmup_epochs'],
                   init_lr=model_train_config['init_lr_ratio'],
                   max_lr=model_train_config['max_lr'],
                   final_lr=model_train_config['final_lr_ratio'])

#mpnn = models.MPNN.load_from_checkpoint('../../../hyperparam_optim_5/best_checkpoint.ckpt')
scaler = StandardScaler().fit(df_train[[target_column]])
data_generator = Data_Preprocessor()


train_streaming_dataset = StreamingMolDataset(
    df=df_train,
    smiles_column=smiles_column,
    target_column=target_column,
    data_generator=data_generator, 
    batch_size=batch_size, scaler=scaler
)

train_loader = torch.utils.data.DataLoader(
    train_streaming_dataset,
    batch_size=batch_size,
    collate_fn=collate_batch)

test_streaming_dataset = StreamingMolDataset(
    df=df_test,
    smiles_column=smiles_column,
    target_column=target_column,
    data_generator=data_generator,
    batch_size=batch_size, scaler=scaler
)

test_loader = torch.utils.data.DataLoader(
    test_streaming_dataset,
    batch_size=batch_size,
    collate_fn=collate_batch)



  df = pd.read_csv(data_path)


In [57]:
import warnings
warnings.filterwarnings("ignore", message="X does not have valid feature names.*")


checkpointing = ModelCheckpoint(
    "checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)


trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=epochs,
    callbacks=[checkpointing]
)

trainer.fit(mpnn, train_dataloaders=train_loader, val_dataloaders=test_loader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/course/.conda/envs/long_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | message_passing | BondMessagePassing | 11.9 M | train
1 | agg             | SumAggregation     | 0      | train
2 | bn              | Identity           | 0      | train
3 | predictor       | RegressionFFN      | 5.5 M  | train
4 | X_d_transform   | Identity           | 0      | train
5 | metrics         | ModuleList     

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/course/.conda/envs/long_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [56]:
df_train

Unnamed: 0,id,smiles,docking_score,split_random_1,split_random_2,split_random_3,split_random_4,split_random_5,weight_lowscores
0,15945052,CC(=O)OCC1(CC23CCC4C(C2CCC1C3)(CCCC4(C)C(=O)O)...,-0.023991,val,train,train,train,train,0.526316
1,42628272,CCOC(=O)C1C(=CCC(N1S(=O)(=O)C2=CC=C(C=C2)C)C3=...,-0.024816,train,train,train,val,train,0.526316
2,16187407,CCN(CC)CCC(C)OC(=O)C1=C(C=C(C=C1)O)O.Cl,-0.031822,train,train,val,train,train,0.526316
3,3580396,CC1CC2=CC=CC=C2N1C(=O)C3=CC=CC=C3NS(=O)(=O)C4=...,-0.043640,train,train,train,train,val,0.526316
4,9614845,CN(C)CCNC(=O)C=NO,-0.076320,train,train,train,val,train,0.526316
...,...,...,...,...,...,...,...,...,...
995754,CP002847469043,CCCCCN(CCCCC)CCOCCOCC,0.192226,train,train,val,train,train,0.526316
995755,CP003439278953,CC(C)(CN(C)C)C(=O)N(C)CCN(C)C(=O)C1(CCC1)N(C)C,0.436100,train,train,train,train,val,0.526316
995756,CP002164348858,Cc1oc(c(c1)C(=O)N[C@H]2C[C@]3(C2)N(CCCC3)C(=O)...,0.596025,val,train,train,train,train,0.526316
995757,CP003228611624,CCC(CC)(CO)C(=O)N1CCN(C2(C1)CCC2)C(=O)C(=C(C)C)C,2.564700,train,val,train,train,train,0.526316
