In [1]:
import pickle
import netdataio
import graph_conv_many_nuc_util
import torch
from rdkit import Chem
from rdkit.Chem import Draw
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from rdkit.Chem import PandasTools

In [2]:
message_data = pickle.load(open('../nmr_mpnn-master/datasets/data_13C.pickle', 'rb'))

In [3]:
seok_train_df = message_data['train_df']
seok_test_df = message_data['test_df']

frames = [seok_train_df, seok_test_df]
seok_df = pd.concat(frames)

In [4]:
new_target = [{k: v for k, v in sorted(d[0].items(), key=lambda item: item[0], reverse = False)} for d in seok_df['value']]

seok_df.drop('value', axis = 1, inplace = True)
seok_df['value'] = new_target

In [6]:
mol = seok_df.rdmol.tolist()[0]

In [7]:
def get_nos_coords(mol, conf_i):
    conformer = mol.GetConformers()[conf_i] #get conf_i'th conformation of mol
    coord_objs = [conformer.GetAtomPosition(i) for i in  range(mol.GetNumAtoms())] #return list of positions for each atom in molecule
    coords = np.array([(c.x, c.y, c.z) for c in coord_objs]) #turn positions into 3d coords
    atomic_nos = np.array([a.GetAtomicNum() for a in mol.GetAtoms()]).astype(int) #return atomic number for each atom
    return atomic_nos, coords #return 3cd coordinates and atomic numbers
def to_onehot(x, vals):
    return [x == v for v in vals]

HYBRIDIZATIONS = [Chem.HybridizationType.S, 
                  Chem.HybridizationType.SP, 
                  Chem.HybridizationType.SP2, 
                  Chem.HybridizationType.SP3, 
                  Chem.HybridizationType.SP3D, 
                  Chem.HybridizationType.SP3D2]

In [15]:
def feat_tensor_atom(mol, 
                     feat_atomicno = True, feat_pos=True, 
                     feat_atomicno_onehot=[1, 6, 7, 8, 9], 
                     feat_valence=True, aromatic=True, hybridization=True, 
                     partial_charge=True, formal_charge=True, r_covalent=True,
                     r_vanderwals=True, default_valence=True, rings=False, 
                     total_valence_onehot=False, 
                     conf_idx = 0):

    """
    Featurize a molecule on a per-atom basis
    feat_atomicno_onehot : list of atomic numbers

    Always assume using conf_idx unless otherwise passed

    Returns an (ATOM_N x feature) float32 tensor

    NOTE: Performs NO santization or cleanup of molecule, 
    assumes all molecules have sanitization calculated ahead
    of time. 

    """

    pt = Chem.GetPeriodicTable()
    mol = Chem.Mol(mol) # copy molecule

    atomic_nos, coords = get_nos_coords(mol, conf_idx) #returns tuple of (array of atomic numbers, array of 3d atom coords) of conf_idx'th conformation
    ATOM_N = len(atomic_nos)

    #Chem.SanitizeMol(mol, Chem.rdmolops.SanitizeFlags.SANITIZE_ALL, 
    #                 catchErrors=True)


    if partial_charge:
        Chem.rdPartialCharges.ComputeGasteigerCharges(mol)

    atom_features = []
      
    for i in range(mol.GetNumAtoms()): #for every atom in molecule
        a = mol.GetAtomWithIdx(i) #gets atom of index i
        atomic_num = int(atomic_nos[i])
        atom_feature = []

        if feat_atomicno: #yes
            atom_feature += [atomic_num]

        if feat_pos: #no
            atom_feature += coords[i].tolist()

        if feat_atomicno_onehot is not None : #yes
            atom_feature += to_onehot(atomic_num, feat_atomicno_onehot)
            
        if feat_valence: #yes
            atom_feature += [a.GetTotalValence()]
        if total_valence_onehot: #yes
            atom_feature +=  to_onehot(a.GetTotalValence(), range(1, 7))

        if aromatic: #yes
            atom_feature += [a.GetIsAromatic()]

        if hybridization: #yes
            atom_feature += to_onehot(a.GetHybridization(), HYBRIDIZATIONS)

        if partial_charge: #no
            gc = float(a.GetProp('_GasteigerCharge'))
            #assert np.isfinite(gc)
            if not np.isfinite(gc):
                gc = 0.0
            atom_feature += [gc]

        if formal_charge: #yes
            atom_feature += to_onehot(a.GetFormalCharge(), [-1, 0, 1])

        if r_covalent: #no
            atom_feature += [pt.GetRcovalent(atomic_num)] #radius of atom in covalent bond
        if r_vanderwals: #no
            atom_feature += [pt.GetRvdw(atomic_num)]
            

        if default_valence: #yes
            atom_feature += to_onehot(pt.GetDefaultValence(atomic_num), range(1, 7))
        
        if rings: #yes
            atom_feature += [a.IsInRingSize(r) for r in range(3, 8)]


        # electronegativities = {1:2.20, 6:2.55, 7:3.04, 8:3.44, 9:3.98, 15:2.19, 16:2.58, 17:3.16}
        # atom_feature += [electronegativities[atomic_num]]

        atom_features.append(atom_feature)

    z = [0]*len(atom_features[0])
    while len(atom_features) < 64:
        atom_features.append(z)

#     atom_features = np.array(atom_features)
#     m = atom_features.mean(axis = 0)
#     s = atom_features.std(axis = 0)
#     atom_features_normalized = (atom_features-m)/s

    #atom features is a list of lists; inner list represents one atom and contains atom features
    # print(torch.Tensor(atom_features).size())
    return torch.Tensor(atom_features)

In [16]:
default_atomicno = [1, 6, 7, 8, 9, 15, 16, 17] #filtering for H, C, N, O, F, Cl, etd

default_feat_vect_args = dict(feat_atomicno=True, feat_pos=False, feat_atomicno_onehot=default_atomicno, 
                              
                              feat_valence=True, aromatic=True, hybridization=True, 
                              partial_charge=False, formal_charge=True,  # WE SHOULD REALLY USE THIS 
                              r_covalent=False,
                              total_valence_onehot=True, 
                              
                              r_vanderwals=False, default_valence=True, rings=True)

In [18]:
torch.set_printoptions(profile="full")
feat_tensor_atom(mol, conf_idx=0, **default_feat_vect_args)

tensor([[6., 0., 1., 0., 0., 0., 0., 0., 0., 4., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
         0.],
        [6., 0., 1., 0., 0., 0., 0., 0., 0., 4., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
         0.],
        [6., 0., 1., 0., 0., 0., 0., 0., 0., 4., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
         0.],
        [6., 0., 1., 0., 0., 0., 0., 0., 0., 4., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
         0.],
        [6., 0., 1., 0., 0., 0., 0., 0., 0., 4., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
         0.],
        [6., 0., 1., 0., 0., 0., 0., 0., 0., 4., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 1., 0.

In [28]:
atomic_nos, coords = get_nos_coords(mol, 0)
pt = Chem.GetPeriodicTable()
atom_features = []
for i in range(mol.GetNumAtoms()): #for every atom in molecule
    atom_feature = []
    a = mol.GetAtomWithIdx(i) #gets atom of index i
    atomic_num = int(atomic_nos[i])
    atom_feature += [atomic_num]
    atom_feature += to_onehot(atomic_num, [1, 6, 7, 8, 9])
    atom_feature += [a.GetTotalValence()]
    atom_feature += to_onehot(a.GetTotalValence(), range(1, 7))
    atom_feature += [a.GetIsAromatic()]
    atom_feature += to_onehot(a.GetHybridization(), HYBRIDIZATIONS)
    atom_feature += to_onehot(a.GetFormalCharge(), [-1, 0, 1])
    atom_feature += to_onehot(pt.GetDefaultValence(atomic_num), range(1, 7))
    atom_feature += [a.IsInRingSize(r) for r in range(3, 8)]

In [29]:
atom_feature

[6,
 False,
 True,
 False,
 False,
 False,
 4,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 6,
 False,
 True,
 False,
 False,
 False,
 4,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 6,
 False,
 True,
 False,
 False,
 False,
 4,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 6,
 False,
 True,
 False,
 False,
 False,
 4,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,