In [27]:
import torch
from torch_geometric.datasets import Planetoid

In [29]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')

In [31]:
device = torch.device('cuda')

In [30]:
class GCN(torch.nn.Module):
    def __init__(self, input_embed_dim : int, output_embed_dim : int, activation=torch.nn.Sigmoid):
        super(GCN, self).__init__()
        # The k-1 embedding dimensions.
        self.input_embed_dim = input_embed_dim
        # The k embedding dimensions.
        self.output_embed_dim = output_embed_dim
        # Th enon linearity
        self.activation = activation
        self.weights = torch.nn.Parameter(torch.rand(self.input_embed_dim, self.output_embed_dim))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weights)

    def forward(self, H : torch.Tensor, A : torch.Tensor):
        if A.shape[0] != A.shape[1]:
          raise Exception("GNN layers expects an Adjecancy Matrix(square).")
        if H.shape[0] != A.shape[0]:
          raise Exception("Shape mismatch between H and A. expected shape (", A.shape[0], self.input_embed_dim, ") got (", H.shape[0], ", ", H.shape[1], ")")
        if H.shape[1] != self.input_embed_dim:
          raise Exception("H expected shape (", A.shape[0], ", ", self.input_embed_dim, ") got (", H.shape[0], ", ", H.shape[1], ")")
        A = A + torch.eye(A.shape[0]).to(torch.device('cuda'))
        D = torch.sum(A, dim=1)
        D_sqrt_inv = torch.diag(torch.pow(torch.sqrt(D), -1))
        A_cap = torch.matmul(torch.matmul(D_sqrt_inv, A), D_sqrt_inv).to(torch.device('cuda'))
        return self.activation(torch.matmul(A_cap, torch.matmul(H, self.weights)))

In [32]:
A = torch.zeros(dataset.x.shape[0], dataset.x.shape[0], requires_grad=False)
edge_index = dataset.edge_index
for i in range (0, edge_index.shape[1]):
  A[edge_index[0, i], edge_index[1, i]] += 1
A = A.to(device)
H = dataset.x.to(device)

In [33]:
Ai = A
Hi = H

In [34]:
class SemiSupervisedClassifier(torch.nn.Module):
    def __init__(self, input_embed_dim : int,  num_classes : int, latent_dim = None):
        super(SemiSupervisedClassifier, self).__init__()
        if latent_dim is None:
          latent_dim = input_embed_dim
        self.gnn1 = GCN(input_embed_dim, latent_dim, torch.nn.ReLU())
        self.gnn2 = GCN(latent_dim, num_classes, torch.nn.Softmax(dim=1))

    def forward(self, H : torch.Tensor, A : torch.Tensor):
        return self.gnn2(self.gnn1(H, A), A)

In [35]:
model = SemiSupervisedClassifier(dataset.x.shape[1], dataset.num_classes).to(torch.device('cuda'))

In [37]:
import torch.nn.functional as F

In [38]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(60):
    optimizer.zero_grad()
    out = model(H, A)
    loss = F.nll_loss(out[dataset.train_mask], dataset.y[dataset.train_mask].to(device))
    print(loss)
    loss.backward()
    optimizer.step()

tensor(-0.1431, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.4026, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.7663, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9148, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9702, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9836, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9887, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9910, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9921, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9919, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9916, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9919, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9918, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9910, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9896, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9884, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(-0.9860, device='

In [39]:
model.eval()
pred = model(H, A).argmax(dim=1)
correct = (pred[dataset.test_mask] == dataset.y[dataset.test_mask].to(device)).sum()
acc = int(correct) / int(dataset.test_mask.sum())
print(f'Accuracy: {acc:.4f}')

Accuracy: 0.8030
