# Chapter 2: Message Passing

## Message Passing Paradigm

Let $x_v \in \mathbb{R}^{d_1}$ be the feature node for v, and $w_e \in \mathbb{R}^{d_2}$ be the feature edge (u,v). The **message passing paradigm** defines the following node-wise and edge-wise computation at step t+1:

$$Edge-wise:\ m_e^{(t+1)} = \phi (x_v^{(t)}, x_u^{(t)}, w_e^{(t)}), (u, v, e) \in \mathcal{E}$$

$$Node-wise: x_v^{(t+1)} = \Psi (x_v^{(t)}, \rho (\{ m_e^{(t+1)}: (u,v,e) \in \mathcal{E} \}))$$

Where $\phi$ is a **message passing function** defined on each edge to generate a message by combining the edge feature with the features of its incident nodes; $\psi$ is an **update function** defined on each node to update the node feature by aggregating its incoming messages using the **reduce function $\rho$**

## 2.1 Built-in Functions and Message Passing APIs

In DGL the message function takes a single argument `edges`, with is an `EdgeBatch` instance. During message passing, DGL generates it internally to represent a batch of edges. The internal representation has three members `src`, `dst`, and `data` to access features of source nodes, destination nodes, and edges, respectively.

The **reduce function** takes a single argument `nodes`, which is a `NodeBatch` instance. During message passing, DGL generates the NodeBatch to represent a batch of nodes. It has member `mailbox` to access the messages recieved for the nodes in the batch. Some of the most common reduce operations are: `sum`, `max`, `min`.

The **update function** takes a single argument `nodes` as described above. This function operates on the aggregation result from `reduce function`, typically combining it with a node's original feature at the last step and saving the result as a node feature. 

DGL has implemented commonly used message passing functions and reduce functions as built-in within the the namespace `dgl.function`. In general, DGL suggest using built-in functions when possible since they are optimized and handle dimension broadcasting. 

Built-in message passing functions can be unary or binary. DGL supports `copy for unary. For binary functions, DGL supports `add`, `sub`, `mul`, `div`, `dot`. The naming convention for message built-in functions is that `u` represents `src` nodes, `v` represents `dst` nodes, and `e` represents `edges`. 

The parameters for built-in functions are strings indicating the input and output field names for the corresponding nodes and edges. The list of supported built-in functions can be found in DGL Build-in Function documentation here: https://docs.dgl.ai/en/1.0.x/api/python/dgl.function.html#api-built-in. For example, to add the `hu` feature from `src` nodes and `hv` feature from the `dst` nodes then save the result on the edge at the `he` field, we can use the builtin as follows:
`dgl.function.u_add_v('hu', 'hv', 'he')`. This is equivalent to the Message UDF: 

```
def message_func(edges):
    return {'he': edges.src['hu']+edges.dst['hv']}
```



Built-in reduce functions support the operations `sum`, `max`, `min`, and `mean`. Reduce functions usually have two parameters, one for field name in `mailbox`, one for field name in node features, both are strings. For example, `dgl.function.sum('m', 'h')` is equivalent to the Reduce UDF that sums up the message `m`:

```
import torch
def reduce_func(nodes): 
    return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
```

It is possible to invoke edge-wise computation by `apply_edge()` without invoking message passing. `apply_edges()` takes a message function for parameter and by default updates the features of all edges. For example:

```
import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
```


For message passing, `update_all()` is a high-level API that merges message generation, message aggregation and node update function in a single call.  

The parameters for `update_all()` are a message function, a reduce function and update function. One can call update function outside of `update_all()` and not specify it in invoking `update_all()`. DGL recommends this approach since the update function can usually be written as pure tensor operations to make code concise. For example: 

```
def update_all_example(graph):
    # store the result in graph.ndata['ft']
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                    fn.sum('m', 'ft'))
    # Call update function outside of update_all
    final_ft = graph.ndata['ft']*2
    return final_ft
```

This call will generate the messages `m` by multiplying `src` node features `ft` and edge features `a`, sum up he messages `m` to update node features `ft`, and finally multiply `ft` by 2 to get the result `final_ft`. After the call DGL will clean up the intermediate messages `m`. The mathematical formula for the above is: 

$$final\_ft_i = 2 * \sum_{j \in \mathcal{N(i)}}(ft_j * a_{ij})$$

## 2.2. Writing Efficient Message Passing Code

Since DGL's message passing functions are optimized for construction and computing speed it's common practise to write custom message passing functions as a combination of `update_all()` calls with builtin functions as parameters.

Since in graphs the number of edges is much larger than the number of nodes for some graphs, avoiding unnecessary memory copy from nodes to edges is beneficial. 

There are some cases such as `GATConv` where it is necessary to save messages on the edges, we need to call `apply_edges()` with builtin functions. In the cse that we do need edge features it's recommended by DGL to keep the dimension of edge features as low as possible. 

Example, Achieving dimension reduction by splitting operations on edges to nodes. The approach is to concatenate the `src` feature and `dst` feature, then apply a linear layer. In effect $W x (u||v)$. The `src` and `dst` feature dimension is high, while the linear layer output dimension is low. A straight forward implementation would be like:

```
import torch 
import torch.nn as nn 

linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2)))

def concat_message_function(edges):
    return {'cat_feat': torch.cat([edges.src['feat'], edges.dst['feat']], dim=1)}

g.apply_edges(concat_message_function)
g.edata['out'] =  g.edata['cat_feat'] @ linear
```

The above implementation splits the linear operation in two, one applies on `src` feature, the other applies on `dst` feature. it then adds the output of the linear operations on the edges at the final stage, in effect performing $W_l x u + W_r x v$. This is because $W x (u||v) = W_l x u +W_r x v$, where $W_l$ and $W_r$ are the left and the right half of the matrix $W$, respectively. 

```
linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
out_src = g.ndata['feat'] @ linear_src
out_dst = g.ndata['feat'] @ linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))
```

The above implementation is more efficient than the one preceeding it even though they are mathematically equivalent as this implementation does not need to save feat_src and feat_dst on edges, which is not memory efficient. 

## 2.3 Apply Message Passing On Part of The Graph

If we want to only update part of the nodes in a graph the best way to do so is by creating a subgraph by providing the IDs for the nodes to include in the update, then call `update_all()` on the subgraph. Example:

```
nid=[0,2,3,6,7,9]
sg=g.subgraph(nid)
sg.update_all(message_func, reduce_func, apply_node_func)
```

## 2.4 Message Passing on Heterogeneous Graph

The message passing on heterographs can be split into two parts:
1. message computatioin and aggregation for each relation r. 
2. Reduction that merges the aggregation results from all relations for each node type.

DGL's interface to call message passing on heterographs is `multi_update_all()`. `multi_update_all()` takes a dictionary containing the parameters for `update_all()` within each relation as the key, and a string representing the cross type reducer. The reduce can be one of `sum`, `min`, `max`, `mean`, `stack`. 

Example: 

```
import dgl.function as fn

for c_etype in G.canonical_etypes:
    srctype, etype, dsttype = c_etype
    Wh = self.weight[etype](feat_dict[srctype])
    # save  it in a graph for message passing
    G.nodes[srctype].data['Wh_%s' % etype] = Wh
    # Specify per-relation message passing functions: (message_func, reduce_func)
    # Note that the results are saved to the same destination feature 'h', which hints the type wise reducer for aggregation
    funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# Trigger message passing of multiple types
G.multi_update_all(funcs, 'sum')
# Return the updated node feature dictionary
return {ntype: G.nodes[ntype].data['h'] for ntype in G.ntypes}

```