**WEIGHTED CONVOLUTION ON DYNAMIC GRAPHS**

In [2]:
import torch
import torch.nn as nn

We are going to construct a convolutions on dynamic graphs. 
Input for this module is a sequence of dynamic graphs $\mathbb{G}_i = \{\mathcal{G}_i^1,...\mathcal{G}_i^T\}$, where graph $\mathcal{G}_i^t \in \mathbb{G}_i$ has a sequene of elements represented as $\{e_{i,j}^t \in \mathbb{R}^F, \forall v_{i,j} \in \mathcal{V}_i\}$. (F is the dimention of element representation equal to `in_features`).

For each graph $\mathcal{G}_i$ the output of this modelue is a new sequence representation, which we will denote as  $\{c_{i,j}^t \in \mathbb{R}^{F'}, \forall v_{i,j} \in \mathcal{V}_i\}$. (F' is the new dimension equal to `out_features`).

To reduce the parameter scale and also make our method flexible to deal with sequences with variable lengths, a parameter sharing strategy is adopted. The weighted convolutions are implemented by propagating information of elements in each dynamic graphs as follows. For graph $\mathcal{G}_i$
$$c_{i,j}^{t,l+1} = \sigma\left( b^l + \sum_{k \in N_{i,j}^t \cup \{j\}}   A_i^t[j,k] \cdot \left( W^t c_{i,k}^{t,l} \right) \right),$$ where $A_i^t[j,k]$ represents the item in j-th row and k-th column of matrix $A_i^t$, which is the edge weight of $v_{i,j}$ and $v_{i,k}$ in graph $\mathcal{G}_i^t$.

We are going to override the `nn.Module` for constructing our convolutional layer.

**DGLGraph library**

We are going to use `DGLGraph` library and here we represent some basics about this library that we are going to use in our code.

DGL represents a directed graph as a `DGLGraph` object. You can construct a graph by specifying the number of nodes in the graph as well as the list of source and destination nodes. Nodes in the graph have consecutive IDs starting from 0. For example `g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6)` creates a star graph `g` with center node 0.

You can assign and retrieve node and edge features via `ndata` and `edata` interface.

`DGLGraph.update_all(message_func, reduce_func, apply_node_func=None, etype=None)`
Send messages along all the edges of the specified type and update all the nodes of the corresponding destination type.

`message_func` (dgl.function.BuiltinFunction or callable) – The message function to generate messages along the edges. It must be either a DGL Built-in Function or a User-defined Functions.

`reduce_func` (dgl.function.BuiltinFunction or callable) – The reduce function to aggregate the messages. It must be either a DGL Built-in Function or a User-defined Functions.

DGL’s convention is to use u, v and e to represent source nodes, destination nodes, and edges, respectively.

`u_mul_e` is a `message_func` that tells DGL to multiply source node features with edge features.
`sum` is a `reduce_func`.

In [3]:
class weighted_graph_conv(nn.Module):
    """
        Apply graph convolution over an input signal.
    """
    def __init__(self, in_features: int, out_features: int):
        super(weighted_graph_conv, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features, bias=True)

    def forward(self, graph, node_features, edge_weights):
        r"""Compute weighted graph convolution.
        -----
        Input:
        graph : DGLGraph, batched graph.
        node_features : torch.Tensor, input features for nodes (n_1+n_2+..., in_features) or (n_1+n_2+..., T, in_features)
        edge_weights : torch.Tensor, input weights for edges  (T, n_1^2+n_2^2+..., n^2)
        Output:
        shape: (N, T, out_features)
        """
        graph = graph.local_var() # Return a graph object for usage in a local function scope.
        # multi W first to project the features, with bias
        # (N, F) / (N, T, F)
        graph.ndata['n'] = node_features
        # edge_weights, shape (T, N^2)
        graph.edata['e'] = edge_weights.t().unsqueeze(dim=-1)  # (E, T, 1)
        # send message along
        graph.update_all(fn.u_mul_e('n', 'e', 'msg'), fn.sum('msg', 'h'))

        # getting the results of an update (we stored it into 'h')
        node_features = graph.ndata.pop('h')
        output = self.linear(node_features)
        return output