### In this file, we will reconstruct our data into graphs

Instructions: You should probably have another separate notebook that creates the graph version of the dataset. Again, you should save the data, and for this make use to use the save_graph and load_graph functions of DGL.

Important Libraries:
ase
Structure/Geometry of a molecule

In [85]:
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
from ase.io import read, write
from dgllife.utils import smiles_to_bigraph, featurizers
from dgl.data.utils import save_graphs, load_graphs

In [86]:

# Features added from this paper (https://arxiv.org/pdf/1704.01212.pdf) Table 1
'''
According to table 1 of that paper, these are the features they use. We have already discovered the importance of
both acceptor/donator of electrons as well as partial magnetism, thus we will for sure want to use those similar features

Unfortunately, we could not get the acceptor/donor binary features in time, though we looked at the WeaveAtomFeaturizer to accomplish this

Atom type        H, C, N, O, F (one-hot)
Atomic number    Number of protons (integer)
Acceptor         Accepts electrons (binary)
Donor            Donates electrons (binary)
Aromatic         In an aromatic system (binary)
Hybridization    sp, sp2, sp3 (one-hot or null)
Number of Hydrogens (integer)
'''

# To get our ROMol and get numhacceptors/donors working, would use PandasTools.AddMoleculeColumnToFrame(esol_data, smilesCol='smiles')

def featurize_atoms(mol):
    
    def one_hot_to_value(one_hot_li):
        for i in range(len(one_hot_li)):
            if one_hot_li[i]:
                return i
        return -1

    feature_map = lambda atom: [
                                one_hot_to_value(featurizers.atom_type_one_hot(atom)), # One-hot to index-value of atom type
                                atom.GetAtomicNum(),
#                                 Chem.Descriptors.NumHAcceptors(atom),    # H-Bond acceptors
#                                 Chem.Descriptors.NumHDonors(atom),    # H-Bond donors
                                atom.GetIsAromatic(),
                                one_hot_to_value(featurizers.atom_hybridization_one_hot(atom)), # Hybridization mentioned above
                                atom.GetDegree(),
                                atom.GetTotalDegree(),
                                atom.GetExplicitValence(),
                                atom.GetImplicitValence(),
                                atom.GetTotalNumHs(),
                                atom.GetFormalCharge(),
                                atom.GetNumRadicalElectrons(),
                                atom.IsInRing(), 
                                atom.GetMass() * 0.01
                                ]
    
    feats = []
    AllChem.EmbedMolecule(mol)
    Chem.rdmolfiles.MolToXYZFile(mol, "3Dembedded.xyz")
    mol_ase = read("3Dembedded.xyz")
    atom_count = 0
    for atom in mol.GetAtoms():
        feats.append(feature_map(atom))
        feats.append(pos[atom_count][0])
        feats.append(pos[atom_count][1])
        feats.append(pos[atom_count][2])
        atom_count += 1
    return {'atom_feats': torch.tensor(feats).reshape(-1, len(feats[0])).float()}

In [79]:
def featurize_bonds(mol):
    feats = []
    bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                  Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
    for bond in mol.GetBonds():
        btype = bond_types.index(bond.GetBondType())
        is_conjugated = bond.GetIsConjugated()
        is_in_ring = bond.IsInRing()
        stereo_config = bond.GetStereo()
        direction = bond.GetBondDir()
        feats.extend([btype, btype])
        feats.extend([is_conjugated, is_conjugated])
        feats.extend([is_in_ring, is_in_ring])
        feats.extend([stereo_config, stereo_config])
        feats.extend([direction, direction])
    return {'bond_feats': torch.tensor(feats).reshape(-1, 5)}

### Generate Graph objects from our Dataset(s)

In [80]:
from tqdm.notebook import tqdm

In [81]:
def get_graphs(dataset_name):
    df = pd.read_csv(dataset_name)
    graphs =[]
    for smile in tqdm(df["SMILES"]):
        graphs.append(smiles_to_bigraph(smile,
                                       node_featurizer = featurize_atoms,
                                       edge_featurizer = featurize_bonds,
                                       explicit_hydrogens = True)
                     )
    return graphs

In [82]:
graphs = get_graphs("Data/pe_data_combined.csv")

  0%|          | 0/13480 [00:00<?, ?it/s]

In [51]:
# Visualize the nodes, the first column of atom features
graphs[0].nodes(), graphs[0].ndata['atom_feats'][:,0]

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41], dtype=torch.int32),
 tensor([1., 1., 6., 1., 1., 6., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 7., 6., 7., 6., 6., 6., 6., 6., 8., 6., 6., 6., 6., 6.,
         6., 7., 6., 6., 6., 7.]))

In [52]:
graphs[-1]

Graph(num_nodes=24, num_edges=46,
      ndata_schemes={'atom_feats': Scheme(shape=(11,), dtype=torch.float32)}
      edata_schemes={'bond_feats': Scheme(shape=(5,), dtype=torch.int64)})

In [54]:
graphs[0].edges()[0], graphs[0].edata['bond_feats'][:4]

(tensor([24, 31, 31, 30, 31, 32, 32, 29, 29, 22, 22, 23, 23, 37, 37, 27, 27, 28,
         28, 35, 35, 36, 36, 38, 38, 39, 39, 25, 25, 26, 36, 33, 33, 34, 34, 40,
         40, 41, 41,  2,  2,  5, 37, 32, 26, 35,  5, 33, 33, 23, 24,  1, 24,  4,
         29, 14, 27, 13, 27, 15, 28, 12, 28, 10, 38, 11, 39,  6, 25,  3, 26,  0,
         34, 19, 34,  7, 40,  9, 40, 18, 41, 20,  2,  8,  2, 17,  5, 21,  5, 16],
        dtype=torch.int32),
 tensor([[0, 0, 1, 1, 0],
         [0, 0, 0, 0, 0],
         [1, 1, 1, 1, 0],
         [0, 0, 0, 0, 0]]))

In [83]:
save_graphs("./DataGraphs/combined_graph_reduced_new_feats.bin", graphs)

In [57]:
gs = load_graphs("./DataGraphs/pe_data_combined_graph.bin")[0][0]
gs.number_of_nodes

<bound method DGLHeteroGraph.number_of_nodes of Graph(num_nodes=42, num_edges=90,
      ndata_schemes={'atom_feats': Scheme(shape=(11,), dtype=torch.float32)}
      edata_schemes={'bond_feats': Scheme(shape=(5,), dtype=torch.int64)})>