In [8]:
import torch
from torch.nn import Linear, Parameter, BatchNorm1d
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_scatter import scatter_add
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU

unique_value_edge_attr = 2

class GINConv(MessagePassing):
    """
    Extension of GIN aggregation to incorporate edge information by concatenation.

    Args:
        emb_dim (int): dimensionality of embeddings for nodes and edges.
        embed_input (bool): whether to embed input or not. 
        

    See https://arxiv.org/abs/1810.00826
    """
    def __init__(self, in_channels, emb_dim, aggr = "add"):
        super(GINConv, self).__init__()
        #multi-layer perceptron
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(in_channels, 2*emb_dim),
            torch.nn.BatchNorm1d(2*emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(2*emb_dim, 4*emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(4*emb_dim, 2*emb_dim),
            torch.nn.ReLU(),
            )
        self.edge_embedding = torch.nn.Embedding(unique_value_edge_attr, in_channels)

        torch.nn.init.xavier_uniform_(self.edge_embedding.weight.data)

        self.aggr = aggr
        self.in_channels = in_channels
        self.emb_dim = emb_dim

    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index, _ = add_self_loops(edge_index, num_nodes = x.size(0))

        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), edge_attr.size(1))
        self_loop_attr[:,-4] = 1 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

        edge_embeddings = torch.zeros((edge_attr.size(0), self.in_channels), dtype=torch.float).to(edge_attr.device).to(edge_attr.dtype)

        for i in range(edge_attr.size(1)):  # Iterate over the second dimension
            embedding_ith = self.edge_embedding(edge_attr[:, i]).clone().detach().to(edge_attr.device).to(edge_attr.dtype)
            edge_embeddings += embedding_ith

        return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        print(x_j.shape, edge_attr.shape)
        return x_j + edge_attr

    def update(self, aggr_out):
        print(aggr_out.size())
        aggr_out = torch.tensor(aggr_out, dtype=torch.float)
        return self.mlp(aggr_out)

In [10]:
import torch
from torch_geometric.data import Data

# Initialize the GCNConv layer
# Transform from 3-dimensional features to 2-dimensional features
gin_conv = GINConv(in_channels=3, emb_dim=6)

# Define node features (4 nodes with 3 features each)
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float)

# Define the edges in the graph (making it undirected)
# Each pair of nodes is connected in both directions
edge_index = torch.tensor([[0, 1, 3, 3, 2], 
                           [1, 2, 2, 0, 0]], dtype=torch.long)  # Edges: 0-1, 1-2, 2-3, 3-0, 0-2
edge_attr = torch.randint(0, 1, (edge_index.size(1), 19))
# Apply the GCNConv layer to the node features
out_features = gin_conv(x, edge_index, edge_attr)

print("Original node features:\n", x)
print("\nTransformed node features:\n", out_features.shape)


torch.Size([9, 3]) torch.Size([9, 3])
torch.Size([4, 3])
Original node features:
 tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.]])

Transformed node features:
 torch.Size([4, 12])


  aggr_out = torch.tensor(aggr_out, dtype=torch.float)


In [134]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax, remove_self_loops
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros
import numpy as np

num_atom_type = 121 #including the extra motif tokens and graph token
num_chirality_tag = 11  #degree

num_bond_type = 7 
num_bond_direction = 3 

class GINConv(MessagePassing):
    """
    Extension of GIN aggregation to incorporate edge information by concatenation.

    Args:
        emb_dim (int): dimensionality of embeddings for nodes and edges.
        embed_input (bool): whether to embed input or not. 
        

    See https://arxiv.org/abs/1810.00826
    """
    def __init__(self, emb_dim, aggr = "add"):
        super(GINConv, self).__init__()
        #multi-layer perceptron
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)

        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
        self.aggr = aggr

    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes = x.size(0))

        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:,0] = 4 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

        edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])

        return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return self.mlp(aggr_out)

In [135]:
import torch
from torch_geometric.data import Data

# Define the number of bond types and directions
num_bond_type = 7
num_bond_direction = 3

# Define the GINConv class (as you provided)

# Initialize node features: 3 nodes with 4 features each
x = torch.randn((3, 4))

# Define edges: 0->1 and 1->2
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)

# Define edge attributes: 2 edges with bond type and direction
# Here we just use random integers for demonstration
edge_attr = torch.randint(0, 3, (4, 2), dtype=torch.long)

# Initialize the GINConv layer
emb_dim = 4  # Embedding dimension
gin_conv = GINConv(emb_dim)

# Apply the GINConv layer
out = gin_conv(x, edge_index, edge_attr)

print("Original node features:")
print(x)
print("\nUpdated node features:")
print(out)


Original node features:
tensor([[ 0.2145, -2.9410,  0.2738,  1.5769],
        [-0.5983, -0.7514, -0.8116,  0.2553],
        [-0.6392, -0.0398, -1.6644,  1.0835]])

Updated node features:
tensor([[ 0.2887,  0.5924, -0.3227, -0.0814],
        [ 0.6193,  0.6467, -0.9354, -0.4441],
        [ 0.5508,  0.4071, -0.2276, -0.2344]], grad_fn=<AddmmBackward0>)
