### In this file, we will reconstruct our data into graphs

Instructions: You should probably have another separate notebook that creates the graph version of the dataset. Again, you should save the data, and for this make use to use the save_graph and load_graph functions of DGL.

Important Libraries:
ase
Structure/Geometry of a molecule

In [1]:
import pandas as pd
import torch
from rdkit import Chem
from dgllife.utils import smiles_to_bigraph
from dgl.data.utils import save_graphs, load_graphs

Using backend: pytorch


In [12]:
feature_map = lambda atom: [atom.GetAtomicNum(), atom.GetDegree(), atom.GetTotalDegree(), atom.GetExplicitValence(),
            atom.GetImplicitValence(), atom.GetTotalNumHs(), atom.GetFormalCharge(), atom.GetNumRadicalElectrons(),
            atom.GetIsAromatic(), atom.IsInRing(), atom.GetMass() * 0.01]

def featurize_atoms(mol):
    feats = []
    for atom in mol.GetAtoms():
        feats.append(feature_map(atom))
    return {'atom_feats': torch.tensor(feats).reshape(-1, 11).float()}

In [3]:
def featurize_bonds(mol):
    feats = []
    bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                  Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
    for bond in mol.GetBonds():
        btype = bond_types.index(bond.GetBondType())
        is_conjugated = bond.GetIsConjugated()
        is_in_ring = bond.IsInRing()
        stereo_config = bond.GetStereo()
        direction = bond.GetBondDir()
        feats.extend([btype, btype])
        feats.extend([is_conjugated, is_conjugated])
        feats.extend([is_in_ring, is_in_ring])
        feats.extend([stereo_config, stereo_config])
        feats.extend([direction, direction])
    return {'bond_feats': torch.tensor(feats).reshape(-1, 5)}

### Generate Graph objects from our Dataset(s)

In [6]:
from tqdm.notebook import tqdm

In [8]:
def get_graphs(dataset_name):
    df = pd.read_csv(dataset_name)
    
    graphs = []
    for smile in tqdm(df["SMILES"]):
        graphs.append(smiles_to_bigraph(smile,
                                       node_featurizer = featurize_atoms,
                                       edge_featurizer = featurize_bonds,
                                       explicit_hydrogens = True)
                     )
    return graphs

In [13]:
graphs = get_graphs("Data/pe_data_F.csv")

  0%|          | 0/1210 [00:00<?, ?it/s]

In [16]:
# Visualize the nodes, the first column of atom features
graphs[0].nodes(), graphs[0].ndata['atom_feats'][:,0]

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41], dtype=torch.int32),
 tensor([1., 1., 6., 1., 1., 6., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 7., 6., 7., 6., 6., 6., 6., 6., 8., 6., 6., 6., 6., 6.,
         6., 7., 6., 6., 6., 7.]))

In [22]:
graphs[0].edges(), graphs[0].edata

((tensor([1, 2, 1, 3, 1, 4, 1, 0], dtype=torch.int32),
  tensor([2, 1, 3, 1, 4, 1, 0, 1], dtype=torch.int32)),
 {'bond_feats': tensor([[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]])})

In [17]:
save_graphs("./DataGraphs/data_F_graph.bin", graphs)

In [18]:
load_graphs("./DataGraphs/data_F_graph.bin")

([Graph(num_nodes=42, num_edges=90,
        ndata_schemes={'atom_feats': Scheme(shape=(11,), dtype=torch.float32)}
        edata_schemes={'bond_feats': Scheme(shape=(5,), dtype=torch.int64)}),
  Graph(num_nodes=43, num_edges=92,
        ndata_schemes={'atom_feats': Scheme(shape=(11,), dtype=torch.float32)}
        edata_schemes={'bond_feats': Scheme(shape=(5,), dtype=torch.int64)}),
  Graph(num_nodes=49, num_edges=110,
        ndata_schemes={'atom_feats': Scheme(shape=(11,), dtype=torch.float32)}
        edata_schemes={'bond_feats': Scheme(shape=(5,), dtype=torch.int64)}),
  Graph(num_nodes=26, num_edges=56,
        ndata_schemes={'atom_feats': Scheme(shape=(11,), dtype=torch.float32)}
        edata_schemes={'bond_feats': Scheme(shape=(5,), dtype=torch.int64)}),
  Graph(num_nodes=22, num_edges=44,
        ndata_schemes={'atom_feats': Scheme(shape=(11,), dtype=torch.float32)}
        edata_schemes={'bond_feats': Scheme(shape=(5,), dtype=torch.int64)}),
  Graph(num_nodes=23, num_edges=50