# GAT (Graph Attention Network) dummy flowchart notebook
This is a small companion notebook to GNN_Attention_notes.ipynb, which presents a more compact version of 'pseudo solved' case.

Full implementation does many operations in-place and adds features.

In [2]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
device = torch.device("cpu")# torch.device("cuda" if torch.cuda.is_available() else "cpu")

from misc_tools.print_latex import print_tex

input example : 
>>> arr_T = np.array([[r'\vec{v}_1', r'\vec{v}_2']]).T
>>> print_tex(arr_T,'=', np.arange(1,5).reshape(2,-1)/4, r'; symbols: \otimes, \cdot,\times')
output: 


<IPython.core.display.Math object>

In [3]:
np.random.seed(1337)
N = 3
G = nx.gnm_random_graph(N, 2*N)
A = nx.adjacency_matrix(G).todense()
print_tex('A = ', A)
A = torch.tensor(A)

<IPython.core.display.Math object>

In [4]:
#N_FEATURES, N_NODES = A.shape
N_NODES = len(G.nodes())
N_FEATURES = 3
N_HIDDEN = 2
N_HEADS = 2
OUT_FEATURES = 2

In [5]:
H = torch.arange(N_NODES*N_FEATURES, dtype = float).view(N_NODES, N_FEATURES) + 1
print_tex('H = ', H.numpy())

<IPython.core.display.Math object>

In [6]:
class debug_net(nn.Module):
    def __init__(self, H0, A, test = False):
        super(debug_net, self).__init__()
        self.H = H0
        self.A = A
        self.W_gh   = nn.Linear(in_features=N_FEATURES, out_features=N_HEADS*N_HIDDEN, bias=False, dtype=H0.dtype)
        self.Gk     = torch.zeros(size=(N_NODES, N_HEADS*N_HIDDEN), dtype=H0.dtype)
        self.GkR    = self.Gk.view(N_NODES, N_HEADS, N_FEATURES)
        self.Ck_l   = torch.zeros(size=(N_NODES*N_NODES, N_HEADS, N_HIDDEN), dtype=H0.dtype) 
        self.Ck_r   = torch.zeros_like(self.Ck_l)
        self.Ck_f   = torch.zeros(size=(N_NODES*N_NODES, N_HEADS, 2*N_HIDDEN), dtype=H0.dtype) 
        self.Ck     = self.Ck_f.view(N_NODES, N_NODES, N_HEADS, 2*N_HIDDEN)
        self.attnt  = nn.Parameter(torch.zeros(size=(2*N_HIDDEN, N_HEADS), dtype=H0.dtype))
        self.activ  = nn.LeakyReLU(0.2)
        self.E      = torch.zeros(size=(N_NODES,N_NODES, N_HEADS), dtype=H0.dtype)
        self.alpha  = torch.zeros_like(self.E)
        self.softmax= nn.Softmax(dim = 1)
        self.GkPrime= torch.zeros_like(self.GkR)
         
        if test:
            self.debug()

    def debug(self):
        with torch.no_grad():
            print_tex(r'N_{nodes} = '+ str(N_NODES) + r'; \ N_{heads} = '+ str(N_HEADS) + r'; \ N_{features} = '+ str(N_FEATURES)+ r'; \ N_{hidden} = '+ str(N_HIDDEN))
            print_tex(r'G_{K} \text{ is a matrix of concatenated embeddings } \vec{g}_i^{k} , \ shape : [N_{nodes} \times N_{heads}*N_{hidden}]')

            # set scaling transforms
            for i in range(N_HEADS):
                s = 1 if i == 0 else 4*i

                self.W_gh.weight[i*N_FEATURES:(i+1)*N_FEATURES] = s*torch.eye(N_FEATURES)

            self.Gk += self.W_gh(self.H)        # cannot redefine, it will break a view

            print_tex('G_K = H W_K^T = ', H.numpy() , self.W_gh.weight[:].T.numpy(),' = ', self.Gk.numpy())
            print_tex(r'\text{Reshape } G_{K} \ to \ [N_{nodes} \times N_{heads} \times N_{hidden}] \text{ to isolate each head`s data to its own dimension}')

            print_tex(r"\text{Goal: a matrix } C_K \text{ that holds concatenated node feature pairs. Shape: }[N_{nodes} \times N_{nodes}\times N_{heads} \times 2 N_{hidden}]")
            print("Its only (?) possible with flattening, concatenating and unflattening. See notes.")

            self.Ck_l += self.GkR.repeat_interleave(N_NODES, dim=0)
            self.Ck_r += self.GkR.repeat(N_NODES, 1, 1)
            self.Ck_f += torch.cat([self.Ck_l, self.Ck_r], dim=-1);
            
            print_tex(r'C_{flat} \ (K=1) = Repeat_{flat} \ ||  \ Interleave_{flat} = '
                      ,self.Ck_l[:,[0]].squeeze(1).numpy(),r' \ \bigg|\bigg| \ ', self.Ck_r[:,[0]].squeeze(1).numpy(), ' = ',
                      self.Ck_f[:,[0]].squeeze(1).numpy())
            
            print_tex(r'\text{Features C}_{0,0} = ', self.Ck[0,0].numpy(), r'; \ shape: \ [N_{heads} \times 2 N_{hidden}]')
            prnt_vec = [r'\vec{a}_'+str(i)+ ' = ' for i in range(N_HEADS)]
            prnt_vec2 = [r'^T ; \ ' for i in range(N_HEADS)]

            self.attnt += np.repeat(1/(torch.arange(N_HEADS, dtype=self.H.dtype).unsqueeze(0)+1), repeats=2*N_HIDDEN, axis = 0)
            print_tex(r"\text{Goal: a matrix E that holds edge weights. Shape: }[N_{nodes} \times N_{nodes} \times N_{heads}]")
            print_tex(r'E = \sigma(\vec{a}[C_K])')
            print('Test attention vectors:')
            print_tex(*[l for lists in zip(prnt_vec,self.attnt.T.numpy(),prnt_vec2) for l in lists])
            print('>>>See how to apply multiple attention vectors to data in notes<<<')
            
            self.E += self.activ(torch.einsum('ijkf,fk -> ijk', self.Ck, self.attnt)).squeeze(-1)#

            print_tex(r'\text{Features E}_{0,0} = ', self.E[0,0].numpy())
            self.alpha += self.E.masked_fill(self.A.view(N_NODES,N_NODES,1) == 0, float('-inf'))
            for i in range(N_HEADS):
                print_tex('E_{K='+str(i + 1)+'} = ',self.E.numpy()[:,:,i], r'\rightarrow MASK \rightarrow ',(self.E[:,:,i]*self.A).numpy() )
            print_tex(r"\text{Goal: a matrix } \Alpha \ or \ \alpha \text{ with row-wise softmax normalized weights. Shape: }[N_{nodes} \times N_{nodes} \times N_{heads}]")
            self.alpha = self.softmax(self.alpha)
            
            for i in range(N_HEADS):
                print_tex(r'\alpha_{K='+str(i + 1)+'} = ', self.alpha.numpy()[:,:,i], r'{\rightarrow set \ to \ A \ for \ example \rightarrow }:', self.A.numpy())
                self.alpha.numpy()[:,:,i] = self.A
            self.GkPrime += torch.einsum('ijk,jkf->ikf', self.alpha , self.GkR)

            print_tex(r"\text{Goal: updated features } G_k^\prime \text{ based on aggregation of features } \vec{g}_i^k \text{ with weights } \Alpha \text{ . Shape: }[N_{nodes} \times N_{heads} \times N_{hidden}]")
            print('>>>See how to aggregate multi head case in notes<<<')
            n1_neighbors_id = torch.argwhere(self.A[0] == 1).flatten().numpy()
            
            a = [r'G_'+str(i)+ ' = ' for i in n1_neighbors_id]
            b = [self.GkR[i].numpy() for i in n1_neighbors_id]
            c = [r'; \ ' for i in n1_neighbors_id]
            print_tex(*[l for lists in zip(a,b,c) for l in lists], r'\Alpha|_{row,1}= ', self.alpha[0].numpy() )
            print_tex(r'G_0^\prime = \vec{A}_0 \otimes G = ', self.GkPrime[0].numpy())
            print('New embeddings can be either concatenated across different variants of k or averaged"')
            GkP_concat = self.GkPrime.reshape(N_NODES, N_HEADS * N_HIDDEN)
            GkP_avg  = self.GkPrime.mean(dim=1)
            print_tex(r'G_0^{concat} = ', GkP_concat[0].numpy(), r'; \ G_0^{Avg} = ',GkP_avg[0].numpy())


model = debug_net(H,A,True)

RuntimeError: shape '[3, 2, 3]' is invalid for input of size 12