In [2]:
!pip install torch_geometric
import torch_geometric as pyg
import torch
from torch_geometric.datasets import Planetoid
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, is_undirected
from torch_geometric.utils import softmax

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [3]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [4]:
edge_index = dataset.edge_index
x = dataset.x

In [1]:
class Gattn(MessagePassing):
  def __init__(self, in_channels, out_channels, activation=torch.nn.ReLU()):
    super().__init__(aggr='add')
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.lin = Linear(in_channels, out_channels, bias=False)
    self.att = Linear(2 * out_channels, 1, bias=False)
    self.reset_parameters()
    self.att_activation  = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)
    self.activation = activation

  def reset_parameters(self):
    self.lin.reset_parameters()
    self.att.reset_parameters()

  def forward(self, x, edge_index):
    edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
    x = self.lin(x)

    att_features = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=1)
    att_coeff = self.att(att_features).squeeze(-1)
    att_coeff = torch.exp(self.att_activation(att_coeff))

    att_coeff = softmax(att_coeff, edge_index[1])

    out = self.propagate(edge_index, x=x, att_coeff=att_coeff)

    return self.activation(out)

  def message(self, x_j, att_coeff):

    return x_j * att_coeff.repeat(x_j.shape[1],1).t()



NameError: name 'MessagePassing' is not defined

In [16]:
import torch
from torch_geometric.nn import MessagePassing
from torch.nn import Linear
from torch_geometric.utils import add_self_loops, softmax

class Gattn(MessagePassing):
    def __init__(self, in_channels, out_channels, dropout=0.3, activation=torch.nn.ReLU()):
        super().__init__(aggr='add')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dropout = dropout

        # Linear transformations
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.att = Linear(2 * out_channels, 1, bias=False)

        # Activations
        self.att_activation = torch.nn.LeakyReLU(negative_slope=0.2)
        self.activation = activation

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.att.reset_parameters()

    def forward(self, x, edge_index):
        # Add self-loops
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Apply linear transformation
        x = self.lin(x)

        # Compute attention
        x_i = x[edge_index[0]]  # Source node features
        x_j = x[edge_index[1]]  # Target node features

        # Attention coefficients
        alpha = torch.cat([x_i, x_j], dim=-1)
        alpha = self.att(alpha)
        alpha = self.att_activation(alpha)
        alpha = softmax(alpha.squeeze(), edge_index[1])

        # Apply dropout to attention coefficients
        alpha = torch.nn.functional.dropout(alpha, p=self.dropout, training=self.training)

        # Propagate
        out = self.propagate(edge_index, x=x, alpha=alpha)

        return self.activation(out)

    def message(self, x_j, alpha):
        # Apply attention coefficients
        return x_j * alpha.view(-1, 1)

In [17]:
class SemiSupervisedClassifier(torch.nn.Module):
    def __init__(self, input_embed_dim : int,  num_classes : int, latent_dim = None):
        super(SemiSupervisedClassifier, self).__init__()
        if latent_dim is None:
          latent_dim = input_embed_dim
        self.gattn1a = Gattn(input_embed_dim, latent_dim)
        self.gattn1b = Gattn(input_embed_dim, latent_dim)
        self.gattn2a = Gattn(2 * latent_dim, num_classes)
        self.gattn2b = Gattn(2 * latent_dim, num_classes)
    def forward(self, H : torch.Tensor, A : torch.Tensor):
        x = self.gattn1a(H, A)
        x = torch.cat([self.gattn1b(H, A), x], dim=1)
        x1 = self.gattn2a(x, A)
        x1 = torch.unsqueeze(x1, 0)
        x2 = self.gattn2b(x, A)
        x2 = torch.unsqueeze(x2, 0)
        x = torch.mean(torch.cat([x1, x2]), dim=0)

        return x

In [66]:
model = SemiSupervisedClassifier(dataset.x.shape[1], dataset.num_classes)

In [67]:
import torch.nn.functional as F
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

model.train()
for epoch in range(15):
    optimizer.zero_grad()
    out = model(x, edge_index)
    loss = F.nll_loss(out[dataset.train_mask], dataset.y[dataset.train_mask])
    loss.backward()
    optimizer.step()

In [68]:
model.eval()
pred = model(x, edge_index).argmax(dim=1)
correct = (pred[dataset.test_mask] == dataset.y[dataset.test_mask]).sum()
acc = int(correct) / int(dataset.test_mask.sum())
print(f'Accuracy: {acc:.4f}')

Accuracy: 0.7120
