In [1]:
import torch
from torch_geometric.datasets import Amazon
from torch_geometric.transforms import NormalizeFeatures
dataset = Amazon(root='/tmp/Amazon', name='Computers', transform=NormalizeFeatures())

In [2]:
print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]

print()
print(data)
print('===========================================================================================================')

print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')


Dataset: AmazonComputers():
Number of graphs: 1
Number of features: 767
Number of classes: 10

Data(x=[13752, 767], edge_index=[2, 491722], y=[13752])
Number of nodes: 13752
Number of edges: 491722
Average node degree: 35.76


In [22]:
from torch_geometric.transforms import RandomNodeSplit

transform = RandomNodeSplit(split = "random", num_train_per_class=1250, num_val = 0.10, num_test = 0.30)
data = transform(data)

print(data.train_mask)
print(data.val_mask)
print(data.test_mask)

tensor([True, True, True,  ..., True, True, True])
tensor([False, False, False,  ..., False, False, False])
tensor([False, False, False,  ..., False, False, False])


In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.ops import MLP
from torch_geometric.nn import GATConv
from torch_geometric.loader import ClusterData, ClusterLoader
from torch_geometric.data import Data

class GATNetwork(nn.Module):
    def __init__(self, heads):
        super(GATNetwork, self).__init__()
        self.gat1 = GATConv(in_channels=data.num_features, out_channels=64, heads=heads[0])
        self.gat2 = GATConv(in_channels=64 * heads[0], out_channels=128, heads=heads[1], concat=False)
        self.bn1 = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(128)
        self.mlp1 = MLP(in_channels=128, hidden_channels=[64, dataset.num_classes], norm_layer=nn.BatchNorm1d)
        self.skip = nn.Sequential(nn.Linear(64 * heads[0], 128), nn.ReLU())

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        x = self.bn1(self.gat1(x, edge_index))
        x = F.elu(x)
        x = self.skip(x) + self.bn2(self.gat2(x, edge_index))
        x = self.mlp1(x)
        return x

model = GATNetwork(heads=[4, 2])

cluster_data = ClusterData(data, num_parts=128)
train_loader = ClusterLoader(cluster_data, batch_size=32, shuffle=True)

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

def train():
      model.train()

      for sub_data in train_loader: 
          batch = Data(x=sub_data.x, edge_index=sub_data.edge_index)
          out = model(batch)  
          loss = criterion(out[sub_data.train_mask], sub_data.y[sub_data.train_mask]) 
          loss.backward()  
          optimizer.step()  
          optimizer.zero_grad()  

def test():
      model.eval()
      out = model(data)
      pred = out.argmax(dim=1) 
      
      accs = []
      for mask in [data.train_mask, data.val_mask, data.test_mask]:
          correct = pred[mask] == data.y[mask] 
          accs.append(int(correct.sum()) / int(mask.sum())) 
      return accs

best_val_acc = 0
patience = 10
counter = 0

for epoch in range(1, 201):
    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
    
    if val_acc > best_val_acc:
        torch.save(model.state_dict(), 'best_model.pth')
        best_val_acc = val_acc
        counter = 0
    else:
        counter += 1
    
    if counter == patience:
        print('Early stopping')
        break

model.load_state_dict(torch.load('best_model.pth'))

Computing METIS partitioning...
Done!


Epoch: 001, Train: 0.1586, Val Acc: 0.0269, Test Acc: 0.0291
Epoch: 002, Train: 0.1588, Val Acc: 0.0276, Test Acc: 0.0291
Epoch: 003, Train: 0.1587, Val Acc: 0.0276, Test Acc: 0.0291
Epoch: 004, Train: 0.1588, Val Acc: 0.0269, Test Acc: 0.0293
Epoch: 005, Train: 0.1592, Val Acc: 0.0269, Test Acc: 0.0296
Epoch: 006, Train: 0.1711, Val Acc: 0.0342, Test Acc: 0.0388
Epoch: 007, Train: 0.2440, Val Acc: 0.0938, Test Acc: 0.1100
Epoch: 008, Train: 0.2950, Val Acc: 0.1542, Test Acc: 0.1592
Epoch: 009, Train: 0.3734, Val Acc: 0.3287, Test Acc: 0.3272
Epoch: 010, Train: 0.5425, Val Acc: 0.7207, Test Acc: 0.7172
Epoch: 011, Train: 0.6080, Val Acc: 0.7935, Test Acc: 0.8005
Epoch: 012, Train: 0.6487, Val Acc: 0.8116, Test Acc: 0.8306
Epoch: 013, Train: 0.6972, Val Acc: 0.8473, Test Acc: 0.8674
Epoch: 014, Train: 0.7548, Val Acc: 0.8596, Test Acc: 0.8805
Epoch: 015, Train: 0.8079, Val Acc: 0.8669, Test Acc: 0.8873
Epoch: 016, Train: 0.8419, Val Acc: 0.8742, Test Acc: 0.8941
Epoch: 017, Train: 0.874

<All keys matched successfully>