In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt


In [2]:
from torch_geometric.data import Data
def get_karate_club_data():
    G = nx.karate_club_graph()
    num_nodes = G.number_of_nodes()
    x = torch.eye(num_nodes, dtype=torch.float)
    edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
    edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)  # undirected
 
    label_map = {'Mr. Hi': 0, 'Officer': 1}
    y = torch.tensor([label_map[G.nodes[i]['club']] for i in range(num_nodes)], dtype=torch.long)
 
    perm = torch.randperm(num_nodes)
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    train_mask[perm[:20]] = True
    test_mask[perm[20:]] = True
 
    data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask)
    return data



In [7]:
'''import torch
import networkx as nx
from torch.utils.data import Dataset

class KarateClubDataset(Dataset):
    def __init__(self):
        G = nx.karate_club_graph()
        self.num_nodes = G.number_of_nodes()
        self.x = torch.eye(self.num_nodes)  # One-hot encoding

        # Get edge list 
        edges = np.array(G.edges())
        edge_index = np.array(edges).T
        edge_index = np.concatenate((edge_index, edge_index[::-1]), axis=1)
        self.edge_index = torch.tensor(edge_index, dtype=torch.long)

        self.labels = torch.zeros(self.num_nodes, dtype=torch.long)
        for i in range(self.num_nodes):
            # Get the club label from networkx (0 for Mr. Hi's group, 1 for Officer's group)
            self.labels[i] = G.nodes[i]['club'] == 'Officer'

    def __len__(self):
        return 1  

    def __getitem__(self, idx):
        return self.x, self.edge_index'''


In [8]:
dataset = KarateClubDataset()
loader = DataLoader(dataset, batch_size=1, shuffle=False)

# For testing
for x, edge_index in loader:
    print("Node features shape:", x.shape)
    print("Edge index shape:", edge_index.shape)


Node features shape: torch.Size([1, 34, 34])
Edge index shape: torch.Size([1, 2, 156])


In [6]:
class GCNConv(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, output_dim):
        super().__init__()
        self.W0 = torch.nn.Parameter(torch.randn(num_features, hidden_channels))
        self.W1 = torch.nn.Parameter(torch.randn(hidden_channels, output_dim))

    def g_conv(self, x, w, edge_indices):
        num_nodes = x.size(0)
        A = torch.zeros((num_nodes, num_nodes), device=x.device)
        
        A[edge_indices[0], edge_indices[1]] = 1
        A += torch.eye(num_nodes, device=x.device)  # Add self-connections
        D_inv_sqrt = torch.diag(torch.pow(A.sum(1), -0.5))
        A_hat = D_inv_sqrt @ A @ D_inv_sqrt
        return A_hat @ x @ w

    def forward(self, x, edge_index):
        h1 = self.g_conv(x, self.W0, edge_index).relu()
        h = self.g_conv(h1, self.W1, edge_index).softmax(dim=1)
        return h


In [9]:
import networkx as nx
data = get_karate_club_data()
model = GCNConv(num_features=data.num_node_features, hidden_channels=16, output_dim=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()
 
for epoch in range(300):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
 
    loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
 
    model.eval()
    _, pred = out.max(dim=1)
    correct = int((pred[data.test_mask] == data.y[data.test_mask]).sum())
    acc = correct / int(data.test_mask.sum())
 
    if epoch % 10 == 0:
        print(f"Epoch {epoch:3d} | Loss: {loss.item():.4f} | Test Acc: {acc:.4f}")

Epoch   0 | Loss: 0.8169 | Test Acc: 0.2857
Epoch  10 | Loss: 0.6389 | Test Acc: 0.6429
Epoch  20 | Loss: 0.5068 | Test Acc: 0.8571
Epoch  30 | Loss: 0.4316 | Test Acc: 0.8571
Epoch  40 | Loss: 0.3912 | Test Acc: 0.8571
Epoch  50 | Loss: 0.3675 | Test Acc: 0.9286
Epoch  60 | Loss: 0.3526 | Test Acc: 0.9286
Epoch  70 | Loss: 0.3420 | Test Acc: 0.9286
Epoch  80 | Loss: 0.3350 | Test Acc: 0.9286
Epoch  90 | Loss: 0.3302 | Test Acc: 0.9286
Epoch 100 | Loss: 0.3268 | Test Acc: 0.9286
Epoch 110 | Loss: 0.3242 | Test Acc: 0.9286
Epoch 120 | Loss: 0.3222 | Test Acc: 0.9286
Epoch 130 | Loss: 0.3206 | Test Acc: 0.9286
Epoch 140 | Loss: 0.3194 | Test Acc: 0.9286
Epoch 150 | Loss: 0.3185 | Test Acc: 0.9286
Epoch 160 | Loss: 0.3177 | Test Acc: 0.9286
Epoch 170 | Loss: 0.3171 | Test Acc: 0.9286
Epoch 180 | Loss: 0.3167 | Test Acc: 0.9286
Epoch 190 | Loss: 0.3163 | Test Acc: 0.9286
Epoch 200 | Loss: 0.3160 | Test Acc: 0.9286
Epoch 210 | Loss: 0.3157 | Test Acc: 0.9286
Epoch 220 | Loss: 0.3155 | Test 