## `dgl.nn` the offical dgl implementations of the famous GNNs

`dgl.nn` is the dgl package that you want to check when you start your GNN projects. In `dgl.nn`, you can find highly optimized GNN layers that are ready to be used for general purposes. Let's recover dgl implemented GCNs.

In [4]:
import dgl
import torch
import dgl.nn.pytorch.conv as dglconv

In [11]:
u, v = torch.tensor([0, 0, 0, 1]), torch.tensor([1, 2, 3, 3])
g = dgl.graph((u, v), num_nodes=8)
g = dgl.add_self_loop(g)

node_feat_dim = 32 # the node feature dim
edge_feat_dim = 3 # the edge feature dim

g.ndata['feat'] = torch.randn(g.number_of_nodes(), node_feat_dim)
g.edata['feat'] = torch.randn(g.number_of_edges(), edge_feat_dim)

In [12]:
gc_out_dim = 256

gc = dglconv.GraphConv(in_feats=node_feat_dim, 
                       out_feats=gc_out_dim)

In [13]:
%%timeit
h_updated = gc(g, g.ndata['feat'])

2.33 ms ± 980 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Is that all? Yes, that is all! Super simple. 

## So what happens under the hood of `dglconv.GraphConv`?

`dglconv.GraphConv` implements practically very important details of GCN.

1. Dynamically computing normalized Laplacian matrix.
2. Adaptive computation of $AXW$
> Basically, checking the input-output dims and perform the matrix product so that the number of arithmetic computations becomes small.

## What the famous gnn layers are implemented?

In the `dgl.nn`, there exist various implementations of the famous GNNs. out of all the implementations, you may be happy that `dgl.nn` has implemented the followings:
1. Graph convolution (GCN) `GraphConv`
2. Graph attention networks (GAT) `GATConv`
3. Graph SAGE `SAGEConv`
4. Graph isomorphism networks (GIN) `GINConv`

In this tutorial, let's check the `GATConv`.

## Computing edge attented node features with `GATConv`

In [14]:
gat = dglconv.GATConv(in_feats=node_feat_dim,
                      out_feats=gc_out_dim,
                      num_heads=4)

In [15]:
h_updated = gat(g, g.ndata['feat'])

In [16]:
h_updated.shape # [#.nodes x #. attn head x # out dim]

torch.Size([8, 4, 256])

## Batched graph computations

Mini batch training is common practice in training neural network models for efficient computations. That is also the same for training GNN models. However, batched computations of GNN made be less intuitive compared to the tensor version of those.

In the batched forward propagations (and also backward) for the tensor inputs, you explicitly consider the first dimension of your inputs are designated for the batch. e.g., $X \in \mathbb{R}^{b \times p}$ where $b$ is the size of mini-batch and $p$ is the input feature dimension.

How can we batch the graph and how to compute the features on the graphs in a batched fashion? This idea is simple. Build a block matrix of adjacent matrices, and each block component is for the graphs. 

Manually implementing the batched graph is painful for many reasons. You have to deal with all the node and edge indices, which block component comes from which graph, etc.

## Instead, why don't we `dgl.batch` ??

In [17]:
batched_g = dgl.batch([g, g])

## Check the graph statistics of `batched_g`

In [39]:
print("Number of graphs in the batched graphs : {} \n".format(batched_g.batch_size))

print("Total number of nodes : {}".format(batched_g.num_nodes()))
print("Total number of edges : {} \n".format(batched_g.num_edges()))
n_nodes = [i.item() for i in batched_g.batch_num_nodes()]
n_edges = [i.item() for i in batched_g.batch_num_edges()]

print("Per graph number of nodes : {}".format(n_nodes))
print("Per graph number of edges : {} \n".format(n_edges))

Number of graphs in the batched graphs : 2 

Total number of nodes : 16
Total number of edges : 24 

Per graph number of nodes : [8, 8]
Per graph number of edges : [12, 12] 



It is noteworthy that in DGL implementations (>= 0.4v), either batched or single graph are different instantiation of same graph class. Therefore, the methods we've used for the single graph `g` are also usable for the batched graph `batched_g`.

In [42]:
type(g), type(batched_g)

(dgl.heterograph.DGLHeteroGraph, dgl.heterograph.DGLHeteroGraph)

## computing with batched graphs

In [45]:
h_updated = gc(batched_g, batched_g.ndata['feat'])
print(h_updated.shape) # take a look at that the first dimension of output is now doubled.

torch.Size([16, 256])
