# Writing GNN Modules for Large Graph Training

This tutorial assumes that
1. You know [how to write GNN modules for full graph training](3_message_passing.ipynb).
2. You know [how large graph training works](L1_large_node_classification.ipynb).

In [None]:
import dgl.function as fn

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.
    
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h):
        """Forward computation
        
        Parameters
        ----------
        g : Graph
            The input bipartite graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
            h_src = h                                     # <---
            h_dst = h[:g.number_of_dst_nodes()]           # <---
            g.srcdata['h'] = h_src                        # <---
            # update_all is a message passing API.
            g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
            h_neigh = g.dstdata['h_neigh']
            h_total = torch.cat([h_dst, h_neigh], dim=1)  # <---
            return self.linear(h_total)
        
class Net(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Net, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)
    
    def forward(self, bipartites, in_feat):
        h = self.conv1(bipartites[0], in_feat)
        h = F.relu(h)
        h = self.conv2(bipartites[1], h)
        return h
    
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

with tqdm.tqdm(train_dataloader) as tq:
    for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
        bipartites = [b.to(torch.device('cuda')) for b in bipartites]
        inputs = node_features[input_nodes].cuda()
        labels = node_labels[output_nodes].cuda()
        predictions = model(bipartites, inputs)

[WIP]

Several things to notice in the tutorial:

* `g` is now a bipartite graph with the input nodes on one side and the output nodes on the other side.  This explains the difference of `h_src` and `h_dst`.
* The output nodes always appear the first in the input nodes; hence the ability to use `h[:g.number_of_dst_nodes()]`.  Such property is guaranteed by the use of `dgl.to_block()` and the neighbor samplers subclassing `BlockSampler` (e.g. the builtin sampler).