In [1]:
import torch
import torch.nn.functional as F

In [2]:
class GCNLayer(torch.nn.Module):
    
    def __init__(self, input_dim: int, output_dim: int, A: torch.Tensor):
        super(GCNLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.A = A
        print(f"macierz A:{A}")
        # A + id 
        self.A_hat = self.A + torch.eye(self.A.size(0))
        # diag matrix
        self.ones = torch.ones(self.input_dim, self.input_dim)
        self.D = torch.matmul(self.A.float(), self.ones.float())

        self.D = torch.diag(self.D)
        self.D = torch.diag_embed(self.D)

        self.D_neg_sqrt = torch.diag_embed(torch.diag(torch.pow(self.D, -0.5)))

        self.W = torch.nn.Parameter(torch.rand(self.input_dim, self.output_dim))

    def forward(self, X: torch.Tensor):
        support_1 = torch.matmul(self.D_neg_sqrt, torch.matmul(self.A_hat, self.D_neg_sqrt))
        support_2 = torch.matmul(support_1, torch.matmul(X, self.W))
        H = F.relu(support_2)
        return H 

In [3]:
A = torch.tensor([[1.,0.,0.],[0.,1.,1.],[0.,1.,1.]])
gcn_layer = GCNLayer(3, 2, A)
X = torch.tensor([[1.,2.,3.],[4.,5.,6.],[7.,8.,9.]])
output = gcn_layer(X)
print(output)

macierz A:tensor([[1., 0., 0.],
        [0., 1., 1.],
        [0., 1., 1.]])
tensor([[ 5.2253,  7.2303],
        [12.2221, 16.8300],
        [14.2978, 19.6819]], grad_fn=<ReluBackward0>)
