Deep Graph Library (DGL)
=====================

DGL is designed to bring machine learning closer to graph-structured data. Specifically DGL enables trouble-free implementation of graph neural network (GNN) model family. Unlike PyTorch or Tensorflow, DGL provides friendly APIs to perform the fundamental operations in GNNs such as message passing and reduction. Through DGL, we hope to benefit both researchers trying out new ideas and engineers in production.

In this tutorial, we demostrate the basics of DGL including:
- How to create a graph?
- How to manipulate node/edge features on a graph?
- How to convert a graph to/from other formats?
- How to perform message passing computation on a graph?
- How to implement a Graph Convolutional Network?
- How to batch execution of multiple graphs?
- How to perform efficient read-out on a batch of graphs?

Although this tutorial uses [PyTorch](https://pytorch.org) as backend for tensor-related computations (thus some familarity with PyTorch is preferred), DGL is designed to be platform-agnostic and can be seamlessly integreted into other frameworks like [MXNet](https://mxnet.apache.org/) and [TensorFlow](https://www.tensorflow.org/), and we are actively working on this.

In [None]:
# A bit of setup, just ignore this cell
import matplotlib.pyplot as plt

# for auto-reloading external modules
%load_ext autoreload
%autoreload 2

%matplotlib inline
plt.rcParams['figure.figsize'] = (8.0, 6.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['animation.html'] = 'html5'

We start by creating the well-knowned *"Zachary's karate club"* social network. The network captures 34 members of a karate club, documenting pairwise links between members who interacted outside the club. The club later splits into two communities led by the instructor (node 0) and club president (node 33). You could read more about the story in the [wiki page](https://en.wikipedia.org/wiki/Zachary%27s_karate_club) A visualization of the network and the community is as follows:

![karate](https://www.dropbox.com/s/d7dgs4fantje3dg/karate.jpg?dl=1)

Part 1: creating a graph
-----------------------------------

Let's see how we can create such a graph in DGL. We start with importing `dgl` and other relevant packages.

In [None]:
import dgl

We first create an empty `DGLGraph`. In DGL, nodes are consecutive integers start from 0. The following codes add all the club members into this graph (34 nodes).

In [None]:
G = dgl.DGLGraph()
G.add_nodes(34)
print('Number of nodes:', G.number_of_nodes())

The Karate Club network contains 78 edges:
```
[1 0]
[2 0] [2 1]
[3 0] [3 1] [3 2]
[4 0]
[5 0]
[6 0] [6 4] [6 5]
[7 0] [7 1] [7 2] [7 3]
[8 0] [8 2]
[9 2]
[10 0] [10 4] [10 5]
[11 0]
[12 0] [12 3]
[13 0] [13 1] [13 2] [13 3]
[16 5] [16 6]
[17 0] [17 1]
[19 0] [19 1]
[21 0] [21 1]
[25 23] [25 24]
[27 2] [27 23] [27 24]
[28 2]
[29 23] [29 26]
[30 1] [30 8]
[31 0] [31 24] [31 25] [31 28]
[32 2] [32 8] [32 14] [32 15] [32 18] [32 20] [32 22] [32 23] [32 29] [32 30] [32 31]
[33 8] [33 9] [33 13] [33 14] [33 15] [33 18] [33 19] [33 20] [33 22] [33 23] [33 26] [33 27] [33 28] [33 29] [33 30] [33 31] [33 32]
```

In DGL, edges can be added by specifying the two endpoints.

In [None]:
G.add_edge(1, 0)
print('Now we have %d edges!' % G.number_of_edges())

To add multiple edges at once, use a list/tensor of nodes to specify the endpoints.

In [None]:
import torch

########
# NOTE: in DGL, edges are added by specifying a list of source nodes and a list of destination nodes,
# rather than a list of source-destination node pairs. This is different from other popular graph
# package such as networkx, python-igraph.

########
# NOTE: edges in DGLGraphs are all directional.

# add two edges 2->0 and 2->1 using list
G.add_edges([2, 2], [0, 1])

# add three edges 3->0, 3->1 and 3->2 using torch tensor
src = torch.tensor([3, 3, 3])
dst = torch.tensor([0, 1, 2])
G.add_edges(src, dst)

print('Now we have %d edges!' % G.number_of_edges())

In [None]:
# add two edges 4->0, 5->0 using list
G.add_edges([4, 5], 0)

# add three edges 6->0 6->4 6->5 using torch tensor
G.add_edges(6, torch.tensor([0, 4, 5]))

print('Now we have %d edges!' % G.number_of_edges())

If the edges share the same source or destination nodes, the list/tensor type can be replaced with a single integer.

In [None]:
# Excercise: please finish the karate club graph by adding the remaining edges. We have provided you all the
# remaining edge tuples in a list.

edge_list = [(7, 0), (7, 1), (7, 2), (7, 3), (8, 0), (8, 2), (9, 2), (10, 0), (10, 4), (10, 5),
             (11, 0), (12, 0), (12, 3), (13, 0), (13, 1), (13, 2), (13, 3), (16, 5), (16, 6),
             (17, 0), (17, 1), (19, 0), (19, 1), (21, 0), (21, 1), (25, 23), (25, 24), (27, 2),
             (27, 23), (27, 24), (28, 2), (29, 23), (29, 26), (30, 1), (30, 8), (31, 0), (31, 24),
             (31, 25), (31, 28), (32, 2), (32, 8), (32, 14), (32, 15), (32, 18), (32, 20), (32, 22),
             (32, 23), (32, 29), (32, 30), (32, 31), (33, 8), (33, 9), (33, 13), (33, 14), (33, 15),
             (33, 18), (33, 19), (33, 20), (33, 22), (33, 23), (33, 26), (33, 27), (33, 28),
             (33, 29), (33, 30), (33, 31), (33, 32)]

# >>> YOUR CODES START

src, dst = tuple(zip(*edge_list))
G.add_edges(src, dst)

# <<< YOUR CODES END

# We should have 78 edges now!
print('Now we have %d edges!' % G.number_of_edges())

Part 2: manipulating node/edge features
---------------------------------------------------------

Nodes and edges in `DGLGraph` can have **features** tensors. Features of multiple nodes/edges are batched on the first dimension. Let's start by assigning a random feature vector of length 5 to all nodes.

In [None]:
G.ndata['feat'] = torch.randn((34, 5))

Now each node has a feature vector `'feat'` that has 5 elements. Note since there are 34 nodes in this graph, the first dimension must be of size 34, so that each row corresponds to the feature vector of each node. Error will be raised if the dimension mismatches:

In [None]:
# This will raise error!!
# G.ndata['wrong_feat'] = torch.randn((35, 5))

The `G.ndata` is a dictionary-like structure, so it is compatible with any operation on dictionary.

In [None]:
# Use `dict.update` to add new features (vector of length 3)
G.ndata.update({'another_feat' : torch.randn((34, 3))})

# Print the feature dictionary
print(G.ndata)

# Delete the new feature using `dict.pop`
G.ndata.pop('another_feat')

Sometimes, you might want to update features of some but not all of the nodes. This can be done using the following syntax:

In [None]:
# Set node 0's feat to be all-zeros vector. Please be aware of the extra size 1 dimension here.
G.nodes[0].data['feat'] = torch.zeros((1, 5))

# Set node 2, 3's feat to be all-ones vector at once using list type.
G.nodes[[2, 3]].data['feat'] = torch.ones((2, 5))

# Set node 10, 11, 12's feat to be all-twos vector at once using tensor type.
to_change = torch.tensor([10, 11, 12])
G.nodes[to_change].data['feat'] = torch.ones((3, 5)) * 2

Similar to `G.ndata` and `G.nodes`, we have `G.edata` and `G.edges` to access and modify edge features:

In [None]:
# The broness edge feature is just a scalar.
G.edata['broness'] = torch.ones((G.number_of_edges(),))

# The instructor (node 0) is a tough guy, so his friends are a little bit scared of him.
G.edges[G.predecessors(0), 0].data['broness'] *= 0.5

print(G.edata)

In [None]:
# Exercise: We know that measuring bro-ness cannot be accurate. Could you add some small random noise to it?
# Hint: Use `torch.randn` to add small permutation to it.
#
# >>> YOUR CODES START

G.edata['broness'] += torch.randn((G.number_of_edges(),)) * 0.1

# <<< YOUR CODES END

# You should see some randomness here
print(G.edata['broness'])

Part 3: converting to/from networkx graph
------------------------------------------------------------

[Networkx](https://networkx.github.io/documentation/stable/) is a classical and popular python graph library. It provides many good utilities to analyze and visualize a graph. `DGLGraph` can be easily converted to/from `networkx` graph very easily:

In [None]:
import networkx as nx

nx_G = G.to_networkx()
pos = nx.circular_layout(nx_G)
nx.draw(nx_G, pos, with_labels=True)

Constructing a DGLGraph from networkx is straight-forward. In fact, DGL borrows many of the networkx utilities to create graph from different format:

In [None]:
# from networkx graph
G_from_nx = dgl.DGLGraph(nx_G)  # this gives you the same karate club network

# from edge list
G_from_elist = dgl.DGLGraph([(0,1), (1,2), (2,3)])  # this gives you a chain graph

# from scipy sparse matrix
import scipy.sparse as sp
A = sp.eye(5, 5, 1)
G_from_sp = dgl.DGLGraph(A)  # this also gives you a chain of 5 nodes

Part 4: Message passing on graph
-------------------------------------------------

Many graph neural networks follows the **message passing** computation model -- nodes can send out messages which are then aggregated and used to update the receiver nodes. We go through the basic mechanism of message passing using a toy task and then use it to implement a Graph Convolutional Network (GCN).

Suppose the club president (node 33) is sending out an invitation of their annual karate match. The president also asks the club members to broadcast the news to, of course, their friends in the club. We use a scalar to represent whether the member has received the invitation or not (1 for invited, 0 for not invited). Initially, everyone is 0 except node 33.

In [None]:
# We first convert the uni-directional edges to bi-directional so messages can
#   be sent in both direction.
# We also add a self loop for each node for convenience.
src, dst = G.edges()
GG = dgl.DGLGraph()
GG.add_nodes(34)
GG.add_edges(src, dst)
GG.add_edges(dst, src)
# add self loop for each nodes
v = G.nodes()
GG.add_edges(v, v)
print('We now have %d edges!' % GG.number_of_edges())

# init the state
GG.ndata['invited'] = torch.zeros((34,))
GG.nodes[33].data['invited'] = torch.tensor([1.])
print(GG.ndata['invited'])

We first define the function that computes the messages. In DGL, the message function is an **Edge UDF** that takes in a single argument `edges`. It has three members `src`, `dst`, and `data` for accessing source node features, destination node features, and edge features respectively.

In [None]:
def message_func(edges):
    # The message is simply the 'invited' state of the source nodes.
    return {'msg' : edges.src['invited']}

Next, we define the reduce function which accumulates and consume the messages to update the node features. In DGL, the reduce function is a **Node UDF** that takes in a single argument `nodes`, which has two members `data` and `mailbox`. `data` contains the node features while `mailbox` contains all incoming message features, stacked along the second dimension (hence the `dim=1` argument).

In [None]:
def reduce_func(nodes):
    # The reduce function sets the 'invited' state to be one if the node has already
    #   been invited or any of the received messages contains an invitation (is one).
    #   This can be done using sum and clamp operations as follows.
    accum = nodes.mailbox['msg'].sum(dim=1)  # note that messages are stacked on dim=1
    return {'invited' : accum.clamp(max=1)}

To trigger the message and reduce function, one can use the `send` and `recv` APIs. Following codes send out the messages from node 33:

In [None]:
# The first argument to `G.send` is the edges along which the messages are sent.
# Note that we can use the same syntax used in adding edges to the graph.
# The second argument is the message function we just defined.
GG.send((33, GG.successors(33)), message_func)

We then call `recv` on the receiver nodes to trigger the reduce function.

In [None]:
GG.recv(GG.successors(33), reduce_func)

You can print out the `'invited'` status to see the invitation being propagated.

In [None]:
print(GG.ndata['invited'])

We can keep doing so until all the nodes received the invitation.

In [None]:
num_invited = int(torch.sum(GG.ndata['invited']))
while num_invited != 34:
    GG.send(GG.edges(), message_func)
    GG.recv(GG.nodes(), reduce_func)
    num_invited = int(torch.sum(GG.ndata['invited']))
    print('%d members have been invited.' % num_invited)

**What's under the hood?**

The key idea here is to automatically batch the node and edge features so that your UDF can compute message passing on multiple nodes and edges in parallel.

```python
def message_func(edges):
    return {'msg' : edges.src['invited']}
```

The `edges` argument is an `EdgeBatch` object representing a batch of edges. It has three members, `src`, `dst`, `data`. The `edges.src['invited']` returns a tensor of shape `(B,)`, where `B` is the number of edges being triggered.

```python
def reduce_func(nodes):
    accum = nodes.mailbox['msg'].sum(dim=1)
    return {'invited' : accum.clamp(max=1)}
```

Similarly, for the reduce function, the argument `nodes` is an `NodeBatch` object representing a batch of nodes. It has two members `data` and `mailbox`. The `nodes.mailbox['msg']` returns a tensor of shape `(B, deg)`, where `B` is the number of nodes that have the same in-degree `deg`. The reduce function will be called *many times* for each degree group.

Part 5: Implementing Graph Convolutional Network (GCN) in DGL
--------------------------------------------------------------

Graph convolutional network (GCN) is a popular model proposed by [Kipf & Welling](https://arxiv.org/abs/1609.02907) to encode graph structure by message passing. The high-level idea is similar to our toy task -- node features are updated by aggregating the messages from the neighbors. Here is its message passing equation:

$$
h_{v_i}^{(l+1)} = \sigma \left(\sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_{v_j}^{(l)}W^{(l)} \right)
$$

, where $v_i$ is any node in the graph; $h_{v_i}$ is the feature of node $v_i$; $\mathcal{N}(i)$ denotes the neighborhood of $v_i$; $c_{ij}$ is the normalization constant related to node degrees; $W$ is the parameter and $\sigma$ is a non-linear activation function.

The procedure to implement GCN in DGL is also similar to the toy task:
* Define the message function.
* Define the reduce function.
* Define how they are triggered using `send` and `recv`.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

# Define the message & reduce function
# NOTE: we ignore the normalization constant c_ij for now.
def gcn_message(edges):
    # messages are the features of the source nodes.
    return {'msg' : edges.src['h']}

def gcn_reduce(nodes):
    # messages are summed
    return {'h' : torch.sum(nodes.mailbox['msg'], dim=1)}

# Define the GCN module
class GCN(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCN, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
    
    def forward(self, g, inputs):
        # g is the graph and the inputs is the input node features
        # first set the node features
        g.ndata['h'] = inputs
        # trigger message passing
        g.send(g.edges(), gcn_message)
        g.recv(g.nodes(), gcn_reduce)
        # get the result node features
        h = g.ndata.pop('h')
        # perform linear transformation
        return self.linear(h)

To test this model, let's try to predict which club member will join whose group (instructor or club president) after the split. We adopt the semi-supervised setting developed by Kipf:

In [None]:
# Clear previous features
GG.ndata.clear()
GG.edata.clear()

# Define a 2-layer GCN model
class Net(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(Net, self).__init__()
        self.gcn1 = GCN(in_feats, hidden_size)
        self.gcn2 = GCN(hidden_size, num_classes)
    
    def forward(self, g, inputs):
        h = self.gcn1(g, inputs)
        h = torch.relu(h)
        h = self.gcn2(g, h)
        return h

inputs = torch.eye(34)  # featureless inputs
labeled_nodes = torch.tensor([0, 33])  # only the instructor and the president nodes are labeled
labels = torch.tensor([0, 1])  # their labels are different
net = Net(34, 5, 2)
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

all_logits = []
for epoch in range(30):
    logits = net(GG, inputs)
    all_logits.append(logits.detach())
    logp = F.log_softmax(logits, 1)
    # we only compute loss for node 0 and node 33
    loss = F.nll_loss(logp[labeled_nodes], labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))

In [None]:
# Visualize the node classification using the logits output.
import numpy as np
import matplotlib.animation as animation
from IPython.display import HTML

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()
def draw(i):
    cls1color = '#00FFFF'
    cls2color = '#FF00FF'
    pos = {}
    colors = []
    for v in range(34):
        pos[v] = all_logits[i][v].numpy()
        cls = np.argmax(pos[v])
        colors.append(cls1color if cls else cls2color)
    ax.cla()
    ax.axis('off')
    ax.set_title('Epoch: %d' % i)
    nx.draw(nx_G.to_undirected(), pos, node_color=colors, with_labels=True, node_size=500)

ani = animation.FuncAnimation(fig, draw, frames=len(all_logits), interval=200)
HTML(ani.to_html5_video())

Advanced Topic: speed up GNN training
--------------------------------------------------

DGL provides many routines that combines basic `send` and `recv` in various ways. They are called **level-2 APIs**. For example, we can use the `update_all` API in the GCN module so that no explicit `edges()` and `nodes()` tensors are generated.

In [None]:
# Re-define the GCN module using level-2 APIs.
class GCN_level2(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCN_level2, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
    
    def forward(self, g, inputs):
        # g is the graph and the inputs is the input node features
        # first set the node features
        g.ndata['h'] = inputs
        # trigger message passing using `update_all`
        # original codes:
        #   g.send(g.edges(), gcn_message)
        #   g.recv(g.nodes(), gcn_reduce)
        g.update_all(gcn_message, gcn_reduce)
        # get the result node features
        h = g.ndata.pop('h')
        # perform linear transformation
        return self.linear(h)

As some of the message and reduce functions are very commonly used, DGL also provides **builtin functions**. The following codes use `copy_src` and `sum` builtins.

In [None]:
# Re-define the GCN module using DGL builtin functions.
import dgl.function as fn

class GCN_builtin(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCN_builtin, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
    
    def forward(self, g, inputs):
        # g is the graph and the inputs is the input node features
        # first set the node features
        g.ndata['h'] = inputs
        # trigger message passing using `update_all`
        # original codes:
        #   g.send(g.edges(), gcn_message)
        #   g.recv(g.nodes(), gcn_reduce)
        g.update_all(fn.copy_src('h', 'msg'), fn.sum('msg', 'h'))
        # get the result node features
        h = g.ndata.pop('h')
        # perform linear transformation
        return self.linear(h)

### Exercise

There is still one missing piece. In our GCN model, 
$$
h_{v_i}^{(l+1)} = \sigma \left(\sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_{v_j}^{(l)}W^{(l)} \right)
$$
And we haven't implemented the normalizer $c_{ij}$. Kipf, in GCN paper, pointed out that the normalizer should be computed as follows:

$$
c_{ij} = \sqrt{d_id_j}
$$

, where $d_i, d_j$ are the degrees of node $v_i$ and $v_j$ respectively. Your task is to modify the program to implement it.

**Hint #1**: Use `GG.in_degrees(GG.nodes())` to get a 1-D tensor containing the degrees of all the nodes.

**Hint #2**: Since $c_{ij}$ has a subscription $ij$, it is tied to the edges, and our message function is (not coincidently) an **edge UDF**.

Have fun :)

Part 6: Batch execution of graphs
------------------------------------------------------
So far, we have learnt together how to create and change graph with DGL, and how to trigger computation on nodes and edges of a graph, which is useful to implement models that focus on representation learning on one single large graph like citation graph or knowledge graph.

However, there are also scenarios like learning sentence syntax trees ([Tai et al., 2015](https://arxiv.org/abs/1503.00075)) or chemical structures ([Jin et at., 2018](https://arxiv.org/abs/1802.04364)) where the goal is to learn representation to classify or extract features for a lot of individual graphs. 

In such applications, ability to batch computation on multiple graphs matter. And we will demonstrate how DGL address batching graph with its BatchedGraph API.

### Simple Graph Classification Task
To make it more concrete, let's use graph classification as our example. Graph classification is an important problem with applications across many fields – bioinformatics, chemoinformatics, social network analysis, urban computing and cyber-security. Applying graph neural networks to this problem has been a popular approach recently ([Ying et al., 2018](https://arxiv.org/abs/1806.08804), [Cangea et al., 2018](https://arxiv.org/abs/1811.01287), [Knyazev et al., 2018](https://arxiv.org/abs/1811.09595), [Bianchi et al., 2019](https://arxiv.org/abs/1901.01343), [Liao et al., 2019](https://arxiv.org/abs/1901.01484), [Gao et al., 2019](https://openreview.net/forum?id=HJePRoAct7)).


In this tutorial, we will use a simple graph classification task. We create a synthetic dataset data.MiniGCDataset, which has 8 different types of graphs and each class has the same number of graph samples. And the task is decide which of the 8 types below each graph belongs to.

![](https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/dataset_overview.png)



In [None]:
from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
# A dataset with 80 samples, each graph is
# of size [10, 20]
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[0]
fig, ax = plt.subplots()
nx.draw(graph.to_networkx())
ax.set_title('Class: {:d}'.format(label))
plt.show()

### Form a graph mini-batch
To train neural networks more efficiently, we need to **batch** multiple samples together to form a mini-batch. Batching fixed-shaped tensor inputs is quite easy (for example, batching two images of size 28×28 gives a tensor of shape 2×28×28). By contrast, batching graph inputs has two challenges:

- Graphs are sparse.
- Graphs can have various length (e.g. number of nodes and edges).

To address this, DGL provides a `dgl.batch()` API. It leverages the trick that a batch of graphs can be viewed as a large graph that have many disjoint connected components. Below is an illustration:

![](https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/batch.png)

We define the following `collate` function to form a mini-batch from a given list of graph and label pairs.

In [None]:
import dgl

def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

The return type of `dgl.batch()` is still a graph (similar to the fact that a batch of tensors is still a tensor). This means that any code that works for one graph immediately works for a batch of graphs. More importantly, since DGL processes messages on all nodes and edges in parallel, this greatly improves efficiency.

### Graph Classifier
The graph classification can be proceeded as follows:
![](https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/graph_classifier.png)
From a batch of graphs, we first perform message passing/graph convolution for nodes to “communicate” with others. After message passing, we compute a tensor for graph representation from node (and edge) attributes. This step may be called “readout/aggregation” interchangeably. Finally, the graph representations can be fed into a classifier g to predict the graph labels.

### Graph Convolution

Our graph convolution operation is basically the same as that for GCN (checkout our tutorial). The only difference is that we use a simpler normalization factor for aggregating messages: $c_{ij}=|N(v_i)|$. Therefore, the update equation becomes:
$$h^{(l+1)}_{v_i}=ReLU(b^{(l)}+\frac{1}{|N(v_i)|}\sum\limits_{v_j\in N(v_i)}h^{(l)}_{v_j}W^{(l)})$$
The replacement of summation by average is to balance nodes with different degrees, which gives a better performance for this experiment.

Note that the self edges added in the dataset initialization allows us to include the original node feature $h^{(l)}_v$.


In [None]:
import dgl.function as fn
import torch
import torch.nn as nn


# Sends a message of node feature h.
msg = fn.copy_src(src='h', out='m')

def reduce(nodes):
    """Take an average over all neighbor node features hu and use it to
    overwrite the original node feature."""
    accum = torch.mean(nodes.mailbox['m'], 1)
    return {'h': accum}

class NodeApplyModule(nn.Module):
    """Update the node feature hv with ReLU(Whv+b)."""
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h' : h}

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        # Initialize the node features with h.
        g.ndata['h'] = feature
        g.update_all(msg, reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

### Readout and Classification

For this demonstration, we consider initial node features to be their degrees. After two rounds of graph convolution, we perform a graph readout by averaging over all node features for each graph in the batch

$$h_g=\frac{1}{|V|}\sum\limits_{v\in V}h_v$$

In DGL, `dgl.mean_nodes()` handles this task for a batch of graphs with variable size. We then feed our graph representations into a classifier with one linear layer to obtain pre-softmax logits.


In [None]:
import torch.nn.functional as F

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()

        self.layers = nn.ModuleList([
            GCN(in_dim, hidden_dim, F.relu),
            GCN(hidden_dim, hidden_dim, F.relu)])
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # For undirected graphs, in_degree is the same as
        # out_degree.
        h = g.in_degrees().view(-1, 1).float()
        for conv in self.layers:
            h = conv(g, h)
        g.ndata['h'] = h
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)

### Setup and Training
We create a synthetic dataset of 400 graphs with 10 ~ 20 nodes. 320 graphs constitute a training set and 80 graphs constitute a test set.


In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader

# Create training and test sets.
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)
# Use PyTorch's DataLoader and the collate function
# defined before.
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
                         collate_fn=collate)

# Create model
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

The learning curve of a run is presented below:

In [None]:
plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()

In [None]:
model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))