In [1]:
import torch

In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [3]:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [13]:
class SGCN(torch.nn.Module):
    def __init__(self, input_embed_dim : int, output_embed_dim : int, k : int):
        super(SGCN, self).__init__()
        # The input embedding dimensions.
        self.input_embed_dim = input_embed_dim
        # The output embedding dimensions.
        self.output_embed_dim = output_embed_dim
        # The non linearity
        self.k = k
        self.weights = torch.nn.Parameter(torch.Tensor(input_embed_dim, output_embed_dim))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weights)

    def forward(self, H : torch.Tensor, A : torch.Tensor):
        if A.shape[0] != A.shape[1]:
          raise Exception("GNN layers expects an Adjecancy Matrix(square).")
        if H.shape[0] != A.shape[0]:
          raise Exception("Shape mismatch between H and A. expected shape (", A.shape[0], self.input_embed_dim, ") got (", H.shape[0], ", ", H.shape[1], ")")
        if H.shape[1] != self.input_embed_dim:
          raise Exception("H expected shape (", A.shape[0], ", ", self.input_embed_dim, ") got (", H.shape[0], ", ", H.shape[1], ")")
        A = A + torch.eye(A.shape[0])
        A = (A > 0) * 1.
        D = torch.sum(A, dim=1) + 1e-6
        D_sqrt_inv = torch.diag(torch.pow(torch.sqrt(D), -1))
        S = torch.matmul(torch.matmul(D_sqrt_inv, A), D_sqrt_inv)
        S_k = torch.linalg.matrix_power(S, self.k)
        A_cap = torch.matmul(S_k, H)

        return torch.matmul(A_cap, self.weights)

In [14]:
A = torch.zeros(dataset.x.shape[0], dataset.x.shape[0], requires_grad=False)
edge_index = dataset.edge_index
for i in range (0, edge_index.shape[1]):
  A[edge_index[0, i], edge_index[1, i]] += 1
A = A
H = dataset.x

In [15]:
class SemiSupervisedClassifier(torch.nn.Module):
    def __init__(self, input_embed_dim : int,  num_classes : int, latent_dim = None):
        super(SemiSupervisedClassifier, self).__init__()
        if latent_dim is None:
          latent_dim = input_embed_dim
        self.gnn1 = SGCN(input_embed_dim, num_classes, 3)

    def forward(self, H : torch.Tensor, A : torch.Tensor):
        return self.gnn1(H, A)

In [29]:
model = SemiSupervisedClassifier(dataset.x.shape[1], dataset.num_classes)

In [30]:
import torch.nn.functional as F
optimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=5e-4)

model.train()
for epoch in range(100):
    optimizer.zero_grad()
    out = model(H, A)
    loss = F.nll_loss(out[dataset.train_mask], dataset.y[dataset.train_mask])
    print(loss)
    loss.backward()
    optimizer.step()

tensor(0.0021, grad_fn=<NllLossBackward0>)
tensor(-3.6099, grad_fn=<NllLossBackward0>)
tensor(-7.1997, grad_fn=<NllLossBackward0>)
tensor(-10.7637, grad_fn=<NllLossBackward0>)
tensor(-14.2988, grad_fn=<NllLossBackward0>)
tensor(-17.8027, grad_fn=<NllLossBackward0>)
tensor(-21.2729, grad_fn=<NllLossBackward0>)
tensor(-24.7069, grad_fn=<NllLossBackward0>)
tensor(-28.1022, grad_fn=<NllLossBackward0>)
tensor(-31.4565, grad_fn=<NllLossBackward0>)
tensor(-34.7683, grad_fn=<NllLossBackward0>)
tensor(-38.0362, grad_fn=<NllLossBackward0>)
tensor(-41.2595, grad_fn=<NllLossBackward0>)
tensor(-44.4378, grad_fn=<NllLossBackward0>)
tensor(-47.5705, grad_fn=<NllLossBackward0>)
tensor(-50.6572, grad_fn=<NllLossBackward0>)
tensor(-53.6974, grad_fn=<NllLossBackward0>)
tensor(-56.6905, grad_fn=<NllLossBackward0>)
tensor(-59.6362, grad_fn=<NllLossBackward0>)
tensor(-62.5342, grad_fn=<NllLossBackward0>)
tensor(-65.3849, grad_fn=<NllLossBackward0>)
tensor(-68.1887, grad_fn=<NllLossBackward0>)
tensor(-70.946

In [31]:
model.eval()
pred = model(H, A).argmax(dim=1)
correct = (pred[dataset.test_mask] == dataset.y[dataset.test_mask]).sum()
acc = int(correct) / int(dataset.test_mask.sum())
print(f'Accuracy: {acc:.4f}')

Accuracy: 0.7040
