In [None]:
%%capture
!pip install kora
!pip install dgl
!pip install dgllife
import dgl
import dgllife
import torch 
import networkx as nx
import dgl.function as fn
import kora.install.rdkit
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
from dgllife.utils import featurizers as fs 
torch.manual_seed(16)

In [None]:
import sys
! wget https://repo.anaconda.com/archive/Anaconda3-2020.02-Linux-x86_64.sh
! chmod +x Anaconda3-2020.02-Linux-x86_64.sh
! bash ./Anaconda3-2020.02-Linux-x86_64.sh -b -f -p /usr/local
sys.path.append('/usr/local/lib/python3.7/site-packages/')
! conda update -n base -c defaults conda -y
! conda config --add channels bioconda
! conda config --add channels conda-forge
!conda install -c conda-forge rdkit
import rdkit
from rdkit import Chem

In [None]:

def atom_is_in_ring_list_one_hot(atom, allowable_set=None, encode_unknown=False):
          list = [3,4,5]
          return [atom.IsInRing()]+[atom.IsInRingSize(i) for i in list]

def bond_is_in_ring_list_one_hot(bond, allowable_set=None, encode_unknown=False):
          list = [3,4,5]
          return [bond.IsInRing()]+[bond.IsInRingSize(i) for i in list]

class CanonicalAtomFeaturizer(fs.BaseAtomFeaturizer):
    def __init__(self, atom_data_field='h'):
        super(CanonicalAtomFeaturizer, self).__init__(
            featurizer_funcs={atom_data_field: fs.ConcatFeaturizer(
                [fs.atomic_number_one_hot,
                 fs.atom_degree_one_hot,
                 fs.atom_explicit_valence_one_hot,
                 fs.atom_implicit_valence_one_hot,
                 fs.atom_hybridization_one_hot,
                 fs.atom_formal_charge_one_hot,
                 fs.atom_num_radical_electrons_one_hot,
                 fs.atom_is_aromatic,
                 fs.atom_total_num_H_one_hot, 
                 fs.atom_is_in_ring,
                 fs.atom_chiral_tag_one_hot,
                 fs.atom_chirality_type_one_hot, 
                 fs.atom_mass,
                 fs.atom_is_chiral_center,
                 atom_is_in_ring_list_one_hot]
            )})

class CanonicalBondFeaturizer(fs.BaseBondFeaturizer):
  def __init__(self, bond_data_field='e', self_loop=False):
        super(CanonicalBondFeaturizer, self).__init__(
            featurizer_funcs={bond_data_field: fs.ConcatFeaturizer(
                [fs.bond_type_one_hot,
                 fs.bond_is_conjugated,
                 fs.bond_is_in_ring,
                 fs.bond_stereo_one_hot,
                 bond_is_in_ring_list_one_hot]
            )}, self_loop=self_loop)

node_featurizer = CanonicalAtomFeaturizer()
bond_featurizer = CanonicalBondFeaturizer()

class DMPNN_i4(torch.nn.Module):

    def __init__(self, input_dim_v=10, output_dim_v=10, input_dim_e=10, output_dim_e=10, act1="relu", act2="relu", act3="relu", bias_v=True):
        super().__init__()     
        self.input_dim_v = input_dim_v
        self.input_dim_e = input_dim_e
        self.output_dim_v = output_dim_v
        self.output_dim_e = output_dim_e
        self.act1= act1
        self.act2= act2
        self.act3= act3
        self.linear1=torch.nn.Linear(self.input_dim_v + self.input_dim_e, self.output_dim_e, bias=False)
        self.linear2=torch.nn.Linear(self.output_dim_e, self.output_dim_e, bias=False)
        self.linear3=torch.nn.Linear(self.output_dim_e + self.input_dim_v, self.output_dim_v, bias=False)
        torch.nn.init.xavier_normal_(self.linear1.weight, gain=torch.nn.init.calculate_gain(self.act1))
        torch.nn.init.xavier_normal_(self.linear2.weight, gain=torch.nn.init.calculate_gain(self.act2))
        torch.nn.init.xavier_normal_(self.linear3.weight, gain=torch.nn.init.calculate_gain(self.act3))

    # def graph_constructor_featurizer(self, molecule_smiles): 
    #   output=[]
    #   # Loop for a batches
    #   for i in range(len(molecule_smiles)):
    #     # Base (graph) featurizer
    #     molecular_graph=dgllife.utils.mol_to_bigraph(Chem.MolFromSmiles(molecule_smiles[i]))
    #     molecular_graph.ndata['initial_states']= node_featurizer(Chem.MolFromSmiles(molecule_smiles[i]))['h']
    #     molecular_graph.edata['features']= bond_featurizer(Chem.MolFromSmiles(molecule_smiles[i]))['e']
    #     #Initial_edge_state
    #     incidence = dgl.DGLGraph.incidence_matrix(molecular_graph, "out").to_dense()
    #     concatenation_features = torch.cat((torch.mm(incidence.T, molecular_graph.ndata['initial_states']), molecular_graph.edata['features']), 1)
    #     molecular_graph.edata['initial_states']=torch.nn.ReLU()(self.linear1(concatenation_features))
    #     molecular_graph.edata['states']= molecular_graph.edata['initial_states']
    #     output.append(molecular_graph)
    #   return output

    def initial_edge_states(self, edges):
      concatenation = torch.cat((edges.src['states'],edges.data['states']),1)
      act_1_eval =eval("torch.nn.functional."+self.act1)
      return {'initial_states' :  act_1_eval(self.linear1(concatenation))}

    def edge_updater(self, edges):
     output=torch.empty(0)
     src, dst, _ = edges.edges()
     len=torch.numel(src)
     for i in range(len):
       buffer=torch.zeros(edges.data['states'][0].shape)
       for j in range(len):
         if dst[j] == src[i] and dst[i] != src[j]: 
           buffer += edges.data['states'][j]           
       output=torch.cat((output, buffer))      
     output=output.view(len,-1)
     output=self.linear2(output) + edges.data['initial_states']
     return {'states': torch.nn.ReLU()(output)}

    def forward(self, molecular_graph):
      ### New alteration
      molecular_graph.apply_edges(self.initial_edge_states) 
      molecular_graph.edata['states']= molecular_graph.edata['initial_states']
      molecular_graph.ndata['initial_states']= molecular_graph.ndata['states']
      ###
      molecular_graph.apply_edges(self.edge_updater)
      # Nodes' state updater
      node_message=torch.empty(0)
      for i in range(molecular_graph.num_nodes()):
        buffer=torch.zeros(1, torch.numel(molecular_graph.edata['states'][0]))
        for j in range(torch.numel(molecular_graph.out_edges(i)[0])):
          buffer += molecular_graph.edata['states'][molecular_graph.edge_ids(molecular_graph.out_edges(i)[0][j].item(),molecular_graph.out_edges(i)[1][j].item())]
        node_message=torch.cat((node_message, buffer))
      node_message=node_message.view(molecular_graph.num_nodes(),-1)
      molecular_graph.ndata['states']=torch.nn.ReLU()(self.linear3(torch.cat((molecular_graph.ndata['initial_states'],node_message),dim=1)))
      molecular_represntation=torch.mean(molecular_graph.ndata['states'],dim=0)
      return molecular_graph, molecular_represntation

class GNN1(torch.nn.Module):

  def __init__(self):
    super().__init__()
    self.GNN1=DMPNN(158,40,16,20)
    self.GNN2=DMPNN(40,10,20,5)

  def forward(self,molecular_graph):
    outputgraph, _ =self.GNN1(molecular_graph)
    return self.GNN2(outputgraph)

class GNN2(torch.nn.Module):

  def __init__(self):
    super().__init__()
    self.GNN1=DMPNN(158,30,16,40)
    self.GNN2=DMPNN(30,10,40,10)

  def forward(self,molecular_graph):
    outputgraph, _ =self.GNN1(molecular_graph)
    return self.GNN2(outputgraph)

def graph_constructor_featurizer(molecule_smiles):
  output=[]
  # Loop for a batches
  for i in range(len(molecule_smiles)):
    # Base (graph) featurizer
    molecular_graph=dgllife.utils.mol_to_bigraph(Chem.MolFromSmiles(molecule_smiles[i]))
    molecular_graph.ndata['states']= node_featurizer(Chem.MolFromSmiles(molecule_smiles[i]))['h']
    molecular_graph.edata['states']= bond_featurizer(Chem.MolFromSmiles(molecule_smiles[i]))['e']
    output.append(molecular_graph)
  return output

g=graph_constructor_featurizer(['C1CCC1','C1CCC1'])
g[0]
model1=GNN1()
model2=GNN2()
print(model1(g[0]))
print(model2(g[1]))




(Graph(num_nodes=4, num_edges=8,
      ndata_schemes={'states': Scheme(shape=(10,), dtype=torch.float32), 'initial_states': Scheme(shape=(40,), dtype=torch.float32)}
      edata_schemes={'states': Scheme(shape=(5,), dtype=torch.float32), 'initial_states': Scheme(shape=(5,), dtype=torch.float32)}), tensor([3.6104, 0.0000, 0.0000, 0.0000, 0.7703, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000], grad_fn=<MeanBackward1>))
(Graph(num_nodes=4, num_edges=8,
      ndata_schemes={'states': Scheme(shape=(10,), dtype=torch.float32), 'initial_states': Scheme(shape=(30,), dtype=torch.float32)}
      edata_schemes={'states': Scheme(shape=(10,), dtype=torch.float32), 'initial_states': Scheme(shape=(10,), dtype=torch.float32)}), tensor([1.8925, 0.0000, 0.8106, 0.0000, 0.0000, 0.0000, 3.1122, 0.0000, 2.0153,
        2.1313], grad_fn=<MeanBackward1>))


## Older Modules

In [None]:

class DMPNN_i4_0(torch.nn.Module):

    def __init__(self, input_dim_v=10, input_dim_e=10, output_dim_e=10, output_dim_v=10, act1="relu", act2="relu", act3="relu", bias_v=True):
        super().__init__()     
        self.input_dim_v = input_dim_v
        self.input_dim_e = input_dim_e
        self.output_dim_v = output_dim_v
        self.output_dim_e = output_dim_e
        self.act1= act1
        self.act2= act2
        self.act3= act3
        self.linear1=torch.nn.Linear(self.input_dim_v + self.input_dim_e, self.output_dim_e, bias=False)
        self.linear2=torch.nn.Linear(self.output_dim_e, self.output_dim_e, bias=False)
        self.linear3=torch.nn.Linear(self.output_dim_e + self.input_dim_v, self.output_dim_v, bias=False)
        torch.nn.init.xavier_normal_(self.linear1.weight, gain=torch.nn.init.calculate_gain(self.act1))
        torch.nn.init.xavier_normal_(self.linear2.weight, gain=torch.nn.init.calculate_gain(self.act2))
        torch.nn.init.xavier_normal_(self.linear3.weight, gain=torch.nn.init.calculate_gain(self.act3))

    def initial_edge_states(self, edges):
      concatenation = torch.cat((edges.src['states'],edges.data['states']),1)
      act_1_eval =eval("torch.nn.functional."+self.act1)
      return {'initial_states' :  act_1_eval(self.linear1(concatenation))}

    def edge_updater(self, edges):
     output=torch.empty(0).to(device)
     src, dst, _ = edges.edges()
     len=torch.numel(src)
     for i in range(len):
       buffer=torch.zeros(edges.data['states'][0].shape).to(device)
       for j in range(len):
         if dst[j] == src[i] and dst[i] != src[j]: 
           buffer += edges.data['states'][j]           
       output=torch.cat((output, buffer))      
     output=output.view(len,-1)
     output=self.linear2(output) + edges.data['initial_states']
     return {'states': torch.nn.ReLU()(output)}

    def forward(self, molecular_graph):
      ### New alteration
      molecular_graph.apply_edges(self.initial_edge_states) 
      molecular_graph.edata['states']= molecular_graph.edata['initial_states']
      molecular_graph.ndata['initial_states']= molecular_graph.ndata['states']
      molecular_graph.apply_edges(self.edge_updater)
      # Nodes' state updater
      node_message=torch.empty(0).to(device)
      for i in range(molecular_graph.num_nodes()):
        buffer=torch.zeros(1, torch.numel(molecular_graph.edata['states'][0])).to(device)
        for j in range(torch.numel(molecular_graph.out_edges(i)[0])):
          buffer += molecular_graph.edata['states'][molecular_graph.edge_ids(molecular_graph.out_edges(i)[0][j].item(),molecular_graph.out_edges(i)[1][j].item())]
        node_message=torch.cat((node_message, buffer))
      node_message=node_message.view(molecular_graph.num_nodes(),-1)
      molecular_graph.ndata['states']=torch.nn.ReLU()(self.linear3(torch.cat((molecular_graph.ndata['initial_states'],node_message),dim=1)))
      return molecular_graph


In [None]:
%%capture
if torch.cuda.is_available():  
  device = "cuda:0" 
else:  
  device = "cpu" 

def atom_is_in_ring_list_one_hot(atom, allowable_set=None, encode_unknown=False):
          list = [3,4,5]
          return [atom.IsInRing()]+[atom.IsInRingSize(i) for i in list]

def bond_is_in_ring_list_one_hot(bond, allowable_set=None, encode_unknown=False):
          list = [3,4,5]
          return [bond.IsInRing()]+[bond.IsInRingSize(i) for i in list]

class CanonicalAtomFeaturizer(fs.BaseAtomFeaturizer):
    def __init__(self, atom_data_field='h'):
        super(CanonicalAtomFeaturizer, self).__init__(
            featurizer_funcs={atom_data_field: fs.ConcatFeaturizer(
                [fs.atomic_number_one_hot,
                 fs.atom_degree_one_hot,
                 fs.atom_explicit_valence_one_hot,
                 fs.atom_implicit_valence_one_hot,
                 fs.atom_hybridization_one_hot,
                 fs.atom_formal_charge_one_hot,
                 fs.atom_num_radical_electrons_one_hot,
                 fs.atom_is_aromatic,
                 fs.atom_total_num_H_one_hot, 
                 fs.atom_is_in_ring,
                 fs.atom_chiral_tag_one_hot,
                 fs.atom_chirality_type_one_hot, 
                 fs.atom_mass,
                 fs.atom_is_chiral_center,
                 atom_is_in_ring_list_one_hot]
            )})

class CanonicalBondFeaturizer(fs.BaseBondFeaturizer):
  def __init__(self, bond_data_field='e', self_loop=False):
        super(CanonicalBondFeaturizer, self).__init__(
            featurizer_funcs={bond_data_field: fs.ConcatFeaturizer(
                [fs.bond_type_one_hot,
                 fs.bond_is_conjugated,
                 fs.bond_is_in_ring,
                 fs.bond_stereo_one_hot,
                 bond_is_in_ring_list_one_hot]
            )}, self_loop=self_loop)

node_featurizer = CanonicalAtomFeaturizer()
bond_featurizer = CanonicalBondFeaturizer()