In [1]:
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from typing import List, Tuple, Union
import numpy as np
from itertools import zip_longest
from copy import deepcopy
from collections import Counter
import logging
from rdkit import Chem

import sys
sys.path.append("/home/chenlidong/polyAttn/models")
sys.path.append("/home/chenlidong/polyAttn/utils")
# import polygnn
import chem_utils
import pdb

# polygnn测试代码

In [2]:
hps ={
    "capacity" : 2,
    "activation": nn.functional.leaky_relu,
    "readout_dim": 128
}
model = polygnn.polygnn(32,7,True,hps)

In [3]:
node_features, edge_features, edge_indices = chem_utils.get_feature("[*]Oc1ccc(-c2ccc(C([*])=O)cc2)cc1")

In [4]:
data = Data(
    x = torch.tensor(node_features, dtype=torch.float),
    edge_index = torch.tensor(edge_indices, dtype=torch.long).T,
    edge_weight = torch.tensor(edge_features, dtype=torch.float),
)

In [5]:
train_loader = DataLoader(
            dataset=[data],
            batch_size=1,
            shuffle=True,
            pin_memory=True
        )

In [8]:
for batch_idx, batch in enumerate(train_loader):
    print(model(batch))

tensor([[0.4337, 0.3888, 0.3650, 0.3436, 0.3988, 0.3499, 0.3956, 0.4133, 0.3567,
         0.4352, 0.3738, 0.3076, 0.3832, 0.3693, 0.4158, 0.4409, 0.4127, 0.4084,
         0.4287, 0.4251, 0.3337, 0.4168, 0.4429, 0.3431, 0.4172, 0.3630, 0.4217,
         0.4064, 0.3803, 0.3957, 0.4111, 0.3503, 0.4047, 0.4475, 0.3495, 0.3449,
         0.3784, 0.3999, 0.3801, 0.3518, 0.4274, 0.3092, 0.3732, 0.3583, 0.3823,
         0.3584, 0.4099, 0.4279, 0.3385, 0.3373, 0.4120, 0.3693, 0.3263, 0.4335,
         0.3913, 0.4095, 0.3489, 0.3001, 0.4066, 0.3830, 0.3599, 0.3715, 0.4065,
         0.4032, 0.3953, 0.4147, 0.3725, 0.4038, 0.3840, 0.4478, 0.3282, 0.3981,
         0.3390, 0.3768, 0.4021, 0.4305, 0.4000, 0.3825, 0.3874, 0.4016, 0.3292,
         0.3710, 0.3279, 0.4633, 0.3882, 0.3455, 0.4211, 0.4184, 0.3750, 0.3709,
         0.3644, 0.3426, 0.4080, 0.4180, 0.3483, 0.3365, 0.3140, 0.2962, 0.3640,
         0.4033, 0.3791, 0.4103, 0.4069, 0.4254, 0.3285, 0.3350, 0.2872, 0.3729,
         0.3825, 0.3926, 0.3

# polymer-chemprob

In [3]:
import sys
sys.path.append("/home/chenlidong/polymer-chemprop-master/chemprop")
from features import mol2graph

In [556]:
smiles = "[*:1]c1ccc([*:2])c2nsnc12.[*:3]c1cc([*:4])nc(OC)c1N|0.25|0.75|<1-3:0.25:0.25<1-4:0.25:0.25<2-3:0.25:0.25<2-4:0.25:0.25<1-2:0.25:0.25<3-4:0.25:0.25<1-1:0.25:0.25<2-2:0.25:0.25<3-3:0.25:0.25<4-4:0.25:0.25"

In [557]:
batch_mol2graph = mol2graph([smiles])
f_atoms, f_bonds, w_atoms, w_bonds, a2b, b2a, b2revb, a_scope, b_scope, _  = batch_mol2graph.get_components()

In [558]:
a2a = batch_mol2graph.get_a2a()
b2b = batch_mol2graph.get_b2b()
# ea = torch.tensor(valid_df.Ea.to_list(),dtype=torch.float32)
# ip = torch.tensor(valid_df.IP.to_list(),dtype=torch.float32)

In [559]:
dataset = []

for idx,(a_scope_i,b_scope_i) in enumerate(zip(a_scope,b_scope)):
    edge_index = [[],[]]
    b2a_i = b2a[b_scope_i[0]:b_scope_i[0]+b_scope_i[1]]
    b2revb_i = b2revb[b_scope_i[0]:b_scope_i[0]+b_scope_i[1]]

    for i in range(b2a_i.size(0)):
        edge_index[0].append(b2a_i[i].item()-a_scope_i[0])
        edge_index[1].append(b2a_i[b2revb_i[i]-b_scope_i[0]]-a_scope_i[0])

    data = Data(
        x = f_atoms[a_scope_i[0]:a_scope_i[0]+a_scope_i[1]],
        edge_index=torch.tensor(edge_index,dtype=torch.long),
        edge_attr = f_bonds[b_scope_i[0]:b_scope_i[0]+b_scope_i[1]],
        w_atoms = w_atoms[a_scope_i[0]:a_scope_i[0]+a_scope_i[1]],
        w_bonds = w_bonds[b_scope_i[0]:b_scope_i[0]+b_scope_i[1]],
        # a2b = a2b[a_scope_i[0]:a_scope_i[0]+a_scope_i[1]],
        # b2a = b2a[b_scope_i[0]:b_scope_i[0]+b_scope_i[1]],
        b2revb = b2revb[b_scope_i[0]:b_scope_i[0]+b_scope_i[1]]-1,
        # a2a = a2a[a_scope_i[0]:a_scope_i[0]+a_scope_i[1]],
        # b2b = b2b[b_scope_i[0]:b_scope_i[0]+b_scope_i[1]],
        # ea = ea[idx],
        # ip = ip[idx]
    )
    dataset.append(data)

In [560]:
loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [693]:
from typing import List, Union, Tuple
from functools import reduce
from rdkit import Chem
from args import TrainArgs
from features import BatchMolGraph, get_atom_fdim, get_bond_fdim
from nn_utils import get_activation_function, index_select_ND
import torch.nn.init as init

class MPNEncoder(nn.Module):
    """An :class:`MPNEncoder` is a message passing neural network for encoding a molecule."""

    def __init__(self, args: TrainArgs, atom_fdim: int, bond_fdim: int):
        """
        :param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments.
        :param atom_fdim: Atom feature vector dimension.
        :param bond_fdim: Bond feature vector dimension.
        """
        super(MPNEncoder, self).__init__()
        self.atom_fdim = atom_fdim                  # atom feature len
        self.bond_fdim = bond_fdim                  # bond feature len
        self.atom_messages = args.atom_messages     # 是否以原子为中心传递
        self.hidden_size = args.hidden_size         # 300
        self.depth = args.depth                     # 消息传递的步数
        self.dropout = args.dropout
        self.undirected = args.undirected           # False
        self.device = torch.device('cpu')
        self.aggregation = args.aggregation         # mean
        self.aggregation_norm = args.aggregation_norm   # 100
        self.atom_messages = True
        self.directed = not self.atom_messages
        # Dropout
        self.dropout_layer = nn.Dropout(p=self.dropout)

        # Activation
        self.act_func = get_activation_function(args.activation)    # Relu

        # Cached zeros
        self.cached_zero_vector = nn.Parameter(torch.zeros(self.hidden_size), requires_grad=False)

        # Input
        input_dim = self.atom_fdim if self.atom_messages else self.bond_fdim
        self.W_i = nn.Linear(input_dim, self.hidden_size)

        if self.atom_messages:
            w_h_input_size = self.hidden_size + self.bond_fdim
        else:
            w_h_input_size = self.hidden_size

        # Shared weight matrix across depths (default)
        self.W_h = nn.Linear(w_h_input_size, self.hidden_size)

        self.W_o = nn.Linear(self.atom_fdim + self.hidden_size, self.hidden_size)

        init.constant_(self.W_i.weight, 0.01)
        init.constant_(self.W_i.bias, 1)
        init.constant_(self.W_h.weight, 0.01)
        init.constant_(self.W_h.bias, 1)
        init.constant_(self.W_o.weight, 0.01)
        init.constant_(self.W_o.bias, 1)

    def forward(self,
                mol_graph: BatchMolGraph,
                atom_descriptors_batch: List[np.ndarray] = None) -> torch.FloatTensor:
        """
        Encodes a batch of molecular graphs.

        :param mol_graph: A :class:`~chemprop.features.featurization.BatchMolGraph` representing
                          a batch of molecular graphs.
        :param atom_descriptors_batch: A list of numpy arrays containing additional atomic descriptors
        :return: A PyTorch tensor of shape :code:`(num_molecules, hidden_size)` containing the encoding of each molecule.
        """
        f_atoms, f_bonds, w_atoms, w_bonds, a2b, b2a, b2revb, \
        a_scope, b_scope, degree_of_polym = mol_graph.get_components(atom_messages=False)

        f_atoms, f_bonds, w_atoms, w_bonds, a2b, b2a, b2revb = f_atoms.to(self.device), f_bonds.to(self.device), \
                                                               w_atoms.to(self.device), w_bonds.to(self.device), \
                                                               a2b.to(self.device), b2a.to(self.device), \
                                                               b2revb.to(self.device)

        if self.atom_messages:
            a2a = mol_graph.get_a2a().to(self.device)


        # Input
        if self.atom_messages:
            input = self.W_i(f_atoms)                                   # [num_atoms , hidden_size]
        else:
            input = self.W_i(f_bonds)                                   # [num_bonds , hidden_size]
        message = self.act_func(input)                                  # [num_bonds / num_atoms , hidden_size]
        
        
        # Message passing                                               
        for depth in range(self.depth - 1):
            if self.undirected:
                message = (message + message[b2revb]) / 2

            if self.atom_messages:
                nei_a_message = index_select_ND(message, a2a)           # [num_atoms , max_num_bonds , hidden_size] 某一原子，其周围所有原子的特征
                nei_f_bonds = index_select_ND(f_bonds, a2b)             # [num_atoms , max_num_bonds , bond_fdim]
                nei_message = torch.cat((nei_a_message, nei_f_bonds), dim=2)  # [num_atoms , max_num_bonds , hidden + bond_fdim]
                message = nei_message.sum(dim=1)                        # [num_atoms , hidden + bond_fdim]
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(message, a2b)           # [num_atoms , max_num_bonds , hidden]
                nei_a_weight = index_select_ND(w_bonds, a2b)            # [num_atoms , max_num_bonds]
                # weight nei_a_message based on edge weights
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1) * weight(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = dot(nei_a_message,nei_a_weight)      rev_message
                nei_a_message = nei_a_message * nei_a_weight[..., None]  # [num_atoms , max_num_bonds , hidden]
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden]
                # print(rev_message.shape)
                # print(nei_a_weight[..., None].shape)
                message = a_message[b2a] - rev_message * w_bonds[..., None]       # [num_bonds , hidden]
            print(message)
            message = self.W_h(message)
            message = self.act_func(input + message)                    # [num_bonds/num_atoms , hidden]  skip connection
            # message = self.dropout_layer(message)                       # [num_bonds/num_atoms , hidden]
            # print(message)
            
            break
        a2x = a2a if self.atom_messages else a2b
        nei_a_message = index_select_ND(message, a2x)                   # [num_bonds/num_atoms , max_num_bonds , hidden]
        nei_a_weight = index_select_ND(w_bonds, a2x)                    # [num_bonds/num_atoms , max_num_bonds]
        # # weight messages
        nei_a_message = nei_a_message * nei_a_weight[..., None]         # [num_bonds/num_atoms , max_num_bonds , hidden]
        a_message = nei_a_message.sum(dim=1)                            # [num_bonds/num_atoms , hidden]
        
        a_input = torch.cat([f_atoms, a_message], dim=1)                # [num_bonds/num_atoms , hidden + f_atom]
        # print(a_input)
        # atom_hiddens = self.act_func(self.W_o(a_input))                 # [num_bonds/num_atoms , hidden]
        # atom_hiddens = self.dropout_layer(atom_hiddens)                 # [num_bonds/num_atoms , hidden]
        # print(a_input)

        # # Readout
        # mol_vecs = []
        # for i, (a_start, a_size) in enumerate(a_scope):
        #     if a_size == 0:
        #         mol_vecs.append(self.cached_zero_vector)
        #     else:
        #         cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)
        #         mol_vec = cur_hiddens  # (num_atoms, hidden_size)
        #         w_atom_vec = w_atoms.narrow(0, a_start, a_size)
        #         # if input are polymers, weight atoms from each repeating unit according to specified monomer fractions
        #         # weight h by atom weights (weights are all 1 for non-polymer input)
        #         mol_vec = w_atom_vec[..., None] * mol_vec
        #         # weight each atoms at readout
        #         if self.aggregation == 'mean':
        #             mol_vec = mol_vec.sum(dim=0) / w_atom_vec.sum(dim=0)  # if not --polymer, w_atom_vec.sum == a_size
        #         elif self.aggregation == 'sum':
        #             mol_vec = mol_vec.sum(dim=0)
        #         elif self.aggregation == 'norm':
        #             mol_vec = mol_vec.sum(dim=0) / self.aggregation_norm

        #         # if input are polymers, multiply mol vectors by degree of polymerization
        #         # if not --polymer, Xn is 1
        #         mol_vec = degree_of_polym[i] * mol_vec

        #         mol_vecs.append(mol_vec)

        # mol_vecs = torch.stack(mol_vecs, dim=0)  # (num_molecules, hidden_size)

        # return mol_vecs  # num_molecules x hidden

In [694]:
from torch_geometric import nn as pnn
from typing import List, Union, Tuple
from functools import reduce
from rdkit import Chem
from args import TrainArgs
from features import BatchMolGraph, get_atom_fdim, get_bond_fdim
from nn_utils import get_activation_function, index_select_ND
class MPNEncoder_PyG(pnn.MessagePassing):
    """An :class:`MPNEncoder` is a message passing neural network for encoding a molecule."""

    def __init__(self, args: TrainArgs, atom_fdim: int, bond_fdim: int):
        """
        :param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments.
        :param atom_fdim: Atom feature vector dimension.
        :param bond_fdim: Bond feature vector dimension.
        """
        super(MPNEncoder_PyG, self).__init__(aggr="add")
        self.atom_fdim = atom_fdim                  # atom feature len
        self.bond_fdim = bond_fdim                  # bond feature len
        self.atom_messages = args.atom_messages     # 是否以原子为中心传递
        self.hidden_size = args.hidden_size         # 300
        self.depth = args.depth                     # 消息传递的步数
        self.dropout = args.dropout
        self.device = torch.device('cpu')
        self.aggregation = args.aggregation         # mean
        self.aggregation_norm = args.aggregation_norm   # 100
        self.atom_messages = True
        self.directed = not self.atom_messages

        # Dropout
        self.dropout_layer = nn.Dropout(p=self.dropout)

        # Activation
        self.act_func = get_activation_function(args.activation)    # Relu

        # Cached zeros
        self.cached_zero_vector = nn.Parameter(torch.zeros(self.hidden_size), requires_grad=False)

        # Input
        input_dim = self.atom_fdim if self.atom_messages else self.bond_fdim
        self.W_i = nn.Linear(input_dim, self.hidden_size)

        if self.atom_messages:
            w_h_input_size = self.hidden_size + self.bond_fdim
        else:
            w_h_input_size = self.hidden_size

        # Shared weight matrix across depths (default)
        self.W_h = nn.Linear(w_h_input_size, self.hidden_size)
        self.W_o = nn.Linear(self.atom_fdim + self.hidden_size, self.hidden_size)


        init.constant_(self.W_i.weight, 0.01)
        init.constant_(self.W_i.bias, 1)
        init.constant_(self.W_h.weight, 0.01)
        init.constant_(self.W_h.bias, 1)
        init.constant_(self.W_o.weight, 0.01)
        init.constant_(self.W_o.bias, 1)

        
    def forward(self, x, edge_index, edge_attr,b2revb,w_atoms,w_bonds, batch):
        a_message = x
        b_message = edge_attr
        if not self.directed:
            input = self.W_i(a_message)
            a_message = self.act_func(input)              # [num_atoms , hidden_size]
        else:
            input = self.W_i(b_message)
            b_message = self.act_func(input)              # [num_bonds , hidden_size]

        for depth in range(self.depth - 1):
            message = self.propagate(
                edge_index      =   edge_index,\
                x               =   a_message,\
                edge_attr       =   b_message,\
                b2revb          =   b2revb,\
                w_atoms         =   w_atoms,\
                w_bonds         =   w_bonds,\
                skip_connection =   False,\
                first           =   x)
            print(message)
            if self.directed:
                b_message = self.act_func(self.W_h(message)+input)
            else:
                a_message = self.act_func(self.W_h(message)+input)
                # print(a_message)
            
            break

        return self.propagate(
                edge_index      =   edge_index,\
                x               =   a_message,\
                edge_attr       =   b_message,\
                b2revb          =   b2revb,\
                w_atoms         =   w_atoms,\
                w_bonds         =   w_bonds,\
                skip_connection =   True,\
                first           =   x)

    def update(self, aggr_out,edge_index,edge_attr,b2revb,w_bonds,skip_connection,first):
        if not skip_connection:
            if self.directed:
                message = aggr_out[edge_index[0]] - edge_attr[b2revb] * w_bonds[...,None]
            else:
                message = aggr_out
        else:
            message = torch.cat([first, aggr_out], dim=1)
        return message

    def message(self, x_i, x_j, edge_index,edge_attr,b2revb,w_bonds,skip_connection,first):
        # x_i x_j   从i节点到j节点的边是edge_index
        if not skip_connection:
            if self.directed:
                nei_a_message = edge_attr * w_bonds[..., None]
            else:
                nei_a_message = torch.cat((x_j,edge_attr),dim=1)
        else:
            if self.directed:
                nei_a_message = edge_attr * w_bonds[..., None]
            else:
                nei_a_message = x_j * w_bonds[edge_index[0]][..., None]
        return nei_a_message

In [695]:
model_pyg = MPNEncoder_PyG(args=TrainArgs, atom_fdim=get_atom_fdim(),bond_fdim=get_bond_fdim())
model = MPNEncoder(args=TrainArgs, atom_fdim=get_atom_fdim(),bond_fdim=get_bond_fdim())

In [696]:
for batch in loader:
    # print(batch.x)
    model(batch_mol2graph)
    model_pyg(batch.x,batch.edge_index,batch.edge_attr,batch.b2revb,batch.w_atoms,batch.w_bonds,batch.batch)
    break

tensor([[6.4272, 6.4272, 6.4272,  ..., 0.0000, 0.0000, 0.0000],
        [2.1424, 2.1424, 2.1424,  ..., 0.0000, 0.0000, 0.0000],
        [2.1424, 2.1424, 2.1424,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [1.0616, 1.0616, 1.0616,  ..., 0.0000, 0.0000, 0.0000],
        [3.2038, 3.2038, 3.2038,  ..., 0.0000, 0.0000, 0.0000],
        [1.0712, 1.0712, 1.0712,  ..., 0.0000, 0.0000, 0.0000]],
       grad_fn=<ScatterAddBackward0>)


In [689]:
for batch in loader:
    print(f"\
    x  {batch.x.shape}\n\
    edge_index  {batch.edge_index.shape}\n\
    edge_attr  {batch.edge_attr.shape}\n\
    w_atoms  {batch.w_atoms.shape}\n\
    w_bonds  {batch.w_bonds.shape}\n\
    batch  {batch.batch.shape}")

    x  torch.Size([18, 133])
    edge_index  torch.Size([2, 54])
    edge_attr  torch.Size([54, 147])
    w_atoms  torch.Size([18])
    w_bonds  torch.Size([54])
    batch  torch.Size([18])


In [245]:
batch.edge_index

tensor([[ 0,  1,  0,  6,  1,  2,  2,  3,  2,  4,  4,  5,  5,  6,  6,  7,  8,  9,
          8, 13,  9, 10, 10, 11, 11, 12, 12, 13,  0,  8,  0, 12,  4,  8,  4, 12,
          0,  4,  8, 12,  0,  4,  8, 12],
        [ 1,  0,  6,  0,  2,  1,  3,  2,  4,  2,  5,  4,  6,  5,  7,  6,  9,  8,
         13,  8, 10,  9, 11, 10, 12, 11, 13, 12,  8,  0, 12,  0,  8,  4, 12,  4,
          4,  0, 12,  8,  0,  4,  8, 12]])

In [608]:
a = torch.tensor([
        [1.0000e+00, 1.9700e+02, 9.0800e+02],
        [2.0000e+00, 4.4000e+01, 2.0000e+02],
        [3.0000e+00, 3.9000e+01, 2.1300e+02],
        [4.0000e+00, 2.0400e+02, 9.1400e+02],
        [5.0000e+00, 6.1000e+01, 2.6800e+02],
        [6.0000e+00, 3.3000e+01, 6.8000e+01],
        [7.0000e+00, 2.8000e+01, 6.4000e+01],
        [8.0000e+00, 2.8000e+01, 6.8000e+01],
        [9.0000e+00, 7.3000e+01, 2.5900e+02],
        [1.0000e+01, 2.4300e+02, 1.0070e+03],
        [1.1000e+01, 1.0600e+02, 4.7300e+02],
        [1.2000e+01, 2.2700e+02, 9.7000e+02],
        [1.3000e+01, 9.6000e+01, 3.3700e+02],
        [1.4000e+01, 9.8000e+01, 2.1800e+02],
        [1.5000e+01, 6.0000e+01, 1.3000e+02],
        [1.6000e+01, 3.0000e+01, 6.7000e+01],
        [1.7000e+01, 1.1700e+02, 3.6300e+02],
        [1.8000e+01, 4.2000e+01, 9.4000e+01]])
b = torch.tensor([
    [1.0000e+00, 1.9700e+02, 9.0800e+02],
        [2.0000e+00, 4.4000e+01, 2.0000e+02],
        [3.0000e+00, 3.9000e+01, 2.1300e+02],
        [4.0000e+00, 2.0400e+02, 9.1400e+02],
        [5.0000e+00, 6.1000e+01, 2.6800e+02],
        [6.0000e+00, 3.3000e+01, 6.8000e+01],
        [7.0000e+00, 2.8000e+01, 6.4000e+01],
        [8.0000e+00, 2.8000e+01, 6.8000e+01],
        [9.0000e+00, 7.3000e+01, 2.5900e+02],
        [1.0000e+01, 2.4300e+02, 1.0070e+03],
        [1.1000e+01, 1.0600e+02, 4.7300e+02],
        [1.2000e+01, 2.2700e+02, 9.7000e+02],
        [1.3000e+01, 9.6000e+01, 3.3700e+02],
        [1.4000e+01, 9.8000e+01, 2.1800e+02],
        [1.5000e+01, 6.0000e+01, 1.3000e+02],
        [1.6000e+01, 3.0000e+01, 6.7000e+01],
        [1.7000e+01, 1.1700e+02, 3.6300e+02],
        [1.8000e+01, 4.2000e+01, 9.4000e+01]
])
a == b

tensor([[True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True]])