In [1]:
import torch
import math
import torch.nn as nn

In [None]:
Attns2v

In [61]:
class Set2Vec(torch.nn.Module):
    """
    S2V readout function.
    """
    def __init__(self, node_features : int, hidden_node_features : int,
                 lstm_computations : int, memory_size : int,
                 constants : namedtuple) -> None:

        super().__init__()

        self.constants         = constants
        self.lstm_computations = lstm_computations
        self.memory_size       = memory_size

        self.embedding_matrix = torch.nn.Linear(
            in_features=node_features + hidden_node_features,
            out_features=self.memory_size,
            bias=True
        )

        self.lstm = torch.nn.LSTMCell(
            input_size=self.memory_size,
            hidden_size=self.memory_size,
            bias=True
        )

    def forward(self, hidden_output_nodes : torch.Tensor, input_nodes : torch.Tensor,
                node_mask : torch.Tensor) -> torch.Tensor:
        """
        Defines forward pass.
        """
        Softmax      = torch.nn.Softmax(dim=1)

        batch_size   = input_nodes.shape[0]
        energy_mask  = torch.bitwise_not(node_mask).float() * 0.3
        lstm_input   = torch.zeros(batch_size, self.memory_size, device=self.constants.device)
        cat          = torch.cat((hidden_output_nodes, input_nodes), dim=2)
        memory       = self.embedding_matrix(cat)
        hidden_state = torch.zeros(batch_size, self.memory_size, device=self.constants.device)
        cell_state   = torch.zeros(batch_size, self.memory_size, device=self.constants.device)

        for _ in range(self.lstm_computations):
            query, cell_state = self.lstm(lstm_input, (hidden_state, cell_state))

            # dot product query x memory
            energies  = (query.view(batch_size, 1, self.memory_size) * memory).sum(dim=-1)
            attention = Softmax(energies + energy_mask)
            read      = (attention.unsqueeze(-1) * memory).sum(dim=1)

            hidden_state = query
            lstm_input   = read

        cat = torch.cat((query, read), dim=1)
        return cat

In [65]:
message_size = 100
hidden_node_features = 100

class SummationMpnn(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_node_features = 100 #size from constant file
        self.edge_features = 4
        self.message_size = 100
        self.message_passes = 3
    def forward(self,nodes:tuple,edges:torch.tensor)->None:
        adjacency = torch.sum(edges, dim=3)
        print("adj is ",adjacency.shape)
        (edge_batch_batch_idc,
        edge_batch_node_idc,
        edge_batch_nghb_idc) = adjacency.nonzero(as_tuple=True)
        
        (node_batch_batch_idc,
        node_batch_node_idc) = adjacency.sum(-1).nonzero(as_tuple=True)
        print("nodes various ares",(node_batch_batch_idc,node_batch_node_idc))

        same_batch = node_batch_batch_idc.view(-1,1)== edge_batch_batch_idc
        same_node  = node_batch_node_idc.view(-1, 1) == edge_batch_node_idc
        print("same batch,node",same_batch,   same_node)
        message_summation_matrix = (same_batch * same_node).float()
        print("message ",message_summation_matrix)
        edge_batch_edges = edges[edge_batch_batch_idc, edge_batch_node_idc, edge_batch_nghb_idc, :]
        print("edge_batch_edges ",edge_batch_edges.shape,'hello ',edges.shape)

        hidden_nodes = torch.zeros(nodes.shape[0],
                                   nodes.shape[1],
                                   self.hidden_node_features,
                                   device='cpu')
        hidden_nodes[:nodes.shape[0], :nodes.shape[1], :nodes.shape[2]] = nodes.clone()#padding upto 13 nodes,with features to 100....we have only 9 features here
        node_batch_nodes = hidden_nodes[node_batch_batch_idc, node_batch_node_idc, :]#picking out the same batch

        # print("yoyo ",node_batch_nodes.shape)
        # print("yoyo2",hidden_nodes.shape)

        for _ in range(self.message_passes):
            edge_batch_nodes = hidden_nodes[edge_batch_batch_idc, edge_batch_node_idc, :]#getting hi  13*13*100

            edge_batch_nghbs = hidden_nodes[edge_batch_batch_idc, edge_batch_nghb_idc, :]#getting neighs(hj)  13*13*1#why its 1
            print("in message pass ",edge_batch_nghbs.shape, "nodes ",edge_batch_nodes.shape)

            message_terms    = self.message_terms(edge_batch_nodes,  
                                                  edge_batch_nghbs,
                                                  edge_batch_edges)

            if len(message_terms.size()) == 1:  # if a single graph in batch
                message_terms = message_terms.unsqueeze(0)

            # the summation in eq. 1 of the NMPQC paper happens here

            messages = torch.matmul(message_summation_matrix.cpu(), message_terms.cpu())

            node_batch_nodes = self.update(node_batch_nodes, messages)
            hidden_nodes[node_batch_batch_idc, node_batch_node_idc, :] = node_batch_nodes.clone()

        node_mask = adjacency.sum(-1) != 0
        output    = self.readout(hidden_nodes, nodes, node_mask)
        return hidden_nodes
        #return output


In [71]:
import math
network = Attns2v(constants)
network(nodes,edges)

hello  4 [250, 250, 250, 250] 10000 0.0
100 100 3 100
adj is  torch.Size([1, 10, 10])
nodes various ares (tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
same batch,node tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, Tr

RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x109 and 200x100)

In [67]:
class Attns2v(SummationMpnn):
    """
    The "message neural network" model.
    """
    def __init__(self,constants):
        super().__init__()
        message_size = 100
        hidden_node_features = 100
        self.constants       = constants
        message_weights      = torch.Tensor(message_size,
                                            hidden_node_features,
                                           4)
        if False:#== "cuda":
            message_weights = message_weights.to("cuda", non_blocking=True)

        self.message_weights = torch.nn.Parameter(message_weights)

        self.gru             = torch.nn.GRUCell(
            input_size=message_size,
            hidden_size = hidden_node_features,
            bias=True
        )
        #4 edge features
        self.mlp1 = MLP(
            in_features=100,
            hidden_layer_sizes=[250] * 1,
            out_features=100,
            dropout_p=0.2
        )
        self.mlp2 = MLP(4,[250]*4,4,dropout_p=0)
        print("hello ",constants.n_edge_features,
            [constants.enn_hidden_dim] * constants.enn_depth,
            constants.hidden_node_features * constants.message_size,
            constants.enn_dropout_p)

        self.enn = MLP(
            in_features= constants.n_edge_features,
            hidden_layer_sizes=[constants.enn_hidden_dim] * constants.enn_depth,
            out_features=constants.hidden_node_features * constants.message_size,
            dropout_p=constants.enn_dropout_p
        )
        print(self.constants.n_node_features,
            self.constants.hidden_node_features,
            self.constants.s2v_lstm_computations,self.constants.s2v_memory_size)#self.constants)
            
        self.s2v = Set2Vec(node_features=self.constants.n_node_features,
            hidden_node_features=self.constants.hidden_node_features,
            lstm_computations=self.constants.s2v_lstm_computations,
            memory_size=self.constants.s2v_memory_size,constants=self.constants)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        stdev = 1.0 / math.sqrt(self.message_weights.size(1))
        self.message_weights.data.uniform_(-stdev, stdev)

    # def message_terms(self, nodes,node_neighbours,edges):
    #     edges_v = edges.view(-1,1,1)#it gives value for edges.....type of edge..size->(6,4)->in molecule with 6 edges kind
    #     neighs = edges_v*node_neighbours.view(-1,1,1)#multiplying each by this number
    #     m1,b1 = [],[]
    #     for i in range(4):
    #         m1.append(edges[:,i,:]*self.mlp1[i](neighs[:,i,:]))
    #         b1.append(edges[:,i,:]*self.mlp2[i](neighs[:,i,:]))
    #     m1 = sum(m1)
    #     b1 = sum(b1)
    #     a = self.Softmax(b1)#check this Neigh(j) not understabbale
    #     return output
    def message_terms(self, nodes : torch.Tensor, node_neighbours : torch.Tensor,
                      edges : torch.Tensor) -> torch.Tensor:
                      
        enn_output = self.enn(edges) #check this out hows its working
        matrices   = enn_output.view(-1,
                                     self.constants.message_size,
                                     self.constants.hidden_node_features)
        print("sizes are such as  ",matrices.shape,node_neighbours.unsqueeze(-1).shape)
        msg_terms  = torch.matmul(matrices,
                                  node_neighbours.unsqueeze(-1)).squeeze(-1)
        return msg_terms

    
    def update(self, nodes : torch.Tensor, messages : torch.Tensor) -> torch.Tensor:
        return self.gru(messages, nodes)

    def readout(self, hidden_nodes : torch.Tensor, input_nodes : torch.Tensor,
                node_mask : torch.Tensor) -> torch.Tensor:
        graph_embeddings = self.s2v(hidden_nodes, input_nodes, node_mask)
        output           = self.APDReadout(hidden_nodes, graph_embeddings)
        return output

    
class MLP(torch.nn.Module):
    def __init__(self, in_features : int, hidden_layer_sizes : list, out_features : int,
                 dropout_p : float) -> None:
        super().__init__()

        activation_function = torch.nn.SELU
        # create list of all layer feature sizes
        fs = [in_features, *hidden_layer_sizes, out_features]
        # create list of linear_blocks
        layers = [self._linear_block(in_f, out_f,
                                     activation_function,
                                     dropout_p)
                  for in_f, out_f in zip(fs, fs[1:])]
        # concatenate modules in all sequentials in layers list
        layers = [module for sq in layers for module in sq.children()]

        # add modules to sequential container
        self.seq = torch.nn.Sequential(*layers)

    def _linear_block(self, in_f : int, out_f : int, activation : torch.nn.Module,
                      dropout_p : float) -> torch.nn.Sequential:
        
        # bias must be used in most MLPs in our models to learn from empty graphs
        linear = torch.nn.Linear(in_f, out_f, bias=True)
        torch.nn.init.xavier_uniform_(linear.weight)
        return torch.nn.Sequential(linear, activation(), torch.nn.AlphaDropout(dropout_p))

    def forward(self, layers_input : torch.nn.Sequential) -> torch.nn.Sequential:
        return self.seq(layers_input)


In [68]:
class GlobalReadout(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp1 = MLP(in_features=constants.message_size,
                  hidden_layer_sizes=[constants.mlp1_hidden_dim]*constants.mlp1_depth,
                  out_features=constants.message_size,
                  dropout_p=0.0)
        self.mlp2 = MLP(in_features=constants.message_size,
                  hidden_layer_sizes=[constants.mlp2_hidden_dim]*constants.mlp2_depth,
                  out_features=constants.message_size,
                  dropout_p=0.0)
        
        self.mlp3 = MLP(in_features=constants.message_size,
                  hidden_layer_sizes=[constants.mlp1_hidden_dim]*constants.mlp1_depth,
                  out_features=constants.message_size,
                  dropout_p=0.0)
        self.mlp4 = MLP(in_features=2*constants.message_size,
                  hidden_layer_sizes=[constants.mlp2_hidden_dim]*constants.mlp2_depth,
                  out_features=constants.message_size,
                  dropout_p=0.0)
        self.mlpt = MLP(in_features=constants.message_size,
                  hidden_layer_sizes=[constants.mlp1_hidden_dim]*constants.mlp1_depth,
                  out_features=constants.message_size,
                  dropout_p=0.0)
          
    def forward(self,features):
        g= torch.sum(features,dim=1)
        g = g.view(1,1,100)
        #g = torch.broadcast_to(g, (1,10, 100))
        print("api is ",g.shape)
        fadd1 = self.mlp1(features)
        fconn1 = self.mlp2(features)  

        print("dims fadd ip",fconn1.shape,g.shape,torch.cat([fadd1,g],dim=1).shape)
        fadd = self.mlp3(torch.cat([fadd1,g],dim=1)).unsqueeze(dim=1)
        fconn = self.mlp3(torch.cat([fconn1,g],dim=1)).unsqueeze(dim=1)

        fterm = self.mlpt(g)
        cat = torch.cat((fadd.squeeze(dim=1), fconn.squeeze(dim=1), fterm), dim=1)
        print("final shape ",cat.shape)
        return cat
        #apd = self.Softmax()....from original code its removed

In [70]:

class Set2Vec(torch.nn.Module):
    """
    S2V readout function.
    """
    def __init__(self, node_features : int, hidden_node_features : int,
                 lstm_computations : int, memory_size : int,
                 constants : namedtuple) -> None:

        super().__init__()

        self.constants         = constants
        self.lstm_computations = lstm_computations
        self.memory_size       = memory_size

        self.embedding_matrix = torch.nn.Linear(
            in_features=100 + 100,
            out_features=self.memory_size,
            bias=True
        )

        self.lstm = torch.nn.LSTMCell(
            input_size=self.memory_size,
            hidden_size=self.memory_size,
            bias=True
        )

    def forward(self, hidden_output_nodes : torch.Tensor, input_nodes : torch.Tensor,
                node_mask : torch.Tensor) -> torch.Tensor:
        """
        Defines forward pass.
        """
        Softmax      = torch.nn.Softmax(dim=1)

        batch_size   = input_nodes.shape[0]
        energy_mask  = torch.bitwise_not(node_mask).float() * 0.3
        lstm_input   = torch.zeros(batch_size, self.memory_size, device=self.constants.device)
        cat          = torch.cat((hidden_output_nodes, input_nodes), dim=2)
        memory       = self.embedding_matrix(cat)
        hidden_state = torch.zeros(batch_size, self.memory_size, device=self.constants.device)
        cell_state   = torch.zeros(batch_size, self.memory_size, device=self.constants.device)

        for _ in range(self.lstm_computations):
            query, cell_state = self.lstm(lstm_input, (hidden_state, cell_state))

            # dot product query x memory
            energies  = (query.view(batch_size, 1, self.memory_size) * memory).sum(dim=-1)
            attention = Softmax(energies + energy_mask)
            read      = (attention.unsqueeze(-1) * memory).sum(dim=1)

            hidden_state = query
            lstm_input   = read

        cat = torch.cat((query, read), dim=1)
        return cat

In [57]:
import torch
from collections import namedtuple
hyperparameters = {
        "att_depth"            : 4,
        "att_dropout_p"        : 0.0,
        "att_hidden_dim"       : 250,
        "enn_depth"            : 4,
        "enn_dropout_p"        : 0.0,
        "enn_hidden_dim"       : 250,
        "mlp1_depth"           : 4,
        "mlp1_dropout_p"       : 0.0,
        "mlp1_hidden_dim"      : 100,
        "mlp2_depth"           : 4,
        "mlp2_dropout_p"       : 0.0,
        "mlp2_hidden_dim"      : 100,
        "hidden_node_features" : 100,
        "message_passes"       : 3,
        "message_size"         : 100,
        "s2v_lstm_computations": 3,
        "s2v_memory_size"      : 100,
        "n_edge_features"      : 4,
        "n_node_features"      :100,
    }
import json
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
constants = dotdict(hyperparameters)

from ashishcode import load_molecule,MolecularGraph,params
path = "data\pre-training\gdb13_1K\Train.smi"
molecule_set = load_molecule(path)
print(molecule_set)
for mol in molecule_set:
    molecule = MolecularGraph(mol)
    print(molecule)
    break

True
<rdkit.Chem.rdmolfiles.SmilesMolSupplier object at 0x0000016569C203B0>
<ashishcode.MolecularGraph object at 0x0000016569E9C0D0>


In [16]:
nodes,edges = molecule.get_graph_state()
nodes,edges = torch.Tensor(nodes).view((1,10,9)),torch.Tensor(edges).view((1,10,10,4))


In [44]:
network(nodes,edges)

adj is  torch.Size([1, 10, 10])
nodes various ares (tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
same batch,node tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper__bmm)