In [None]:
import torch

In [None]:
!pip install torch_geometric
import torch_geometric as pyg
from torch_geometric.datasets import Planetoid



In [None]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class DeepSet(torch.nn.Module):
    def __init__(self, in_channels : int, out_channels : int, message_channels=None,
                 hidden_channels=None, in_network=None, out_network=None,
                 in_activation=torch.nn.ReLU(), out_activation=torch.sigmoid):
        super(DeepSet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        if hidden_channels is None:
            self.hidden_channels = in_channels // 2
        else:
            self.hidden_channels = hidden_channels
        if message_channels is None:
          self.message_channels = in_channels
        else:
          self.message_channels = message_channels
        if in_network is None:
            self.in_network = torch.nn.Linear(in_channels, self.hidden_channels)
        else:
            self.in_network = in_network
        if out_network is None:
            self.out_network = torch.nn.Linear(self.hidden_channels, self.message_channels)
        else:
            self.out_network = out_network
        W_message = torch.nn.Parameter(torch.rand(self.out_channels, self.message_channels))
        self.register_parameter('W_message', W_message)
        W_alpha = torch.nn.Parameter(torch.rand(self.out_channels, self.in_channels))
        self.register_parameter('W_alpha', W_alpha)
        self.in_activation = in_activation
        self.out_activation = out_activation
        self.reset_parameters()

    def reset_parameters(self):
        self.in_network.reset_parameters()
        self.out_network.reset_parameters()
        torch.nn.init.xavier_uniform_(self.W_message)
        torch.nn.init.xavier_uniform_(self.W_alpha)

    def forward(self, X, A):
        alpha = self.in_network(X)
        alpha = self.in_activation(alpha)
        alpha  = torch.mm(A, alpha)
        message = self.out_network(alpha)

        return self.out_activation(torch.mm(message, self.W_message.t()) + torch.mm(X, self.W_alpha.t()))



In [None]:
device = torch.device('cuda')

In [None]:
A = torch.zeros(dataset.x.shape[0], dataset.x.shape[0], requires_grad=False)
edge_index = dataset.edge_index
for i in range (0, edge_index.shape[1]):
  A[edge_index[0, i], edge_index[1, i]] += 1
A = A.to(device)
H = dataset.x.to(device)

In [None]:
class SemiSupervisedClassifier(torch.nn.Module):
    def __init__(self, input_embed_dim : int,  num_classes : int, latent_dim = None):
        super(SemiSupervisedClassifier, self).__init__()
        if latent_dim is None:
          latent_dim = input_embed_dim
        self.gnn1 = DeepSet(input_embed_dim, latent_dim)
        self.gnn2 = DeepSet(latent_dim, num_classes)

    def forward(self, H : torch.Tensor, A : torch.Tensor):
        return self.gnn2(self.gnn1(H, A), A)

In [None]:
model = SemiSupervisedClassifier(dataset.x.shape[1], dataset.num_classes).to(torch.device('cuda'))

In [None]:
import torch.nn.functional as F
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(1):
    optimizer.zero_grad()
    out = model(H, A)
    loss = F.nll_loss(out[dataset.train_mask], dataset.y[dataset.train_mask].to(device))
    print(loss)
    loss.backward()
    optimizer.step()

tensor(-0.5769, device='cuda:0', grad_fn=<NllLossBackward0>)


In [None]:
model.eval()
pred = model(H, A).argmax(dim=1)
correct = (pred[dataset.test_mask] == dataset.y[dataset.test_mask].to(device)).sum()
acc = int(correct) / int(dataset.test_mask.sum())
print(f'Accuracy: {acc:.4f}')

Accuracy: 0.1300
