# Neighbor Sampling and Customization

This tutorial will teach you the details of neighbor sampling and how to specify your own neighbor sampling algorithm in stochastic GNN training.  It assumes that you have read the basics of stochastic GNN training in the [node classification tutorial](L1_large_node_classification.ipynb).

In [None]:
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset('ogbn-products')

graph, node_labels = dataset[0]
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']

## Review of Stochastic Training Pipeline

Training a GNN stochastically involves the following steps:

1. Iterate over the nodes or edges in minibatches.
2. For each minibatch, generate a list of bipartite graphs, one bipartite graph for each GNN layer.
3. Perform message passing on the list of bipartite graphs.
4. Compute gradient and optimize.

This tutorial will dive into the details of Step 2.  It first introduces what is neighbor sampling, and why neighbor sampling will yield bipartite graphs.  Then it shows how to plug your own neighbor sampling algorithm in DGL.

## Multi-layer Message Passing in Detail

The message passing formulation defined in [Gilmer et al.](https://arxiv.org/abs/1704.01212) goes as below:

$$
m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right) \\
m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)} \\
h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)
$$

Essentially, the $l$-th layer representation of a node depends on the $(l-1)$-th layer representation of the same node, as well as the $(l-1)$-th layer representation of the neighboring nodes.

<div class="alert alert-info">
    <b>Note: </b>See <a href=3_message_passing.ipynb>this tutorial</a> for more details in message passing in DGL.
</div>

For instance, if you would like to compute the representation of the red node by a 2-layer GNN:

![Imgur](assets/seed.png)

The formulation shows that to compute the red node's second GNN layer output $\boldsymbol{h}_8^{(2)}$ you will need the first GNN layer output of the same node $\boldsymbol{h}_8^{(1)}$, as well as the first GNN layer output of its neighboring nodes $\boldsymbol{h}_4^{(1)}$, $\boldsymbol{h}_5^{(1)}$, $\boldsymbol{h}_7^{(1)}$, and $\boldsymbol{h}_{11}^{(1)}$ (colored green).  The message passing will happen on the green dashed edges visualized below.

![Imgur](assets/3.png)

To compute the first-layer representation of the red and green nodes, you further need to perform message passing on the yellow edges visualized below.  Therefore, other than the red and green nodes, the yellow nodes' input features are also necessary to compute the red nodes' second GNN layer output.

<div class="alert alert-info">
    <b>Note</b>: The edges that showed up in the second layer (i.e. as the green dashed arrows) appears again in the first layer (i.e. as the yellow dashed arrows), but they represent different message passing computations.  The second layer computes from $\boldsymbol{h}_\cdot^{(1)}$ to $\boldsymbol{h}_\cdot^{(2)}$ while the first layer computes from $\boldsymbol{h}_\cdot^{(0)}$ (i.e. the input) to $\boldsymbol{h}_\cdot^{(1)}$.
</div>

![Imgur](assets/4.png)

You may notice that to figure out which nodes' input features are necessary, you are going in the opposite direction of message aggregation: you start from the layer closest to the output and work backward to the input.  Message passing, in contrast, goes from the layer closest to the input towards the output.

## Neighbor Sampling Overview

You can also see from the previous example that computing representation for a small number of nodes often requires input features of a significantly larger number of nodes.  Taking all neighbors for message aggregation is often too costly since the nodes needed for input features would easily cover a large portion of the graph, especially for real-world graphs which are often [scale-free](https://en.wikipedia.org/wiki/Scale-free_network).

Neighbor sampling addresses this issue by selecting a subset of the neighbors to perform aggregation.  For instance, to compute $\boldsymbol{h}_8^{(1)}$, you can choose two of the neighbors instead of all of them to aggregate, so you need the first layer representation of the red node and only two green nodes.

![Imgur](assets/5.png)

Similarly, to compute the red and green nodes' first layer representation, you can also do neighbor sampling that takes two of each node's neighbors.

<div class="alert alert-info">
    <b>Note</b>: Since you need the first layer representation of the red node, you will need to sample the neighbors of the red node again.
</div>

![Imgur](assets/6.png)

You can see that this method could give us fewer nodes needed for input features.

## How Neighbor Sampling Works in DGL

DGL implements uniform neighbor sampling as mentioned above (and also in the [node classification tutorial](L1_large_node_classification.ipynb)) in `dgl.dataloading.MultiLayerNeighborSampler`, whose bulk is as follows:

In [2]:
class MultiLayerNeighborSampler(dgl.dataloading.BlockSampler):
    def __init__(self, fanouts):
        super().__init__(len(fanouts), return_eids=False)
        self.fanouts = fanouts
        
    def sample_frontier(self, layer_id, g, seed_nodes):
        fanout = self.fanouts[layer_id]
        return dgl.sampling.sample_neighbors(g, seed_nodes, fanout)

The `NeighborSampler` inherits the `dgl.dataloading.BlockSampler` class.  The core method is `sample_frontier`, where it receives three arguments:

* `layer_id`: Indicating which GNN layer the neighbor sampler is processing.  0 indicates the GNN layer closest to the input.  DGL iterates from the last GNN layer to the first GNN layer as introduced in the [node classification tutorial](L1_large_node_classification.ipynb).
* `g`: The entire graph.  It can contain node and edge features so that the neighbor sampler can utilize them.
* `seed_nodes`: The array of node IDs whose output representations the given GNN layer should compute.

It returns one single object, which is a graph containing all the nodes in the original graph, as well as the edges for the GNN layer to perform message passing to compute the outputs of the `seed_nodes`.

`MultiLayerNeighborSampler` does this via `dgl.sampling.sample_neighbors` that returns a subgraph of the original graph `g`.  It includes all the nodes in `g`, as well as the randomly sampled, fixed number of incoming edges for each node in the ID array `seed_nodes`.

The inheritance of `dgl.dataloading.BlockSampler` would transform the graph returned by `sample_frontier` into a *bipartite* graph:

![](assets/bipartite.png)

So why does DGL return a list of *bipartite* graphs for training a *homogeneous* graph?  The reason is that the number of nodes for input and that for output of a given GNN layer is different.  Take the example above:

![Imgur](assets/6.png)

That GNN layer will output the representation of three nodes (two green nodes and one red node), but it will require input from 7 nodes (the green nodes and red node, plus 4 yellow nodes).  Only a bipartite graph can describe such computation.

Moreover, the transformation in `dgl.dataloading.BlockSampler` ensures that the first few input nodes are always the same as the output nodes, even if some of the output nodes did not appear as a neighbor of any output node:

![](assets/bipartite2.png)

<div class="alert alert-info">

**Note**: this tutorial will only deal with how to generate those bipartite graphs with your own algorithm.  For more details on how message passing works with those bipartite graphs, please see the [message passing tutorial](L4_message_passing.ipynb).
    
</div>

## Customizing Neighbor Sampler

This tutorial gives two examples of customized neighbor sampling:

### Message dropout

One example is to randomly drop the neighboring edges with a fixed probability.  You can do that via returning a subgraph:

In [26]:
class MessageDropoutNeighborSampler(dgl.dataloading.BlockSampler):
    def __init__(self, p, num_layers):
        super().__init__(num_layers, return_eids=False)
        self.p = p
        
    def sample_frontier(self, layer_id, g, seed_nodes):
        sg = dgl.in_subgraph(g, seed_nodes)
        eids = sg.edata[dgl.EID]
        mask = torch.zeros(len(eids), dtype=torch.bool).bernoulli_(self.p)
        sampled_eids = eids[mask]
        return dgl.edge_subgraph(g, sampled_eids, preserve_nodes=True)

In [41]:
sampler = MessageDropoutNeighborSampler(0.5, 2)
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

input_nodes, output_nodes, bipartites = next(iter(train_dataloader))

### Sampling k-hop neighbors

Another commonly seen scenario is to directly aggregate 2-hop or even 3-hop neighbors.  Since usually computing all 2-hop and 3-hop neighbors (i.e. computing the square or the cube of the adjacency matrix) is impractical, sampling neighbors by random walk is a way to approximate.

Note that the returned graph will not be a subgraph of the original graph in this case.

* The nodes of the returned graph will be the same as the original one, but the edges of the returned graph will instead indicate message passing directions for k-hop neighbors.
* You will need to copy the node features from the original graph to the sampled graph yourself.

In [36]:
class KHopNeighborSampler(dgl.dataloading.BlockSampler):
    def __init__(self, fanouts, k):
        super().__init__(len(fanouts), return_eids=False)
        self.k = k
        self.fanouts = fanouts
        
    def sample_frontier(self, layer_id, g, seed_nodes):
        #### ATTENTION
        # Currently the seed nodes can be either a dict or a tensor.
        if isinstance(seed_nodes, dict):
            seed_nodes = next(iter(seed_nodes.values()))
        fanout = self.fanouts[layer_id]
        # Generate the number of random walk traces equal to the fanout.
        seed_nodes = seed_nodes.repeat_interleave(fanout)
        nodes, _ = dgl.sampling.random_walk(g, seed_nodes, length=self.k)
        neighbor_nodes = nodes[:, self.k]
        # When a random walk cannot continue because of lacking of successors (e.g. a
        # node is isolate, with no edges going out), dgl.sampling.random_walk will
        # pad the trace with -1.  Since OGB Products have isolated nodes, we should
        # look for the -1 entries and remove them.
        mask = (neighbor_nodes != -1)
        neighbor_nodes = neighbor_nodes[mask]
        seed_nodes = seed_nodes[mask]
        # Construct a new graph with only the edges involved in message passing.
        sg = dgl.graph((neighbor_nodes, seed_nodes), num_nodes=g.num_nodes())
        # Copy the node data from the original graph.
        # The edges in the returned graph may not exist in the original graph anyway, so we do not have to
        # copy the edge data.
        sg.ndata.update(g.ndata)
        return sg

In [37]:
sampler = KHopNeighborSampler([4, 4], 2)
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

input_nodes, output_nodes, bipartites = next(iter(train_dataloader))