In [1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)



1.12.0


# Customizing Aggregations within Message Passing with `torch_geometric.nn.aggr`

Aggregation functions play an important role in the message passing framework and the readout function when implementing GNNs. Many works in the GNN literature ([Hamilton et al. (2017)](https://cs.stanford.edu/~jure/pubs/graphsage-nips17.pdf), [Xu et al. (2018)](https://arxiv.org/abs/1810.00826), [Corso et al. (2020)](https://proceedings.neurips.cc/paper/2020/file/99cad265a1768cc2dd013f0e740300ae-Paper.pdf), [Li et al. (2020)](https://arxiv.org/abs/2006.07739)), demonstrate that the choice of aggregation functions contributes significantly to the performance of GNN models. In particular, the performance of GNNs with different aggregation functions differs when applied to distinct tasks and datasets. Recent works also show that using multiple aggregations ([Corso et al. (2020)](https://proceedings.neurips.cc/paper/2020/file/99cad265a1768cc2dd013f0e740300ae-Paper.pdf)) and learnable aggregations ([Li et al. (2020)](https://arxiv.org/abs/2006.07739)) can potentially gain substantial improvements. To facilitate experimentation with these different aggregation schemes and unify concepts of aggregation within GNNs across both [`MessagePassing`](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py) and [global readouts](https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric/nn/glob), we provide **modular and re-usable aggregations** in the newly defined `torch_geometric.nn.aggr.*` package. Unifying these concepts also helps us to perform optimization and specialized implementations in a single place. In the new integration, the following functionality is applicable:

```python
# Original interface with string type as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr="mean")

# Use a single aggregation module as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=MeanAggregation())

# Use a list of aggregation strings as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=['mean', 'max', 'sum', 'std', 'var'])

# Use a list of aggregation modules as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=[
          MeanAggregation(),
          MaxAggregation(),
          SumAggregation(),
          StdAggregation(),
          VarAggregation(),
          ])

# Use a list of mixed modules and strings as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=[
          'mean',
          MaxAggregation(),
          'sum',
          StdAggregation(),
          'var',
          ])

# Define multiple learnable aggregations with keyword arguments
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=['softmax', 'softmax', 'softmax'],
        aggr_kwargs = dict(aggrs_kwargs=[
                            dict(t=0.1, learn=True),
                            dict(t=1, learn=True),
                            dict(t=10, learn=True)]))

# Define multiple aggregations with `MultiAggregation` module
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=MultiAggregation([
          SoftmaxAggregation(t=0.1, learn=True),
          SoftmaxAggregation(t=1, learn=True),
          SoftmaxAggregation(t=10, learn=True)]))
```



In this tutorial, we explore the new aggregation package with `SAGEConv` ([Hamilton et al. (2017)](https://cs.stanford.edu/~jure/pubs/graphsage-nips17.pdf)) and `ClusterLoader` ([Chiang et al. (2019)](https://arxiv.org/abs/1905.07953)) and showcase on the `PubMed` graph from the `Planetoid` node classification benchmark suite ([Yang et al. (2016)](https://arxiv.org/abs/1603.08861)).

## Loading the dataset

In [2]:
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='PubMed', transform=NormalizeFeatures())

print()
print(f'Dataset: {dataset}:')
print('==================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('===============================================================================================================')

from torch_geometric.loader import ClusterData, ClusterLoader

torch.manual_seed(12345)
cluster_data = ClusterData(data, num_parts=128)  # 1. Create subgraphs.
train_loader = ClusterLoader(cluster_data, batch_size=32, shuffle=True)  # 2. Stochastic partioning scheme.


Dataset: PubMed():
Number of graphs: 1
Number of features: 500
Number of classes: 3

Data(x=[19717, 500], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717])


Computing METIS partitioning...
Done!


## Define train, test and run functions

In [3]:
criterion = torch.nn.CrossEntropyLoss()

def train(model):
  model.train()
  optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
  for sub_data in train_loader:  # Iterate over each mini-batch.
    optimizer.zero_grad()  # Clear gradients.
    out = model(sub_data.x, sub_data.edge_index)  # Perform a single forward pass.
    loss = criterion(out[sub_data.train_mask], sub_data.y[sub_data.train_mask])  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.

def test(model):
  model.eval()
  out = model(data.x, data.edge_index)
  pred = out.argmax(dim=1)  # Use the class with highest probability.
  
  accs = []
  for mask in [data.train_mask, data.val_mask, data.test_mask]:
    correct = pred[mask] == data.y[mask]  # Check against ground-truth labels.
    accs.append(int(correct.sum()) / int(mask.sum()))  # Derive ratio of correct predictions.
  return accs

def run(model, epochs=5):
  for epoch in range(1, epochs):
    loss = train(model)
    train_acc, val_acc, test_acc = test(model)
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')


## Training GNNs with `torch_geometric.nn.aggr` package

### Define a GNN class




In [10]:
import copy
import torch.nn.functional as F
from torch_geometric.nn import (
    SAGEConv,
   
)
from torch_geometric.nn.aggr import Aggregation, MultiAggregation

class GNN(torch.nn.Module):
  def __init__(self, hidden_channels, aggr='mean', aggr_kwargs=None):
      super(GNN, self).__init__()
      torch.manual_seed(12345)
      if isinstance(aggr, list):
        num_aggrs = len(aggr)
      elif isinstance(aggr, str):
        num_aggrs = 1
      elif isinstance(aggr, MultiAggregation):
        num_aggrs = len(aggr.aggrs)
      elif isinstance(aggr, Aggregation):
        num_aggrs = 1
      else:
        raise KeyError(f"Unknown aggr: {aggr}")
      conv1_aggr, conv2_aggr = aggr, copy.deepcopy(aggr)
      self.conv1 = SAGEConv([dataset.num_node_features * num_aggrs, dataset.num_node_features],
                            hidden_channels,
                            aggr=conv1_aggr,
                            aggr_kwargs=aggr_kwargs)
      self.conv2 = SAGEConv([hidden_channels * num_aggrs, hidden_channels],
                            dataset.num_classes,
                            aggr=conv2_aggr,
                            aggr_kwargs=aggr_kwargs)

  def forward(self, x, edge_index):
      x = self.conv1(x, edge_index)
      x = x.relu()
      x = F.dropout(x, p=0.5, training=self.training)
      x = self.conv2(x, edge_index)
      return x

ModuleNotFoundError: No module named 'torch_geometric.nn.aggr'

### Original interface with string type as aggregation argument

In [9]:
model = GNN(16, aggr='mean')
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
run(model)

NameError: name 'GNN' is not defined

### Use a single aggregation module as aggregation argument

In [None]:
model = GNN(16, aggr=MeanAggregation())
print(model)
run(model)

GNN(
  (conv1): SAGEConv([500, 500], 16, aggr=MeanAggregation())
  (conv2): SAGEConv([16, 16], 3, aggr=MeanAggregation())
)
Epoch: 001, Train: 0.3333, Val Acc: 0.3880, Test Acc: 0.4130
Epoch: 002, Train: 0.3333, Val Acc: 0.3880, Test Acc: 0.4130
Epoch: 003, Train: 0.3333, Val Acc: 0.3880, Test Acc: 0.4130
Epoch: 004, Train: 0.6667, Val Acc: 0.5060, Test Acc: 0.5430


### Use a list of aggregation strings as aggregation argument

In [None]:
model = GNN(16, aggr=['mean', 'max', 'sum', 'std', 'var'])
print(model)
run(model)

GNN(
  (conv1): SAGEConv([2500, 500], 16, aggr=['mean', 'max', 'sum', 'std', 'var'])
  (conv2): SAGEConv([80, 16], 3, aggr=['mean', 'max', 'sum', 'std', 'var'])
)
Epoch: 001, Train: 0.5000, Val Acc: 0.3640, Test Acc: 0.3550
Epoch: 002, Train: 0.7833, Val Acc: 0.6120, Test Acc: 0.6160
Epoch: 003, Train: 0.8167, Val Acc: 0.5680, Test Acc: 0.5350
Epoch: 004, Train: 0.8667, Val Acc: 0.7120, Test Acc: 0.6940


### Use a list of aggregation modules as aggregation argument

In [None]:
model = GNN(16, aggr=[
                      MeanAggregation(),
                      MaxAggregation(),
                      SumAggregation(),
                      StdAggregation(),
                      VarAggregation(),
                      ])
print(model)
run(model)

GNN(
  (conv1): SAGEConv([2500, 500], 16, aggr=['MeanAggregation()', 'MaxAggregation()', 'SumAggregation()', 'StdAggregation()', 'VarAggregation()'])
  (conv2): SAGEConv([80, 16], 3, aggr=['MeanAggregation()', 'MaxAggregation()', 'SumAggregation()', 'StdAggregation()', 'VarAggregation()'])
)
Epoch: 001, Train: 0.5000, Val Acc: 0.3640, Test Acc: 0.3550
Epoch: 002, Train: 0.7833, Val Acc: 0.6120, Test Acc: 0.6160
Epoch: 003, Train: 0.8167, Val Acc: 0.5680, Test Acc: 0.5350
Epoch: 004, Train: 0.8667, Val Acc: 0.7120, Test Acc: 0.6940


### Use a list of mixed modules and strings as aggregation argument

In [None]:
model = GNN(16, aggr=[
                      'mean',
                      MaxAggregation(),
                      'sum',
                      StdAggregation(),
                      'var',
                      ])
print(model)
run(model)

GNN(
  (conv1): SAGEConv([2500, 500], 16, aggr=['mean', 'MaxAggregation()', 'sum', 'StdAggregation()', 'var'])
  (conv2): SAGEConv([80, 16], 3, aggr=['mean', 'MaxAggregation()', 'sum', 'StdAggregation()', 'var'])
)
Epoch: 001, Train: 0.5000, Val Acc: 0.3640, Test Acc: 0.3550
Epoch: 002, Train: 0.7833, Val Acc: 0.6120, Test Acc: 0.6160
Epoch: 003, Train: 0.8167, Val Acc: 0.5680, Test Acc: 0.5350
Epoch: 004, Train: 0.8667, Val Acc: 0.7120, Test Acc: 0.6940


### Define multiple learnable aggregations with keyword arguments

In [None]:
aggr = ['softmax', 'softmax', 'softmax']
aggrs_kwargs = [dict(t=0.1, learn=True),
               dict(t=1, learn=True),
               dict(t=10, learn=True)]
model = GNN(16, aggr=aggr, aggr_kwargs=dict(aggrs_kwargs=aggrs_kwargs))
print(model)
run(model)

GNN(
  (conv1): SAGEConv([1500, 500], 16, aggr=['softmax', 'softmax', 'softmax'])
  (conv2): SAGEConv([48, 16], 3, aggr=['softmax', 'softmax', 'softmax'])
)
Epoch: 001, Train: 0.8500, Val Acc: 0.6980, Test Acc: 0.7010
Epoch: 002, Train: 0.9333, Val Acc: 0.6420, Test Acc: 0.6600
Epoch: 003, Train: 0.7500, Val Acc: 0.6260, Test Acc: 0.6520
Epoch: 004, Train: 0.9333, Val Acc: 0.7580, Test Acc: 0.7430


### Define multiple aggregations with `MultiAggregation` module

In [None]:
aggr = MultiAggregation([SoftmaxAggregation(t=0.1, learn=True),
                         SoftmaxAggregation(t=1, learn=True),
                         SoftmaxAggregation(t=10, learn=True)])       
model = GNN(16, aggr=aggr)
print(model)
run(model)

GNN(
  (conv1): SAGEConv([1500, 500], 16, aggr=MultiAggregation([
    SoftmaxAggregation(learn=True),
    SoftmaxAggregation(learn=True),
    SoftmaxAggregation(learn=True)
  ], mode=cat))
  (conv2): SAGEConv([48, 16], 3, aggr=MultiAggregation([
    SoftmaxAggregation(learn=True),
    SoftmaxAggregation(learn=True),
    SoftmaxAggregation(learn=True)
  ], mode=cat))
)
Epoch: 001, Train: 0.8500, Val Acc: 0.6980, Test Acc: 0.7010
Epoch: 002, Train: 0.9333, Val Acc: 0.6420, Test Acc: 0.6600
Epoch: 003, Train: 0.7500, Val Acc: 0.6260, Test Acc: 0.6520
Epoch: 004, Train: 0.9333, Val Acc: 0.7580, Test Acc: 0.7430


## Conclusion

In this tutorial, you have been presented with the new `torch_geometric.nn.aggr` package which provides a flexible interface to experiment with different aggregation functions with your message passing convolutions and unifies aggregation within GNNs across [`MessagePassing`](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py) and [global readouts](https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric/nn/glob). This new abstraction also makes designing new type of aggregation function easier. Now, you can create your own aggregation function with the base `Aggregation` class:

```python
class MyAggregation(Aggregation):
    def __init__(self, ...):
      ...

    def forward(self, x: Tensor, index: Optional[Tensor] = None,
                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
                dim: int = -2) -> Tensor:
      ...
```

*Have fun!*