In [1]:
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root=".", name="Cora")
data = dataset[0]
num_labels = len(set(data.y.numpy()))  # used for output_dim

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.nn import GATv2Conv, GCNConv


class MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.lin1 = nn.Linear(input_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        x = data.x  # no graph structure, only node features
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return F.log_softmax(x, dim=1)
    
class GCN(torch.nn.Module):    
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
    
class GAT(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, heads=8):
        super().__init__()
        self.gat1 = GATv2Conv(input_dim, hidden_dim, heads=heads)
        # for the last GAT layer we use concat=False to average the outputs of the heads
        self.gat2 = GATv2Conv(hidden_dim * heads, output_dim, heads=heads, concat=False)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.gat2(x, edge_index)
        return F.log_softmax(x, dim=1)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
results = {}

# Define the accuracy function
def accuracy(pred, target):
    return (pred == target).sum().item() / target.size(0)

# iterate over the different model types
for model_class in [MLP, GCN, GAT]: # later we test also with GCN (this post) and GAT (next blog post)
    results[model_class.__name__] = []
    for i in range(10):
        print(f"Training {model_class.__name__} iteration {i+1}")
        
        # the output_dim is the number of unique classes in the set
        model = model_class(input_dim=data.x.shape[1], hidden_dim=32, output_dim=num_labels).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

        # deal with the class imbalance
        class_weights = torch.bincount(data.y) / len(data.y)

        print(f"Class weights: {class_weights}")

        loss_fn = nn.CrossEntropyLoss(weight=1/class_weights).to(device)

        data = data.to(device)

        # training loop
        for epoch in range(100):
            model.train()
            optimizer.zero_grad()
            out = model(data)
            
            # calculate loss
            train_loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
            acc = accuracy(out[data.train_mask].argmax(dim=1), data.y[data.train_mask])
            train_loss.backward()
            optimizer.step()

            if epoch % 10 == 0:
                model.eval()
                with torch.no_grad():
                    val_loss = loss_fn(out[data.val_mask], data.y[data.val_mask])
                    val_acc = accuracy(out[data.val_mask].argmax(dim=1), data.y[data.val_mask])
                    print(f'Epoch {epoch} | Training Loss: {train_loss.item():.2f} | Train Acc: {acc:>5.2f} | Validation Loss: {val_loss.item():.2f} | Validation Acc: {val_acc:>5.2f}')

        # final evaluation on the test set
        model.eval()
        with torch.no_grad():
            out = model(data)
            test_loss = loss_fn(out[data.test_mask], data.y[data.test_mask])
            test_acc = accuracy(out[data.test_mask].argmax(dim=1), data.y[data.test_mask])
            print(f'{model_class.__name__} Test Loss: {test_loss.item():.2f} | Test Acc: {test_acc:>5.2f}')
            results[model_class.__name__].append([acc, val_acc, test_acc])


# print average on test set and standard deviation
for model_name, model_results in results.items():
    model_results = torch.tensor(model_results)
    print(f'{model_name} Test Accuracy: {model_results[:, 2].mean():.2f} ± {model_results[:, 2].std():.2f}')

Training MLP iteration 1
Class weights: tensor([0.1296, 0.0801, 0.1544, 0.3021, 0.1573, 0.1100, 0.0665])
Epoch 0 | Training Loss: 1.94 | Train Acc:  0.14 | Validation Loss: 1.95 | Validation Acc:  0.12
Epoch 10 | Training Loss: 0.28 | Train Acc:  0.98 | Validation Loss: 1.54 | Validation Acc:  0.37
Epoch 20 | Training Loss: 0.03 | Train Acc:  1.00 | Validation Loss: 1.32 | Validation Acc:  0.54
Epoch 30 | Training Loss: 0.01 | Train Acc:  1.00 | Validation Loss: 1.36 | Validation Acc:  0.56
Epoch 40 | Training Loss: 0.00 | Train Acc:  1.00 | Validation Loss: 1.38 | Validation Acc:  0.53
Epoch 50 | Training Loss: 0.00 | Train Acc:  1.00 | Validation Loss: 1.35 | Validation Acc:  0.53
Epoch 60 | Training Loss: 0.00 | Train Acc:  1.00 | Validation Loss: 1.31 | Validation Acc:  0.54
Epoch 70 | Training Loss: 0.01 | Train Acc:  1.00 | Validation Loss: 1.29 | Validation Acc:  0.54
Epoch 80 | Training Loss: 0.01 | Train Acc:  1.00 | Validation Loss: 1.28 | Validation Acc:  0.53
Epoch 90 | Tra