# Attention mechanism

In [1]:
import math, torch
from torch import nn
from miniai.activations import *

In [2]:
import matplotlib.pyplot as plt

In [3]:
# to check our implementation
from diffusers.models.attention import Attention

In [5]:
set_seed(42)
# (batch_size, channels, height, width)
x = torch.randn(64, 32, 16, 16)

In [7]:
nc = 32

## Scaled Dot-Product Attention
Attention works by picking up which pixel(s) is/are the most important, i.e. it does a weighted average across the pixels.

In [6]:
t = x.view(*x.shape[:2], -1).transpose(1, 2)
t.shape

torch.Size([64, 256, 32])

Let's dive into each step of the Attention Mechanism.

In [9]:
# normalise input across channels for better training
norm = nn.BatchNorm2d(nc)

In [11]:
scale = math.sqrt(nc)
scale

5.656854249492381

In [12]:
kqv = nn.Linear(nc, nc*3)

In [14]:
k,q,v = torch.chunk(kqv(t), 3, dim=-1)
q.shape

torch.Size([64, 256, 32])

In [17]:
k.permute(0, 2, 1).shape

torch.Size([64, 32, 256])

In [27]:
prod = torch.matmul(q, k.transpose(1,2))
prod.shape

torch.Size([64, 256, 256])

In [28]:
prod /= scale

In [29]:
prod.shape

torch.Size([64, 256, 256])

"In the context of a 3D tensor like [64, 256, 256], applying softmax along the last dimension would independently normalize the values in each row (256 elements) across all the columns, resulting in a probability distribution for each row."

In [24]:
torch.softmax(prod, dim=-1)

tensor([[[0.0022, 0.0041, 0.0027,  ..., 0.0036, 0.0019, 0.0043],
         [0.0028, 0.0036, 0.0057,  ..., 0.0068, 0.0037, 0.0054],
         [0.0043, 0.0033, 0.0025,  ..., 0.0024, 0.0023, 0.0036],
         ...,
         [0.0033, 0.0037, 0.0039,  ..., 0.0031, 0.0046, 0.0041],
         [0.0033, 0.0043, 0.0043,  ..., 0.0044, 0.0031, 0.0045],
         [0.0049, 0.0055, 0.0059,  ..., 0.0043, 0.0041, 0.0034]],

        [[0.0030, 0.0023, 0.0046,  ..., 0.0038, 0.0028, 0.0035],
         [0.0043, 0.0032, 0.0033,  ..., 0.0035, 0.0096, 0.0033],
         [0.0035, 0.0043, 0.0039,  ..., 0.0041, 0.0038, 0.0034],
         ...,
         [0.0025, 0.0032, 0.0028,  ..., 0.0045, 0.0027, 0.0030],
         [0.0048, 0.0094, 0.0026,  ..., 0.0030, 0.0026, 0.0028],
         [0.0029, 0.0026, 0.0050,  ..., 0.0034, 0.0031, 0.0036]],

        [[0.0058, 0.0089, 0.0049,  ..., 0.0051, 0.0032, 0.0024],
         [0.0053, 0.0042, 0.0026,  ..., 0.0058, 0.0025, 0.0044],
         [0.0041, 0.0032, 0.0027,  ..., 0.0040, 0.0042, 0.

In [25]:
sm_prod = torch.softmax(prod, dim=-1)
sm_prod.shape

torch.Size([64, 256, 256])

In [26]:
v.shape

torch.Size([64, 256, 32])

In [30]:
(sm_prod@v).shape

torch.Size([64, 256, 32])

`proj` is simply a linear layer to allow our computed matrix to make some adjustments such as scaling, translation etc.

In [31]:
proj = nn.Linear(nc, nc)

In [32]:
proj(sm_prod@v).shape

torch.Size([64, 256, 32])

In [33]:
proj(sm_prod@v).transpose(1, 2).shape

torch.Size([64, 32, 256])

In [35]:
projected = proj(sm_prod@v).transpose(1, 2)
projected.reshape(64, 32, 16, 16).shape

torch.Size([64, 32, 16, 16])

Working on not flattening 2D image to 1D tensor.

In [37]:
norm(x).transpose(1,2).shape

torch.Size([64, 16, 32, 16])

In [49]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.BatchNorm2d(ni)
        self.kqv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, x):
        n, c, h, w = x.shape
        x = self.norm(x)
        x = x.view(*x.shape[:2], -1).transpose(1, 2)
        k, q, v = self.kqv(x).chunk(3, dim=-1)
        x = (q@k.transpose(1,2))/self.scale
        x = x.softmax(dim=-1)
        x = x@v
        x = self.proj(x).transpose(1, 2)
        x = x.reshape(n, c, h, w)
        return x
     

In [50]:
x.shape

torch.Size([64, 32, 16, 16])

In [51]:
sa = SelfAttention(32)
sa(x).shape

torch.Size([64, 32, 16, 16])

In [52]:
sa(x)

tensor([[[[-7.1404e-02, -9.4668e-02, -7.5854e-02,  ..., -1.2073e-01,
           -4.8674e-02, -5.8882e-02],
          [-9.5503e-02, -5.9106e-02, -9.8340e-02,  ..., -8.4026e-02,
           -9.5866e-02, -6.2507e-02],
          [-5.3636e-02, -1.0224e-01, -2.6004e-02,  ..., -7.9946e-02,
           -8.3766e-02, -8.4824e-02],
          ...,
          [-7.5610e-02, -8.0140e-02, -7.9661e-02,  ..., -7.2991e-02,
           -2.9725e-02, -7.8685e-02],
          [-6.5076e-02, -1.0856e-01, -5.9817e-02,  ..., -4.1871e-02,
           -6.5401e-02, -8.1491e-02],
          [-7.0243e-02, -8.7479e-02, -4.7026e-02,  ..., -3.5627e-02,
           -7.6907e-02, -6.4361e-02]],

         [[-1.2950e-01, -6.7728e-02, -1.0389e-01,  ..., -1.3295e-01,
           -9.0823e-02, -1.1598e-01],
          [-1.3166e-01, -1.3364e-01, -1.0794e-01,  ..., -1.5637e-01,
           -9.7193e-02, -1.4320e-01],
          [-1.4450e-01, -9.5884e-02, -7.7089e-02,  ..., -1.3615e-01,
           -1.2720e-01, -1.7574e-01],
          ...,
     

In [54]:
sa(x).mean()

tensor(0.0201, grad_fn=<MeanBackward0>)

In [55]:
sa(x).std()

tensor(0.1221, grad_fn=<StdBackward0>)

## Multi-Head Attention
While scaled dot product attention helps pixels within the image to find its relationship with other pixels of the image via a weighted average, there is a limiting factor. Scaled dot product attention works across all the channels, but this limits its ability to 'find other features' in certain scenario.

Recall that a channel represents output of a filter. The importance of pixels in a channel can be dependent on the pixels in another channel- e.g. pointedness of 2 ears + length of tail to determine the breed of a dog. 

In this case, we will want to assess the channels separately. Multi-headed attention allows us to do this.

We transform each image from `(1, c, h, w)` to `(1*n_heads, c/n_heads, h, w)`. Each batch of $\frac{c}{n_heads}$ channels can then interact more 'independently' with each other.

In [56]:
x.shape

torch.Size([64, 32, 16, 16])

In [58]:
n,c,h,w = x.shape
n,c,h,w

(64, 32, 16, 16)

In [59]:
t = x.view(*x.shape[:2], -1).transpose(1, 2)
t.shape

torch.Size([64, 256, 32])

In [69]:
kqv = nn.Linear(c, c*3)
kqv = kqv(t)
kqv.shape

torch.Size([64, 256, 96])

In [70]:
n_heads = 8

We want to change the dimension from `(64,256,96)` to `(64*n_heads, 256, 96/n_heads)`.

In [71]:
n, sz, nf = kqv.shape
n, sz, nf

(64, 256, 96)

In [73]:
kqv.reshape(n, sz, n_heads, -1).shape

torch.Size([64, 256, 8, 12])

In [74]:
# batch_size, n_heads, num_pixels, channels_per_head
kqv.reshape(n, sz, n_heads, -1).transpose(1, 2).shape

torch.Size([64, 8, 256, 12])

In [75]:
# batch_size * n_heads, num_pixels, channels_per_head
kqv = kqv.reshape(n, sz, n_heads, -1).transpose(1, 2).reshape(n*n_heads, sz, -1)
kqv.shape

torch.Size([512, 256, 12])

We can now chunk `kqv` into individual `k`, `q` and `v` to do the scaled dot product attention.

In [76]:
k, q, v = kqv.chunk(3, dim=-1)
k.shape

torch.Size([512, 256, 4])

In [77]:
(q@k.transpose(1,2)).shape

torch.Size([512, 256, 256])

In [78]:
scale = math.sqrt(nc)

In [80]:
qk_scaled = (q@k.transpose(1,2))/scale
qk_scaled.shape

torch.Size([512, 256, 256])

In [81]:
softmax_qk = qk_scaled.softmax(dim=-1)
softmax_qk.shape

torch.Size([512, 256, 256])

In [83]:
attn_prod = softmax_qk@v
attn_prod.shape

torch.Size([512, 256, 4])

In [85]:
# convert back to shape of (batch_size, sz, nc)
attn_prod_reshaped = attn_prod.reshape(n, sz, -1)
attn_prod_reshaped.shape

torch.Size([64, 256, 32])

In [86]:
proj = nn.Linear(nc,nc)

In [87]:
projected_attn = proj(attn_prod_reshaped)
projected_attn.shape

torch.Size([64, 256, 32])

In [88]:
projected_attn.transpose(1, 2).reshape(n,c,h,w).shape

torch.Size([64, 32, 16, 16])

This is the workflow for multi-headed attention.

In [89]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, ni, n_heads):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.BatchNorm2d(ni)
        self.kqv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
        self.n_heads = n_heads
    
    def forward(self, x):
        n, c, h, w = x.shape
        x = self.norm(x)
        x = x.view(*x.shape[:2], -1).transpose(1, 2)
        x = self.kqv(x)
        x = x.reshape(n, h*w, self.n_heads, -1).transpose(1, 2) \
             .reshape(n*self.n_heads, h*w, -1)
        k, q, v = x.chunk(3, dim=-1)
        x = (q@k.transpose(1,2))/self.scale
        x = x.softmax(dim=-1)
        x = (x@v).reshape(n, h*w, -1)
        x = self.proj(x).transpose(1, 2)
        x = x.reshape(n, c, h, w)
        return x
     

In [90]:
mha = MultiHeadedAttention(32, 8)

In [96]:
x = torch.randn(64, 32, 16, 16)
x

tensor([[[[ 1.3436e+00, -5.7902e-01,  3.3024e-01,  ...,  6.8883e-01,
           -1.7589e+00, -2.4182e-01],
          [ 2.2925e+00,  8.5349e-01,  6.1042e-01,  ..., -1.5404e+00,
           -1.5196e+00, -1.0797e+00],
          [-2.5255e-01, -7.9961e-01, -3.0755e-01,  ...,  9.5952e-01,
            1.2505e+00, -2.2428e+00],
          ...,
          [ 1.8232e-01,  3.4944e-01, -5.5609e-01,  ...,  6.3996e-01,
            1.3510e+00,  9.8223e-01],
          [ 2.1318e-01, -1.2750e-01, -1.0384e+00,  ..., -4.6253e-02,
           -5.6790e-02, -7.6728e-02],
          [ 1.9761e+00, -1.6511e+00,  5.9411e-01,  ...,  6.0437e-01,
           -7.3852e-01,  9.9078e-01]],

         [[-4.3274e-01,  5.1995e-01,  2.1706e-01,  ..., -5.8823e-01,
           -2.0282e+00, -1.5199e+00],
          [ 7.2450e-01,  4.6457e-01, -5.5252e-01,  ...,  2.2462e+00,
           -6.2283e-01,  1.3075e-02],
          [-9.6425e-01,  1.4548e+00,  1.6828e+00,  ..., -7.2695e-02,
           -1.1654e-01, -3.9207e-01],
          ...,
     

In [92]:
mha(x)

tensor([[[[ 8.6980e-02,  1.0640e-01,  1.1123e-01,  ...,  8.6808e-02,
            9.3569e-02,  7.9098e-02],
          [ 1.1391e-01,  9.9476e-02,  9.3977e-02,  ...,  9.7607e-02,
            9.4890e-02,  8.9818e-02],
          [ 3.9074e-02,  3.1692e-02,  6.2416e-02,  ...,  6.1854e-02,
            4.5320e-02,  4.9323e-02],
          ...,
          [ 4.1540e-02,  3.6839e-02,  3.0327e-02,  ...,  3.5147e-02,
            4.2592e-02,  3.7658e-02],
          [ 2.2156e-02,  2.2266e-02,  1.2131e-02,  ...,  8.2679e-03,
            2.3898e-02,  2.8549e-02],
          [ 2.4418e-02,  2.6451e-02,  2.3892e-02,  ...,  1.7363e-02,
            2.5196e-02,  2.7250e-02]],

         [[ 2.3117e-01,  2.2114e-01,  2.2113e-01,  ...,  2.3514e-01,
            2.2097e-01,  2.1988e-01],
          [ 2.3845e-01,  2.3209e-01,  2.0031e-01,  ...,  2.2696e-01,
            2.3378e-01,  2.4057e-01],
          [ 2.0781e-01,  1.9323e-01,  2.0696e-01,  ...,  2.0826e-01,
            2.1336e-01,  2.0367e-01],
          ...,
     

In [93]:
mha(x).shape

torch.Size([64, 32, 16, 16])

In [94]:
mha(x).mean()

tensor(0.0031, grad_fn=<MeanBackward0>)

In [95]:
mha(x).std()

tensor(0.1223, grad_fn=<StdBackward0>)