# Semi-supervised Community Detection using Graph Neural Networks

Almost every computer 101 class starts with a "Hello World" example. Like MNIST for deep learning, in graph domain we have the Zachary's Karate Club problem. The karate club is a social network that includes 34 members and documents pairwise links between members who interact outside the club. The club later divides into two communities led by the instructor (node 0) and the club president (node 33). The network is visualized as follows with the color indicating the community.

<img src='../asset/karat_club.png' align='center' width="400px" height="300px" />

In this tutorial, you will learn:

* Formulate the community detection problem as a semi-supervised node classification task.
* Build a GraphSAGE model, a popular Graph Neural Network architecture proposed by [Hamilton et al.](https://arxiv.org/abs/1706.02216)
* Train the model and understand the result.

In [None]:
.. _guide-nn-heterograph:

3.3 Heterogeneous GraphConv Module
------------------------------------

:ref:`(中文版) <guide_cn-nn-heterograph>`

:class:`~dgl.nn.pytorch.HeteroGraphConv`
is a module-level encapsulation to run DGL NN module on heterogeneous
graphs. The implementation logic is the same as message passing level API
:meth:`~dgl.DGLGraph.multi_update_all`, including:

-  DGL NN module within each relation :math:`r`.
-  Reduction that merges the results on the same node type from multiple
   relations.

This can be formulated as:

.. math::  h_{dst}^{(l+1)} = \underset{r\in\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))

where :math:`f_r` is the NN module for each relation :math:`r`,
:math:`AGG` is the aggregation function.

HeteroGraphConv implementation logic:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. code::

    import torch.nn as nn

    class HeteroGraphConv(nn.Module):
        def __init__(self, mods, aggregate='sum'):
            super(HeteroGraphConv, self).__init__()
            self.mods = nn.ModuleDict(mods)
            if isinstance(aggregate, str):
                # An internal function to get common aggregation functions
                self.agg_fn = get_aggregate_fn(aggregate)
            else:
                self.agg_fn = aggregate

The heterograph convolution takes a dictionary ``mods`` that maps each
relation to an nn module and sets the function that aggregates results on
the same node type from multiple relations.

.. code::

    def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
        if mod_args is None:
            mod_args = {}
        if mod_kwargs is None:
            mod_kwargs = {}
        outputs = {nty : [] for nty in g.dsttypes}

Besides input graph and input tensors, the ``forward()`` function takes
two additional dictionary parameters ``mod_args`` and ``mod_kwargs``.
These two dictionaries have the same keys as ``self.mods``. They are
used as customized parameters when calling their corresponding NN
modules in ``self.mods`` for different types of relations.

An output dictionary is created to hold output tensor for each
destination type ``nty`` . Note that the value for each ``nty`` is a
list, indicating a single node type may get multiple outputs if more
than one relations have ``nty`` as the destination type. ``HeteroGraphConv``
will perform a further aggregation on the lists.

.. code::

          if g.is_block:
              src_inputs = inputs
              dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
          else:
              src_inputs = dst_inputs = inputs

          for stype, etype, dtype in g.canonical_etypes:
              rel_graph = g[stype, etype, dtype]
              if rel_graph.num_edges() == 0:
                  continue
              if stype not in src_inputs or dtype not in dst_inputs:
                  continue
              dstdata = self.mods[etype](
                  rel_graph,
                  (src_inputs[stype], dst_inputs[dtype]),
                  *mod_args.get(etype, ()),
                  **mod_kwargs.get(etype, {}))
              outputs[dtype].append(dstdata)

The input ``g`` can be a heterogeneous graph or a subgraph block from a
heterogeneous graph. As in ordinary NN module, the ``forward()``
function need to handle different input graph types separately.

Each relation is represented as a ``canonical_etype``, which is
``(stype, etype, dtype)``. Using ``canonical_etype`` as the key, one can
extract out a bipartite graph ``rel_graph``. For bipartite graph, the
input feature will be organized as a tuple
``(src_inputs[stype], dst_inputs[dtype])``. The NN module for each
relation is called and the output is saved. To avoid unnecessary call,
relations with no edges or no nodes with the src type will be skipped.

.. code::

        rsts = {}
        for nty, alist in outputs.items():
            if len(alist) != 0:
                rsts[nty] = self.agg_fn(alist, nty)

Finally, the results on the same destination node type from multiple
relations are aggregated using ``self.agg_fn`` function. Examples can
be found in the API Doc for :class:`~dgl.nn.pytorch.HeteroGraphConv`.


## Community detection as node classification

The study of community structure in graphs has a long history. Many proposed methods are *unsupervised* (or *self-supervised* by recent definition), where the model predicts the community labels only by connectivity. Recently, [Kipf et al.,](https://arxiv.org/abs/1609.02907) proposed to formulate the community detection problem as a semi-supervised node classification task. With the help of only a small portion of labeled nodes, a GNN can accurately predict the community labels of the others.

In this tutorial, we apply Kipf's setting to the Zachery's Karate Club network to predict the community membership, where only the labels of a few nodes are used.

We first load the graph and node labels as is covered in the [last session](./1_load_data.ipynb). Here, we have provided you a function for loading the data.

In [None]:
from tutorial_utils import load_zachery

# ----------- 0. load graph -------------- #
g = load_zachery()
print(g)

In the original Zachery's Karate Club graph, nodes are feature-less. (The `'Age'` attribute is an artificial one mainly for tutorial purposes). For feature-less graph, a common practice is to use an embedding weight that is updated during training for every node.

We can use PyTorch's `Embedding` module to achieve this.

In [None]:
# ----------- 1. node features -------------- #
node_embed = nn.Embedding(g.number_of_nodes(), 5)  # Every node has an embedding of size 5.
inputs = node_embed.weight                         # Use the embedding weight as the node features.
nn.init.xavier_uniform_(inputs)
print(inputs)

The community label is stored in the `'club'` node feature (0 for instructor, 1 for club president). Only nodes 0 and 33 are labeled.

In [None]:
labels = g.ndata['club']
labeled_nodes = [0, 33]
print('Labels', labels[labeled_nodes])

## Define a HeteroGraphConv model

HeteroGraphConv is a module-level encapsulation to run DGL NN module on heterogeneous graphs. The implementation logic is the same as message passing level API multi_update_all(), including:

DGL NN module within each relation 𝑟.

Reduction that merges the results on the same node type from multiple relations.

$$
h_{dst}^{(l+1)} = \underset{r\in\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))$$

https://docs.dgl.ai/guide/nn-heterograph.html?highlight=heterogenous%20graphs


If your graph is heterogeneous, you may want to gather message from neighbors along all edge types. You can use the module dgl.nn.pytorch.HeteroGraphConv (also available in MXNet and Tensorflow) to perform message passing on all edge types, then combining different graph convolution modules for each edge type.

The following code will define a heterogeneous graph convolution module that first performs a separate graph convolution on each edge type, then sums the message aggregations on each edge type as the final result for all node types.

dgl.nn.HeteroGraphConv takes in a dictionary of node types and node feature tensors as input, and returns another dictionary of node types and node features.



In [None]:
# ----------- 2. create model -------------- #
# build a two-layer RGCN model
import dgl.nn as dglnn

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h
    
# Create the model with given dimensions 
# input layer dimension: 5, node embeddings
# hidden layer dimension: 16
# output layer dimension: 2, the two classes, 0 and 1
net = RGCN(5, 16, 2)

In [None]:
# ----------- 3. set up loss and optimizer -------------- #
# in this case, loss will in training loop
optimizer = torch.optim.Adam(itertools.chain(net.parameters(), node_embed.parameters()), lr=0.01)

# ----------- 4. training -------------------------------- #
all_logits = []
for e in range(100):
    # forward
    logits = net(g, inputs)
    
    # compute loss
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[labeled_nodes], labels[labeled_nodes])
    
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    all_logits.append(logits.detach())
    
    if e % 5 == 0:
        print('In epoch {}, loss: {}'.format(e, loss))

In [None]:
# ----------- 5. check results ------------------------ #
pred = torch.argmax(logits, axis=1)
print('Accuracy', (pred == labels).sum().item() / len(pred))

## Visualize the result

Since the GNN produces a logit vector of size 2 for each array. We can plot to a 2-D plane.

<img src='../asset/gnn_ep0.png' align='center' width="400px" height="300px"/>
<img src='../asset/gnn_ep_anime.gif' align='center' width="400px" height="300px"/>

Run the following code to visualize the result. Require ffmpeg.

In [None]:
# A bit of setup, just ignore this cell
import matplotlib.pyplot as plt

# for auto-reloading external modules
%load_ext autoreload
%autoreload 2

%matplotlib inline
plt.rcParams['figure.figsize'] = (4.0, 3.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['animation.html'] = 'html5'

In [None]:
# Visualize the node classification using the logits output. Requires ffmpeg.
import networkx as nx
import numpy as np
import matplotlib.animation as animation
from IPython.display import HTML

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()
nx_G = g.to_networkx()
def draw(i):
    cls1color = '#00FFFF'
    cls2color = '#FF00FF'
    pos = {}
    colors = []
    for v in range(34):
        pred = all_logits[i].numpy()
        pos[v] = pred[v]
        cls = labels[v]
        colors.append(cls1color if cls else cls2color)
    ax.cla()
    ax.axis('off')
    ax.set_title('Epoch: %d' % i)
    nx.draw(nx_G.to_undirected(), pos, node_color=colors, with_labels=True, node_size=200)

ani = animation.FuncAnimation(fig, draw, frames=len(all_logits), interval=200)
HTML(ani.to_html5_video())

## Exercise

Play with the GNN models by using other [graph convolution modules](https://docs.dgl.ai/api/python/nn.pytorch.html#module-dgl.nn.pytorch.conv).