## Self attention

Used in Transformer aka [*Attention is all you need*](https://proceedings.neurips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html).

Take the generalised attention format and then instead of comparing src -> tgt, do src -> src, which means that you compare the src to itself. Hence, self attention. 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
import numpy as np

In [2]:
# size for the input
seq_len = 50
src = torch.arange(seq_len)

In [3]:
@torch.no_grad()
def self_attention(src: torch.Tensor) -> torch.Tensor:
    # modules
    src_embed = nn.Embedding(100, 10)
    tgt_embed = nn.Embedding(100, 10)

    Q = nn.Linear(10, 10)
    K = nn.Linear(10, 10)
    V = nn.Linear(10, 10)

    x = src_embed(src)

    q = Q(x)
    k = K(x)
    v = V(x)


    attention_scores = q @ k.T
    alpha = attention_scores.softmax(dim=-1)
    # attention
    a = alpha @ v
    return a

sa = self_attention(src)

In [4]:
sa

tensor([[ 2.2446e-02,  2.1148e-02, -4.1858e-01,  7.3752e-02,  1.4769e-01,
          3.2192e-02,  3.7492e-02,  1.3095e-01, -1.3525e-01, -1.2915e-01],
        [-6.5445e-02,  1.0459e-02, -1.9456e-01,  2.5783e-02,  2.8349e-01,
          1.3515e-01,  1.7737e-01,  1.4303e-01, -5.6320e-02, -3.6342e-02],
        [-4.5950e-02,  1.4819e-01, -1.5381e-01,  5.2708e-02,  2.4689e-01,
          1.0627e-01,  1.2080e-01,  2.4718e-01, -3.6499e-02, -9.3811e-03],
        [ 5.1255e-02, -1.9176e-02, -4.8275e-01,  3.5726e-02,  1.5140e-01,
          6.3747e-02,  2.5899e-02,  1.0242e-01, -1.2603e-01, -7.2023e-02],
        [-1.1805e-01, -6.0078e-02, -1.4068e-01,  1.9195e-02,  1.9841e-01,
          1.3663e-01,  1.1511e-01,  2.4002e-01, -3.7578e-05,  7.2950e-02],
        [-3.1617e-02,  3.7727e-02, -2.4627e-01,  7.4321e-02,  1.9820e-01,
          1.2627e-01,  9.8531e-02,  1.2523e-01, -1.2109e-01, -1.1994e-01],
        [-6.4535e-02,  2.8273e-02, -1.2834e-01,  5.8515e-02,  2.3900e-01,
          1.8573e-01,  1.7891e-0

In [5]:
sa.size()

torch.Size([50, 10])

## Scaled self attention

Self attention is based on dot product between the key and the query. Problem is, dot products can get very large at times. Suppose the embedding dimension $d_k$ was about 512, then, the dot product can be as large as 512. Increase the size and we may have a computation problem. Furthermore, larger values don't play well with softmax, especially if not scaled. (Softmax provides a probability dist. after all and that can get skewed.)

So in the Transformer paper, they just scale the attention scores before applying softmax.

In [7]:
@torch.no_grad()
def scaled_self_attention(src: torch.Tensor) -> torch.Tensor:
    src_embed = nn.Embedding(100, 512)
    x = src_embed(src)
    # scaling factor
    dk = x.size(-1)

    # modules
    Q = nn.Linear(dk, 100)
    K = nn.Linear(dk, 100)
    V = nn.Linear(dk, 100)

    

    q = Q(x)
    k = K(x)
    v = V(x)


    # scale
    attention_scores = (q @ k.T) / torch.sqrt(torch.tensor(dk))
    alpha = attention_scores.softmax(dim=-1)
    # attention
    a = alpha @ v
    return a

ssa = scaled_self_attention(src)

In [8]:
ssa

tensor([[-0.0383,  0.0145,  0.0007,  ..., -0.0889,  0.0965, -0.0302],
        [-0.0370,  0.0450, -0.0110,  ..., -0.0941,  0.1030, -0.0178],
        [-0.0371,  0.0136,  0.0092,  ..., -0.1074,  0.0901, -0.0361],
        ...,
        [-0.0302,  0.0214,  0.0121,  ..., -0.0812,  0.0733, -0.0547],
        [-0.0642,  0.0256,  0.0020,  ..., -0.0957,  0.0789, -0.0306],
        [-0.0639,  0.0295, -0.0151,  ..., -0.1058,  0.0872, -0.0257]])

In [9]:
ssa.size()

torch.Size([50, 100])