In [78]:
import torch

from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj

In [79]:
class VanillaGNNLayer(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.linear = torch.nn.Linear(dim_in, dim_out, bias=False)

    def forward(self, x, adjacency):
        x = self.linear(x)
        x = torch.sparse.mm(adjacency, x)
        return x

In [80]:
edge_index = torch.tensor([
    [0, 1],
    [0, 2],
    [0, 3],
    [1, 0],
    [2, 0],
    [3, 0]
], dtype=torch.long)

# x = torch.tensor([[0], [0], [0], [0]], dtype=torch.float) # configs: [0, 0, 0, 0], [1, 1, 1, 1]
x = torch.tensor([[0, 1], [0, 1], [0, 1], [0, 1]], dtype=torch.float) # configs: [0, 0, 0, 0], [1, 1, 1, 1]

y = torch.tensor([[3.0, 1.0]]) # ranks

data = Data(x=x, edge_index=edge_index.t().contiguous(), y=y)
data.train_mask = torch.tensor([1 for _ in range(len(x))])

In [81]:
adjacency = to_dense_adj(edge_index.t().contiguous())[0]
adjacency += torch.eye(len(adjacency))
adjacency

tensor([[1., 1., 1., 1.],
        [1., 1., 0., 0.],
        [1., 0., 1., 0.],
        [1., 0., 0., 1.]])

In [82]:
class VanillaGNN(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.gnn1 = VanillaGNNLayer(dim_in, dim_h)
        self.gnn2 = VanillaGNNLayer(dim_h, dim_h)
        self.out = torch.nn.Linear(dim_h, dim_out)
    
    def forward(self, x, adjacency):
        h = self.gnn1(x, adjacency)
        h = torch.relu(h)
        h = self.gnn2(h, adjacency)
        h = torch.relu(h)
        h = self.out(h)
        return torch.sum(h, dim=0)
    
    def fit(self, data, epochs):
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        self.train()
        for epoch in range(epochs+1):
            optimizer.zero_grad()
            out = self(data.x, adjacency)
            loss = criterion(out, data.y)
            print("Loss:", loss)
            loss.backward()
            optimizer.step()


In [83]:
gnn = VanillaGNN(data.num_features, 16, data.num_features)
print(gnn)
gnn.fit(data, epochs=200)

VanillaGNN(
  (gnn1): VanillaGNNLayer(
    (linear): Linear(in_features=2, out_features=16, bias=False)
  )
  (gnn2): VanillaGNNLayer(
    (linear): Linear(in_features=16, out_features=16, bias=False)
  )
  (out): Linear(in_features=16, out_features=2, bias=True)
)
Loss: tensor(7.4973, grad_fn=<MseLossBackward0>)
Loss: tensor(4.5842, grad_fn=<MseLossBackward0>)
Loss: tensor(2.5534, grad_fn=<MseLossBackward0>)
Loss: tensor(1.2304, grad_fn=<MseLossBackward0>)
Loss: tensor(0.4593, grad_fn=<MseLossBackward0>)
Loss: tensor(0.1056, grad_fn=<MseLossBackward0>)
Loss: tensor(0.1012, grad_fn=<MseLossBackward0>)
Loss: tensor(0.3242, grad_fn=<MseLossBackward0>)
Loss: tensor(0.5694, grad_fn=<MseLossBackward0>)
Loss: tensor(0.6836, grad_fn=<MseLossBackward0>)
Loss: tensor(0.6446, grad_fn=<MseLossBackward0>)
Loss: tensor(0.4972, grad_fn=<MseLossBackward0>)
Loss: tensor(0.3210, grad_fn=<MseLossBackward0>)
Loss: tensor(0.1747, grad_fn=<MseLossBackward0>)
Loss: tensor(0.0696, grad_fn=<MseLossBackward0>)

  return F.mse_loss(input, target, reduction=self.reduction)


In [84]:
out = gnn(data.x, adjacency)
print(out)

tensor([3.0000, 1.0000], grad_fn=<SumBackward1>)
