In [1]:
import torch

In [2]:
class GCNConvByHand(torch.nn.Module):
    """maps D x N to O x N"""

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.linear = torch.nn.Linear(dim_in, dim_out, bias=True)

    def forward(self, x, A):
        num_nodes = A.shape[1]  # B x N x N
        omega_k = self.linear.weight
        beta_k = self.linear.bias.reshape(-1, 1)
        h = torch.matmul(
            beta_k, torch.reshape(torch.ones(num_nodes), (1, -1))
        ) + torch.matmul(omega_k, torch.matmul(x, A + torch.eye(num_nodes)))
        return h

In [3]:
B = 2
D = 4
N = 3

x = torch.randn((B, D, N))  # B x D x N
A = torch.FloatTensor(
    [[[0, 2, 1], [1, 0, 0], [1, 0, 0]], [[0, 2, 1], [1, 0, 2], [1, 0, 0]]]
)  # adjacency matrix ( B x N x N )

In [None]:
O = 10
conv = GCNConvByHand(D, O)  # ( (O x D) x ( (D x N) x (N x N) = D x N ) = O x N)  => B x O x N

In [6]:
conv(x, A)

tensor([[[-2.2564, -2.0061, -1.8521],
         [ 1.1493,  1.7374,  1.2489],
         [ 1.3229,  1.3732,  1.1650],
         [ 1.1285,  1.0028,  1.4222],
         [ 1.1270,  1.4183,  0.9826],
         [-2.8941, -1.6851, -1.4126],
         [-3.2285, -2.3829, -1.4578],
         [ 0.6476,  0.0695,  0.5307],
         [-1.6266, -1.3949, -0.3867],
         [ 2.4957,  1.9101,  1.0963]],

        [[-0.4018,  0.8643, -0.5551],
         [-0.4297, -1.1870, -0.8342],
         [-0.3052, -1.5593, -0.2877],
         [-0.4393, -0.8246, -0.5393],
         [-0.9022, -1.1699, -1.3383],
         [-1.4261, -0.9230, -1.9998],
         [-1.1011, -1.2705, -1.1923],
         [ 1.1236,  1.3713,  1.4659],
         [ 0.1648,  0.5788,  0.2182],
         [ 0.6467,  0.6597,  0.6749]]], grad_fn=<AddBackward0>)

In [8]:
import torch
from torch_geometric.utils import to_dense_adj

# Example edge index (edge list format) as a torch tensor
edge_index = torch.tensor([[0, 1, 0],   # Row 1: Edges from nodes 0 -> 1, 0 -> 2
                           [1, 2, 2]],  # Row 2: Edges from nodes 1 -> 2, 2 -> 0
                          dtype=torch.long)

# Number of nodes in the graph
num_nodes = 3

# Convert edge index to adjacency matrix
adj_matrix = to_dense_adj(edge_index)

# Print the adjacency matrix
print(adj_matrix)

tensor([[[0., 1., 1.],
         [0., 0., 1.],
         [0., 0., 0.]]])
