In [1]:
import torch
from torch import nn
from einops import rearrange
import math


In [2]:
class SelfAttentionMultiHead(nn.Module):
    def __init__(self, ic, nheads):
        super().__init__()
        self.nheads = nheads
        self.scale = math.sqrt(ic/nheads)
        self.norm = nn.BatchNorm2d(ic)
        self.qkv = nn.Linear(ic, ic*3)
        self.proj = nn.Linear(ic, ic)
    
    def forward(self, inp):
        n,c,h,w = inp.shape
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        x = self.qkv(x)
        x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)
        q,k,v = torch.chunk(x, 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)
        x = self.proj(x).transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [3]:
x = torch.randn(64,32,16,16)
x.shape

torch.Size([64, 32, 16, 16])

In [4]:
t = x.view(*x.shape[:2], -1).transpose(1, 2)
t.shape

torch.Size([64, 256, 32])

In [5]:
sa = SelfAttentionMultiHead(32, 4)
sx = sa(x)
sx.shape

torch.Size([64, 32, 16, 16])

In [6]:
sx.mean()

tensor(0.0110, grad_fn=<MeanBackward0>)

In [7]:
nm = nn.MultiheadAttention(32, num_heads=8, batch_first=True)
nmx, nmw = nm(t, t, t)
nmx = nmx + t