# Class 15: Machine Learning 2 — Building up to Graph Convolutional Networks

## Today's Goals
1. Think about how graph neural networks operate at a high level.
2. Understand what the building blocks of a graph neural network are.
3. Play with a graph neural network's hyperparameters.

In [42]:
from torch.nn import Linear
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import CitationFull

dataset = CitationFull('/courses/PHYS7332.202510/shared/data/', name='Cora')

In [70]:
import numpy as np
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import RandomNodeSplit

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gcn1 = GCNConv(dataset.num_features, 64)
        self.gcn2 = GCNConv(64, 16)
        self.out = Linear(16, dataset.num_classes)    
        
    def forward(self, x, edge_index):
        h1 = self.gcn1(x, edge_index).relu()
        h2 = self.gcn2(h1, edge_index).relu()
        z = self.out(h2)
        return h2, z

splits = RandomNodeSplit(split='train_rest', num_val=0.15, num_test=0.15)(dataset.data)

model = GCN()
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
loader = DataLoader(dataset, batch_size=128)

for epoch in range(1, 101):
    model.train()
    total_loss = 0
    tot_accuracy = 0
    for batch in loader:
        optimizer.zero_grad()
        h2, z = model(batch.x.to(device), batch.edge_index.to(device))
        loss = criterion(z[splits.train_mask], batch.y.to(device)[splits.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    model.eval()
    val_h, val_z = model(dataset.x.to(device), dataset.edge_index.to(device))
    val_z = val_z[splits.val_mask]
    ans = val_z.argmax(dim=1) 
    ys = batch.y.to(device)[splits.val_mask]
    tot_accuracy += torch.mean(torch.eq(ans, ys).float())
    loss = total_loss / len(loader)
    accuracy = tot_accuracy / len(loader)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}')

Epoch: 001, Loss: 4.2583, Accuracy: 0.0478
Epoch: 002, Loss: 4.1405, Accuracy: 0.0276
Epoch: 003, Loss: 3.9989, Accuracy: 0.0623
Epoch: 004, Loss: 3.7828, Accuracy: 0.1479
Epoch: 005, Loss: 3.5689, Accuracy: 0.1563
Epoch: 006, Loss: 3.3722, Accuracy: 0.2587
Epoch: 007, Loss: 3.1978, Accuracy: 0.3254
Epoch: 008, Loss: 3.0163, Accuracy: 0.3567
Epoch: 009, Loss: 2.8457, Accuracy: 0.3772
Epoch: 010, Loss: 2.6924, Accuracy: 0.4136
Epoch: 011, Loss: 2.5250, Accuracy: 0.4436
Epoch: 012, Loss: 2.3838, Accuracy: 0.4732
Epoch: 013, Loss: 2.2103, Accuracy: 0.4938
Epoch: 014, Loss: 2.0728, Accuracy: 0.5251
Epoch: 015, Loss: 1.9217, Accuracy: 0.5487
Epoch: 016, Loss: 1.7906, Accuracy: 0.5800
Epoch: 017, Loss: 1.6499, Accuracy: 0.5901
Epoch: 018, Loss: 1.5480, Accuracy: 0.6069
Epoch: 019, Loss: 1.4414, Accuracy: 0.6154
Epoch: 020, Loss: 1.3672, Accuracy: 0.6295
Epoch: 021, Loss: 1.2793, Accuracy: 0.6443
Epoch: 022, Loss: 1.2053, Accuracy: 0.6474
Epoch: 023, Loss: 1.1450, Accuracy: 0.6564
Epoch: 024,