# Customizing Neighborhood Sampling

This tutorial will teach you how to specify your own neighbor sampling algorithm in large scale graph training.  It assumes that you have read the basics of large graph training in the [node classification tutorial](L1_large_node_classification.ipynb).

In [None]:
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset('ogbn-products')

graph, node_labels = dataset[0]
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']

## Review of large graph training pipeline

Training a GNN on a single large graph involves the following steps:

1. Iterate over the nodes or edges in minibatches.
2. For each minibatch, generate a list of bipartite graphs, one bipartite graph for each GNN layer.
3. Perform message passing on the list of bipartite graphs.
4. Compute gradient and optimize.

Neighborhood sampling customization takes place in Step 2.

## How neighbor sampling works in DGL

Uniform neighbor sampling in DGL as mentioned in the [node classification tutorial](L1_large_node_classification.ipynb) is simply implemented as follows:

In [2]:
class MultiLayerNeighborSampler(dgl.dataloading.BlockSampler):
    def __init__(self, fanouts):
        super().__init__(len(fanouts), return_eids=False)
        self.fanouts = fanouts
        
    def sample_frontier(self, layer_id, g, seed_nodes):
        fanout = self.fanouts[layer_id]
        return dgl.sampling.sample_neighbors(g, seed_nodes, fanout)

The `NeighborSampler` inherits the `dgl.dataloading.BlockSampler` class.  The core method is `sample_frontier`, where it receives three arguments:

* `layer_id`: Indicating which GNN layer the neighbor sampler is processing.  0 indicates the GNN layer closest to the input.  DGL iterates from the last GNN layer to the first GNN layer as introduced in the [node classification tutorial](L1_large_node_classification.ipynb).
* `g`: The entire graph.  It can contain node and edge features so that the neighbor sampler can utilize them.
* `seed_nodes`: The array of node IDs whose output representations the given GNN layer should compute.

It returns one single object, which is a graph containing all the nodes in the original graph, as well as the edges for the GNN layer to perform message passing to compute the outputs of the `seed_nodes`.

`MultiLayerNeighborSampler` does this via `dgl.sampling.sample_neighbors` that returns a subgraph of the original graph `g`.  It includes all the nodes in `g`, as well as the randomly sampled, fixed number of incoming edges for each node in the ID array `seed_nodes`.

The inheritance of `dgl.dataloading.BlockSampler` would transform the graph returned by `sample_frontier` into a bipartite graph:

![](assets/bipartite.png)

The transformation in `dgl.dataloading.BlockSampler` ensures that the first few input nodes are always the same as the output nodes, even if some of the output nodes did not appear as a neighbor of any output node:

![](assets/bipartite2.png)

For more details, please see the [message passing tutorial](L4_message_passing.ipynb).

## Customizing Neighbor Sampler

This tutorial gives two examples of customized neighbor sampling:

### Message dropout

One example is to randomly drop the neighboring edges with a fixed probability.  You can do that via returning a subgraph:

In [26]:
class MessageDropoutNeighborSampler(dgl.dataloading.BlockSampler):
    def __init__(self, p, num_layers):
        super().__init__(num_layers, return_eids=False)
        self.p = p
        
    def sample_frontier(self, layer_id, g, seed_nodes):
        sg = dgl.in_subgraph(g, seed_nodes)
        eids = sg.edata[dgl.EID]
        mask = torch.zeros(len(eids), dtype=torch.bool).bernoulli_(self.p)
        sampled_eids = eids[mask]
        return dgl.edge_subgraph(g, sampled_eids, preserve_nodes=True)

In [41]:
sampler = MessageDropoutNeighborSampler(0.5, 2)
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

input_nodes, output_nodes, bipartites = next(iter(train_dataloader))

### Sampling k-hop neighbors

Another commonly seen scenario is to directly aggregate 2-hop or even 3-hop neighbors.  Since usually computing all 2-hop and 3-hop neighbors (i.e. computing the square or the cube of the adjacency matrix) is impractical, sampling neighbors by random walk is a way to approximate.

Note that the returned graph will not be a subgraph of the original graph in this case.

* The nodes of the returned graph will be the same as the original one, but the edges of the returned graph will instead indicate message passing directions for k-hop neighbors.
* You will need to copy the node features from the original graph to the sampled graph yourself.

In [36]:
class KHopNeighborSampler(dgl.dataloading.BlockSampler):
    def __init__(self, fanouts, k):
        super().__init__(len(fanouts), return_eids=False)
        self.k = k
        self.fanouts = fanouts
        
    def sample_frontier(self, layer_id, g, seed_nodes):
        #### ATTENTION
        # Currently the seed nodes can be either a dict or a tensor.
        if isinstance(seed_nodes, dict):
            seed_nodes = next(iter(seed_nodes.values()))
        fanout = self.fanouts[layer_id]
        # Generate the number of random walk traces equal to the fanout.
        seed_nodes = seed_nodes.repeat_interleave(fanout)
        nodes, _ = dgl.sampling.random_walk(g, seed_nodes, length=self.k)
        neighbor_nodes = nodes[:, self.k]
        # When a random walk cannot continue because of lacking of successors (e.g. a
        # node is isolate, with no edges going out), dgl.sampling.random_walk will
        # pad the trace with -1.  Since OGB Products have isolated nodes, we should
        # look for the -1 entries and remove them.
        mask = (neighbor_nodes != -1)
        neighbor_nodes = neighbor_nodes[mask]
        seed_nodes = seed_nodes[mask]
        # Construct a new graph with only the edges involved in message passing.
        sg = dgl.graph((neighbor_nodes, seed_nodes), num_nodes=g.num_nodes())
        # Copy the node data from the original graph.
        # The edges in the returned graph may not exist in the original graph anyway, so we do not have to
        # copy the edge data.
        sg.ndata.update(g.ndata)
        return sg

In [37]:
sampler = KHopNeighborSampler([4, 4], 2)
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

input_nodes, output_nodes, bipartites = next(iter(train_dataloader))