In [None]:
import math,torch
from torch import nn

from miniai.activations import *

import matplotlib.pyplot as plt
from diffusers.models.attention import AttentionBlock

set_seed(42)
x = torch.randn(64,32,16,16)

# Attention is used in the stable diffusion implementation
# Attention lets you take a weighted average of other pixels
# 1D attention is what is used in Stable Diffusion
# This attention that is 1d will flatten everything down into a 1d output, it will however be a matrix because we will have channels

# To do this attention we take a weighted average of thes pixels each pixel is the original pixel + the weighted average - the weights will sum to 1

# K, Q, V are all being passed the same matrix - self-attention
# Softmax will make the weights get summed to 1

In [None]:
# Now trying to make the identical stable diffusion attention block

# We can flatten thsee out with x.view()
# The *x.shape[:2] will unpack these as args into the function since we just want to copy the exact dimension and then -1 will unpack the rest of the tensors
# So then we just transpose the first two dimensions
t = x.view(*x.shape[:2], -1).transpose(1, 2)
t.shape

# Batch, Sequence, Dimension

In [None]:
# 32 different projections
ni = 32

In [None]:
sk = nn.Linear(ni, ni)
sq = nn.Linear(ni, ni)
sv = nn.Linear(ni, ni)

In [None]:
# Pass the key query values from the weights
k = sk(t)
q = sq(t)
v = sv(t)

In [None]:
# THen we transpose the dimensions  and this is exactly self attention we just need to normalise them
(q @ k.transpose(1,2)).shape

In [None]:
# We also have normalisation, group norm will do batch norm split into channels
# taking the q k v and then 2d self atetntion we need to noramlise them, transpoe teh dimensions
# matrix multiplcation - we have change the scale by multipllying so we need to square root by the numbero f input dimensions
# then we do the softmax and apply our projection to map things needed
# Then we reshape it back

class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.GroupNorm(1, ni)
        self.q = nn.Linear(ni, ni)
        self.k = nn.Linear(ni, ni)
        self.v = nn.Linear(ni, ni)
        self.proj = nn.Linear(ni, ni)
    
    # Then we have a residual path in our attention block
    def forward(self, x):
        inp = x
        n,c,h,w = x.shape
        x = self.norm(x)
        x = x.view(n, c, -1).transpose(1, 2)
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = self.proj(x)
        x = x.transpose(1,2).reshape(n,c,h,w)
        return x + inp

In [None]:
sa = SelfAttention(32)

In [None]:
ra = sa(x)
ra.shape

In [None]:
ra[0,0,0]

In [None]:
def cp_parms(a,b):
    b.weight = a.weight
    b.bias = a.bias

In [None]:
# By copying the weights and biases we can check that htey give the same values
at = AttentionBlock(32, norm_num_groups=1)
src = sa.q,sa.k,sa.v,sa.proj,sa.norm
dst = at.query,at.key,at.value,at.proj_attn,at.group_norm
for s,d in zip(src,dst): cp_parms(s,d)

In [None]:
rb = at(x)
rb[0,0,0]

In [None]:
sqkv = nn.Linear(ni, ni*3)
st = sqkv(t)
st.shape

In [None]:
q,k,v = torch.chunk(st, 3, dim=-1)
q.shape

In [None]:
(k@q.transpose(1,2)).shape

In [None]:
# Here we just have one matrix for qkv which should be faster as it will be doing far less loading of variables 
# Different channles briing in information from different parts - this is done with multi headed attention
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.BatchNorm2d(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, inp):
        n,c,h,w = inp.shape
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        q,k,v = torch.chunk(self.qkv(x), 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = self.proj(x).transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.BatchNorm2d(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, x):
        x = self.norm(x).transpose(1, 2)
        q,k,v = torch.chunk(self.qkv(x), 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        return self.proj(x).transpose(1,2)

In [None]:
sa = SelfAttention(32)
sa(x).shape

torch.Size([64, 32, 16, 16])

In [None]:
sa(x).std()

tensor(1.0047, grad_fn=<StdBackward0>)

In [None]:
# Heads to batch - 64 per batch and 256 pixels an 
def heads_to_batch(x, heads):
    n,sl,d = x.shape
    # reshape it all so that we split the last two dimensions onto heads by the rest
    x = x.reshape(n, sl, heads, -1)
    # if we transpose these 2 dimension it will be n by heads
    return x.transpose(2, 1).reshape(n*heads,sl,-1)

# Just  reverses it
def batch_to_heads(x, heads):
    n,sl,d = x.shape
    x = x.reshape(-1, heads, sl, d)
    return x.transpose(2, 1).reshape(-1,sl,d*heads)

In [None]:
# ein ops is inspired by einsum
from einops import rearrange

In [None]:
# This is for tensor rearrangement notation turn this into this
# so essentailly we have 3 dimensional tensor - n s (h d) where h = 8
# now we want n * 8 and s and d the same so we are reducing teh number of d channels by 8 
# (torch.Size([64, 256, 32]), torch.Size([512, 256, 4]))
# now we have 4 x the number of images with teh same number of channels
t2 = rearrange(t , 'n s (h d) -> (n h) s d', h=8)
t.shape, t2.shape

(torch.Size([64, 256, 32]), torch.Size([512, 256, 4]))

In [None]:
t3 = rearrange(t2, '(n h) s d -> n s (h d)', h=8)

In [None]:
t2.shape,t3.shape

(torch.Size([512, 256, 4]), torch.Size([64, 256, 32]))

In [None]:
(t==t3).all()

tensor(True)

In [None]:
# multi headed attention does the entire dot product on separate channels ,the heads just split out the channels
# This allows us to have different channels to extract different information using the multi head attention

class SelfAttentionMultiHead(nn.Module):
    def __init__(self, ni, nheads):
        super().__init__()
        self.nheads = nheads
        self.scale = math.sqrt(ni/nheads)
        self.norm = nn.BatchNorm2d(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, inp):
        n,c,h,w = inp.shape
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        x = self.qkv(x)
        # After teh projection they take the number of heads and they make each batch 4x bigger
        # 1 image 32 channles - 4 images 8 channels - just make them be different images and they will have nothing to do with eachother
        # h groups of d but now we make it n groups of h groups of 4
        # Look above to see how they are broken down
        x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)
        q,k,v = torch.chunk(x, 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)
        x = self.proj(x).transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [None]:
sa = SelfAttentionMultiHead(32, 4)
sx = sa(x)
sx.shape

torch.Size([64, 32, 16, 16])

In [None]:
sx.mean(),sx.std()

(tensor(0.0248, grad_fn=<MeanBackward0>),
 tensor(1.0069, grad_fn=<StdBackward0>))

In [None]:
# pytorch has nn.multihead attetnion - it expects teh batch to be the second dimensiion
# if we do batch first we can make ti the same as diffisuers
# self attention will ahve everything that will be the same - q k v projections
# if we pass differet things we will get cross attention
nm = nn.MultiheadAttention(32, num_heads=8, batch_first=True)
nmx,nmw = nm(t,t,t)
nmx = nmx + t

In [None]:

nmx.mean(), nmx.std()

(tensor(-0.0021, grad_fn=<MeanBackward0>),
 tensor(1.0015, grad_fn=<StdBackward0>))