In [11]:
# Environment Setup
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [12]:
# Setting
args = {
    'device': device,
    "dataset": "Cora",
    'hidden_dim': 16,
    'lr': 0.01,
    'weight_decay': 5e-4,
    'epochs': 200
}

In [13]:
# Load the data
dataset = Planetoid(root='./', name=args['dataset'], transform=T.NormalizeFeatures())
data = dataset[0].to(device)
data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [14]:
# Implement simple GCNConv Layer with pytorch MessagePassing
class GCNLayer(MessagePassing):
  def __init__(self, input_dim, output_dim):
    # aggr: how to aggregate message of each node    
    super(GCNLayer, self).__init__(aggr='add')

    self.input_dim = input_dim
    self.output_dim = output_dim
    self.lin = nn.Linear(input_dim, output_dim).to(device)

  def forward(self, X, edge_index):

    edge_index, _ = add_self_loops(edge_index, num_nodes = X.shape[0])
    x = self.lin(X)

    row, col = edge_index
    deg_mat = degree(row, x.shape[0], dtype=x.dtype)
    deg_inv_sqrt_mat = deg_mat.pow(-0.5)
    norm = deg_inv_sqrt_mat[row] * deg_inv_sqrt_mat[col]

    return self.propagate(edge_index, size=(x.shape[0], x.shape[0]),
                          x = x, norm = norm)
    
    def message(self, X_j):
      return norm.view(-1, 1) * X_j

    # Update node embedding
    def update(self, aggr_output):
      return aggr_output

In [15]:
# Model
class GCN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = GCNLayer(dataset.num_node_features, args['hidden_dim'])
    self.conv2 = GCNLayer(args['hidden_dim'], dataset.num_classes)

  def forward(self, data):
    x, edge_index = data.x, data.edge_index

    x = self.conv1(x, edge_index)
    x = F.relu(x)
    x = F.dropout(x, training=self.training)
    x = self.conv2(x, edge_index)

    return F.log_softmax(x, dim=1)

In [16]:
# learning
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

def train():
  model.train()
  optimizer.zero_grad()
  out = model(data)
  loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
  loss.backward()
  optimizer.step()

In [17]:
@torch.no_grad()
def test():
  model.eval()
  out, accuracy = model(data), []
  for _, mask in data('train_mask','val_mask', 'test_mask'):
    pred = out[mask].max(1)[1]
    acc = pred.eq(data.y[mask]).sum() / mask.sum()
    accuracy.append(acc)
  return accuracy

In [18]:
for epoch in range(args['epochs']):
  train()
  train_acc, val_acc, test_acc = test()
  
  print('Epoch: {:02d},'.format(epoch+1),
        'train_acc: {:.4f},'.format(train_acc.item()),
        'val_acc: {:.4f},'.format(val_acc.item()),
        'test_acc: {:.4f}'.format(test_acc.item()))

Epoch: 01, train_acc: 0.1143, val_acc: 0.1840, test_acc: 0.1760
Epoch: 02, train_acc: 0.1500, val_acc: 0.1600, test_acc: 0.1410
Epoch: 03, train_acc: 0.1571, val_acc: 0.1580, test_acc: 0.1410
Epoch: 04, train_acc: 0.2429, val_acc: 0.2220, test_acc: 0.2090
Epoch: 05, train_acc: 0.3143, val_acc: 0.2900, test_acc: 0.2770
Epoch: 06, train_acc: 0.3571, val_acc: 0.3000, test_acc: 0.2920
Epoch: 07, train_acc: 0.4214, val_acc: 0.3540, test_acc: 0.3510
Epoch: 08, train_acc: 0.6500, val_acc: 0.4860, test_acc: 0.5200
Epoch: 09, train_acc: 0.7214, val_acc: 0.5560, test_acc: 0.5800
Epoch: 10, train_acc: 0.7714, val_acc: 0.5900, test_acc: 0.5910
Epoch: 11, train_acc: 0.7857, val_acc: 0.6000, test_acc: 0.5910
Epoch: 12, train_acc: 0.8000, val_acc: 0.6140, test_acc: 0.5970
Epoch: 13, train_acc: 0.8286, val_acc: 0.6300, test_acc: 0.6100
Epoch: 14, train_acc: 0.8571, val_acc: 0.6420, test_acc: 0.6260
Epoch: 15, train_acc: 0.8786, val_acc: 0.6820, test_acc: 0.6670
Epoch: 16, train_acc: 0.8929, val_acc: 0