## A more generalised form of attention

There are some issues with the method described in bahdanau. First, it's hard-tied to a RNN based encoder decoder model. What if you don't want to use recurrent networks. Second, it needs the outputs from the encoder and decoder to form the information space (encoder state) and the search space (decoder output) to get probable output candidates; but you may not always want to model seq2seq using attention. 


So there is a more generalised version of attention, which uses the terms key, query and values to describe the attention process. 

To find parallels with the bahdanau method, 

query : what we want to know or search (decoder output)

key: what information we have (and can be used for searching) , (the encoder state)

value: the probability of the search result being related to the query (in the bahdanau paper, values are the same as the keys)



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
import numpy as np

So the implementation details become this:

Say you have an input sequence $X$,

Then you need a query representation (with a query weight)
A key rep (same, another weight) and so on for value

In [2]:
# size for the input
seq_len = 50
src = torch.arange(seq_len)
tgt = torch.arange(seq_len) + 1

In [3]:
@torch.no_grad()
def attention(src: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    # modules
    src_embed = nn.Embedding(100, 10)
    tgt_embed = nn.Embedding(100, 10)
    
    Q = nn.Linear(10, 10)
    K = nn.Linear(10, 10)
    V = nn.Linear(10, 10)
    
    x = src_embed(src)
    y = tgt_embed(target)
    
    
    q = Q(x)
    k = K(y)
    # src -> target attetntion
    # optimising the target to be aligned with the src
    # so values come from the src
    v = V(x)
    
    # now that we have q, k, v
    
    attention_scores = q @ k.T
    alpha = attention_scores.softmax(dim=-1)
    # attention
    a = alpha @ v
    return a


a = attention(src, tgt)  

In [4]:
a.size()

torch.Size([50, 10])

In [5]:
a

tensor([[ 3.6128e-02, -2.9233e-01, -2.4212e-01, -2.5898e-01,  3.5666e-02,
         -1.2268e-01,  2.5762e-01,  3.5595e-01, -9.5216e-02, -9.7683e-02],
        [ 4.3134e-02, -3.7733e-01, -2.3537e-01, -2.8420e-01,  2.0973e-02,
         -1.8788e-01,  3.3490e-01,  2.7411e-01,  7.4452e-02, -7.2363e-02],
        [-1.2856e-02, -4.9443e-01, -1.6425e-01, -1.7546e-01, -9.8770e-02,
          4.4919e-02,  4.5979e-01,  1.1665e-01, -9.5980e-02,  8.1025e-02],
        [-3.4816e-02, -3.5247e-01, -2.0780e-01, -3.1957e-01,  7.1137e-02,
         -3.3858e-02,  3.6715e-01,  2.7959e-01,  1.2284e-01, -1.4793e-01],
        [ 2.1492e-02, -3.5198e-01, -9.7461e-02, -2.8343e-01,  9.7485e-02,
         -6.8736e-02,  3.1947e-01,  2.3882e-01, -2.1193e-02, -1.0721e-01],
        [ 1.4464e-01, -4.0921e-01, -2.7013e-01, -1.9694e-01,  1.6508e-01,
         -1.5613e-01,  2.3145e-01, -2.9909e-02, -1.8270e-01, -1.9135e-01],
        [-1.5929e-02, -4.3593e-01, -2.1074e-01, -2.3540e-01, -7.0843e-02,
         -3.9005e-02,  4.2824e-0