In [1]:
import torch,os
from torch.utils.data import TensorDataset,random_split
import pandas as pd
from rdkit import Chem,DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
import os, random
import numpy as np
from dgllife.data import MoleculeCSVDataset
from functools import partial
from dgllife.utils import smiles_to_bigraph, RandomSplitter# ConsecutiveSplitter,

Using backend: pytorch


In [14]:
# Obtain the features of atoms and bonds
def load_coeff(mol):
    mol = Chem.MolToSmiles(mol, canonical=True)    
    df = pd.read_csv('train_set_5631.csv')
    sms=[Chem.MolToSmiles(Chem.MolFromSmiles(sm), canonical=True) for sm in df['smiles'].tolist()]
    idx=sms.index(mol)    
    coeff = torch.load('orb_coeff.t')
    return coeff[idx]


def featurize_atoms(mol):  
    feats = []
#    coo = load_coeff(mol)
    for atom in mol.GetAtoms():
        hy = [int(atom.GetHybridization()==y) for y in [Chem.rdchem.HybridizationType.SP,
              Chem.rdchem.HybridizationType.SP2,Chem.rdchem.HybridizationType.SP3]]
        feats.append([atom.GetAtomicNum(), atom.GetExplicitValence(), atom.GetImplicitValence(),
                      atom.GetTotalNumHs(),atom.GetDegree(), int(atom.GetIsAromatic())]+hy)#+[coo[atom.GetIdx()]]
#, atom.GetFormalCharge(), atom.GetNumRadicalElectrons(), int(atom.IsInRing())  
    return {'h': torch.tensor(feats).float()}


def featurize_edges(mol, add_self_loop=True):   
    feats = []
#    coo = load_coeff(mol)
    num_atoms = mol.GetNumAtoms()
    for i in range(num_atoms):
        for j in range(num_atoms):
            e_ij = mol.GetBondBetweenAtoms(i,j)
            if e_ij is None:
                bond_type = None
            else:
                bond_type = e_ij.GetBondType()
                

                feats.append([int(bond_type == x)for x in (Chem.rdchem.BondType.SINGLE,
                              Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC)])#+[coo[i] * coo[j]]
            if i == j:
                feats.append([0,0,0,0])    #+[coo[i] * coo[j]]       
    return {'e': torch.tensor(feats).float()}



if __name__ == "__main__":
    torch.manual_seed(1024)
    random.seed(1024)
    np.random.seed(1024)

    df = pd.read_csv('est/5000est.csv')
    sms=[Chem.MolToSmiles(Chem.MolFromSmiles(sm), canonical=True) for sm in df['smiles'].tolist()]
   
    # SMILES to graph-based dataset for prediction model with DGL-Life
    dataset=MoleculeCSVDataset(df=df,
                               smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True),
                               node_featurizer=featurize_atoms,
                               edge_featurizer=None,
#                               edge_featurizer=partial(featurize_edges, add_self_loop=True),                               
                               smiles_column='smiles',
                               cache_file_path='graph.pt',log_every=200)

    train_set, val_set, test_set = RandomSplitter.train_val_test_split(dataset, frac_train=0.8, frac_val=0.1, frac_test=0.1, random_state=1024)
    torch.save([train_set,val_set,test_set], "graph.bin")

Processing dgl graphs from scratch...
Processing molecule 200/5000
Processing molecule 400/5000
Processing molecule 600/5000
Processing molecule 800/5000
Processing molecule 1000/5000
Processing molecule 1200/5000
Processing molecule 1400/5000
Processing molecule 1600/5000
Processing molecule 1800/5000
Processing molecule 2000/5000
Processing molecule 2200/5000
Processing molecule 2400/5000
Processing molecule 2600/5000
Processing molecule 2800/5000
Processing molecule 3000/5000
Processing molecule 3200/5000
Processing molecule 3400/5000
Processing molecule 3600/5000
Processing molecule 3800/5000
Processing molecule 4000/5000
Processing molecule 4200/5000
Processing molecule 4400/5000
Processing molecule 4600/5000
Processing molecule 4800/5000
Processing molecule 5000/5000


In [16]:
train_set.load_full = True
train_set[0][1].ndata

{'h': tensor([[6., 3., 1., 1., 2., 1., 0., 1., 0.],
        [7., 2., 1., 1., 2., 0., 0., 1., 0.],
        [6., 3., 1., 1., 2., 1., 0., 1., 0.],
        [6., 3., 1., 1., 2., 1., 0., 1., 0.],
        [6., 4., 0., 0., 3., 1., 0., 1., 0.],
        [6., 3., 1., 1., 2., 1., 0., 1., 0.],
        [7., 3., 0., 0., 2., 1., 0., 1., 0.],
        [6., 4., 0., 0., 3., 1., 0., 1., 0.],
        [6., 3., 1., 1., 2., 1., 0., 1., 0.],
        [6., 4., 0., 0., 3., 1., 0., 1., 0.],
        [6., 3., 1., 1., 2., 1., 0., 1., 0.],
        [6., 3., 1., 1., 2., 1., 0., 1., 0.],
        [6., 4., 0., 0., 3., 1., 0., 1., 0.],
        [6., 3., 1., 1., 2., 1., 0., 1., 0.],
        [6., 4., 0., 0., 3., 1., 0., 1., 0.],
        [6., 3., 1., 1., 2., 1., 0., 1., 0.],
        [6., 4., 0., 0., 3., 1., 0., 1., 0.],
        [6., 4., 0., 0., 3., 1., 0., 1., 0.],
        [6., 4., 0., 0., 3., 1., 0., 1., 0.],
        [6., 3., 1., 1., 2., 1., 0., 1., 0.],
        [6., 3., 1., 1., 2., 1., 0., 1., 0.],
        [5., 3., 0., 0., 3.,