# Install Pytorch and PyG

In [1]:
import os

# Use the eager mode
os.environ['PT_HPU_LAZY_MODE'] = '0'

# Verify the environment variable is set
print(f"PT_HPU_LAZY_MODE: {os.environ['PT_HPU_LAZY_MODE']}")

import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

import habana_frameworks.torch.core as htcore

# use rich traceback

from rich import traceback
traceback.install()

device = torch.device("hpu")


PT_HPU_LAZY_MODE: 0




2.4.0a0+git74cd574


  return isinstance(object, types.FunctionType)


# Customizing Aggregations within Message Passing

Aggregation functions play an important role in the message passing framework and the readout function when implementing GNNs. Many works in the GNN literature ([Hamilton et al. (2017)](https://cs.stanford.edu/~jure/pubs/graphsage-nips17.pdf), [Xu et al. (2018)](https://arxiv.org/abs/1810.00826), [Corso et al. (2020)](https://proceedings.neurips.cc/paper/2020/file/99cad265a1768cc2dd013f0e740300ae-Paper.pdf), [Li et al. (2020)](https://arxiv.org/abs/2006.07739)), demonstrate that the choice of aggregation functions contributes significantly to the performance of GNN models. In particular, the performance of GNNs with different aggregation functions differs when applied to distinct tasks and datasets. Recent works also show that using multiple aggregations ([Corso et al. (2020)](https://proceedings.neurips.cc/paper/2020/file/99cad265a1768cc2dd013f0e740300ae-Paper.pdf)) and learnable aggregations ([Li et al. (2020)](https://arxiv.org/abs/2006.07739)) can potentially gain substantial improvements. To facilitate experimentation with these different aggregation schemes and unify concepts of aggregation within GNNs across both [`MessagePassing`](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py) and [global readouts](https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric/nn/glob), we provide **modular and re-usable aggregations** in the newly defined `torch_geometric.nn.aggr.*` package. Unifying these concepts also helps us to perform optimization and specialized implementations in a single place. In the new integration, the following functionality is applicable:

```python
# Original interface with string type as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr="mean")


# Use a single aggregation module as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=MeanAggregation())


# Use a list of aggregation strings as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=['mean', 'max', 'sum', 'std', 'var'])


# Use a list of aggregation modules as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=[
            MeanAggregation(),
            MaxAggregation(),
            SumAggregation(),
            StdAggregation(),
            VarAggregation(),
        ])


# Use a list of mixed modules and strings as aggregation argument
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=[
            'mean',
            MaxAggregation(),
            'sum',
            StdAggregation(),
            'var',
        ])


# Define multiple aggregations with `MultiAggregation` module
class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=MultiAggregation([
            SoftmaxAggregation(t=0.1, learn=True),
            SoftmaxAggregation(t=1, learn=True),
            SoftmaxAggregation(t=10, learn=True)
        ]))

```



In this tutorial, we explore the new aggregation package with `SAGEConv` ([Hamilton et al. (2017)](https://cs.stanford.edu/~jure/pubs/graphsage-nips17.pdf)) and `ClusterLoader` ([Chiang et al. (2019)](https://arxiv.org/abs/1905.07953)) and showcase on the `PubMed` graph from the `Planetoid` node classification benchmark suite ([Yang et al. (2016)](https://arxiv.org/abs/1603.08861)).

## Loading the dataset
Let's first load the `Planetoid` dataset and create subgraphs with `ClusterData` for training.

In [2]:
torch.ops.load_library("/root/raw_torch_for_scatter/metis/csrc/build/libtsmetis.so")
torch.ops.load_library("/root/raw_torch_for_scatter/pytorch_cluster/csrc/build/librandom_walk.so")

In [5]:
# torch_geometric/loader/cluster.py

import copy
import os
import os.path as osp
import sys
from dataclasses import dataclass
from typing import List, Literal, Optional

import torch
import torch.utils.data
from torch import Tensor

import torch_geometric.typing
from torch_geometric.data import Data
from torch_geometric.index import index2ptr, ptr2index
from torch_geometric.io import fs
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import index_sort, narrow, select, sort_edge_index
from torch_geometric.utils.map import map_index


@dataclass
class Partition:
    indptr: Tensor
    index: Tensor
    partptr: Tensor
    node_perm: Tensor
    edge_perm: Tensor
    sparse_format: Literal['csr', 'csc']


class ClusterData(torch.utils.data.Dataset):
    r"""Clusters/partitions a graph data object into multiple subgraphs, as
    motivated by the `"Cluster-GCN: An Efficient Algorithm for Training Deep
    and Large Graph Convolutional Networks"
    <https://arxiv.org/abs/1905.07953>`_ paper.

    .. note::
        The underlying METIS algorithm requires undirected graphs as input.

    Args:
        data (torch_geometric.data.Data): The graph data object.
        num_parts (int): The number of partitions.
        recursive (bool, optional): If set to :obj:`True`, will use multilevel
            recursive bisection instead of multilevel k-way partitioning.
            (default: :obj:`False`)
        save_dir (str, optional): If set, will save the partitioned data to the
            :obj:`save_dir` directory for faster re-use. (default: :obj:`None`)
        filename (str, optional): Name of the stored partitioned file.
            (default: :obj:`None`)
        log (bool, optional): If set to :obj:`False`, will not log any
            progress. (default: :obj:`True`)
        keep_inter_cluster_edges (bool, optional): If set to :obj:`True`,
            will keep inter-cluster edge connections. (default: :obj:`False`)
        sparse_format (str, optional): The sparse format to use for computing
            partitions. (default: :obj:`"csr"`)
    """
    def __init__(
        self,
        data,
        num_parts: int,
        recursive: bool = False,
        save_dir: Optional[str] = None,
        filename: Optional[str] = None,
        log: bool = True,
        keep_inter_cluster_edges: bool = False,
        sparse_format: Literal['csr', 'csc'] = 'csr',
    ):
        assert data.edge_index is not None
        assert sparse_format in ['csr', 'csc']

        self.num_parts = num_parts
        self.recursive = recursive
        self.keep_inter_cluster_edges = keep_inter_cluster_edges
        self.sparse_format = sparse_format

        recursive_str = '_recursive' if recursive else ''
        root_dir = osp.join(save_dir or '', f'part_{num_parts}{recursive_str}')
        path = osp.join(root_dir, filename or 'metis.pt')

        if save_dir is not None and osp.exists(path):
            self.partition = fs.torch_load(path)
        else:
            if log:  # pragma: no cover
                print('Computing METIS partitioning...', file=sys.stderr)

            cluster = self._metis(data.edge_index, data.num_nodes)
            self.partition = self._partition(data.edge_index, cluster)

            if save_dir is not None:
                os.makedirs(root_dir, exist_ok=True)
                torch.save(self.partition, path)

            if log:  # pragma: no cover
                print('Done!', file=sys.stderr)

        self.data = self._permute_data(data, self.partition)

    def _metis(self, edge_index: Tensor, num_nodes: int) -> Tensor:
        # Computes a node-level partition assignment vector via METIS.
        if self.sparse_format == 'csr':  # Calculate CSR representation:
            row, index = sort_edge_index(edge_index, num_nodes=num_nodes)
            indptr = index2ptr(row, size=num_nodes)
        else:  # Calculate CSC representation:
            index, col = sort_edge_index(edge_index, num_nodes=num_nodes,
                                         sort_by_row=False)
            indptr = index2ptr(col, size=num_nodes)

        # Compute METIS partitioning:
        cluster: Optional[Tensor] = None
        
        cluster = torch.ops.torch_sparse.partition(
            indptr.cpu(),
            index.cpu(),
            None,
            self.num_parts,
            self.recursive,
        ).to(edge_index.device)

        return cluster

    def _partition(self, edge_index: Tensor, cluster: Tensor) -> Partition:
        # Computes node-level and edge-level permutations and permutes the edge
        # connectivity accordingly:

        # Sort `cluster` and compute boundaries `partptr`:
        cluster, node_perm = index_sort(cluster, max_value=self.num_parts)
        partptr = index2ptr(cluster, size=self.num_parts)

        # Permute `edge_index` based on node permutation:
        edge_perm = torch.arange(edge_index.size(1), device=edge_index.device)
        arange = torch.empty_like(node_perm)
        arange[node_perm] = torch.arange(cluster.numel(),
                                         device=cluster.device)
        edge_index = arange[edge_index]

        # Compute final CSR representation:
        (row, col), edge_perm = sort_edge_index(
            edge_index,
            edge_attr=edge_perm,
            num_nodes=cluster.numel(),
            sort_by_row=self.sparse_format == 'csr',
        )
        if self.sparse_format == 'csr':
            indptr, index = index2ptr(row, size=cluster.numel()), col
        else:
            indptr, index = index2ptr(col, size=cluster.numel()), row

        return Partition(indptr, index, partptr, node_perm, edge_perm,
                         self.sparse_format)

    def _permute_data(self, data: Data, partition: Partition) -> Data:
        # Permute node-level and edge-level attributes according to the
        # calculated permutations in `Partition`:
        out = copy.copy(data)
        for key, value in data.items():
            if key == 'edge_index':
                continue
            elif data.is_node_attr(key):
                cat_dim = data.__cat_dim__(key, value)
                out[key] = select(value, partition.node_perm, dim=cat_dim)
            elif data.is_edge_attr(key):
                cat_dim = data.__cat_dim__(key, value)
                value = value.tolist()                
                out[key] = select(value, partition.edge_perm, dim=cat_dim)
        out.edge_index = None

        return out

    def __len__(self) -> int:
        return self.partition.partptr.numel() - 1

    def __getitem__(self, idx: int) -> Data:
        node_start = int(self.partition.partptr[idx])
        node_end = int(self.partition.partptr[idx + 1])
        node_length = node_end - node_start

        indptr = self.partition.indptr[node_start:node_end + 1]
        edge_start = int(indptr[0])
        edge_end = int(indptr[-1])
        edge_length = edge_end - edge_start
        indptr = indptr - edge_start

        if self.sparse_format == 'csr':
            row = ptr2index(indptr)
            col = self.partition.index[edge_start:edge_end]
            if not self.keep_inter_cluster_edges:
                edge_mask = (col >= node_start) & (col < node_end)
                row = row[edge_mask]
                col = col[edge_mask] - node_start
        else:
            col = ptr2index(indptr)
            row = self.partition.index[edge_start:edge_end]
            if not self.keep_inter_cluster_edges:
                edge_mask = (row >= node_start) & (row < node_end)
                col = col[edge_mask]
                row = row[edge_mask] - node_start

        out = copy.copy(self.data)

        for key, value in self.data.items():
            if key == 'num_nodes':
                out.num_nodes = node_length
            elif self.data.is_node_attr(key):
                cat_dim = self.data.__cat_dim__(key, value)
                out[key] = narrow(value, cat_dim, node_start, node_length)
            elif self.data.is_edge_attr(key):
                cat_dim = self.data.__cat_dim__(key, value)
                out[key] = narrow(value, cat_dim, edge_start, edge_length)
                if not self.keep_inter_cluster_edges:
                    out[key] = out[key][edge_mask]

        out.edge_index = torch.stack([row, col], dim=0)

        return out

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.num_parts})'


class ClusterLoader(torch.utils.data.DataLoader):
    r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm
    for Training Deep and Large Graph Convolutional Networks"
    <https://arxiv.org/abs/1905.07953>`_ paper which merges partioned subgraphs
    and their between-cluster links from a large-scale graph data object to
    form a mini-batch.

    .. note::

        Use :class:`~torch_geometric.loader.ClusterData` and
        :class:`~torch_geometric.loader.ClusterLoader` in conjunction to
        form mini-batches of clusters.
        For an example of using Cluster-GCN, see
        `examples/cluster_gcn_reddit.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/cluster_gcn_reddit.py>`_ or
        `examples/cluster_gcn_ppi.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/cluster_gcn_ppi.py>`_.

    Args:
        cluster_data (torch_geometric.loader.ClusterData): The already
            partioned data object.
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
    """
    def __init__(self, cluster_data, **kwargs):
        self.cluster_data = cluster_data
        iterator = range(len(cluster_data))
        super().__init__(iterator, collate_fn=self._collate, **kwargs)

    def _collate(self, batch: List[int]) -> Data:
        if not isinstance(batch, torch.Tensor):
            batch = torch.tensor(batch)

        global_indptr = self.cluster_data.partition.indptr
        global_index = self.cluster_data.partition.index

        # Get all node-level and edge-level start and end indices for the
        # current mini-batch:
        node_start = self.cluster_data.partition.partptr[batch]
        node_end = self.cluster_data.partition.partptr[batch + 1]
        edge_start = global_indptr[node_start]
        edge_end = global_indptr[node_end]

        # Iterate over each partition in the batch and calculate new edge
        # connectivity. This is done by slicing the corresponding source and
        # destination indices for each partition and adjusting their indices to
        # start from zero:
        rows, cols, nodes, cumsum = [], [], [], 0
        for i in range(batch.numel()):
            nodes.append(torch.arange(node_start[i], node_end[i]))
            indptr = global_indptr[node_start[i]:node_end[i] + 1]
            indptr = indptr - edge_start[i]
            if self.cluster_data.partition.sparse_format == 'csr':
                row = ptr2index(indptr) + cumsum
                col = global_index[edge_start[i]:edge_end[i]]

            else:
                col = ptr2index(indptr) + cumsum
                row = global_index[edge_start[i]:edge_end[i]]

            rows.append(row)
            cols.append(col)
            cumsum += indptr.numel() - 1

        node = torch.cat(nodes, dim=0)
        row = torch.cat(rows, dim=0)
        col = torch.cat(cols, dim=0)

        # Map `col` vector to valid entries and remove any entries that do not
        # connect two nodes within the same mini-batch:
        if self.cluster_data.partition.sparse_format == 'csr':
            col, edge_mask = map_index(col, node)
            row = row[edge_mask]
        else:
            row, edge_mask = map_index(row, node)
            col = col[edge_mask]
        out = copy.copy(self.cluster_data.data)

        # Slice node-level and edge-level attributes according to its offsets:
        for key, value in self.cluster_data.data.items():
            if key == 'num_nodes':
                out.num_nodes = cumsum
            elif self.cluster_data.data.is_node_attr(key):
                cat_dim = self.cluster_data.data.__cat_dim__(key, value)
                out[key] = torch.cat([
                    narrow(out[key], cat_dim, s, e - s)
                    for s, e in zip(node_start, node_end)
                ], dim=cat_dim)
            elif self.cluster_data.data.is_edge_attr(key):
                cat_dim = self.cluster_data.data.__cat_dim__(key, value)
                value = torch.cat([
                    narrow(out[key], cat_dim, s, e - s)
                    for s, e in zip(edge_start, edge_end)
                ], dim=cat_dim)
                out[key] = select(value, edge_mask, dim=cat_dim)

        out.edge_index = torch.stack([row, col], dim=0)

        return out

In [6]:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='PubMed',
                    transform=NormalizeFeatures())

print()
print(f'Dataset: {dataset}:')
print('==================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)

# from torch_geometric.loader import ClusterData, ClusterLoader

seed = 42
torch.manual_seed(seed)
cluster_data = ClusterData(data, num_parts=128)  # 1. Create subgraphs.
train_loader = ClusterLoader(cluster_data, batch_size=32,
                             shuffle=True)  # 2. Stochastic partioning scheme.


Dataset: PubMed():
Number of graphs: 1
Number of features: 500
Number of classes: 3

Data(x=[19717, 500], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717])


Computing METIS partitioning...
Done!


## Define train, test, and run functions
Here we define a simple `run` function for training the GNN model.

In [15]:
criterion = torch.nn.CrossEntropyLoss()

data = data.to(device)

def train(model, optimizer):
    model.train()
    # optimizer = torch.optim.Adam(model.parameters(), lr=0.01,
    #                              weight_decay=5e-4)
    for sub_data in train_loader:  # Iterate over each mini-batch.
        optimizer.zero_grad()  # Clear gradients.
        sub_data = sub_data.to(device)
        out = model(sub_data.x,
                    sub_data.edge_index)  # Perform a single forward pass.
        loss = criterion(
            out[sub_data.train_mask], sub_data.y[sub_data.train_mask]
        )  # Compute the loss solely based on the training nodes.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.


def test(model):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)  # Use the class with highest probability.

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = pred[mask] == data.y[
            mask]  # Check against ground-truth labels.
        accs.append(int(correct.sum()) /
                    int(mask.sum()))  # Derive ratio of correct predictions.
    return accs


def run(model, optimizer, epochs=5):
    for epoch in range(epochs):
        loss = train(model, optimizer)
        train_acc, val_acc, test_acc = test(model)
        print(
            f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}'
        )

## Define a GNN class and Import Aggregations
Now, let's define a GNN helper class and import all those new aggregation operators!




In [13]:
import copy

import torch.nn.functional as F
from torch_geometric.nn import (
    Aggregation,
    MaxAggregation,
    MeanAggregation,
    MultiAggregation,
    SAGEConv,
    SoftmaxAggregation,
    StdAggregation,
    SumAggregation,
    VarAggregation,
)


class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, aggr='mean', aggr_kwargs=None):
        super().__init__()
        self.conv1 = SAGEConv(
            dataset.num_node_features,
            hidden_channels,
            aggr=aggr,
            aggr_kwargs=aggr_kwargs,
        )
        self.conv2 = SAGEConv(
            hidden_channels,
            dataset.num_classes,
            aggr=copy.deepcopy(aggr),
            aggr_kwargs=aggr_kwargs,
        )

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

### Original interface with string type as the aggregation argument
Previously, PyG only supports customizing [MessagePassing](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py) with simple aggregations (e.g., `'mean'`, `'max'`, `'sum'`). Let's define a GNN with `mean` aggregation and run it for 5 epochs.

In [9]:
device

device(type='hpu')

In [21]:
torch.manual_seed(seed)
model = GNN(16, aggr='mean').to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
model = torch.compile(model,backend="hpu_backend")

run(model, optimizer)

GNN(
  (conv1): SAGEConv(500, 16, aggr=mean)
  (conv2): SAGEConv(16, 3, aggr=mean)
)




Epoch: 000, Train: 0.4167, Val Acc: 0.2220, Test Acc: 0.2140
Epoch: 001, Train: 0.6167, Val Acc: 0.4040, Test Acc: 0.3920
Epoch: 002, Train: 0.7833, Val Acc: 0.5220, Test Acc: 0.5080
Epoch: 003, Train: 0.8833, Val Acc: 0.6360, Test Acc: 0.6340
Epoch: 004, Train: 0.9000, Val Acc: 0.6820, Test Acc: 0.6730


### Use a single aggregation module as the aggregation argument
In the new interface, the [MessagePassing](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.MessagePassing) class can take an [Aggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.Aggregation) module as an argument. Here we can define the mean aggregation by `MeanAggregation`. We can see the model achieves the same performance as previously.

In [20]:
torch.manual_seed(seed)
model = GNN(16, aggr=MeanAggregation()).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
torch.compile(model,backend="hpu_backend")
run(model, optimizer)

GNN(
  (conv1): SAGEConv(500, 16, aggr=MeanAggregation())
  (conv2): SAGEConv(16, 3, aggr=MeanAggregation())
)
Epoch: 000, Train: 0.4667, Val Acc: 0.2240, Test Acc: 0.2130


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch: 001, Train: 0.5167, Val Acc: 0.2820, Test Acc: 0.2920
Epoch: 002, Train: 0.8333, Val Acc: 0.5420, Test Acc: 0.5390
Epoch: 003, Train: 0.9000, Val Acc: 0.6360, Test Acc: 0.6150
Epoch: 004, Train: 0.9000, Val Acc: 0.6720, Test Acc: 0.6470


### Use a list of aggregation strings as the aggregation argument

For defining multiple aggregations, we can use a list of strings as the input argument. The aggregations will be **resolved from pure strings** via a lookup table, following the design principles of the [class-resolver](https://github.com/cthoyt/class-resolver) library, e.g., by simply passing in `"mean"` to the [**MessagePassing**](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.MessagePassing) module. This will automatically resolve it to the MeanAggregation class. Let's see how a PNA-like GNN ([Corso et al. (2020)](https://proceedings.neurips.cc/paper/2020/file/99cad265a1768cc2dd013f0e740300ae-Paper.pdf)) works. It converges much faster!

In [23]:
torch.manual_seed(seed)
model = GNN(16, aggr=['mean', 'max', 'sum', 'std', 'var']).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
model = torch.compile(model,backend="hpu_backend")
run(model, optimizer)

GNN(
  (conv1): SAGEConv(500, 16, aggr=['mean', 'max', 'sum', 'std', 'var'])
  (conv2): SAGEConv(16, 3, aggr=['mean', 'max', 'sum', 'std', 'var'])
)




Epoch: 000, Train: 0.8167, Val Acc: 0.7040, Test Acc: 0.6900
Epoch: 001, Train: 0.8667, Val Acc: 0.6860, Test Acc: 0.6850
Epoch: 002, Train: 0.9333, Val Acc: 0.7240, Test Acc: 0.7170
Epoch: 003, Train: 0.9333, Val Acc: 0.7340, Test Acc: 0.7290
Epoch: 004, Train: 0.9667, Val Acc: 0.7560, Test Acc: 0.7360


### Use a list of aggregation modules as the aggregation argument
You can also use a list of [Aggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.Aggregation) modules to specify your convolutions.

In [24]:
torch.manual_seed(seed)
model = GNN(
    16, aggr=[
        MeanAggregation(),
        MaxAggregation(),
        SumAggregation(),
        StdAggregation(),
        VarAggregation(),
    ]).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
model = torch.compile(model,backend="hpu_backend")
run(model, optimizer)

GNN(
  (conv1): SAGEConv(500, 16, aggr=['MeanAggregation()', 'MaxAggregation()', 'SumAggregation()', 'StdAggregation()', 'VarAggregation()'])
  (conv2): SAGEConv(16, 3, aggr=['MeanAggregation()', 'MaxAggregation()', 'SumAggregation()', 'StdAggregation()', 'VarAggregation()'])
)




Epoch: 000, Train: 0.8167, Val Acc: 0.7040, Test Acc: 0.6900
Epoch: 001, Train: 0.8667, Val Acc: 0.6860, Test Acc: 0.6840
Epoch: 002, Train: 0.9333, Val Acc: 0.7260, Test Acc: 0.7180
Epoch: 003, Train: 0.9167, Val Acc: 0.7340, Test Acc: 0.7300
Epoch: 004, Train: 0.9667, Val Acc: 0.7560, Test Acc: 0.7380


### Use a list of mixed modules and strings as the aggregation argument
And the mix of them is supported as well for your convenience.

In [25]:
torch.manual_seed(seed)
model = GNN(16, aggr=[
    'mean',
    MaxAggregation(),
    'sum',
    StdAggregation(),
    'var',
]).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
model = torch.compile(model,backend="hpu_backend")
run(model, optimizer)

GNN(
  (conv1): SAGEConv(500, 16, aggr=['mean', 'MaxAggregation()', 'sum', 'StdAggregation()', 'var'])
  (conv2): SAGEConv(16, 3, aggr=['mean', 'MaxAggregation()', 'sum', 'StdAggregation()', 'var'])
)




Epoch: 000, Train: 0.8167, Val Acc: 0.7040, Test Acc: 0.6900
Epoch: 001, Train: 0.8667, Val Acc: 0.6860, Test Acc: 0.6840
Epoch: 002, Train: 0.9333, Val Acc: 0.7240, Test Acc: 0.7170
Epoch: 003, Train: 0.9333, Val Acc: 0.7340, Test Acc: 0.7280
Epoch: 004, Train: 0.9667, Val Acc: 0.7560, Test Acc: 0.7370


### Define multiple aggregations with `MultiAggregation` module

When a list is taken, [MessagePassing](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.MessagePassing) would stack these aggregators in via the [MultiAggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.MultiAggregation) module automatically. But you can also directly pass a [MultiAggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.MultiAggregation) instead of a list. Now let's see how can we define multiple aggregations with [MultiAggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.MultiAggregation). Here we use different initial temperatures for [SoftmaxAggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.SoftmaxAggregation) ([Li et al. (2020)](https://arxiv.org/abs/2006.07739)). Every different temperature will result in aggregation with different softness.

In [27]:
torch.manual_seed(seed)
aggr = MultiAggregation([
    SoftmaxAggregation(t=0.01, learn=True),
    SoftmaxAggregation(t=1, learn=True),
    SoftmaxAggregation(t=100, learn=True),
])
model = GNN(16, aggr=aggr).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
model = torch.compile(model,backend="hpu_backend")
run(model, optimizer)

GNN(
  (conv1): SAGEConv(500, 16, aggr=MultiAggregation([
    SoftmaxAggregation(learn=True),
    SoftmaxAggregation(learn=True),
    SoftmaxAggregation(learn=True),
  ], mode=cat))
  (conv2): SAGEConv(16, 3, aggr=MultiAggregation([
    SoftmaxAggregation(learn=True),
    SoftmaxAggregation(learn=True),
    SoftmaxAggregation(learn=True),
  ], mode=cat))
)




Epoch: 000, Train: 0.7000, Val Acc: 0.6300, Test Acc: 0.6330
Epoch: 001, Train: 0.9333, Val Acc: 0.7740, Test Acc: 0.7090
Epoch: 002, Train: 0.9500, Val Acc: 0.7720, Test Acc: 0.7210
Epoch: 003, Train: 0.9500, Val Acc: 0.7660, Test Acc: 0.7190
Epoch: 004, Train: 0.9667, Val Acc: 0.7540, Test Acc: 0.7210


What is more?
There are many other aggregation operators supported for you to "lego" your GNNs. [PowerMeanAggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.PowerMeanAggregation) allows you to define and potentially learn generalized means beyond simple  arithmetic mean such as harmonic mean and geometric mean. [LSTMAggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.LSTMAggregation) can perform permutation-variant aggregation. More other interesting aggregation operators such as [Set2Set](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.Set2Set), [DegreeScalerAggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.DegreeScalerAggregation), [SortAggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.SortAggregation), [GraphMultisetTransformer](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.GraphMultisetTransformer), [AttentionalAggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.AttentionalAggregation) and [EquilibriumAggregation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.EquilibriumAggregation) are ready for you to explore.

## Conclusion

In this tutorial, you have been presented with the `torch_geometric.nn.aggr` package which provides a flexible interface to experiment with different aggregation functions with your message passing convolutions and unifies aggregation within GNNs across [`MessagePassing`](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py) and [global readouts](https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric/nn/glob). This new abstraction also makes designing new types of aggregation functions easier. Now, you can create your own aggregation function with the base `Aggregation` class. Please refer to the [docs](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/aggr/base.html#Aggregation) for more details.

```python
class MyAggregation(Aggregation):
    def __init__(self, ...):
      ...

    def forward(self, x: Tensor, index: Optional[Tensor] = None,
                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
                dim: int = -2) -> Tensor:
      ...
```

*Have fun!*