<a href="https://colab.research.google.com/github/Carlos1729/DGL/blob/main/Write_your_own_GNN_module_Message_Passing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%matplotlib inline

In [2]:
!pip install  dgl -f https://data.dgl.ai/wheels/cu116/repo.html

Looking in links: https://data.dgl.ai/wheels/cu116/repo.html
Collecting dgl
  Downloading https://data.dgl.ai/wheels/cu116/dgl-1.1.2%2Bcu116-cp310-cp310-manylinux1_x86_64.whl (92.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.2/92.2 MB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dgl
Successfully installed dgl-1.1.2+cu116


In [3]:
!pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html

Looking in links: https://data.dgl.ai/wheels-test/repo.html
Collecting dglgo
  Downloading dglgo-0.0.2-py3-none-any.whl (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.5/63.5 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Collecting isort>=5.10.1 (from dglgo)
  Downloading isort-5.12.0-py3-none-any.whl (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting autopep8>=1.6.0 (from dglgo)
  Downloading autopep8-2.0.4-py2.py3-none-any.whl (45 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.3/45.3 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numpydoc>=1.1.0 (from dglgo)
  Downloading numpydoc-1.6.0-py3-none-any.whl (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.7/61.7 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
Collecting ruamel.yaml>=0.17.20 (from dglgo)
  Downloading ruamel.yaml-0.18.3-py3-none-any.whl (11


Write your own GNN module
=========================

Sometimes, your model goes beyond simply stacking existing GNN modules.
For example, you would like to invent a new way of aggregating neighbor
information by considering node importance or edge weights.

By the end of this tutorial you will be able to

-  Understand DGL’s message passing APIs.
-  Implement GraphSAGE convolution module by your own.

This tutorial assumes that you already know :doc:`the basics of training a
GNN for node classification <1_introduction>`.

(Time estimate: 10 minutes)


In [4]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


Message passing and GNNs
------------------------

DGL follows the *message passing paradigm* inspired by the Message
Passing Neural Network proposed by `Gilmer et
al. <https://arxiv.org/abs/1704.01212>`__ Essentially, they found many
GNN models can fit into the following framework:

\begin{align}m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right)\end{align}

\begin{align}m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)}\end{align}

\begin{align}h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)\end{align}

where DGL calls $M^{(l)}$ the *message function*, $\sum$ the
*reduce function* and $U^{(l)}$ the *update function*. Note that
$\sum$ here can represent any function and is not necessarily a
summation.




For example, the `GraphSAGE convolution (Hamilton et al.,
2017) <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__
takes the following mathematical form:

\begin{align}h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\}\end{align}

\begin{align}h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right)\end{align}

You can see that message passing is directional: the message sent from
one node $u$ to other node $v$ is not necessarily the same
as the other message sent from node $v$ to node $u$ in the
opposite direction.

Although DGL has builtin support of GraphSAGE via
:class:`dgl.nn.SAGEConv <dgl.nn.pytorch.SAGEConv>`,
here is how you can implement GraphSAGE convolution in DGL by your own.




In [5]:
import dgl.function as fn

# This line imports functions from the dgl.function module. In this code, functions like copy_u and mean from
# DGL's function library are used for message passing.

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):#This line defines the constructor for the SAGEConv class, which takes two arguments:
                                          # in_feat: The size of the input features for each node.
                                          # out_feat: The size of the output features that this convolution will produc

        super(SAGEConv, self).__init__() #Calls the constructor of the parent class nn.Module, which is necessary when creating a custom
        # neural network module.A linear submodule for projecting the input and neighbor feature to the output.

        self.linear = nn.Linear(in_feat * 2, out_feat)#Initializes a linear transformation (nn.Linear) that combines input and neighbor features
        # to produce output features. The input dimension of the linear layer is in_feat * 2 because it concatenates the input feature (h) and
        #  the aggregated neighbor feature (h_N)

    def forward(self, g, h):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
            g.ndata['h'] = h
            # Assigns the input node features h to the nodes of the graph g with the key 'h'. This is a common practice in DGL to set initial node features.
            # update_all is a message passing API.
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
            # Utilizes the update_all method to perform message passing in the graph. Specifically:
            # message_func (fn.copy_u('h', 'm')): Defines the message function, which copies the node feature 'h' to the message 'm'. This function is used to gather information from neighboring nodes.
            # reduce_func (fn.mean('m', 'h_N')): Specifies the reduce function, which calculates the mean of the messages 'm' received from neighbors and stores the result in the node feature 'h_N'.
            h_N = g.ndata['h_N']
            # Retrieves the node features 'h_N' from the graph, which now contain the aggregated information from neighboring nodes.
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)
            #Concatenates the input node features h and the aggregated neighbor features h_N along the specified dimension (dim=1) to form the
            # total input for the linear layer return self.linear(h_total):

            # Passes the concatenated features h_total through a linear transformation defined earlier (self.linear).
            # The result is the output of the GraphSAGE convolution.

The central piece in this code is the
:func:`g.update_all <dgl.DGLGraph.update_all>`
function, which gathers and averages the neighbor features. There are
three concepts here:

* Message function ``fn.copy_u('h', 'm')`` that
  copies the node feature under name ``'h'`` as *messages* sent to
  neighbors.

* Reduce function ``fn.mean('m', 'h_N')`` that averages
  all the received messages under name ``'m'`` and saves the result as a
  new node feature ``'h_N'``.

* ``update_all`` tells DGL to trigger the
  message and reduce functions for all the nodes and edges.




Afterwards, you can stack your own GraphSAGE convolution layers to form
a multi-layer GraphSAGE network.




In [6]:
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

Training loop
~~~~~~~~~~~~~
The following code for data loading and training loop is directly copied
from the introduction tutorial.




In [7]:
import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(200):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that we should only compute the losses of the nodes in the training set,
        # i.e. with train_mask 1.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())

        if e % 5 == 0:
            print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
                e, loss, val_acc, best_val_acc, test_acc, best_test_acc))

model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
Extracting file to /root/.dgl/cora_v2_d697a464
Finished data loading and preprocessing.
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.
In epoch 0, loss: 1.951, val acc: 0.114 (best 0.114), test acc: 0.103 (best 0.103)
In epoch 5, loss: 1.880, val acc: 0.308 (best 0.308), test acc: 0.289 (best 0.289)
In epoch 10, loss: 1.741, val acc: 0.414 (best 0.414), test acc: 0.397 (best 0.397)
In epoch 15, loss: 1.527, val acc: 0.586 (best 0.586), test acc: 0.602 (best 0.602)
In epoch 20, loss: 1.246, val acc: 0.640 (best 0.640), test acc: 0.643 (best 0.643)
In epoch 25, loss: 0.934, val acc: 0.672 (best 0.672), test acc: 0.672 (best 0.662)
In epoch 30, loss: 0.641, val acc: 0.712 (best 0.712), test acc: 0.691 (best 0.691)
In epoch 35, loss: 0.409, val acc: 0.718 (best 0.718), test ac

More customization
------------------

In DGL, we provide many built-in message and reduce functions under the
``dgl.function`` package. You can find more details in `the API
doc <apifunction>`.




These APIs allow one to quickly implement new graph convolution modules.
For example, the following implements a new ``SAGEConv`` that aggregates
neighbor representations using a weighted average. Note that ``edata``
member can hold edge features which can also take part in message
passing.




In [8]:
class WeightedSAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model with edge weights.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)

    def forward(self, g, h, w):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        w : Tensor
            The edge weight.
        """
        with g.local_scope():
            g.ndata['h'] = h
            g.edata['w'] = w
            g.update_all(message_func=fn.u_mul_e('h', 'w', 'm'), reduce_func=fn.mean('m', 'h_N'))
            #Builtin message function that computes a message on an edge by performing element-wise mul between features of u and e if the
            #features have the same shape; otherwise, it first broadcasts the features to a new shape and performs the element-wise operation
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

Because the graph in this dataset does not have edge weights, we
manually assign all edge weights to one in the ``forward()`` function of
the model. You can replace it with your own edge weights.




In [9]:
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
        h = F.relu(h)
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
        return h

model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

In epoch 0, loss: 1.951, val acc: 0.122 (best 0.122), test acc: 0.130 (best 0.130)
In epoch 5, loss: 1.859, val acc: 0.462 (best 0.468), test acc: 0.472 (best 0.466)
In epoch 10, loss: 1.705, val acc: 0.340 (best 0.468), test acc: 0.356 (best 0.466)
In epoch 15, loss: 1.484, val acc: 0.402 (best 0.468), test acc: 0.418 (best 0.466)
In epoch 20, loss: 1.212, val acc: 0.480 (best 0.480), test acc: 0.481 (best 0.481)
In epoch 25, loss: 0.921, val acc: 0.542 (best 0.542), test acc: 0.556 (best 0.556)
In epoch 30, loss: 0.650, val acc: 0.584 (best 0.584), test acc: 0.592 (best 0.592)
In epoch 35, loss: 0.425, val acc: 0.638 (best 0.638), test acc: 0.634 (best 0.634)
In epoch 40, loss: 0.261, val acc: 0.696 (best 0.696), test acc: 0.681 (best 0.681)
In epoch 45, loss: 0.156, val acc: 0.722 (best 0.722), test acc: 0.701 (best 0.701)
In epoch 50, loss: 0.095, val acc: 0.732 (best 0.732), test acc: 0.714 (best 0.715)
In epoch 55, loss: 0.060, val acc: 0.730 (best 0.732), test acc: 0.713 (best 0

Even more customization by user-defined function
------------------------------------------------

DGL allows user-defined message and reduce function for the maximal
expressiveness. Here is a user-defined message function that is
equivalent to ``fn.u_mul_e('h', 'w', 'm')``.




In [10]:
def u_mul_e_udf(edges):
    return {'m' : edges.src['h'] * edges.data['w']}

``edges`` has three members: ``src``, ``data`` and ``dst``, representing
the source node feature, edge feature, and destination node feature for
all edges.




You can also write your own reduce function. For example, the following
is equivalent to the builtin ``fn.mean('m', 'h_N')`` function that averages
the incoming messages:




In [11]:
def mean_udf(nodes):
    return {'h_N': nodes.mailbox['m'].mean(1)}

In short, DGL will group the nodes by their in-degrees, and for each
group DGL stacks the incoming messages along the second dimension. You
can then perform a reduction along the second dimension to aggregate
messages.

For more details on customizing message and reduce function with
user-defined function, please refer to the `API
reference <apiudf>`.




Best practice of writing custom GNN modules
-------------------------------------------

DGL recommends the following practice ranked by preference:

-  Use ``dgl.nn`` modules.
-  Use ``dgl.nn.functional`` functions which contain lower-level complex
   operations such as computing a softmax for each node over incoming
   edges.
-  Use ``update_all`` with builtin message and reduce functions.
-  Use user-defined message or reduce functions.




What’s next?
------------

-  `Writing Efficient Message Passing
   Code <guide-message-passing-efficient>`.




In [12]:
# Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018
# sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png'