In [1]:
import logging
from rdkit import Chem
from stellargraph import StellarGraph

import networkx as nx
import numpy as np
import six
import pandas as pd


In [2]:
def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception(
                "input {0} not in allowable set{1}:".format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def get_atom_features(atom):
    return  1*(np.array(one_of_k_encoding_unk(atom.GetSymbol(),
                                            ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl',
                                            'Br', 'Mg', 'Na',
                                            'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K',
                                            'Tl', 'Yb',
                                            'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti',
                                            'Zn', 'H',  
                                            'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In',
                                            'Mn','Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
                        
                        one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) +
                        
                        one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
                        
                        one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5]) + 
                        
                        [atom.GetIsAromatic()]))

def get_bond_features(bond):
    
        bt = bond.GetBondType()
        return 1*(np.array([bt == Chem.rdchem.BondType.SINGLE,
                         bt == Chem.rdchem.BondType.DOUBLE,
                         bt == Chem.rdchem.BondType.TRIPLE,
                         bt == Chem.rdchem.BondType.AROMATIC,
                         bond.GetIsConjugated(),
                         bond.IsInRing()]))
    
def smirk_to_stellarGraph(reactans):
    
    no_of_atoms = reactans.GetNumAtoms()
    source_list = []
    target_list = []
    node_feature_matrix = []
    edge_feature_matrix = []
    
    for i in range(no_of_atoms):
        
        atom = reactans.GetAtomWithIdx(i)
        node_feature_matrix.append(get_atom_features(atom))
        
        for neighbour in atom.GetNeighbors():
        
            target = neighbour.GetIdx()
            source_list.append(i)
            target_list.append(target)
            bond = reactans.GetBondBetweenAtoms(i, target)
            edge_feature_matrix.append(get_bond_features(bond))
            
    indirected_edges = pd.DataFrame({"source": source_list, "target": target_list})
    
    node_feature_matrix = np.array(node_feature_matrix)
    edge_feature_matrix = np.array(edge_feature_matrix)
    
    node_feature_matrix = pd.DataFrame(node_feature_matrix)
    edge_feature_matrix =  pd.DataFrame(edge_feature_matrix)
    
    indirected_features_edges = pd.concat([indirected_edges, edge_feature_matrix], axis=1)
    molecularGraph = StellarGraph(node_feature_matrix, indirected_features_edges)
    print(molecularGraph.info())

In [4]:
smile = "CC1CCC[CH:20]1.C1CC[CH:10]C1>>CC1CCC[CH:20]1[CH:10]2CCCC2 10-10,20;20-10,20"
chunks = smile.split(">>")

reactans = Chem.MolFromSmiles(chunks[0])
smirk_to_stellarGraph(reactans)

StellarGraph: Undirected multigraph
 Nodes: 11, Edges: 22

 Node types:
  default: [11]
    Features: float32 vector, length 62
    Edge types: default-default->default

 Edge types:
    default-default->default: [22]
        Weights: all 1 (default)
        Features: float32 vector, length 6
