In [1]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, in_channels, num_heads=8, head_dim=32, groups=32) -> None:
        super().__init__()
        self.scale = head_dim ** -0.5
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.q = nn.Conv2d(in_channels, num_heads * head_dim, kernel_size=1)
        self.k = nn.Conv2d(in_channels, num_heads * head_dim, kernel_size=1)
        self.v = nn.Conv2d(in_channels, num_heads * head_dim, kernel_size=1)
        self.norm = nn.GroupNorm(groups, in_channels)
        self.proj = nn.Conv2d(num_heads * head_dim, in_channels, kernel_size=1)

    def forward(self, x):
        B, _, H, W = x.shape
        q = self.q(x).view(B, self.num_heads, self.head_dim, H * W).permute(0, 1, 3, 2)
        k = self.k(x).view(B, self.num_heads, self.head_dim, H * W)
        v = self.v(x).view(B, self.num_heads, self.head_dim, H * W).permute(0, 1, 3, 2)

        attention = torch.softmax(torch.matmul(q, k) * self.scale, dim=-1)
        attention = torch.matmul(attention, v)
        attention = attention.permute(0, 1, 3, 2).contiguous().view(B, self.num_heads * self.head_dim, H, W)
        return self.norm(x + self.proj(attention))

In [6]:
num_heads = 8
head_dim = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

q, k, v = torch.randn(3, 2, 64, 32, 32).to(device)

q_proj = nn.Conv2d(64, num_heads * head_dim, kernel_size=1).to(device)
k_proj = nn.Conv2d(64, num_heads * head_dim, kernel_size=1).to(device)
v_proj = nn.Conv2d(64, num_heads * head_dim, kernel_size=1).to(device)

B, _, H, W = q.shape

q = q_proj(q).view(B, num_heads, head_dim, H * W)
k = k_proj(k).view(B, num_heads, head_dim, H * W)
v = v_proj(v).view(B, num_heads, head_dim, H * W)


mha = nn.MultiheadAttention(embed_dim=head_dim * num_heads, num_heads=num_heads).to(device)
attn = mha(q, k, v, need_weights=False)



AssertionError: was expecting embedding dimension of 256, but got 262144