In [85]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

import torch#, time
import torch.nn as nn
device = torch.device("cpu")# torch.device("cuda" if torch.cuda.is_available() else "cpu")
from sklearn.preprocessing import OneHotEncoder
import matplotlib.animation as animation
from IPython.display import HTML
from misc_tools.print_latex import print_tex

For derivation see GNN_Attention_notes.ipynb
Im not sure whether implement unique $\vec{a}$ for all heads. Original paper hints on that. Implementation i have found reuses it.


In [86]:
np.random.seed(1337)
#G = nx.house_graph()
N = 3
G = nx.gnm_random_graph(N, 2*N)

A = nx.adjacency_matrix(G).todense()
print_tex('A = ', A)

<IPython.core.display.Math object>

In [87]:
#N_FEATURES, N_NODES = A.shape
N_NODES = len(G.nodes())
N_FEATURES = 2
N_HIDDEN = 2
N_HEADS = 2
OUT_FEATURES = 2

In [88]:
H = 0.5*torch.arange(N_NODES*N_FEATURES, dtype = float).view(N_NODES, N_FEATURES)
print_tex('H = ', H.numpy())

<IPython.core.display.Math object>

In [94]:
class net(nn.Module):
    def __init__(self, H0, A, test = False):
        super(net, self).__init__()
        self.H = H0
        self.A = A
        self.W_gh   = nn.Linear(in_features=N_FEATURES, out_features=N_HEADS*N_HIDDEN, bias=False, dtype=H0.dtype)
        self.Gk     = torch.zeros(size=(N_NODES, N_HEADS*N_HIDDEN), dtype=H0.dtype)
        self.GkR    = self.Gk.view(N_NODES, N_HEADS, N_FEATURES)
        self.Ck_l   = torch.zeros(size=(N_NODES*N_NODES, N_HEADS, N_HIDDEN), dtype=H0.dtype) 
        self.Ck_r   = torch.zeros_like(self.Ck_l)
        self.Ck_f   = torch.zeros(size=(N_NODES*N_NODES, N_HEADS, 2*N_HIDDEN), dtype=H0.dtype) 
        self.Ck     = self.Ck_f.view(N_NODES, N_NODES, N_HEADS, 2*N_HIDDEN)
        #self.attnt  = nn.Linear(2*N_FEATURES, 1, bias=False, dtype = H0.dtype)
        self.attnt  = nn.Parameter(torch.zeros(size=(2*N_HIDDEN, N_HEADS), dtype=H0.dtype))
        self.activ  = nn.ReLU(0.2)
        self.E      = torch.zeros(size=(N_NODES,N_NODES, N_HEADS), dtype=H0.dtype)
        self.alpha  = torch.zeros_like(self.E)
        self.softmax= nn.Softmax(dim = 1)
        self.GkPrime= torch.zeros_like(self.GkR)
         
        if test:
            self.debug()

    def debug(self):
        with torch.no_grad():
            print('Create matrix transformed embeddings (concatenated N_HEADS (K) blocks of shape (N_nodes,N_HIDDEN):')
            print_tex('G_K = H W_{K=1}^T || H W_{K=2}^T = G_{K=1} || G_{K=2}')

            # set scaling transforms
            for i, s in enumerate(range(2,2*N_HEADS + 2, 2)):
                self.W_gh.weight[i*N_FEATURES:(i+1)*N_FEATURES] = s*torch.eye(N_FEATURES)

            self.Gk += self.W_gh(self.H)        # cannot redefine, it will break a view

            print_tex('G_K = H W_K^T = ', H.numpy() , self.W_gh.weight[:].T.numpy(),' = ', self.Gk.numpy())
            print('Reshape G_K form (N_NODES, N_HEADS*N_HIDDEN) to (N_NODES, N_HEADS, N_HIDDEN)')
            print("Prepare C_K: stacked N_HEADS attention matrices of shape (N_NODES, N_NODES, N_HEADS, 2*N_HIDDEN)")
            print('Each entry i,j of C_K holds N_HEADS of concatenated feature pair which via attention mechanism will determine weights if edge(i,j).')
            print('Concatenation is not broadcasted. Create it from flattened features that have proper ordering. shape (N_NODES*N_NODES, N_HEADS, 2*N_HIDDEN)')

            self.Ck_l += self.GkR.repeat_interleave(N_NODES, dim=0)
            self.Ck_r += self.GkR.repeat(N_NODES, 1, 1)
            self.Ck_f += torch.cat([self.Ck_l, self.Ck_r], dim=-1);
            
            print_tex(r'Repeat_{flat} \ (K=1): '    ,self.Ck_l[:,[0]].squeeze(1).numpy(), 
                      r'Interleave_{flat} \ (K=1): ',self.Ck_r[:,[0]].squeeze(1).numpy(),
                       r'C_{flat} \ (K=1): '        ,self.Ck_f[:,[0]].squeeze(1).numpy())
            
            print_tex(r'\text{Features C}_{0,0} = ', self.Ck[0,0].numpy(), r'; \ shape: \ [N_{heads} \times 2 N_{hidden}]')
            prnt_vec = [r'\vec{a}_'+str(i)+ ' = ' for i in range(N_HEADS)]
            prnt_vec2 = [r'^T ; \ ' for i in range(N_HEADS)]

            self.attnt += np.repeat(torch.arange(N_HEADS, dtype=self.H.dtype).unsqueeze(0)+1, repeats=2*N_HIDDEN, axis = 0)

            print_tex(r'E = \sigma(\vec{a}[C_K])')
            print('Test attention vectors:')
            print_tex(*[l for lists in zip(prnt_vec,self.attnt.T.numpy(),prnt_vec2) for l in lists])
            print('>>>See how to apply multiple attention vectors to data in notes<<<')
            
            self.E +=  self.activ(torch.einsum('ijkf,fk -> ijk', self.Ck, self.attnt)).squeeze(-1)#

            print_tex(r'\text{Features E}_{0,0} = ', self.E[0,0].numpy())
            self.alpha += self.E.masked_fill(self.A.view(N_NODES,N_NODES,1) == 0, float('-inf'))
            for i in range(N_HEADS):
                print_tex('E_{K='+str(i + 1)+'} = ',self.E.numpy()[:,:,i], r'\rightarrow MASK \rightarrow ',(self.E[:,:,i]*self.A).numpy() )

            print('Apply row-wise softmax:')
            self.alpha = self.softmax(self.alpha)
            
            for i in range(N_HEADS):
                print_tex(r'\alpha_{K='+str(i + 1)+'} = ', self.alpha.numpy()[:,:,i], r'{\rightarrow set \ to \ A \ for \ example \rightarrow }:', self.A.numpy())
                self.alpha.numpy()[:,:,i] = self.A
            self.GkPrime += torch.einsum('ijk,jkf->ikf', self.alpha , self.GkR)

            print('>>>See how to aggregate multi head case in notes<<<')
            n1_neighbors_id = torch.argwhere(self.A[0] == 1).flatten().numpy()
            
            a = [r'\vec{g}_'+str(i)+ '^T = ' for i in n1_neighbors_id]
            b = [self.GkR[i].numpy() for i in n1_neighbors_id]
            c = [r'; \ ' for i in n1_neighbors_id]
            print_tex(*[l for lists in zip(a,b,c) for l in lists], r'\vec{A}_0 = ', self.A[0].numpy() )
            print_tex(r'\vec{g}_0^\prime = \vec{A}_0 \otimes G = ', self.GkPrime[0].numpy())
            concat = self.GkPrime.reshape(N_NODES, N_HEADS * N_HIDDEN)
            print_tex(concat[0].numpy())

model = net(H,torch.tensor(A),True)

Create matrix transformed embeddings (concatenated N_HEADS (K) blocks of shape (N_nodes,N_HIDDEN):


<IPython.core.display.Math object>

<IPython.core.display.Math object>

Reshape G_K form (N_NODES, N_HEADS*N_HIDDEN) to (N_NODES, N_HEADS, N_HIDDEN)
Prepare C_K: stacked N_HEADS attention matrices of shape (N_NODES, N_NODES, N_HEADS, 2*N_HIDDEN)
Each entry i,j of C_K holds N_HEADS of concatenated feature pair which via attention mechanism will determine weights if edge(i,j).
Concatenation is not broadcasted. Create it from flattened features that have proper ordering. shape (N_NODES*N_NODES, N_HEADS, 2*N_HIDDEN)


<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

Test attention vectors:


<IPython.core.display.Math object>

>>>See how to apply multiple attention vectors to data in notes<<<


<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

Apply row-wise softmax:


<IPython.core.display.Math object>

<IPython.core.display.Math object>

>>>See how to aggregate multi head case in notes<<<


<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

## Some things ive learned
* you can define an array, and its reshaped representation. If you dont redefine array, you can change it, and reshaped representation will also change (duh)

In [90]:
asd = torch.arange(2,5,1);print(asd)
asd2 = asd.reshape(-1,1);print(asd2)
asd += torch.arange(3,6,1);print(asd)
print(asd2)

tensor([2, 3, 4])
tensor([[2],
        [3],
        [4]])
tensor([5, 7, 9])
tensor([[5],
        [7],
        [9]])


* masking does not create a view

In [91]:
asd = torch.arange(2,5,1);print(asd)
mask = torch.tensor([1,0,0], dtype=bool);print(mask)
asd2 = asd.masked_fill(mask=mask,value= 0);print(asd2)
asd[2] = 8; print(asd)
print(asd2)

tensor([2, 3, 4])
tensor([ True, False, False])
tensor([0, 3, 4])
tensor([2, 3, 8])
tensor([0, 3, 4])


In [92]:
asd = torch.tensor([[1,2,3],[4,5,6]])
asd2 = torch.stack([torch.tensor([[1,2,3],[4,5,6]]),
                    torch.tensor([[7,8,9],[10,11,12]])])

a = torch.tensor([[0,0,1],[0,1,0]]).T
print(asd2.shape, a.shape)
torch.diagonal(asd2 @ a, dim1=-2, dim2=-1)


torch.Size([2, 2, 3]) torch.Size([3, 2])


tensor([[ 3,  5],
        [ 9, 11]])

In [93]:
print([*[i, k] for i,k in zip([1,2],[10,20])])

SyntaxError: iterable unpacking cannot be used in comprehension (3707840137.py, line 1)

In [None]:
list1 = [1, 2]
list2 = [10, 20]

interleaved_list = [item for sublist in zip(list1, list2) for item in sublist]

print(interleaved_list)


[1, 10, 2, 20]
