In [1]:
import torch
import torch.nn as nn

In [77]:
# a Single set of nodes
q = torch.arange(0, 4).float()
q

tensor([0., 1., 2., 3.])

In [100]:
key = torch.arange(2, 10, 2).float()
key

tensor([2., 4., 6., 8.])

A Directed graph $\mathcal{G} = (\mathcal{V}, \mathcal{E})$ contains a set of nodes $\mathcal{V} = \{1, \ldots,n\}$ and a set of edges $\mathcal{E} \subseteq \mathcal{V} \times \mathcal{V}$, where $(j, i) \in \mathcal{E}$ denotes an edge from node $j$ to node $i$. We assume that every node $i \in \mathcal{V}$ has an initial representation $\mathbf{h}_i^{(0)} \in \mathbb{R}^{d_0}$.

A graph nerual network (GNN) layer updates every node representation by aggregating its neighbours' representations. A layer's input is a set of node represnetations $\{\mathbf{h}_i \in \mathbb{R}^{d}| i \in \mathcal{V}$  and the set of edges $\mathcal{E}$. A layer outputs a new set of node representations $\{\mathbf{h}_i' \in \mathbb{R}^{d'}| i \in \mathcal{V}\}$, where the same parametric function is applied to every node given its neighbours $\mathcal{N}_i = \{j \in \mathcal{V} | (j, i) \in \mathcal{E}\}$:
$$\mathbf{h}_i' = f_{\theta}(\mathbf{h}_i), \text{AGGREGATE}(\{\mathbf{h}_j|j\in \mathcal{N}_i\})$$ 

The design of $f$ and $\text{AGGREGATE}$ is mostly what distinguishes one type of GNN from the other.

For example [_GraphSAGE_ (Hamilton et al., 2017)](https://arxiv.org/abs/1706.02216) uses a mean aggregator, followed by concatenation with $\mathbf{h}_i$, a linear layer and a ReLU as $f$, i.e.:

In [35]:
aggregate = (torch.ones_like(q)/q.shape[0]).unsqueeze(0)
mean = torch.matmul(aggregate, q.unsqueeze(1))
input = torch.cat([q[0:1].unsqueeze(-1), mean], dim=0)
W = torch.ones_like(input)
h_1 = nn.functional.relu(torch.matmul(W.T, input))

In [36]:
h_1

tensor([[4.5000]])

GraphSage and many other popular GNN archtictures weight all neigbours equally (e.g. mean or max-pooling as the AGGREGATE function) e.g.

In [37]:
aggregate

tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000]])

To address this limitation, [Graph Attention Networks (GAT)](https://arxiv.org/abs/1710.10903) (Veličković et al., 2017)  computes a weighted average of the representation of $\mathcal{N}_i$. A scoring function $e: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}$ computes a score for every edge $(j, i) \in \mathcal{E}$, which is then normalized by a softmax function over all edges $(j, i) \in \mathcal{E}$, which indicates the importance of the representation of node $j$  to its neighbour $i$:

$$e(\mathbf{h}_i, \mathbf{h}_j) = \text{LeakyReLU}(\mathbf{a}^\top\cdot [\mathbf{Wh}_i\|\mathbf{Wh}_j])$$

In [44]:
attention_weight = torch.ones(4, 1)
W = torch.ones(2, 1)
e_12 = torch.cat([torch.matmul(W, q[1:2].unsqueeze(-1)), torch.matmul(W, key[2:3].unsqueeze(-1))])

torch.Size([4, 1])

We can break this down into the following steps:
Step 1 - Linear Transformation

In [101]:
W = torch.ones(2, 1)
g_query = torch.matmul(W, q.unsqueeze(0))
g_key = torch.matmul(W, key.unsqueeze(0))
g_value = torch.matmul(W, key.unsqueeze(0))

In [102]:
g_query.shape

torch.Size([2, 4])

In [103]:
query_batched = g_query.T.unsqueeze(0)

In [104]:
query_batched.shape

torch.Size([1, 4, 2])

In [105]:
key_batched = g_key.T.unsqueeze(0)
key_batched.shape

torch.Size([1, 4, 2])

In [106]:
repeated_q = query_batched.repeat(1, 4, 1)

In [107]:
repeated_q.shape

torch.Size([1, 16, 2])

In [108]:
repeated_q

tensor([[[0., 0.],
         [1., 1.],
         [2., 2.],
         [3., 3.],
         [0., 0.],
         [1., 1.],
         [2., 2.],
         [3., 3.],
         [0., 0.],
         [1., 1.],
         [2., 2.],
         [3., 3.],
         [0., 0.],
         [1., 1.],
         [2., 2.],
         [3., 3.]]])

In [109]:
repeated_key = key_batched.repeat_interleave(4, dim=1)

In [110]:
repeated_key.shape

torch.Size([1, 16, 2])

In [111]:
repeated_key

tensor([[[2., 2.],
         [2., 2.],
         [2., 2.],
         [2., 2.],
         [4., 4.],
         [4., 4.],
         [4., 4.],
         [4., 4.],
         [6., 6.],
         [6., 6.],
         [6., 6.],
         [6., 6.],
         [8., 8.],
         [8., 8.],
         [8., 8.],
         [8., 8.]]])

In [112]:
g_pre_sum = repeated_q + repeated_key

In [113]:
g_pre_sum

tensor([[[ 2.,  2.],
         [ 3.,  3.],
         [ 4.,  4.],
         [ 5.,  5.],
         [ 4.,  4.],
         [ 5.,  5.],
         [ 6.,  6.],
         [ 7.,  7.],
         [ 6.,  6.],
         [ 7.,  7.],
         [ 8.,  8.],
         [ 9.,  9.],
         [ 8.,  8.],
         [ 9.,  9.],
         [10., 10.],
         [11., 11.]]])

In [114]:
g_pre_sum.view(1, 4, 4, 2)

tensor([[[[ 2.,  2.],
          [ 3.,  3.],
          [ 4.,  4.],
          [ 5.,  5.]],

         [[ 4.,  4.],
          [ 5.,  5.],
          [ 6.,  6.],
          [ 7.,  7.]],

         [[ 6.,  6.],
          [ 7.,  7.],
          [ 8.,  8.],
          [ 9.,  9.]],

         [[ 8.,  8.],
          [ 9.,  9.],
          [10., 10.],
          [11., 11.]]]])

In [115]:
g_sum = g_pre_sum.view(1, 4, 4, 2)

In [116]:
g_sum.shape

torch.Size([1, 4, 4, 2])

In [117]:
g_sum[0, :, 0, 0]

tensor([2., 4., 6., 8.])

In [118]:
g_sum[0, :, 1, 0]

tensor([3., 5., 7., 9.])

In [119]:
adjacency = g_sum[0, :, :, 0]

In [130]:
adjacency[0]

tensor([2., 3., 4., 5.])

In [None]:
mask = torch.ones_like(adjacency)


In [124]:
e = g_sum.sum(dim=-1, keepdim=True)

In [127]:
g_sum

tensor([[[[ 2.,  2.],
          [ 3.,  3.],
          [ 4.,  4.],
          [ 5.,  5.]],

         [[ 4.,  4.],
          [ 5.,  5.],
          [ 6.,  6.],
          [ 7.,  7.]],

         [[ 6.,  6.],
          [ 7.,  7.],
          [ 8.,  8.],
          [ 9.,  9.]],

         [[ 8.,  8.],
          [ 9.,  9.],
          [10., 10.],
          [11., 11.]]]])

In [128]:
e = e.squeeze(-1)

In [129]:
e.shape

torch.Size([1, 4, 4])

In [131]:
e

tensor([[[ 4.,  6.,  8., 10.],
         [ 8., 10., 12., 14.],
         [12., 14., 16., 18.],
         [16., 18., 20., 22.]]])

In [132]:
e.shape

torch.Size([1, 4, 4])

In [133]:
e = e.unsqueeze(-1)

In [134]:
e.shape

torch.Size([1, 4, 4, 1])

In [136]:
mask = torch.ones_like(e)
mask[0, 1, 0, 0] = 0
mask[0, 2, 0, 0] = 0
mask[0, 3, 0, 0] = 0
mask[0, 2, 1, 0] = 0
mask[0, 3, 1, 0] = 0
mask[0, 3, 2, 0] = 0
mask[0, :, :, 0]

tensor([[1., 1., 1., 1.],
        [0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.]])

In [142]:
attn_res = torch.einsum('bij,bjf->bjf', mask.squeeze(-1), e.squeeze(-1))

In [143]:
attn_res.shape

torch.Size([1, 4, 4])

In [144]:
attn_res

tensor([[[ 4.,  6.,  8., 10.],
         [16., 20., 24., 28.],
         [36., 42., 48., 54.],
         [64., 72., 80., 88.]]])

In [154]:
g = torch.arange(0, 16).float().reshape(4, 4, 1)
e = torch.arange(0, 16).float().reshape(4, 4, 1)


In [150]:
e = e.masked_fill(mask.squeeze(-1) == 0, float('-inf'))

In [151]:
e

tensor([[[ 0.,  0.,  0.,  0.],
         [-inf,  1.,  1.,  1.],
         [-inf, -inf,  2.,  2.],
         [-inf, -inf, -inf,  3.]],

        [[ 4.,  4.,  4.,  4.],
         [-inf,  5.,  5.,  5.],
         [-inf, -inf,  6.,  6.],
         [-inf, -inf, -inf,  7.]],

        [[ 8.,  8.,  8.,  8.],
         [-inf,  9.,  9.,  9.],
         [-inf, -inf, 10., 10.],
         [-inf, -inf, -inf, 11.]],

        [[12., 12., 12., 12.],
         [-inf, 13., 13., 13.],
         [-inf, -inf, 14., 14.],
         [-inf, -inf, -inf, 15.]]])

In [152]:
softmax = nn.Softmax(dim=1)

In [153]:
softmax(e)

tensor([[[1.0000, 0.2689, 0.0900, 0.0321],
         [0.0000, 0.7311, 0.2447, 0.0871],
         [0.0000, 0.0000, 0.6652, 0.2369],
         [0.0000, 0.0000, 0.0000, 0.6439]],

        [[1.0000, 0.2689, 0.0900, 0.0321],
         [0.0000, 0.7311, 0.2447, 0.0871],
         [0.0000, 0.0000, 0.6652, 0.2369],
         [0.0000, 0.0000, 0.0000, 0.6439]],

        [[1.0000, 0.2689, 0.0900, 0.0321],
         [0.0000, 0.7311, 0.2447, 0.0871],
         [0.0000, 0.0000, 0.6652, 0.2369],
         [0.0000, 0.0000, 0.0000, 0.6439]],

        [[1.0000, 0.2689, 0.0900, 0.0321],
         [0.0000, 0.7311, 0.2447, 0.0871],
         [0.0000, 0.0000, 0.6652, 0.2369],
         [0.0000, 0.0000, 0.0000, 0.6439]]])

In [155]:
attn_res = torch.einsum('ijh,jhf->ihf', e, g)

In [156]:
attn_res.shape

torch.Size([4, 4, 1])

In [158]:
attn_res[:, :, 0]

tensor([[ 56.,  62.,  68.,  74.],
        [152., 174., 196., 218.],
        [248., 286., 324., 362.],
        [344., 398., 452., 506.]])

In [159]:
g

tensor([[[ 0.],
         [ 1.],
         [ 2.],
         [ 3.]],

        [[ 4.],
         [ 5.],
         [ 6.],
         [ 7.]],

        [[ 8.],
         [ 9.],
         [10.],
         [11.]],

        [[12.],
         [13.],
         [14.],
         [15.]]])

In [161]:
g_value.shape

torch.Size([2, 4])

In [162]:
g_value = g_value.T

In [163]:
g_value.shape

torch.Size([4, 2])

In [164]:
g_value = g_value.unsqueeze(1)

In [165]:
e.shape

torch.Size([4, 4, 1])

In [166]:
g_value.shape

torch.Size([4, 1, 2])

In [193]:
mask = mask.transpose(0, 1)
mask[:, :, 0]

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [194]:
e = torch.arange(0, 16).float().reshape(4, 4, 1)
e = e.masked_fill(mask == 0, float('-inf'))
e = softmax(e)
e[:, :, 0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.2689, 0.7311, 0.0000, 0.0000],
        [0.0900, 0.2447, 0.6652, 0.0000],
        [0.0321, 0.0871, 0.2369, 0.6439]])

In [195]:

attn_res = torch.einsum('ijh,jhf->ihf', e, g_value)

In [196]:
attn_res.shape

torch.Size([4, 1, 2])

In [197]:
g_value

tensor([[[2., 2.]],

        [[4., 4.]],

        [[6., 6.]],

        [[8., 8.]]])

In [198]:
attn_res

tensor([[[2.0000, 2.0000]],

        [[3.4621, 3.4621]],

        [[5.1504, 5.1504]],

        [[6.9853, 6.9853]]])