## Tutorial 4: Minimalistic re-implementation of Graph Convolutional Networks (GCN) in PyG

In this tutorial, we will reimplement Semi-Supervised Classification with Graph Convolutional Networks (GCN) introduced by [Kipf et al. (2017)](https://arxiv.org/abs/1609.02907) with PyTorch Geometric. The following codes are inspired by an open source implementation [here](https://github.com/ki-ljl/PyG-GCN/tree/main)

In [1]:
import torch
from torch_geometric.datasets import Planetoid, NELL
from torch_geometric.nn.models import GCN

from tqdm.auto import tqdm

In [2]:
def get_dataset(name: str):
    assert name in ['Cora', 'CiteSeer', 'PubMed']
    dataset = Planetoid(root=f'/tmp/{name}', name=f'{name}')
    return dataset 

def train(model, data, 
          num_epochs:int=200,
          device:str='cpu'):
    model = model.to(device)
    data = data.to(device)
    
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
    loss_fn = torch.nn.CrossEntropyLoss()
    model.train() # Set model to 'train' mode
    
    pbar = tqdm(range(num_epochs), total=num_epochs, ascii=' =', leave=True)
    for epoch in pbar:
        out = model(data.x, data.edge_index)
        opt.zero_grad()
        loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        opt.step()

        # Progress bar
        pbar.set_description('Epoch {:03d} loss {:.4f}'.format(epoch, loss.item()))
        
    model = model.to('cpu')

def test(model, data):
    model.eval()
    _, pred = model(data.x, data.edge_index).max(dim=1)
    correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
    acc = correct / int(data.test_mask.sum())
    print('Accuracy: {:.4f}'.format(acc))
    


In [3]:
dataset_names = ['Cora', 'CiteSeer', 'PubMed']
device = 'cuda:0'

for dataset_name in dataset_names:
    dataset = get_dataset(dataset_name)

    model = GCN(in_channels=dataset.num_node_features, 
                hidden_channels=32, 
                num_layers=2,
                out_channels=dataset.num_classes,
                dropout=0.5,
                norm='batch', # 'batch', 'instance', 'layer', 'none'
                )

    train(model, dataset[0], device=device)
    print(f'--- {dataset_name} ---')
    test(model, dataset[0])

  0%|          | 0/200 [00:00<?, ?it/s]

--- Cora ---
Accuracy: 0.6920


  0%|          | 0/200 [00:00<?, ?it/s]

--- CiteSeer ---
Accuracy: 0.5120


  0%|          | 0/200 [00:00<?, ?it/s]

--- PubMed ---
Accuracy: 0.7310
