In [1]:
import math,torch
from torch import nn
import sys
from pathlib import Path

# Add workspace root to Python path so we can import silen_lib
workspace_root = Path.cwd().parent.parent
if str(workspace_root) not in sys.path:
    sys.path.insert(0, str(workspace_root))

import matplotlib.pyplot as plt

from diffusers.models.attention_processor import Attention

import silen_lib.utils as utils

In [2]:
utils.set_seed(42)
# Here, we simulate the output from a convolutional block (not a raw image), 
# which is why num_channels (C) is 32 instead of 1 or 3 as with input images.
x = torch.randn(64,32,16,16) # N (batch size), C (channels from conv feature map), H, W

In [3]:
# The code below reshapes the 4D tensor of shape (batch_size, channels, height, width)
# into a 3D tensor suitable for attention modules.
# - x.shape[:2] unpacks (N, C)
# - .view(*x.shape[:2], -1) collapses H*W into the last dimension: (N, C, H*W)
# - .transpose(1, 2) swaps the channels and "sequence" axes to: (N, H*W, C)
#
# This is done to treat each spatial location as an element in a sequence,
# with the "embedding" size being the channels.

# Example:
# Suppose x.shape == (64, 32, 16, 16)
# After `.view(*x.shape[:2], -1)`, shape is (64, 32, 256)
# After `.transpose(1, 2)`, shape is (64, 256, 32)

t = x.view(*x.shape[:2], -1).transpose(1, 2)
print("Original x shape:", x.shape)
print("After view and transpose, t shape:", t.shape)

Original x shape: torch.Size([64, 32, 16, 16])
After view and transpose, t shape: torch.Size([64, 256, 32])


In [4]:
ni = 32

In [5]:
sk = nn.Linear(ni, ni)
sq = nn.Linear(ni, ni)
sv = nn.Linear(ni, ni)

In [6]:
k = sk(t)
q = sq(t)
v = sv(t)

In [7]:
(q@k.transpose(1,2)).shape

torch.Size([64, 256, 256])

In [8]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.GroupNorm(1, ni)
        self.q = nn.Linear(ni, ni)
        self.k = nn.Linear(ni, ni)
        self.v = nn.Linear(ni, ni)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, x):
        inp = x
        n,c,h,w = x.shape
        x = self.norm(x)
        x = x.view(n, c, -1).transpose(1, 2)
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = self.proj(x)
        x = x.transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [9]:
sa = SelfAttention(32)

In [10]:
ra = sa(x)
ra.shape

torch.Size([64, 32, 16, 16])

In [11]:
ra[0,0,0]

tensor([ 1.9104,  1.4186,  0.8385, -2.1584,  0.6318, -1.2443, -0.0789, -1.6844,
        -0.7939,  1.6117, -0.3852, -1.4307, -0.7494, -0.6010, -0.8335,  0.7477],
       grad_fn=<SelectBackward0>)

In [12]:
def cp_parms(a,b):
    b.weight = a.weight
    b.bias = a.bias

In [13]:
# New Attention class API:
# - Uses query_dim instead of just a single channels arg
# - heads=1, dim_head=32 for single-head attention matching our implementation
# - norm_num_groups=1 for group normalization
# - residual_connection=True to add input to output
at = Attention(
    query_dim=32,
    heads=1, 
    dim_head=32,
    norm_num_groups=1,
    residual_connection=True,
    bias=True
)

# New attribute names: to_q, to_k, to_v (instead of query, key, value)
# and to_out[0] for the projection (it's a ModuleList)
src = sa.q, sa.k, sa.v, sa.proj, sa.norm
dst = at.to_q, at.to_k, at.to_v, at.to_out[0], at.group_norm
for s,d in zip(src,dst): cp_parms(s,d)

In [14]:
# The new Attention class expects hidden_states (same as x)
rb = at(x)
rb[0,0,0]

tensor([ 1.9104,  1.4186,  0.8385, -2.1584,  0.6318, -1.2443, -0.0789, -1.6844,
        -0.7939,  1.6117, -0.3852, -1.4307, -0.7494, -0.6010, -0.8335,  0.7477],
       grad_fn=<SelectBackward0>)

In [15]:
sqkv = nn.Linear(ni, ni*3)
st = sqkv(t)
st.shape

torch.Size([64, 256, 96])

In [16]:
q,k,v = torch.chunk(st, 3, dim=-1)
q.shape

torch.Size([64, 256, 32])

In [17]:
(k@q.transpose(1,2)).shape

torch.Size([64, 256, 256])

In [18]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.BatchNorm2d(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, inp):
        n,c,h,w = inp.shape
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        q,k,v = torch.chunk(self.qkv(x), 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = self.proj(x).transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [19]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.BatchNorm2d(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, x):
        x = self.norm(x).transpose(1, 2)
        q,k,v = torch.chunk(self.qkv(x), 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        return self.proj(x).transpose(1,2)

In [20]:
sa = SelfAttention(32)
sa(x).shape

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32768x16 and 32x96)

In [None]:
sa(x).std()

In [21]:
def heads_to_batch(x, heads):
    n,sl,d = x.shape
    x = x.reshape(n, sl, heads, -1)
    return x.transpose(2, 1).reshape(n*heads,sl,-1)

def batch_to_heads(x, heads):
    n,sl,d = x.shape
    x = x.reshape(-1, heads, sl, d)
    return x.transpose(2, 1).reshape(-1,sl,d*heads)

In [22]:
from einops import rearrange

In [23]:
t2 = rearrange(t , 'n s (h d) -> (n h) s d', h=8)
t.shape, t2.shape

(torch.Size([64, 256, 32]), torch.Size([512, 256, 4]))

In [24]:
t3 = rearrange(t2, '(n h) s d -> n s (h d)', h=8)

In [25]:
t2.shape,t3.shape

(torch.Size([512, 256, 4]), torch.Size([64, 256, 32]))

In [26]:
(t==t3).all()

tensor(True)

In [27]:
class SelfAttentionMultiHead(nn.Module):
    def __init__(self, ni, nheads):
        super().__init__()
        self.nheads = nheads
        # Scale factor for attention scores to help with stable gradients
        # (division by sqrt of dimension per head)
        self.scale = math.sqrt(ni/nheads)
        # Normalize input on channel dimension for more stable and performant training
        self.norm = nn.BatchNorm2d(ni)
        # Linear to compute all queries, keys, values at once
        # Output has 3x input feature dimension: for q, k, v concatenated
        self.qkv = nn.Linear(ni, ni*3)
        # Linear projection for output after attention mechanism
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, inp):
        n, c, h, w = inp.shape  # n=batch, c=channels, h & w = spatial dims
        # Apply normalization, then flatten spatial dims for attention
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)  # shape: (n, hw, c)
        # Compute queries, keys, values (all at once)
        x = self.qkv(x)  # shape: (n, sequence_len, c*3)
        # Rearrange so that 'n' and 'nheads' are combined into one dimension, 
        # splitting the (c*3) inner dimension into (number of heads x features per head).
        # einops notation: 
        #    - n: batch size
        #    - s: sequence length (here hw)
        #    - h: number of heads
        #    - d: features per head (should be c // nheads)
        # Reshape for multi-head: (n, s, h*d) --> (n*h, s, d)
        x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)
        # Split heads into Q, K, V on last dimension (each of shape (n*h, s, d))
        q, k, v = torch.chunk(x, 3, dim=-1)
        # Compute scaled dot-product attention scores
        # k.transpose(1,2) changes (n*h, s, d) -> (n*h, d, s)
        # Resulting s: (n*h, s, s), attention map for each head & batch element
        s = (q @ k.transpose(1, 2)) / self.scale
        # Apply softmax so each row of attention adds up to 1 (probability distribution)
        # then multiply by v to get weighted representations
        x = s.softmax(dim=-1) @ v  # (n*h, s, d)
        # Rearrange back to (batch, sequence, features): (n*h, s, d) -> (n, s, h*d)
        # This undoes the earlier reshape for multi-head
        x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)
        # Final linear projection, then reshape back to original (n, c, h, w)
        # .transpose(1,2): (n, s, c) -> (n, c, s). .reshape(n, c, h, w) maps back to 4D
        x = self.proj(x).transpose(1, 2).reshape(n, c, h, w)
        # Add residual connection to preserve input information
        return x + inp

In [None]:
sa = SelfAttentionMultiHead(32, 4)
sx = sa(x)
sx.shape

In [None]:
sx.mean(),sx.std()

In [None]:
nm = nn.MultiheadAttention(32, num_heads=8, batch_first=True)
nmx,nmw = nm(t,t,t)
nmx = nmx+t

In [42]:
nmx.mean(),nmx.std()

(tensor(-0.0015, grad_fn=<MeanBackward0>),
 tensor(1.0034, grad_fn=<StdBackward0>))

Thoughts
- Has someone played around with the ratio of MLP params to attention params?