In [42]:
import torch
import torch_geometric.nn as nn
from torch_geometric.utils import degree
from torch_geometric.utils import to_undirected

In [59]:
class CustomLayer(nn.MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        
        self.linear = nn.Linear(in_channels, out_channels, bias=False)
        self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
        
        self.reset_parameters()
        
    def forward(self, x, edge_index):
        x = self.linear(x)
        row, col = edge_index
        degrees = degree(col, len(x))
        inv_sqrt = degrees.pow(-0.5)
        inv_sqrt[inv_sqrt == float('inf')] = 0
        norm = inv_sqrt[row] * inv_sqrt[col]
        
        x = self.propagate(edge_index, x=x, norm=norm)
        
        x = x + self.bias
        
        return x
        
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j
        
    def reset_parameters(self):
        self.linear.reset_parameters()
        self.bias.data.zero_()
        

In [60]:
class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.customLayer = CustomLayer(3, 1)

In [90]:
model = GNN()

customLayer = CustomLayer(3, 1)
test_features = torch.rand(5, 3)
edge_index = torch.tensor([[0, 1, 2, 3], [4, 4, 4, 4]], dtype=torch.long)
edge_index = to_undirected(edge_index)
optim = torch.optim.SGD()
y = torch.tensor([[0],[0],[0],[0],[1]], dtype=torch.float)
for i in range(0, 10):
    out = customLayer.forward(test_features, edge_index)
    print(out)
    print(y)
    loss = torch.nn.MSELoss()
    los = loss(out, y)
    los.backward()

tensor([[0.2001],
        [0.2001],
        [0.2001],
        [0.2001],
        [0.3110]], grad_fn=<AddBackward0>)
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [1.]])
tensor([[0.2001],
        [0.2001],
        [0.2001],
        [0.2001],
        [0.3110]], grad_fn=<AddBackward0>)
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [1.]])
tensor([[0.2001],
        [0.2001],
        [0.2001],
        [0.2001],
        [0.3110]], grad_fn=<AddBackward0>)
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [1.]])
tensor([[0.2001],
        [0.2001],
        [0.2001],
        [0.2001],
        [0.3110]], grad_fn=<AddBackward0>)
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [1.]])
tensor([[0.2001],
        [0.2001],
        [0.2001],
        [0.2001],
        [0.3110]], grad_fn=<AddBackward0>)
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [1.]])
tensor([[0.2001],
        [0.2001],
        [0.2001],
        [0.2001]