In [8]:
import torch
import torch.nn as nn
A = torch.tensor([
    [0, 1, 0, 0],
    [1, 0, 1, 0],
    [0, 1, 0, 1],
    [0, 0, 1, 0]
], dtype=torch.float32)


# Node features (4 nodes, 3 features each)
X = torch.tensor([
    [1.0, 0.0, 1.0],
    [0.0, 1.0, 1.0],
    [1.0, 1.0, 0.0],
    [0.0, 0.0, 1.0]
])

I = torch.eye(A.size(0))
A_hat = A + I

D_hat = torch.diag(torch.sum(A_hat, dim = 1))
D_hat_inv_sqrt = torch.linalg.inv(torch.sqrt(D_hat))

A_norm = D_hat_inv_sqrt @ A_hat @ D_hat_inv_sqrt


In [15]:
class GCN(nn.Module):
  def __init__(self, in_dim, hidden_dim, out_dim):
    super().__init__()
    self.W1 = torch.nn.Parameter(torch.randn(in_dim, hidden_dim))
    self.W2 = torch.nn.Parameter(torch.randn(hidden_dim, out_dim))

  def forward(self, A_norm , X):
    H = torch.relu(A_norm @ X @ self.W1) # 1st Aggreation and projection
    H = A_norm @ H @ self.W2 # 2nd Aggregation & projection
    return H
  

class node_classifer(nn.Module):
  def __init__(self, in_dim, hidden_dim, num_classes):
    super().__init__()
    self.l1 = nn.Linear(in_dim, hidden_dim, bias = True)
    self.l2 = nn.Linear(hidden_dim, num_classes, bias= True)

  def forward(self,H):
    out = self.l1(H)
    out = torch.relu(out)
    logits = self.l2(out)
    return logits

class classify_model(nn.Module):
  def __init__(self, in_dim, hidden_dim, proj_emb, num_classes ):
    super().__init__()
    self.aggregator = GCN(in_dim, hidden_dim, proj_emb)
    self.classifier = node_classifer(proj_emb, hidden_dim, num_classes)

  def forward(self, A_norm, X):
    H = self.aggregator(A_norm, X)
    logits = self.classifier(H)

    return logits

In [12]:
labels = torch.tensor([0, 1, 0, 2])  # size N

train_mask = torch.tensor([1, 1, 0, 0], dtype=torch.bool)
val_mask   = torch.tensor([0, 0, 1, 1], dtype=torch.bool)
#test_mask  = torch.tensor([0, 0, 0, 0, 1, 1], dtype=torch.bool)


In [16]:
model = classify_model(in_dim = 3, hidden_dim= 4, proj_emb= 4, num_classes= 3)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

for epoch in range(200):
  model.train()
  optimizer.zero_grad()

  out = model(A_norm, X)        # (N, C)
  loss = loss_fn(out[train_mask], labels[train_mask])

  loss.backward()
  optimizer.step()

  if epoch % 20 == 0:
    pred = out.argmax(dim=1)
    acc = (pred[train_mask] == labels[train_mask]).float().mean()
    print(f"Epoch {epoch} | Loss {loss:.4f} | Train Acc {acc:.3f}")

Epoch 0 | Loss 0.9500 | Train Acc 0.500
Epoch 20 | Loss 0.9241 | Train Acc 0.500
Epoch 40 | Loss 0.9061 | Train Acc 0.500
Epoch 60 | Loss 0.8992 | Train Acc 0.500
Epoch 80 | Loss 0.8925 | Train Acc 0.500
Epoch 100 | Loss 0.8861 | Train Acc 0.500
Epoch 120 | Loss 0.8800 | Train Acc 0.500
Epoch 140 | Loss 0.8742 | Train Acc 0.500
Epoch 160 | Loss 0.8685 | Train Acc 0.500
Epoch 180 | Loss 0.8631 | Train Acc 0.500
