In this week, you are required to implement a toy GATConv and SAGEConv based on document. Also, you need to implement both in PyG and DGL. In this work, you will get a further understanding of tensor-centric in PyG and graph-centric in DGL.

## DGL

In [4]:
import dgl
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

### GraphConv
Mathematically it is defined as follows:

$$
  h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ji}}h_j^{(l)}W^{(l)})
$$
where $\mathcal{N}(i)$ is the set of neighbors of node $i$, 
$c_{ji}$ is the product of the square root of node degrees
$(i.e.,  c_{ji} = \sqrt{|\mathcal{N}(j)|}\sqrt{|\mathcal{N}(i)|})$,
and $\sigma$ is an activation function.

If a weight tensor on each edge is provided, the weighted graph convolution is defined as:

$$
  h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{e_{ji}}{c_{ji}}h_j^{(l)}W^{(l)})
$$
where $e_{ji}$is the scalar weight on the edge from node $j$ to node $i$.
This is NOT equivalent to the weighted graph convolutional network formulation in the paper.

To customize the normalization term :$c_{ji}$, one can first set ``norm='none'`` for
the model, and send the pre-normalized :$e_{ji}$ to the forward computation. We provide
:class:`~dgl.nn.pytorch.EdgeWeightNorm` to normalize scalar edge weight following the GCN paper.

In [26]:
class DGL_GraphConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DGL_GraphConv, self).__init__()
        self.W = nn.Parameter(torch.rand(in_channels, out_channels))
        self.b = nn.Parameter(torch.rand(out_channels))
        self.activate = nn.ReLU()

    def forward(self, g, h):
        with g.local_scope():
            # 这里的normalization用了left+right，也就是考虑了出度+入度
            norm_src = g.out_degrees().clamp(min=1).view(-1, 1)
            norm_src = torch.pow(norm_src, -0.5)

            feat_src = norm_src * h
            feat_src = torch.matmul(feat_src, self.W)
            g.srcdata["h"] = feat_src
            g.update_all(fn.u_mul_e("h", "edge_weight", "he"), fn.sum("he", "rst"))
            rst = g.dstdata["rst"]

            norm_dst = g.in_degrees().clamp(min=1).view(-1, 1)
            norm_dst = torch.pow(norm_dst, -0.5)
            rst = rst * norm_dst

            rst += self.b
            return F.relu(rst)

In [27]:
src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))

g = dgl.graph((src, dst))
g.ndata["h"] = h
edge_weight = torch.ones(g.num_edges())  # 给各个边赋格权重
g.edata["edge_weight"] = edge_weight
g

Graph(num_nodes=5, num_edges=6,
      ndata_schemes={'h': Scheme(shape=(8,), dtype=torch.float32)}
      edata_schemes={'edge_weight': Scheme(shape=(), dtype=torch.float32)})

In [32]:
in_channels = h.shape[1]
out_channels = in_channels  # *2
dgl_graphConv = DGL_GraphConv(in_channels, out_channels)

In [33]:
dgl_graphConv(g, h)

tensor([[2.1296, 4.5994, 3.6825, 3.6156, 2.8758, 3.3465, 2.9255, 4.2599],
        [0.2342, 0.9638, 0.6446, 0.8826, 0.4221, 0.4741, 0.7989, 0.4236],
        [3.4698, 7.1701, 5.8306, 5.5481, 4.6109, 5.3777, 4.4293, 6.9727],
        [3.4698, 7.1701, 5.8306, 5.5481, 4.6109, 5.3777, 4.4293, 6.9727],
        [2.1296, 4.5994, 3.6825, 3.6156, 2.8758, 3.3465, 2.9255, 4.2599]],
       grad_fn=<ReluBackward0>)

### GATConv
Graph attention layer from Graph Attention Network
$$h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)}$$

where $\alpha_{ij}$ is the attention score bewteen node $i$ and
node $j$:

$$
\begin{align}\begin{aligned}\alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l})\\e_{ij}^{l} &= \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right)\end{aligned}\end{align}
$$


In [34]:
class DGL_GATConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DGL_GATConv,self).__init__()
        self.W=nn.Parameter(torch.rand(in_channels,out_channels))
        self.a=nn.Parameter(torch.rand(2*in_channels,1))
        self.leakyrelu=nn.LeakyReLU()
        pass

    def forward(self, g, h):
        # 参考了源码，论文中先将[Wh_i||Wh_j]拼接，再计算a[Wh_i||Wh_j]，
        # 这样会使[Wh_i||Wh_j]更大的矩阵再边上传输
        # 因此先将a分解为[a_i||a_j]，然后进行a^T@[Wh_i || Wh_j] = a_l@Wh_i + a_r@Wh_j
        
        
        
        

In [131]:
src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))

g = dgl.graph((src, dst))
g.ndata["h"] = h
g

Graph(num_nodes=5, num_edges=6,
      ndata_schemes={'h': Scheme(shape=(8,), dtype=torch.float32)}
      edata_schemes={})

In [134]:
in_channels=8
out_channels=4
h_input=g.ndata['h']
W=torch.ones(in_channels,out_channels)
a=torch.ones(2*in_channels,1)

with g.local_scope():
    hW=h_input@W    
    

In [135]:
N=3

In [174]:
h=torch.tensor([
    [1,11,111,1111],
    [2,22,222,2222],
    [3,33,333,3333]
],dtype=torch.float)
a=torch.ones(4*2,1,dtype=torch.float)

In [175]:
hl=h.repeat(1,N).view(N*N,-1)
hl

tensor([[1.0000e+00, 1.1000e+01, 1.1100e+02, 1.1110e+03],
        [1.0000e+00, 1.1000e+01, 1.1100e+02, 1.1110e+03],
        [1.0000e+00, 1.1000e+01, 1.1100e+02, 1.1110e+03],
        [2.0000e+00, 2.2000e+01, 2.2200e+02, 2.2220e+03],
        [2.0000e+00, 2.2000e+01, 2.2200e+02, 2.2220e+03],
        [2.0000e+00, 2.2000e+01, 2.2200e+02, 2.2220e+03],
        [3.0000e+00, 3.3000e+01, 3.3300e+02, 3.3330e+03],
        [3.0000e+00, 3.3000e+01, 3.3300e+02, 3.3330e+03],
        [3.0000e+00, 3.3000e+01, 3.3300e+02, 3.3330e+03]])

In [176]:
hr=h.repeat(N,1)
hr

tensor([[1.0000e+00, 1.1000e+01, 1.1100e+02, 1.1110e+03],
        [2.0000e+00, 2.2000e+01, 2.2200e+02, 2.2220e+03],
        [3.0000e+00, 3.3000e+01, 3.3300e+02, 3.3330e+03],
        [1.0000e+00, 1.1000e+01, 1.1100e+02, 1.1110e+03],
        [2.0000e+00, 2.2000e+01, 2.2200e+02, 2.2220e+03],
        [3.0000e+00, 3.3000e+01, 3.3300e+02, 3.3330e+03],
        [1.0000e+00, 1.1000e+01, 1.1100e+02, 1.1110e+03],
        [2.0000e+00, 2.2000e+01, 2.2200e+02, 2.2220e+03],
        [3.0000e+00, 3.3000e+01, 3.3300e+02, 3.3330e+03]])

In [190]:
h_cat=torch.concat([hl,hr],dim=1)
h_cat.view(N,-1,in_channels).shape


torch.Size([3, 3, 8])

In [181]:
e=F.leaky_relu(h_cat@a)
alpha=F.softmax(e,dim=0)
alpha

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.]])

In [184]:
(h_cat@a).shape

torch.Size([9, 1])

In [75]:
# Case 1: Homogeneous graph
from dgl.nn import GATConv
g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
g = dgl.add_self_loop(g)
feat = torch.ones(6, 10)
gatconv = GATConv(10, 2, num_heads=1)
res = gatconv(g, feat)
res

tensor([[[ 1.4764, -2.1203]],

        [[ 1.4764, -2.1203]],

        [[ 1.4764, -2.1203]],

        [[ 1.4764, -2.1203]],

        [[ 1.4764, -2.1203]],

        [[ 1.4764, -2.1203]]], grad_fn=<AddBackward0>)

### SAGEConv
$$
\begin{align}\begin{aligned}h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate}
\left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)\\h_{i}^{(l+1)} &= \sigma \left(W \cdot \mathrm{concat}
(h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right)\\h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{(l+1)})\end{aligned}\end{align}
$$
If a weight tensor on each edge is provided, the aggregation becomes:
$$
h_{\mathcal{N}(i)}^{(l+1)} = \mathrm{aggregate}
\left(\{e_{ji} h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)
$$
where $e_{ji}$ is the scalar weight on the edge from node $j$ to node $i$.
    Please make sure that $e_{ji}$ is broadcastable with $h_j^{l}$.





In [None]:
class DGL_SAGEConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        pass

    def forward(self, g, h):
        pass

If you want to check your answer, you can run the following code.

In [None]:
edge_index = torch.tensor([[0, 1, 1, 2, 2, 4], [2, 0, 2, 3, 4, 3]])
x = torch.ones((5, 8))
conv = PyG_GATConv(8, 4)
output = conv(x, edge_index)
print(output)
conv = PyG_SAGEConv(8, 4)
output = conv(x, edge_index)
print(output)

src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))
g = dgl.graph((src, dst))
conv = DGL_GATConv(8, 4)
output = conv(g, h)
print(output)
conv = DGL_SAGEConv(8, 4)
output = conv(g, h)
print(output)