In [1]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

import torch#, time
import torch.nn as nn
device = torch.device("cpu")# torch.device("cuda" if torch.cuda.is_available() else "cpu")
from sklearn.preprocessing import OneHotEncoder
import matplotlib.animation as animation
from IPython.display import HTML
from misc_tools.print_latex import print_tex

input example : 
>>> arr_T = np.array([[r'\vec{v}_1', r'\vec{v}_2']]).T
>>> print_tex(arr_T,'=', np.arange(1,5).reshape(2,-1)/4, r'; symbols: \otimes, \cdot,\times')
output: 


<IPython.core.display.Math object>

For derivation see GNN_Attention_notes.ipynb
Im not sure whether implement unique $\vec{a}$ for all heads. Original paper hints on that. Implementation i have found reuses it.


In [2]:
np.random.seed(1337)
G_house = nx.house_graph()
A = nx.adjacency_matrix(G_house).todense()
print_tex('A = ', A)

<IPython.core.display.Math object>

In [3]:
#N_FEATURES, N_NODES = A.shape
N_NODES = len(G_house.nodes())
N_FEATURES = 2
N_HIDDEN = 2
N_HEADS = 2
OUT_FEATURES = 2

In [4]:
H = 0.5*torch.arange(N_NODES*N_FEATURES, dtype = float).view(N_NODES, N_FEATURES)
print_tex('H = ', H.numpy())

<IPython.core.display.Math object>

In [5]:
class net(nn.Module):
    def __init__(self, H0, A, test = False):
        super(net, self).__init__()
        self.H = H0
        self.A = A
        self.W_gh   = nn.Linear(in_features=N_FEATURES, out_features=N_HEADS*N_HIDDEN, bias=False, dtype=H0.dtype)
        self.Gk     = torch.zeros(size=(N_NODES, N_HEADS*N_FEATURES), dtype=H0.dtype)
        self.GkR    = self.Gk.view(N_NODES, N_HEADS, N_FEATURES)
        self.Ck_l   = torch.zeros(size=(N_NODES*N_NODES, N_HEADS, N_FEATURES), dtype=H0.dtype) 
        self.Ck_r   = torch.zeros_like(self.Ck_l)
        self.Ck_f   = torch.zeros(size=(N_NODES*N_NODES, N_HEADS, 2*N_FEATURES), dtype=H0.dtype) 
        self.Ck     = self.Ck_f.view(N_NODES, N_NODES, N_HEADS, 2*N_FEATURES)
        self.attnt  = nn.Linear(2*N_FEATURES, 1, bias=False, dtype = H0.dtype)
        self.activ  = nn.ReLU(0.2)
        self.E      = torch.zeros(size=(N_NODES,N_NODES, N_HEADS), dtype=H0.dtype)
        self.alpha  = torch.zeros_like(self.E)
        self.softmax= nn.Softmax(dim = 1)
         
        if test:
            self.debug()

    def debug(self):
        with torch.no_grad():
            print('Create matrix transformed embeddings (concatenated N_HEADS (K) blocks of shape (N_nodes,N_HIDDEN):')
            print_tex('G_K = H W_{K=1}^T || H W_{K=2}^T = G_{K=1} || G_{K=2}')
            # set scaling transforms
            for i, s in enumerate(range(2,2*N_HEADS + 2, 2)):
                self.W_gh.weight[i*N_FEATURES:(i+1)*N_FEATURES] = s*torch.eye(N_FEATURES)
            self.Gk += self.W_gh(self.H)        # cannot redefine, it will break a view
            print_tex('G_K = H W_K^T = ', H.numpy() , self.W_gh.weight[:].T.numpy(),' = ', self.Gk.numpy())
            print('Reshape G_K form (N_NODES, N_HEADS*N_FEATURES) to (N_NODES, N_HEADS, N_FEATURES)')
            print("Prepare C_K: stacked N_HEADS attention matrices of shape (N_NODES, N_NODES, N_HEADS, 2*N_FEATURES)")
            print('Each entry i,j of C_K holds N_HEADS of concatenated feature pair which via attention mechanism will determine weights if edge(i,j).')
            print('Concatenation is not broadcasted. Create it from flattened features that have proper ordering. shape (N_NODES*N_NODES, N_HEADS, 2*N_FEATURES)')
            self.Ck_l += self.GkR.repeat_interleave(N_NODES, dim=0)
            self.Ck_r += self.GkR.repeat(N_NODES, 1, 1)
            self.Ck_f += torch.cat([self.Ck_l, self.Ck_r], dim=-1);
            print_tex(r'Repeat_{flat} \ (K=1): '    ,self.Ck_l[:,[0]].squeeze(1).numpy(), 
                      r'Interleave_{flat} \ (K=1): ',self.Ck_r[:,[0]].squeeze(1).numpy(),
                       r'C_{flat} \ (K=1): '        ,self.Ck_f[:,[0]].squeeze(1).numpy())
            
            self.attnt.weight[:] = torch.ones_like(self.attnt.weight[:])
            print_tex(r'E = \sigma(\vec{a}[C_K])')
            print('To make attention vector interpretable for debug, fill it with ones. Dot = element sum.')
            print_tex('Attention \ vector = ', self.attnt.weight[:].numpy())
            print('Apply attention and activation to C_K ')
            print('(N_NODES, N_NODES, N_HEADS, 2*N_FEATURES) . (2*N_FEATURES, 1) -> (N_NODES, N_NODES, N_HEADS, 2*N_FEATURES, 1) -> squeeze last')
            self.E +=  self.activ(self.attnt(self.Ck)).squeeze(-1)#
            self.alpha += self.E.masked_fill(self.A.view(N_NODES,N_NODES,1) == 0, float('-inf'))
            for i in range(N_HEADS):
                print_tex('E_{K='+str(i + 1)+'} = ',self.E.numpy()[:,:,i], r'\rightarrow MASK \rightarrow ',(self.E[:,:,i]*self.A).numpy() )


            print('Apply row-wise softmax:')
            self.alpha = self.softmax(self.alpha)
            
            for i in range(N_HEADS):
                print_tex(r'\alpha_{K='+str(i + 1)+'} = ', self.alpha.numpy()[:,:,i])

model = net(H,torch.tensor(A),True)



Create matrix transformed embeddings (concatenated N_HEADS (K) blocks of shape (N_nodes,N_HIDDEN):


<IPython.core.display.Math object>

<IPython.core.display.Math object>

Reshape G_K form (N_NODES, N_HEADS*N_FEATURES) to (N_NODES, N_HEADS, N_FEATURES)
Prepare C_K: stacked N_HEADS attention matrices of shape (N_NODES, N_NODES, N_HEADS, 2*N_FEATURES)
Each entry i,j of C_K holds N_HEADS of concatenated feature pair which via attention mechanism will determine weights if edge(i,j).
Concatenation is not broadcasted. Create it from flattened features that have proper ordering. shape (N_NODES*N_NODES, N_HEADS, 2*N_FEATURES)


<IPython.core.display.Math object>

<IPython.core.display.Math object>

To make attention vector interpretable for debug, fill it with ones. Dot = element sum.


<IPython.core.display.Math object>

Apply attention and activation to C_K 
(N_NODES, N_NODES, N_HEADS, 2*N_FEATURES) . (2*N_FEATURES, 1) -> (N_NODES, N_NODES, N_HEADS, 2*N_FEATURES, 1) -> squeeze last


<IPython.core.display.Math object>

<IPython.core.display.Math object>

Apply row-wise softmax:


<IPython.core.display.Math object>

<IPython.core.display.Math object>

## Some things ive learned
* you can define an array, and its reshaped representation. if you dont redefine array, you can change it, and reshaped representation will also change (duh)

In [6]:
asd = torch.arange(2,5,1);print(asd)
asd2 = asd.reshape(-1,1);print(asd2)
asd += torch.arange(3,6,1);print(asd)
print(asd2)



tensor([2, 3, 4])
tensor([[2],
        [3],
        [4]])
tensor([5, 7, 9])
tensor([[5],
        [7],
        [9]])


* masking does not create a view

In [7]:
asd = torch.arange(2,5,1);print(asd)
mask = torch.tensor([1,0,0], dtype=bool);print(mask)
asd2 = asd.masked_fill(mask=mask,value= 0);print(asd2)
asd[2] = 8; print(asd)
print(asd2)

tensor([2, 3, 4])
tensor([ True, False, False])
tensor([0, 3, 4])
tensor([2, 3, 8])
tensor([0, 3, 4])


In [8]:
asd = torch.tensor([[1,2,3],[4,5,6]])
asd2 = torch.stack([torch.tensor([[1,2,3],[4,5,6]]),
                    torch.tensor([[7,8,9],[10,11,12]])])

a = torch.tensor([[0,0,1],[0,1,0]]).T
print(asd2.shape, a.shape)
torch.diagonal(asd2 @ a, dim1=-2, dim2=-1)


torch.Size([2, 2, 3]) torch.Size([3, 2])


tensor([[ 3,  5],
        [ 9, 11]])