As you have seen in the [node classification tutorial](H1_node_classification.ipynb), you can assign one convolution module for each edge type.  Here, you can also write your own message passing module for each edge type.

In [None]:
import dgl.function as fn

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.
    
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h):
        """Forward computation
        
        Parameters
        ----------
        g : Graph
            The input bipartite graph.
        h : tuple[Tensor, Tensor]
            The features of input nodes and output nodes.
        """
        with g.local_scope():
            h_src, h_dst = h                              # <--- Note the difference here
            g.srcdata['h'] = h_src                        # <---
            # update_all is a message passing API.
            g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
            h_neigh = g.dstdata['h_neigh']
            h_total = torch.cat([h_dst, h_neigh], dim=1)  # <---
            return self.linear(h_total)

Then you can combine them with `dgl.nn.HeteroGraphConv` module.

In [None]:
class HeteroSAGEConv(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(HeteroSAGEConv, self).__init__()
        
        self.conv1 = dglnn.HeteroGraphConv({
            'follows': SAGEConv(10, 20),
            'plays': SAGEConv(10, 20),
            'sells': SAGEConv(10, 20)
        }, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            'follows': SAGEConv(20, 20),
            'plays': SAGEConv(20, 20),
            'sells': SAGEConv(20, 20)
        }, aggregate='sum')
        
    def forward(self, g, x):
        x = self.conv1(g, x)
        x = {k: F.relu(v) for k, v in x.items()}
        x = self.conv2(g, x)
        return x