In [3]:
!pip install torch



[0m

In [4]:
import torch
import networkx as nx
import numpy as np

# Load graph
G = nx.karate_club_graph()

# Number of nodes
num_nodes = G.number_of_nodes()

# One-hot features for each node
X = torch.eye(num_nodes)

# Get edge_index tensor: shape (2, num_edges)
edges = list(G.edges())
edge_index = torch.tensor(edges, dtype=torch.long).t()
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)


label_map = {'0': 0, '1': 1}
y = torch.tensor([label_map[G.nodes[i]['club']] for i in G.nodes], dtype=torch.long)


In [5]:
class GCNConv(torch.nn.Module):

    def __init__(self, num_features, output_dim):
        super().__init__()
        self.W0 = torch.nn.Parameter(torch.randn(num_features, 16))  # First layer
        self.W1 = torch.nn.Parameter(torch.randn(16, output_dim))    # Second layer

    def g_conv(self, x, w, edge_indices):
        num_nodes = x.size(0)
        A = torch.zeros((num_nodes, num_nodes))

        # Fill in adjacency matrix from edge indices
        for i in range(edge_indices.shape[1]):
            src = edge_indices[0, i]
            tgt = edge_indices[1, i]
            A[src, tgt] = 1

        # Add self-loops
        A += torch.eye(num_nodes)

        # Normalize A: D^(-1/2) A D^(-1/2)
        D = torch.diag(A.sum(1))
        D_inv_sqrt = torch.linalg.inv(torch.sqrt(D))
        A_hat = D_inv_sqrt @ A @ D_inv_sqrt

        # GCN operation: A_hat X W
        h = A_hat @ x @ w
        return h

    def forward(self, x, edge_index):
        h1 = self.g_conv(x, self.W0, edge_index).relu()
        h = self.g_conv(h1, self.W1, edge_index)
        return torch.softmax(h, dim=1)


In [6]:
model = GCNConv(num_features=X.shape[1], output_dim=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()

# Train
for epoch in range(200):
    model.train()
    out = model(X, edge_index)
    loss = loss_fn(out, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        pred = out.argmax(dim=1)
        acc = (pred == y).float().mean()
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}, Accuracy: {acc.item():.4f}")


Epoch 0, Loss: 0.8084, Accuracy: 0.2647
Epoch 10, Loss: 0.6597, Accuracy: 0.6176
Epoch 20, Loss: 0.5502, Accuracy: 0.7647
Epoch 30, Loss: 0.4803, Accuracy: 0.9118
Epoch 40, Loss: 0.4328, Accuracy: 0.9412
Epoch 50, Loss: 0.3986, Accuracy: 0.9412
Epoch 60, Loss: 0.3745, Accuracy: 0.9706
Epoch 70, Loss: 0.3596, Accuracy: 0.9706
Epoch 80, Loss: 0.3505, Accuracy: 0.9706
Epoch 90, Loss: 0.3448, Accuracy: 0.9706
Epoch 100, Loss: 0.3411, Accuracy: 0.9706
Epoch 110, Loss: 0.3382, Accuracy: 1.0000
Epoch 120, Loss: 0.3361, Accuracy: 1.0000
Epoch 130, Loss: 0.3343, Accuracy: 1.0000
Epoch 140, Loss: 0.3329, Accuracy: 1.0000
Epoch 150, Loss: 0.3317, Accuracy: 1.0000
Epoch 160, Loss: 0.3305, Accuracy: 1.0000
Epoch 170, Loss: 0.3295, Accuracy: 1.0000
Epoch 180, Loss: 0.3285, Accuracy: 1.0000
Epoch 190, Loss: 0.3276, Accuracy: 1.0000
