The message passing module is identical to what you have seen in the [full graph message passing tutorial](H3_message_passing.ipynb).

In [None]:
import dgl.function as fn

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.
    
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h):
        """Forward computation
        
        Parameters
        ----------
        g : Graph
            The input bipartite graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
            h_src, h_dst = h                              # <---
            g.srcdata['h'] = h_src                        # <---
            # update_all is a message passing API.
            g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
            h_neigh = g.dstdata['h_neigh']
            h_total = torch.cat([h_dst, h_neigh], dim=1)  # <---
            return self.linear(h_total)

My problem: currently we recommend users to write the following for message passing with minibatch training:

In [None]:
class HeteroSAGEConv(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(HeteroSAGEConv, self).__init__()
        
        self.conv1 = dglnn.HeteroGraphConv({
            'follows': SAGEConv(10, 20),
            'plays': SAGEConv(10, 20),
            'sells': SAGEConv(10, 20)
        }, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            'follows': SAGEConv(20, 20),
            'plays': SAGEConv(20, 20),
            'sells': SAGEConv(20, 20)
        }, aggregate='sum')
        
    def forward(self, gs, x):
        x = self.conv1(gs[0], x)
        x = {k: F.relu(v) for k, v in x.items()}
        x = self.conv2(gs[1], x)
        return x

In [None]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids_dict, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

with tqdm.tqdm(train_dataloader) as tq:
    for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
        bipartites = [b.to(torch.device('cuda')) for b in bipartites]
        inputs = {k: node_features[k][v].cuda() for k, v in input_nodes.items()}    # <---
        labels = {v: node_labels[k][v].cuda() for k, v in output_nodes.items()}     # <---
        predictions = model(bipartites, inputs)

The above code will work only if the output nodes of each node type always appear the first in the input nodes of the same node type for each bipartite graph.  The reason is that recall we have the following statement in [large (homogeneous) graph message passing](L4_message_passing.ipynb):

```python
h_dst = h[:g.number_of_dst_nodes()]           # <---
```

`dgl.nn.HeteroGraphConv` already does that for you for each node type if it can know that the bipartite graph comes from neighbor sampling (i.e. `dgl.to_block`).

```python
    def forward(self, g, inputs):
        if isinstance(inputs, tuple) or g.is_block:
            if isinstance(inputs, tuple):
                src_inputs, dst_inputs = inputs
            else:
                src_inputs = inputs
                dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
```

However, we intend to remove the concept of blocks, so `HeteroGraphConv` will never know if it needs to perform the slicing or not.