In [1]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.out_channels = out_channels
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))
        self.embeddings = torch.nn.Embedding(2, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index, edge_attr):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        # print(f"1. shape endge index {edge_index.shape}")
        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 19)
        self_loop_attr[:,-1] = 1 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
        
        summed_embeddings = torch.zeros(edge_attr.size(0), self.out_channels)

        for i in range(19):  # Iterate over the second dimension
            summed_embeddings += self.embeddings(edge_attr[:, i])
        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)
        # print(f"2. shape x {x.shape}")
        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        # print(f"3. shape norm {norm.shape}")
        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, edge_attr = summed_embeddings,norm=norm)
        # print(f"7. shape out {out.shape}")
        # Step 6: Apply a final bias vector.
        out = out + self.bias

        return out

    def message(self, x_j, edge_attr, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        # print(f"4. shape xj {x_j.shape}")
        # print(f"5. shape norm.vier(-1,1) {norm.view(-1, 1).shape}")
        # print(f"6. shape output {(norm.view(-1, 1) * x_j).shape}")
        return norm.view(-1, 1) * (x_j + edge_attr)

In [2]:
import torch
from torch_geometric.data import Data

# Initialize the GCNConv layer
# Transform from 3-dimensional features to 2-dimensional features
gcn_conv = GCNConv(in_channels=3, out_channels=2)

# Define node features (4 nodes with 3 features each)
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float)

# Define the edges in the graph (making it undirected)
# Each pair of nodes is connected in both directions
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 0, 0, 2], 
                           [1, 0, 2, 1, 3, 2, 0, 3, 2, 0]], dtype=torch.long)  # Edges: 0-1, 1-2, 2-3, 3-0, 0-2
edge_attr = torch.randint(0, 1, (edge_index.size(1), 19))
# Apply the GCNConv layer to the node features
out_features = gcn_conv(x, edge_index, edge_attr)

print("Original node features:\n", x)
print("\nTransformed node features:\n", out_features)


Original node features:
 tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.]])

Transformed node features:
 tensor([[-1.4963, 16.2255],
        [-2.0074, 12.4923],
        [-1.4963, 16.2255],
        [-1.0643, 14.1538]], grad_fn=<AddBackward0>)
