In [70]:
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Dataset
import numpy as np 
import os
from tqdm import tqdm
import deepchem as dc
from rdkit import Chem 


In [74]:
class MoleculeDataset(Dataset):
    def __init__(self, root, filename, test=False, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data). 
        """
        self.test = test
        self.filename = filename
        super(MoleculeDataset, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        """ If this file exists in raw_dir, the download is not triggered.
            (The download func. is not implemented here)  
        """
        return self.filename

    @property
    def processed_file_names(self):
        """ If these files are found in raw_dir, processing is skipped"""
        self.data = pd.read_csv(self.raw_paths[0]).reset_index()
        

        if self.test:
            return [f'data_test_{i}.pt' for i in list(self.data.index)]
        else:
            return [f'data_{i}.pt' for i in list(self.data.index)]
        

    def download(self):
        pass

    def process(self):
        self.data = pd.read_csv(self.raw_paths[0]).reset_index()
        featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
        for index, row in tqdm(self.data.iterrows(), total=self.data.shape[0]):
            # Featurize molecule
            mol = Chem.MolFromSmiles(row["smiles"])
            f = featurizer._featurize(mol)
            data = f.to_pyg_graph()
            data.y = self._get_label(row["HIV_active"])
            data.smiles = row["smiles"]
            if self.test:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_test_{index}.pt'))
            else:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_{index}.pt'))
            

    def _get_label(self, label):
        label = np.asarray([label])
        return torch.tensor(label, dtype=torch.int64)

    def len(self):
        return self.data.shape[0]

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_{idx}.pt'))        
        return data

In [75]:
dataset=MoleculeDataset(root="data/", filename="HIV_test.csv", test=True)

Processing...


      level_0  index                                             smiles  \
0           0      0  CCC1=[O+][Cu-3]2([O+]=C(CC)C1)[O+]=C(CC)CC(CC)...   
1           1      1  C(=Cc1ccccc1)C1=[O+][Cu-3]2([O+]=C(C=Cc3ccccc3...   
2           2      2                   CC(=O)N1c2ccccc2Sc2c1ccc1ccccc21   
3           3      3    Nc1ccc(C=Cc2ccc(N)cc2S(=O)(=O)O)c(S(=O)(=O)O)c1   
4           4      4                             O=S(=O)(O)CCS(=O)(=O)O   
...       ...    ...                                                ...   
3994     3994   3994  COc1cc(C[N+]23CC[N+](Cc4cc(OC)cc(OC)c4)(CC2)C3...   
3995     3995   3995   CCC1(O)C(=O)OCc2c1cc1n(c2=O)Cc2cc3cc(O)ccc3nc2-1   
3996     3996   3996  O=C(Oc1ccc(N=Cc2ccc3c(c2)OCO3)cc1)Oc1ccc(N=Cc2...   
3997     3997   3997     [N-]=[N+]=NC(=O)c1ccc(S(=O)(=O)c2ccc(F)cc2)cc1   
3998     3998   3998  C=CC(C)(C)c1c(OC(C)=O)cc2oc3c(OC(C)=O)c(OC(C)=...   

     activity  HIV_active  
0          CI           0  
1          CI           0  
2          CI  

100%|██████████| 3999/3999 [00:36<00:00, 110.24it/s]
Done!


In [76]:
dataset

MoleculeDataset2(3999)