In [164]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from misc_tools.print_latex import print_tex
import torch

paper: https://arxiv.org/pdf/1710.10903.pdf

pytorch source: https://nn.labml.ai/graphs/gat/index.html

* features $\vec{h}_i$ -> colums
* $\vec{h}_i$ -> matrix H

$$H=
\begin{bmatrix}
\vec{h}_1 & \vec{h}_2 & \dots
\end{bmatrix}
$$


In [165]:
N_NODES,N_FEATS = 2,2
H = 0.5*torch.arange(N_NODES*N_FEATS, dtype = float).view(N_NODES, N_FEATS).T
print_tex(r'\vec{h}_0 = ', H[:,[0]].numpy(), r'; \vec{h}_1 = ', H[:,[1]].numpy(), '; H = ', H.numpy())

<IPython.core.display.Math object>

* embed $\vec{h}_i$ into $\vec{g}_i$ via $W$
* 

In [166]:
W1 = 2*torch.eye(N_FEATS, dtype=H.dtype)
G = W1 @ H
print_tex('W = ', W1.numpy(),'; G = HW = ', G.numpy())

<IPython.core.display.Math object>

* make pairwise concatenation matrix $C$
$$G = 
\begin{bmatrix}
\vec{g}_1 &
\vec{g}_2 & \dots
\end{bmatrix}
$$

$$C = 
\begin{bmatrix}
\vec{g}_1||\vec{g}_1 &  \vec{g}_1||\vec{g}_2 \\
\vec{g}_2||\vec{g}_1 &  \vec{g}_2||\vec{g}_2
\end{bmatrix}
$$
cannot broadcast, have to concat all elements. If we flatten:
$$C_{flat} = 
\begin{bmatrix}
\vec{g}_1||\vec{g}_1 & \vec{g}_1||\vec{g}_2 &
\vec{g}_2||\vec{g}_1 & \vec{g}_2||\vec{g}_2
\end{bmatrix}
\approx
\begin{bmatrix}
\vec{g}_1 &  \vec{g}_1 &
\vec{g}_2 &  \vec{g}_2
\end{bmatrix}
||
\begin{bmatrix}
\vec{g}_1 &  \vec{g}_2 &
\vec{g}_1 &  \vec{g}_2
\end{bmatrix}
$$

So its element-wise concat. Observe ordering of both vectors. 

In [167]:
g1      = torch.repeat_interleave(G, N_NODES, dim = 1)  # h1,h1,h1,h2,h2,h2,h3,...
g2      = torch.tile(G, dims=(1,2))                     # h1,h2,h3,h1,h2,h3,h1,...
C_f     = torch.cat((g1,g2), dim = 0)
print_tex(r'\vec{g}_0 = ', G[:,[0]].numpy(), r'; \vec{g}_1 = ', G[:,[1]].numpy(), '; C_{flat} = ',C_f.numpy())

<IPython.core.display.Math object>

* repack into form where each concatenated feature can be dot-multiplied by attention vector $\vec{a} \in \R^{2F \times 1}$

$\vec{C_{i,j}} = (\vec{g}_i||\vec{g}_j) \in \R^{2F \times 1}$ a column vector

$$\vec{a}^T = 
\begin{bmatrix}
a_1 & a_2 & \dots & a_{2F}
\end{bmatrix}
$$

technically, i could keep this $C \in \R^{[2F \times  N^2]}$ matrix:
$$E = 
\vec{a}^T
\begin{bmatrix}
\vec{C_{1,1}} & \vec{C_{1,2}} & \vec{C_{2,1}} &  \vec{C_{2,2}}
\end{bmatrix}=

\begin{bmatrix}
\vec{a}^T\vec{C_{1,1}} & \vec{a}^T\vec{C_{1,2}} & \vec{a}^T\vec{C_{2,1}} &  \vec{a}^T\vec{C_{2,2}}
\end{bmatrix}=
\begin{bmatrix}
e_{1,1} & e_{1,2} & e_{2,1} & e_{2,2}
\end{bmatrix}
$$

or i can cast it into shape ${[N \times N  \times 2F \times 1]}$, which will be the same layout as adjacency matrix:
$$
C = 
\begin{bmatrix}
\vec{C_{1,1}} & \vec{C_{1,2}} \\ \vec{C_{2,1}} &  \vec{C_{2,2}}
\end{bmatrix}=
\begin{bmatrix}
\begin{bmatrix}
0 \\ 1 \\ 0 \\ 1
\end{bmatrix}
& 
\begin{bmatrix}
0 \\ 1 \\ 2 \\ 3
\end{bmatrix}
\\
\begin{bmatrix}
2 \\ 3 \\ 0 \\ 1
\end{bmatrix}
& 
\begin{bmatrix}
2 \\ 3 \\ 2 \\ 3
\end{bmatrix}
\end{bmatrix}
$$

Cast $\vec{a}^T$ to shape ${[1 \times 1  \times 2F \times 1]}$ for broadcasting(reminder) to  <br>
match C shape ${[N \times N  \times 2F \times 1]}$

In [168]:
C = C_f.T.reshape(N_NODES,N_NODES, 2*N_FEATS,1);C
print_tex(r'\vec{C}_{1,1} = ', C[0,0].numpy(),r'\vec{C}_{1,2} = ', C[0,1].numpy(),r'\vec{C}_{2,1} = ', C[1,0].numpy(),r'\vec{C}_{2,2} = ', C[1,1].numpy(),)

<IPython.core.display.Math object>

Can test broadcasting by element-picker vector. 

i.e $\vec{a}^T$ = [1,0,0,0] $\rightarrow \vec{a}^T \vec{C_{1,1}} = (\vec{C_{1,1}})_1$

In [169]:
diag = torch.eye(2*N_FEATS).to(C.dtype)
for i,a in enumerate(diag):
    a = a.reshape(1,1,2*N_FEATS,1).transpose(3,2);
    print_tex(f'element \ {i} \ of \ C_(i,j)',( a @ C).squeeze(-1).squeeze(-1).numpy())
print(f'{C.shape = }')
print(f'{a.shape = }')


<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

C.shape = torch.Size([2, 2, 4, 1])
a.shape = torch.Size([1, 1, 1, 4])


* Because torch Linear() does left matrix-multiply (https://pytorch.org/docs/stable/generated/torch.nn.Linear.html),
we might transpose problem so its $C^T\vec{a}$ 

* apply non-linearity

* apply softmax