In [4]:
import math,torch
from torch import nn
from minai.activations import *

In [5]:
import matplotlib.pyplot as plt

In [6]:
from diffusers.models.attention import Attention

In [7]:
set_seed(42)
x = torch.randn(64,32,16,16)

In [8]:
t = x.view(*x.shape[:2], -1).transpose(1, 2)
t.shape

torch.Size([64, 256, 32])

In [9]:
ni = 32

Three linear projections are needed. (I called these $W_q$ etc)

In [10]:
sk = nn.Linear(ni, ni)
sq = nn.Linear(ni, ni)
sv = nn.Linear(ni, ni)

In [11]:
k = sk(t)
q = sq(t)
v = sv(t)

In [12]:
(q@k.transpose(1,2)).shape

torch.Size([64, 256, 256])

Group norm here, with number of groups = 1, standardizes over the channel dimension.  I believe this is the same as Layer norm in this case.

Note that we set bias = False, as the Attention in Diffusers now turns this off by default.

In [13]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.GroupNorm(1, ni)
        self.q = nn.Linear(ni, ni, bias=False)
        self.k = nn.Linear(ni, ni, bias=False)
        self.v = nn.Linear(ni, ni, bias=False)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, x):
        inp = x
        n,c,h,w = x.shape
        x = self.norm(x)
        x = x.view(n, c, -1).transpose(1, 2)
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = self.proj(x)
        x = x.transpose(1,2).reshape(n,c,h,w)
        return x + inp  # skip connection

In [14]:
sa = SelfAttention(32)

In [15]:
sa

SelfAttention(
  (norm): GroupNorm(1, 32, eps=1e-05, affine=True)
  (q): Linear(in_features=32, out_features=32, bias=False)
  (k): Linear(in_features=32, out_features=32, bias=False)
  (v): Linear(in_features=32, out_features=32, bias=False)
  (proj): Linear(in_features=32, out_features=32, bias=True)
)

In [16]:
ra = sa(x)
ra.shape

torch.Size([64, 32, 16, 16])

In [17]:
ra[0,0,0]

tensor([ 1.9386,  1.5361,  0.9011, -2.0526,  0.6612, -1.2371, -0.0379, -1.5735,
        -0.7395,  1.6879, -0.3889, -1.4020, -0.7013, -0.5660, -0.7522,  0.7790],
       grad_fn=<SelectBackward0>)

In [18]:
def cp_parms(a,b):
    b.weight = a.weight
    b.bias = a.bias

In [19]:
at = Attention(32, heads= 1, dim_head = 32,out_dim = 32, residual_connection=1, norm_num_groups=1, dropout=0)
src = sa.q, sa.k, sa.v, sa.proj, sa.norm
dst = at.to_q, at.to_k, at.to_v, at.to_out[0], at.group_norm
for s,d in zip(src,dst): cp_parms(s,d)

In [20]:
at.eval()
at(x)[0,0,0]

tensor([ 1.9386,  1.5361,  0.9011, -2.0526,  0.6612, -1.2371, -0.0379, -1.5735,
        -0.7395,  1.6879, -0.3889, -1.4020, -0.7013, -0.5660, -0.7522,  0.7790],
       grad_fn=<SelectBackward0>)

In [18]:
sa(x)[0,0,0]

tensor([ 1.9386,  1.5361,  0.9011, -2.0526,  0.6612, -1.2371, -0.0379, -1.5735,
        -0.7395,  1.6879, -0.3889, -1.4020, -0.7013, -0.5660, -0.7522,  0.7790],
       grad_fn=<SelectBackward0>)

We could combine the three linear projections into one and chunk them out : 

In [19]:
sqkv = nn.Linear(ni, ni*3)
st = sqkv(t)
st.shape

torch.Size([64, 256, 96])

In [20]:
q,k,v = torch.chunk(st, 3, dim=-1)
q.shape

torch.Size([64, 256, 32])

In [21]:
(k@q.transpose(1,2)).shape

torch.Size([64, 256, 256])

This version uses that qkv 'trick, and also replaces the group norm with batch norm.  Jermemy doesn't mention this.

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.BatchNorm2d(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, inp):
        n,c,h,w = inp.shape
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        q,k,v = torch.chunk(self.qkv(x), 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = self.proj(x).transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.BatchNorm2d(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, x):
        x = self.norm(x).transpose(1, 2)
        q,k,v = torch.chunk(self.qkv(x), 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        return self.proj(x).transpose(1,2)

In [None]:
sa = SelfAttention(32)
sa(x).shape

torch.Size([64, 32, 16, 16])

In [None]:
sa(x).std()

tensor(1.0047, grad_fn=<StdBackward0>)

### Multihead Attention 

With multiheads, the model can attent to differnet parts of the image depending on the head, allowing it to attend to different parts at the same time. (Due to the softmax, the mixing tends to be quite peaky, so this is a good idea).

This is done by just splitting the input features into small chunks, and running attention on each chunk.  The output is then concatenated.

We apply the $W_q$ etc to the input, and then split, so that the split is not on the same features every time. 

To implement this, the just turn each of heads into a new batch item, run the attention as normal, and then split the batches back into heads.

In [None]:
def heads_to_batch(x, heads):
    n,sl,d = x.shape
    x = x.reshape(n, sl, heads, -1)
    return x.transpose(2, 1).reshape(n*heads,sl,-1)

def batch_to_heads(x, heads):
    n,sl,d = x.shape
    x = x.reshape(-1, heads, sl, d)
    return x.transpose(2, 1).reshape(-1,sl,d*heads)

In [21]:
from einops import rearrange

In [22]:
t2 = rearrange(t , 'n s (h d) -> (n h) s d', h=8)
t.shape, t2.shape

(torch.Size([64, 256, 32]), torch.Size([512, 256, 4]))

In [23]:
t3 = rearrange(t2, '(n h) s d -> n s (h d)', h=8)

In [24]:
t2.shape,t3.shape

(torch.Size([512, 256, 4]), torch.Size([64, 256, 32]))

In [25]:
(t==t3).all()

tensor(True)

In [26]:
class SelfAttentionMultiHead(nn.Module):
    def __init__(self, ni, nheads):
        super().__init__()
        self.nheads = nheads
        self.scale = math.sqrt(ni/nheads)
        self.norm = nn.BatchNorm2d(ni) # note still using batchnorm 
        self.qkv = nn.Linear(ni, ni*3) # using the qkv thingy.
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, inp):
        n,c,h,w = inp.shape
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        x = self.qkv(x)
        x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)  ## heads_to_batch
        q,k,v = torch.chunk(x, 3, dim=-1)  ## chunk the qkv projections
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads) ## batch_to_heads
        x = self.proj(x).transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [27]:
sa = SelfAttentionMultiHead(32, 4)
sx = sa(x)
sx.shape

torch.Size([64, 32, 16, 16])

In [28]:
sx.mean(),sx.std()

(tensor(-0.0191, grad_fn=<MeanBackward0>),
 tensor(1.0074, grad_fn=<StdBackward0>))

In [29]:
nm = nn.MultiheadAttention(32, num_heads=8, batch_first=True)
nmx,nmw = nm(t,t,t)
nmx = nmx+t

In [30]:
nmx.mean(),nmx.std()

(tensor(-0.0021, grad_fn=<MeanBackward0>),
 tensor(1.0040, grad_fn=<StdBackward0>))