# Convolutional operators in PyG

PyG has a lot of useful convolution-like operators already implemented. Moreover, with the  `MessagePassing` Base Class of PyG, it is easy to implement additional operators. In this notebook we will implement the graph convolutional operator from the paper [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907) (Kipf and Welling ICLR 2017). In the process we will be translating equations from the paper into code.

#Installation

In [1]:
!pip uninstall -y torch torchvision torchtext fastai
!pip install torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter==latest+cu101 torch-sparse==latest+cu101 -f https://s3.eu-central-1.amazonaws.com/pytorch-geometric.com/whl/torch-1.6.0.html
!pip install torch-geometric==1.6.1

Uninstalling torch-1.8.1+cu101:
  Successfully uninstalled torch-1.8.1+cu101
Uninstalling torchvision-0.9.1+cu101:
  Successfully uninstalled torchvision-0.9.1+cu101
Uninstalling torchtext-0.9.1:
  Successfully uninstalled torchtext-0.9.1
Uninstalling fastai-1.0.61:
  Successfully uninstalled fastai-1.0.61
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.6.0+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torch-1.6.0%2Bcu101-cp37-cp37m-linux_x86_64.whl (708.0MB)
[K     |████████████████████████████████| 708.0MB 30kB/s 
Installing collected packages: torch
Successfully installed torch-1.6.0+cu101
Looking in links: https://s3.eu-central-1.amazonaws.com/pytorch-geometric.com/whl/torch-1.6.0.html
Collecting torch-scatter==latest+cu101
[?25l  Downloading https://s3.eu-central-1.amazonaws.com/pytorch-geometric.com/whl/torch-1.6.0/torch_scatter-latest%2Bcu101-cp37-cp37m-linux_x86_64.whl (11.5MB)
[K     |██████████████████████████████

# Graph-Network Block

A general way to express convolution-like operators on graphs is by using a message-passing framework described in
[Battaglia, Peter W., et al. "Relational inductive biases, deep learning, and graph networks."]( https://arxiv.org/pdf/1806.01261.pdf). In this framework, a convolution-like operator applied to a graph consists of the following steps:


1.  For each edge $e_k=(v_{s_k}, v_{r_k})$  which starts at node $v_{s_k}$ and ends at node $v_{r_k}$, <sup>[*](#myfootnote1)</sup>  we compute an (updated) edge feature-vector 
$$ \mathbf{e}_k^\prime = \phi^e\left(\mathbf{e}_k, \mathbf{v}_{s_k}, \mathbf{v}_{r_k}\right),$$ 
where $\mathbf{v}_{s_k}$ is the initial feature vector of node $v_{s_k}$,  $\mathbf{v}_{r_k}$ is the initial feature vector of node $v_{r_k}$, and $\mathbf{e}_k$ is the initial feature vector of edge $e_k.$ The updated edge feature vector  will be passed to node  $v_{r_k}$ as a message. 
The function $\phi^e$ can be any differentiable function such as a multi-layer perceptron or a simple linear function. 

<center>
<img src="https://ai.science/api/authorized-images/HzNp9jfkWd9u2Mi4WPB8Yyux4ayT2%2B0ywDOVoSDzXECYMfdaH%2B8JA8Cj4g18CiAtxl2tnuIQeK8RUpgqA03QMu2LiYKAOrVqNnFq%2Fc9EWKnFR0DDejFabHGh0sM5J%2FlLVGA88L9ZPZLD3LPJNLJN4rrsdqCJQOWwUiupBlmOHY2bc4xG5fwvdh5l8xy81betZCZrIzZR6nqBMtb7PKOz9%2BF88IHUbCkK99YZau079WGzkd0IX8t67kJpZuiD4rFQQPMYaN2WUBJyn8RnDsntlsIkeYS10jDbvNRuHNkrAU5SY%2FiPBTnhDUyMLrmMViI24WOAy2v1r4wrWDf5kjpTgDaV09RjXNFMj8MRqnEAHtnWRtBCGsbmZFoWmNrFo0YyuGUP%2BxydbkHuzXxTdEhW6FrsGvxGPDap0OGoVfneXPZ87kic1fYS8P3Qv65VgTk%2FNHXWeXhTLN9Y4hbpCOTaaIXJpYJaPPNofuDGQJK2zbPVFv6elUtNY2wwgDt%2F60KfWSdopVF9xqTB1twzLoqru7nfYN%2BhrGlohwn425rBk%2BR1FZpAxX8MoRtycuykiPhxsHJN9obSsPCRk1nswRMtXJRoe5UnoCrCXpwu0jHq%2F64l45XLfFC%2FKwWFAYyyY7Xr2VbqYJh374zEO47S08Ead1c0y27zkVSdQftaw1ZJibk%3D" width="40%" > </center>


2. Once all the messages are computed, all incoming messages to a given node $v_i$ are aggregated by a permutation-invariant operator 
$\rho^{e\rightarrow v}$ 
to compute a single aggregated message $\bar{\mathbf{e}}_i^\prime$ for node $v_i$,
 where
$$ \bar{\mathbf{e}}_i^\prime = \rho^{e\rightarrow v}(\{\mathbf{e}_k^\prime\}_{r_k=i}). $$ 
Here, $\{\mathbf{e}_k^\prime\}_{r_k=i}$ is the set of all incoming messages from edges that end at node $v_i$. We stress here that the function $\rho^{e\rightarrow v}$  should take an arbitrary number of arguments (because a node can have an arbitrary number of neighbors) and be permutation invariant (because a node's neighbors do not have any natural ordering).  Examples of permutation-invariant operator are
`sum`, `mean`, `max`, `min`, and `softmax`. 
It is important that the operator is permutation invariant because the incoming messages to a node have no natural ordering.


<center>
<img src="https://ai.science/api/authorized-images/Zs594R5Id36UdAwRFxSDGUv0Xc09IYGZ4%2BNmsHXffKW%2Bo1t6tRTvAunyORISMAEQKUCWIkrXap%2B1Jrr6cHGrXGcnNGiTuuSplHi6szdstTRtpz9r1S9jGT4MjdBdCJMT%2BcWmSzFWe2Nbxt7CZCiC%2BWRXhtlHmE7kdQhuKmkPxuJI6G7qVa7KCPxWj43I7aBdfSLBmAlf36qn4oaGAUmp%2FuoOtEnAnjgSwVJoCCSZlpSfjzsLy%2Bp26R69Ltr9e1MXax1OChGc4Aco%2BvfPbx1Dt29ayGCJhRJQ0h6XqfXFq87gQSMTMe852ftzx%2Bo5xIcpOJjD2sZwDG%2Bjsk2Kdwr9ZiKUetHAU1yOk%2BvgeRzIxuOw4znbcacTN0gjxMT0l6rTl09o4ueX6dGh9JNC0loidSpo1qC6CqOLlKefnjMsz5joSk0xOc2fvFc38rzgSqVk7zjEvf8KEoKENKYQjTdA6wxdYHoeol0ugHSPf7Kk42l0E4%2FqNmBm8Lv9%2BXYjHumkXPoVqQdGTxk7vJlGT5k5%2F7qUuboCheHhPbnfdybtSILTJ8wYUKjXybJ1DjXw3%2BBsCFZEhkqj39LFDQBKriKAaaQqxtX6M1iw%2B8uZoo9XZQNZlK2amPiA27z8eDWlCv09gP%2BN9KPWrBN8i1a04sR8FRubm90fw4nkIEKi9FdPIAU%3D" width="40%" > </center>



3. Finally the aggregated message $ \bar{\mathbf{e}}_i^\prime$ can be used to compute an updated feature vector for node $v_i$ for the next layer or to be used for a downstream task:
$$ \mathbf{v}_i^\prime = \phi^v\left(\bar{\mathbf{e}}_i^\prime, \mathbf{v}_i \right)$$
where $\phi^v$ is any differentiable  function.

<center>
<img src="https://ai.science/api/authorized-images/0C4kMtUxFY3jcOh4F6mtIAO4n2GR9aNZ%2BtJuecObzsJSEPJYKE7ljsswtSCKxc7GoSH64LU5Z%2BO7AtEkQ6Jkpm4hpvjfrRLJSgbf3lyS1YMxsFbBYT7TaCQPcfxYonNWEcC%2FEKQHtlXsbfzgdhxdRuNLuIiHU3wtHcQ9LA7XMXAkNR8yR1Zd%2FYO%2Fr85l0hXKc9mhmZDmGato56E7kAga23mLpNBHo42Trtt4mq%2Bn5%2Fma5QzT0KGW1XjrFM%2BAeqnr1L%2Fr8a4I6lfUUfoa9LNog5s9Ln3KL7DY0C2p1DrQb9mOmmaMkjJ4ArKpyx6jshV4vAdI9FZNvzai8AvTxd4Wr7hIhIN%2FW1HzrOpv1RfDqh5hI%2FwVMORUz51PoeYhggVoMZnqxTMbQNWOPAJaorxL7lfcXkmjpdUI9RawmjBUJ3wI8PVVqpeE%2BeWTShK3jYuCKHX51frjOBpQq6uzfh6cGALEcoUvdTUhBu0hNYE0dgXV%2Fn8lGte%2FrtcaRVqkFpJ1k0e5g9jzXHd4NV5okfpTLELevii5sHHc%2B8%2BJQmvjKAJgbxWvzE1BFXx6NUWH1ZpdAvmqaoRNgpCMZhb%2FHVlLbb%2Fia3zTFZgpZ%2FImCjXx6pQmi5I%2Fr603AsHbwIuHpvkUTWJNoxT%2FRsY%2BBd%2BfldYelZU4Fjwu2cnzUL5keAPZM8I%3D" width="40%" > </center>



Note: For simplicity, the treatment of the feature vector of the graph itself $\mathbf{u}$ has been omitted above.

Steps 1-3 form one GN block. The output of this block is the updated edge and node feature-vectors. These updated feature-vectors can be passed on to another block.



<a name="myfootnote1">*</a>: Note: the letter *r* is for receiver and the letter *s* is for sender


# Implementation of Graph Convolutional Network (Kipf et al 2017)

We will implement the graph convolutional operator described in [Kipf and Welling ICLR 2017](https://arxiv.org/abs/1609.02907)

The convolution operator can be written as

$$\mathbf{v}_i^{\prime} = \sum_{j \in \mathcal{N}(i) \cup i}^{n}  \frac{1}{\sqrt{(d_i+1)(d_j+1)}}\Theta \mathbf{v}_j,$$ 
where $\mathcal{N}(i)$ are the neighboring node-indices of node $v_i$,  $d_i$ is the degree (number of neighbors) of node $v_i$, and $\Theta$ is a learnable tensor.


In terms of the message passing framework described above  <sup>[*](#myfootnote1)</sup>, the message  $\phi^e$ function (which doesn't use any edge features in this case) is 

$$\phi^e\left(\mathbf{v}_i, \mathbf{v}_j\right) = \frac{1}{\sqrt{(d_i+1)(d_j+1)}}\Theta\mathbf{v}_j. $$ 

The aggregation operator is the `sum` operator $ \sum$ i.e.:
$$ \rho^{e\rightarrow v}(\{\mathbf{e}_k^\prime\}_{r_k=i})=  \sum(\{\mathbf{e}_k^\prime\}_{r_k=i})$$
and the update operator $\phi^v$ is given by

$$\phi^v\left(\bar{\mathbf{e}}_i^\prime, \mathbf{v}_i \right)= \frac{1}{1+d_i}\Theta\mathbf{v}_i + \bar{\mathbf{e}}_i^\prime.$$


<a name="myfootnote1">*</a>: Note: The Message Passing framework was initially described in  [Gilmer et al., 2017 ](https://arxiv.org/abs/1704.01212). The paper [Kipf and Welling ICLR 2017](https://arxiv.org/abs/1609.02907) does not use this terminology.


Let's see how this can be implemented in PyG.
We will define our `MyGCNConv` class, which inherets from the 
`MessagePassing` Base Class, which itself inherets from `torch.nn.Module`.
In addition to the usual forward() function, we need to define 
message() and update() functions.

See [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html) for more details on the `MessagePassing` Base Class


In [2]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree


class MyGCNConv(MessagePassing):

    # 0 - Instantiation of class. Here the linear operator is defined according
    # to the arguments in_channels and out_channels
    def __init__(self, in_channels, out_channels):

        # Define the aggregation function used to aggregate 
        #all messages passed into a node
        super().__init__(aggr='add') 

        # Linear layer to transform feature vector of a neighboring node
        # from length in_channels to length out_channels
        self.lin = torch.nn.Linear(in_channels, out_channels,bias=False)

    # 1 - The forward function is first called when a user calls an instance
    # of MyGCNConv. By PyG's convention, the feature vector is denoted by x. 
    def forward(self, x, edge_index):
        # x: node features. Shape:  (N, in_channels)
        # edge_index: graph structure. Shape: (2, E)

        # Transform the feature vectors of all nodes. 
        # Output Shape: (N, out_channels)
        x = self.lin(x) # Note this gives \Theta * V

        # rows is the source-node indices and cols contains the receiver-node indices
        rows, cols = edge_index 

        # Compute degree (i.e. number of neighbors) of each node
        degrees = degree(rows, dtype=x.dtype) # Shape:(N)
        degrees = degrees.view(-1,1) # Shape (N,1)

        # We are now ready to call the propagate function defined in the base class,
        # which takes in the edge indices (graph structure) and additional data.
        # The propagate function first calls the message function (defined below).
        # After the message is built, the propagate function then aggregates 
        # the messages according the operator defined in the call to 
        # super().__init__ (see above). Finally the propagate function calls 
        # the update function (defined below).
        # The user can specify here which variables to pass on to 
        # both the message and update functions 
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), 
                              x=x, degrees=degrees)

    # The message function computes the message for each edge. 
    # The arguments provided to the propagate function can be accessed here.
    # The suffices "_i" and "_j" can be used to map any node features to the 
    # source and destination nodes,respectively, of the edges in the graph.
    def message(self, x_j, degrees_i, degrees_j):
        # x_j:features of the source node for each edge. Shape: [E, out_channels] 
        # degrees_j:degree of the source node for each edge. Shape: [E, 1]
        # degrees_i:degree of the destination node for each edge. Shape: [E, 1]
        norm = ((degrees_j+1).pow(-0.5)*(degrees_i+1).pow(-0.5))

        # Construct the message
        return norm*x_j

    # The update function. Takes the output of the aggregate function.
    # The arguments in the propagate function can also be accessed here
    def update(self, aggr_out, x, degrees):
        # aggr_out has shape [N, out_channels]
        # x has shape [N, out_channels]
        # degrees has shape [N, out_channels]
        x_new = aggr_out +x/(degrees+1)
        return x_new

# Optional Exercises:

Exercise 1: Write a Graph Neural Network which uses the GCN Operator above.

Exercise 2: Train the Network in Exercise 1 for a node classification task.

Exercise 3: Try modifying different aspects of the network above (e.g. normalization factor, bias in the linear layer or in the update step), and see how it affects the accuracy of the task in Exercise 2