# Customizing Aggregations within Message Passing

we provide **modular and re-usable aggregations** in the newly defined `torch_geometric.nn.aggr.*` package. Unifying these concepts also helps us to perform optimization and specialized implementations in a single place. In the new integration, the following functionality is applicable:

```python
# Original interface with string type as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr="mean")


# Use a single aggregation module as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=MeanAggregation())


# Use a list of aggregation strings as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=['mean', 'max', 'sum', 'std', 'var'])


# Use a list of aggregation modules as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=[
            MeanAggregation(),
            MaxAggregation(),
            SumAggregation(),
            StdAggregation(),
            VarAggregation(),
        ])


# Use a list of mixed modules and strings as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=[
            'mean',
            MaxAggregation(),
            'sum',
            StdAggregation(),
            'var',
        ])


# Define multiple aggregations with `MultiAggregation` module
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=MultiAggregation([
            SoftmaxAggregation(t=0.1, learn=True),
            SoftmaxAggregation(t=1, learn=True),
            SoftmaxAggregation(t=10, learn=True)
        ]))

```



In [2]:
import os

import torch

os.environ["TORCH"] = torch.__version__
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

In [7]:
dataset = Planetoid(root="data/Planetoid", name="PubMed", transform=NormalizeFeatures())
print(f"Dataset: {dataset}:")
print("==================")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")

Dataset: PubMed():
Number of graphs: 1
Number of features: 500
Number of classes: 3


In [8]:
data = dataset[0]  # Get the first graph object.
print(data)

Data(x=[19717, 500], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717])


In [9]:
from torch_geometric.loader import ClusterData, ClusterLoader

seed = 42
torch.manual_seed(seed)
cluster_data = ClusterData(data, num_parts=128)  # 1. Create subgraphs.
train_loader = ClusterLoader(
    cluster_data, batch_size=32, shuffle=True
)  # 2. Stochastic partioning scheme.

Computing METIS partitioning...
Done!


In [10]:
criterion = torch.nn.CrossEntropyLoss()


def train(model):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    for sub_data in train_loader:  # Iterate over each mini-batch.
        optimizer.zero_grad()  # Clear gradients.
        out = model(sub_data.x, sub_data.edge_index)  # Perform a single forward pass.
        loss = criterion(
            out[sub_data.train_mask], sub_data.y[sub_data.train_mask]
        )  # Compute the loss solely based on the training nodes.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.


def test(model):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)  # Use the class with highest probability.

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = pred[mask] == data.y[mask]  # Check against ground-truth labels.
        accs.append(
            int(correct.sum()) / int(mask.sum())
        )  # Derive ratio of correct predictions.
    return accs


def run(model, epochs=5):
    for epoch in range(epochs):
        loss = train(model)
        train_acc, val_acc, test_acc = test(model)
        print(
            f"Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}"
        )

## Define a GNN class and Import Aggregations
Now, let's define a GNN helper class and import all those new aggregation operators!




In [11]:
import copy

import torch.nn.functional as F
from torch_geometric.nn import (
    Aggregation,
    MaxAggregation,
    MeanAggregation,
    MultiAggregation,
    SAGEConv,
    SoftmaxAggregation,
    StdAggregation,
    SumAggregation,
    VarAggregation,
)


class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, aggr="mean", aggr_kwargs=None):
        super().__init__()
        self.conv1 = SAGEConv(
            dataset.num_node_features,
            hidden_channels,
            aggr=aggr,
            aggr_kwargs=aggr_kwargs,
        )
        self.conv2 = SAGEConv(
            hidden_channels,
            dataset.num_classes,
            aggr=copy.deepcopy(aggr),
            aggr_kwargs=aggr_kwargs,
        )

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x