In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

In [2]:
class GCNLayer(nn.Module):
    # https://medium.com/@jrosseruk/demystifying-gcns-a-step-by-step-guide-to-building-a-graph-convolutional-network-layer-in-pytorch-09bf2e788a51
    # https://blog.csdn.net/qq_43787862/article/details/113830925
    def __init__(self, input_dim, output_dim):
        super(GCNLayer, self).__init__()
        # Initialise the weight matrix as a parameter
        self.W = nn.Parameter(torch.rand(input_dim, output_dim))

    def forward(self, x, adj):
        """L * X * W"""
        out = torch.matmul(adj, torch.matmul(X, self.W))
        return out


class GCN(nn.Module):
    def __init__(self, A, input_dim, output_dim):
        super(GCN, self).__init__()
        self.A = A
        self.adj = self.__init_normal_adj(self.A)
        hidden_dim = input_dim
        self.gl1 = GCNLayer(input_dim=input_dim, output_dim=hidden_dim)
        self.gl2 = GCNLayer(input_dim=hidden_dim, output_dim=output_dim)

    def __init_normal_adj(self, A):
        """compute L=D^-0.5 * (A+I) * D^-0.5"""
        A_dim = A.size(0)
        # A_hat = A + I
        A_hat = A + torch.eye(A_dim)
        # Create diagonal degree matrix D
        ones = torch.ones(A_dim, A_dim)
        D = torch.matmul(A_hat.float(), ones.float())
        # Extract the diagonal elements
        D = torch.diag(D)
        # Create a new tensor with the diagonal elements and zeros elsewhere
        D = torch.diag_embed(D)
        # Create D^{-1/2}
        D_neg_sqrt = torch.diag_embed(torch.diag(torch.pow(D, -0.5)))
        # D^-1/2 * (A_hat * D^-1/2)
        adj = torch.matmul(D_neg_sqrt, torch.matmul(A_hat, D_neg_sqrt))
        return adj

    def forward(self, x):
        out = self.gl1(x, self.adj)
        out = F.relu(out)
        # out = self.gl2(out, self.adj)
        return out


A = torch.tensor(
    [
        [0.0, 0.0, 0.0, 0.0, 1.0],
        [0.0, 0.0, 0.0, 1.0, 1.0],
        [0.0, 0.0, 0.0, 1.0, 1.0],
        [0.0, 1.0, 1.0, 0.0, 1.0],
        [1.0, 1.0, 1.0, 1.0, 0.0],
    ]
)
model = GCN(A=A, input_dim=5, output_dim=5)
X = torch.randn(5, 5)
model(X), model.adj

(tensor([[0.2583, 0.1185, 0.0000, 0.4876, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0976, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<ReluBackward0>),
 tensor([[0.5000, 0.0000, 0.0000, 0.0000, 0.3162],
         [0.0000, 0.3333, 0.0000, 0.2887, 0.2582],
         [0.0000, 0.0000, 0.3333, 0.2887, 0.2582],
         [0.0000, 0.2887, 0.2887, 0.2500, 0.2236],
         [0.3162, 0.2582, 0.2582, 0.2236, 0.2000]]))

In [3]:
A = torch.zeros((60, 60))
small_matrix = torch.ones((4, 4))
for i in range(0, 15):
    A[4 * i : 4 * (i + 1), 4 * i : 4 * (i + 1)] = small_matrix
I = torch.eye(60)
A = A - I
print(A[:10, :10])

model = GCN(A=A, input_dim=128, output_dim=128)
X = torch.randn(60, 128)
model(X), model.adj, model(X).shape

tensor([[0., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]])


(tensor([[0.0000, 0.0000, 0.0000,  ..., 0.3109, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.3109, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.3109, 0.0000, 0.0000],
         ...,
         [0.5077, 0.0000, 0.8649,  ..., 0.1727, 1.5901, 0.0000],
         [0.5077, 0.0000, 0.8649,  ..., 0.1727, 1.5901, 0.0000],
         [0.5077, 0.0000, 0.8649,  ..., 0.1727, 1.5901, 0.0000]],
        grad_fn=<ReluBackward0>),
 tensor([[0.2500, 0.2500, 0.2500,  ..., 0.0000, 0.0000, 0.0000],
         [0.2500, 0.2500, 0.2500,  ..., 0.0000, 0.0000, 0.0000],
         [0.2500, 0.2500, 0.2500,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.2500, 0.2500, 0.2500],
         [0.0000, 0.0000, 0.0000,  ..., 0.2500, 0.2500, 0.2500],
         [0.0000, 0.0000, 0.0000,  ..., 0.2500, 0.2500, 0.2500]]),
 torch.Size([60, 128]))

In [4]:
X

tensor([[-0.0973,  2.0113,  1.0547,  ...,  1.2939, -0.1909, -0.4825],
        [-0.5021, -0.0394,  0.0563,  ..., -0.7722, -0.1307, -0.6411],
        [-1.7490, -0.7771,  1.5408,  ...,  0.2686,  0.2822, -1.7990],
        ...,
        [-0.5704, -0.1729, -0.3620,  ..., -0.3185,  0.4943,  2.6446],
        [-0.1729, -1.0727, -0.3436,  ...,  0.8912,  0.0719, -0.3624],
        [ 0.2210, -1.5154, -0.8486,  ...,  0.4117, -0.4244, -0.0537]])