<a href="https://colab.research.google.com/github/QasimWani/simple-transformer/blob/main/transformers/attention_variants.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
### This notebook implements several attention variants

# Variants:
# 1. SingleQuery Dot Attention - 2m
# 2. SHA - 5m
# 3. MHA - 8m (with causal + attention mask)
# 4. SingleQuery MHA - 10m
# 5. MQA - 7m
# 6. SingleQuery MQA - 10m
# 7. GQA - 15m
# 8. Sliding Window Attention - 7m
# 9. Custom Hand-Rolled QKV attention - 10m
# 10. RoPE - 9m

In [2]:
import torch
from einops import rearrange
import torch.nn as nn
import torch.nn.functional as F

In [3]:
## Define Constants
batch_size, seq_len, d_embed, max_seq_len, num_heads = 4, 1024, 128, 4096, 16

attention_mask = torch.arange(seq_len)[None, :] < torch.randint(low=1, high=(seq_len+1), size=(batch_size, seq_len))
X = torch.randn(batch_size, seq_len, d_embed)
QUERY = torch.rand(batch_size, d_embed)

In [4]:
class SingleQueryDotAttention(nn.Module):
  '''
  From https://arxiv.org/pdf/1911.02150, section 2.1
  Attention(q, K, V) = ∑ alpha_i * V_i, where alpha_i = softmax(qKi)
  '''
  def __init__(self, d_embed: int):
    super().__init__()

    self.k = nn.Linear(d_embed, d_embed)
    self.v = nn.Linear(d_embed, d_embed)

  def forward(self, query: torch.Tensor, x:torch.Tensor) -> torch.Tensor:
    # query: [batch_size, d_embed]
    # x: [batch_size, seq_len, d_embed]

    K = self.k(x) # batch_size, k_seq_len, d_k
    V = self.v(x) # batch_size, v_seq_len, d_v

    # Attention
    logits = torch.einsum('bd,bmd->bm', query, K) # dot product
    scores = torch.softmax(logits, dim=-1)
    out = torch.einsum('bm,bmd->bd', scores, V)
    return out

In [5]:
m_sqda = SingleQueryDotAttention(d_embed)
assert m_sqda(QUERY, X).shape == (batch_size, d_embed)

In [6]:
class SHA(nn.Module):
  def __init__(self, d_embed: int):
    super().__init__()

    self.qkv = nn.Linear(d_embed, d_embed * 3)
    self.w_out = nn.Linear(d_embed, d_embed)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    '''
    Time complexity: O(bsd^2 + bs^2d), where s = seq_len, d = d_embed, b = batch_size
    Space complexity: O(bsd + bs^2)
    '''
    batch_size, seq_len, d_embed = x.shape
    # How to calculate time complexity? (m x k) @ (k x n) = O(m x k x n)
    # How to calculate space complexity? Output shape of resulting tensor: (m x k) @ (k x n) = O(m x n)
    # (s x d) x (d x 3d) -> O(bsd^2), batch needs to be multiplied
    Q, K, V = rearrange(self.qkv(x), 'batch_size seq_len (d_embed three) -> three batch_size seq_len d_embed', three=3).unbind(0)

    # Attention
    scale = 1 / (d_embed ** 0.5)
    # (s x d) x (s x d) = O(b s^2 d)
    weight = torch.einsum('bqd,bkd -> bqk', Q, K) * scale
    score = torch.softmax(weight, dim=-1) # apply softmax across the key dimension
    # (s x s) x (s x d) -> O(bs^2d), same as above, so ignore
    attention = torch.einsum('bqk,bkd->bqd', score, V)
    # (s x d) x (d x s) => O(bsd^2), same as above, so ignore
    out = self.w_out(attention) # batch_size, seq_len_q, d_embed
    return out

In [7]:
m_sha = SHA(d_embed)
assert m_sha(X).shape == (batch_size, seq_len, d_embed)

In [8]:
class MHA(nn.Module):
  def __init__(self, d_embed: int, num_heads: int, max_seq_len: int):
    super().__init__()
    assert d_embed % num_heads == 0, f"d_embed needs to be divisible by number of heads"

    self.head_dim = d_embed // num_heads

    self.q = nn.Linear(d_embed, d_embed)
    self.k = nn.Linear(d_embed, d_embed)
    self.v = nn.Linear(d_embed, d_embed)

    self.w_out = nn.Linear(d_embed, d_embed)

    causal_mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
    self.register_buffer('mask', causal_mask)

  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
    '''
    Time complexity: O(bsd^2 + bs^2d), where s = seq_len, d = d_embed, b = batch_size
    Space complexity: O(bsd + bs^2)
    '''
    batch_size, seq_len, d_embed = x.shape

    Q = rearrange(self.q(x), 'batch_size seq_len (num_heads head_dim) -> batch_size num_heads seq_len head_dim', seq_len=seq_len, head_dim=self.head_dim)
    K = rearrange(self.k(x), 'batch_size seq_len (num_heads head_dim) -> batch_size num_heads seq_len head_dim', seq_len=seq_len, head_dim=self.head_dim)
    V = rearrange(self.v(x), 'batch_size seq_len (num_heads head_dim) -> batch_size num_heads seq_len head_dim', seq_len=seq_len, head_dim=self.head_dim)

    # Compute attention
    weight = torch.einsum('bhqd,bhkd->bhqk', Q, K) / (self.head_dim ** 0.5)
    # apply causal mask
    mask = self.mask[:seq_len, :seq_len][None, None, :, :] # (1, 1, seq_len, seq_len)
    if attention_mask is not None: # apply pad mask by combining it with causal mask
      # 1 = real, 0 = <pad>
      # we specifically want to mask out pad tokens, so first need to invert it and then OR it with our existing mask
      # batch_size, seq_len. move to batch_size, 1, 1, seq_len
      mask = mask | ~attention_mask[:, None, None, :].to(x.device)
    weight = weight.masked_fill(mask, float('-inf'))
    score = torch.softmax(weight, dim=-1)

    attention = torch.einsum('bhqk,bhkd->bhqd', score, V)
    attention = rearrange(attention, 'batch_size num_heads seq_len head_dim -> batch_size seq_len (num_heads head_dim)')
    out = self.w_out(attention)
    return out

In [9]:
m_mha = MHA(d_embed, num_heads=num_heads, max_seq_len=max_seq_len)
assert m_mha(X, attention_mask).shape == (batch_size, seq_len, d_embed)

In [10]:
class SingleQueryMHA(nn.Module):
  '''
  From https://arxiv.org/pdf/1911.02150, section 2.4
  '''
  def __init__(self, d_embed: int, num_heads: int):
    super().__init__()
    assert d_embed % num_heads == 0
    self.head_dim = d_embed // num_heads

    self.q = nn.Linear(d_embed, d_embed)
    self.k = nn.Linear(d_embed, d_embed)
    self.v = nn.Linear(d_embed, d_embed)

    self.w_out = nn.Linear(d_embed, d_embed)

  def forward(self, x:torch.Tensor, prev_k, prev_v) -> torch.Tensor:
    # x: [batch_size, d_embed], (b, d)
    # Note: d = n * h
    # prev_k: [batch_size, num_heads, seq_len, head_dim], (b, n, m, h)
    # prev_v: [batch_size, num_heads, seq_len, head_dim], (b, n, m, h)

    # out: [batch_size, num_heads, seq_len + 1, head_dim], (b, n, m+1, h)
    b, d = x.shape
    _, n, m, h = prev_k.shape

    Q = rearrange(self.q(x), 'b (n h) -> b n h', n=n, h=h) # b, n, h
    K = rearrange(self.k(x), 'b (n h) -> b n 1 h', n=n, h=h) # b, n, 1, h
    V = rearrange(self.v(x), 'b (n h) -> b n 1 h', n=n, h=h) # b, n, 1, h

    K = torch.concat((K, prev_k), dim=2) # b, n, m+1, h
    V = torch.concat((V, prev_v), dim=2) # b, n, m+1, h

    # Attention
    logits = torch.einsum('bnh,bnmh -> bnm', Q, K) / (self.head_dim ** 0.5)
    scores = torch.softmax(logits, dim=-1)

    attention = torch.einsum('bnm,bnmh->bnh', scores, V) # b, n, h
    attention = rearrange(attention, 'b n h -> b (n h)', n=n, h=h) # b, (n, h)=d

    out = self.w_out(attention) # b, d
    return out, K, V

In [11]:
m_sqmha = SingleQueryMHA(d_embed, num_heads=num_heads)
prev_k, prev_v = torch.rand(2, batch_size, num_heads, seq_len, m_sqmha.head_dim)

out = m_sqmha(QUERY, prev_k, prev_v)

assert out[0].shape == (batch_size, d_embed)
assert out[1].shape == out[2].shape == (batch_size, 16, seq_len + 1, m_sqmha.head_dim)

In [12]:
class MQA(nn.Module):
  def __init__(self, d_embed: int, num_heads: int):
    '''
    Introduced in Section 3 of https://arxiv.org/pdf/1911.02150
    '''
    super().__init__()
    assert d_embed % num_heads == 0
    self.head_dim = d_embed // num_heads

    self.q = nn.Linear(d_embed, d_embed)
    self.k = nn.Linear(d_embed, self.head_dim)
    self.v = nn.Linear(d_embed, self.head_dim)

    self.w_out = nn.Linear(d_embed, d_embed)

  def forward(self, x:torch.Tensor) -> torch.Tensor:
    '''
    Identical to MHA except different heads all share a single set of keys and values
    '''
    # x: [batch_size, seq_len, d_embed]
    # out = [batch_size, d_embed]
    b, m, d = x.shape

    Q = rearrange(self.q(x), 'b m (n h) -> b m n h', b=b, m=m, h=self.head_dim) # batch_size, seq_len, num_heads, head_dim
    K = self.k(x) # batch_size, seq_len, head_dim
    V = self.v(x) # batch_size, seq_len, head_dim

    # Attention
    logits = torch.einsum('bqnh,bmh -> bnqm', Q, K) / (self.head_dim ** 0.5) # batch_size, num_heads, q-seq_len, k-seq_len
    weights = torch.softmax(logits, dim=-1)

    attention = torch.einsum('bnqm,bmh -> bqnh', weights, V).flatten(-2, -1) # batch_size, seq_len, (num_heads, head_dim)
    out = self.w_out(attention) # batch_size, seq_len, d_embed
    return out


In [13]:
m_mqa = MQA(d_embed, num_heads=num_heads)
out = m_mqa(X)
assert out.shape == (batch_size, seq_len, d_embed)

In [14]:
class SingleQueryMQA(nn.Module):
  def __init__(self, d_embed: int, num_heads: int):
    '''
    Combination of SingleQuery MHA and MQA. Introduced in https://arxiv.org/pdf/1911.02150
    '''

    super().__init__()
    assert d_embed % num_heads == 0
    self.head_dim = d_embed // num_heads

    self.q = nn.Linear(d_embed, d_embed)
    self.k = nn.Linear(d_embed, self.head_dim)
    self.v = nn.Linear(d_embed, self.head_dim)

    self.w_out = nn.Linear(d_embed, d_embed)

  def forward(self, x: torch.Tensor, prev_k: torch.Tensor, prev_v: torch.Tensor) -> torch.Tensor:
    # x: [batch_size, d_embed], b, d
    # prev_k: [batch_size, seq_len, head_dim], b, m, h
    # prev_v: [batch_size, seq_len, head_dim], b, m, h

    # out_attention = [batch_size, d_embed]
    # k: [batch_size, seq_len + 1, head_dim]
    # v: [batch_size, seq_len + 1, head_dim]

    b, d = x.shape
    _, m, h = prev_k.shape # note: head_dim = dk = dv

    Q = rearrange(self.q(x), 'b (n h) -> b n h', h=self.head_dim) # b, n h
    K = self.k(x).unsqueeze(1) # b, 1, h
    V = self.v(x).unsqueeze(1) # b, 1, h

    K = torch.concat((prev_k, K), dim=1) # b, m+1, h
    V = torch.concat((prev_v, V), dim=1) # b, m+1, h


    # Attention
    logits = torch.einsum('bnh,bmh -> bnm', Q, K) / (self.head_dim ** 0.5) # batch_size, num_heads, seq_len
    scores = torch.softmax(logits, dim=-1)

    attention = torch.einsum('bnm,bmh -> bnh', scores, V).flatten(1, 2) # batch_size, (num_heads, head_dim)
    out = self.w_out(attention)

    return out, K, V


In [15]:
m_sqmqa = SingleQueryMQA(d_embed, num_heads=num_heads)
prev_k, prev_v = torch.rand(2, batch_size, seq_len, m_sqmha.head_dim)

out = m_sqmqa(QUERY, prev_k, prev_v)

assert out[0].shape == (batch_size, d_embed)
assert out[1].shape == out[2].shape == (batch_size, seq_len + 1, m_sqmha.head_dim)

In [16]:
class GQA(nn.Module):
  def __init__(self, d_embed: int, num_heads: int, num_groups: int):
    '''
    Balance between MQA and MHA.
    If GQA(num_groups=1), MQA.
    If GQA(num_groups=num_heads), MHA
    '''
    super().__init__()
    assert d_embed % num_heads == 0
    self.head_dim = d_embed // num_heads
    self.num_heads = num_heads
    self.num_groups = num_groups

    self.q = nn.Linear(d_embed, d_embed)
    self.k = nn.Linear(d_embed, num_groups * self.head_dim)
    self.v = nn.Linear(d_embed, num_groups * self.head_dim)

    self.w_out = nn.Linear(d_embed, d_embed)

  def forward(self, x: torch.Tensor):
    # x: [batch_size, seq_len, d_embed]

    b, m, d = x.shape
    num_repeat = self.num_heads // self.num_groups

    Q = rearrange(self.q(x), 'b m (n h) -> b n m h', h=self.head_dim, n=self.num_heads) # b, n, m, h
    # repeat: how many times you want to repeat per axis?
    # expand: what should the final shape look like? Note it works correctly only when size is 1 for that dim.
    # otherwise, use repeat or create a new dimension and then reshape later
    K = rearrange(self.k(x), 'b m (g h) -> b 1 g m h', h=self.head_dim, g=self.num_groups).expand(b, num_repeat, self.num_groups, m, self.head_dim).flatten(1, 2) # b, g=n, m, h
    # Alternative way of doing it, memory intensive since repeats copies it over unlike expand which creates a .view
    V = rearrange(self.v(x), 'b m (g h) -> b g m h', h=self.head_dim, g=self.num_groups).repeat(1, num_repeat, 1, 1) # b, g=n, m, h

    # Attention
    logits = torch.einsum('bnqh,bnkh -> bnqk', Q, K) # b, n, q-seq_len, k-seq_len
    scores = torch.softmax(logits / (self.head_dim ** 0.5), dim=-1)

    attention = torch.einsum('bnqk,bnkh -> bnqh', scores, V)
    attention = rearrange(attention, 'b n m h -> b m (n h)') # b, m, d

    out = self.w_out(attention)
    return out


In [17]:
m_gqa = GQA(d_embed, num_heads=num_heads, num_groups=2)
out = m_gqa(X)
assert out.shape == (batch_size, seq_len, d_embed)

In [18]:
class SlidingWindowAttention(nn.Module):

  def __init__(self, d_embed: int, window: int):
    ''' SHA version that implements sliding window. Trivial to extend to MHA / other variants '''
    super().__init__()
    self.q = nn.Linear(d_embed, d_embed)
    self.k = nn.Linear(d_embed, d_embed)
    self.v = nn.Linear(d_embed, d_embed)

    self.w_out = nn.Linear(d_embed, d_embed)
    self.window = window

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    b, m, d = x.shape # batch_size, seq_len, d_embed

    Q = self.q(x) # b, m, d
    K = self.k(x) # b, m, d
    V = self.v(x) # b, m, d

    base = torch.ones(m, m)
    # diagonal shifts everything post the main diagonal by k units. triu sets everything after diagonal + k as 1
    # and triu sets everything diagonal - k as 0. So the span between the window is marked as 1s, everything else is zeros.
    mask = torch.triu(base, diagonal=(self.window+1)).bool() | torch.tril(base, diagonal=-self.window).bool()

    # Attention (compute rest as-is)
    logits = torch.einsum('bqd,bkd -> bqk', Q, K) / (d ** 0.5)
    scores = logits.masked_fill(mask, float('-inf'))
    scores = torch.softmax(scores, dim=-1)

    attention = torch.einsum('bqk,bkd -> bqd', scores, V)
    out = self.w_out(attention)
    return out


In [19]:
m_swa = SlidingWindowAttention(d_embed, window=5)
out = m_swa(X)
assert out.shape == (batch_size, seq_len, d_embed)

In [20]:
# Build a custom Hand-rolled QKV Attention with Custom Scaling

# Problem: Implement a QKV attention module from scratch in PyTorch without using built-in attention functions. For input X (shape [batch_size, seq_len, d_model]),
# project to Q, K, V (each [batch_size, seq_len, d_k]). Compute attention as Attn = softmax( (Q @ K^T / sqrt(d_k)) + scaling_matrix ),
# where scaling_matrix_{i,j} = log(1 + |i - j|), then output = Attn @ V. Use einsum for all matrix operations to ensure efficiency.

# Include a forward pass test with dummy data.

# Follow-up ML questions:
# 1. Explain why the sqrt(d_k) scaling is typically used in standard attention, and how the added log(1 + |i - j|) might affect long-range dependencies.
# Answer: as context window increases, the logits for long-term dependencies will get quite large. this teaches the model to attend to tokens farther apart a lot more
#         than tokens that are closer together. Additionally, this will cause an unbalance shift in logits and potentially nullifying the impact of d_k scaling factor.
# 2. If this were part of a multi-layer transformer, what gradients might become unstable during backpropagation, and why?
# Answer: gradients from q and k will be unstable because we're apply a relative large scaling factor. This is because the softmax is becoming saturated with large shifts

In [21]:
class HandRolledQKV(nn.Module):

  def __init__(self, d_embed: int):
    super().__init__()

    self.qkv = nn.Linear(d_embed, d_embed * 3)
    self.w_out = nn.Linear(d_embed, d_embed)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    b, m, d = x.shape # batch_size, seq_len, d_embed

    Q, K, V = rearrange(self.qkv(x), 'b m (three d) -> three b m d', three=3).unbind(0)

    # Attention
    logits = torch.einsum('bqd,bkd -> bqk', Q, K) / (d ** 0.5)
    # scaling_matrix[i, j] = log(1 + abs(i - j)). I'm assuming i refers to q-seq_len and j refers to k-seq_len
    _, q_seq_len, k_seq_len = logits.shape

    # Un-vectorized approach
    # scaling_matrix = torch.zeros(q_seq_len, k_seq_len)
    # for i in range(q_seq_len):
    #   for j in range(k_seq_len):
    #     scaling_matrix[i, j] = torch.log(1 + torch.abs(i - j))

    # Vectorized approach
    q_idx = torch.arange(q_seq_len).unsqueeze(1) # Q, 1
    k_idx = torch.arange(k_seq_len).unsqueeze(0) # 1, K
    scaling_matrix = torch.log1p(torch.abs(q_idx - k_idx))

    score = torch.softmax(logits + scaling_matrix, dim=-1)

    attention = torch.einsum('bqk,bkd -> bqd', score, V)

    out = self.w_out(attention)
    return out



In [22]:
m_customQKV = HandRolledQKV(d_embed)
out = m_customQKV(X)
assert out.shape == (batch_size, seq_len, d_embed)

In [23]:
class RoPE(nn.Module):

  def __init__(self, d_embed: int):
    super().__init__()

    idx = torch.arange(d_embed//2) # range(0, d_embed / 2)
    freqs = (1/1e3) ** (idx/d_embed) # 1/10k^idx/d_embed
    self.register_buffer('freqs', freqs)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    b, m, d = x.shape # batch_size, seq_len, d_embed

    x = rearrange(x, 'b m (d two) -> b m d two', two=2) # b, m, d/2, 2. last dim: first component is reserved for sin, second for cos

    positions = torch.arange(m) # seq_len
    angles = torch.einsum('m,d->md', positions, self.freqs) # seq_len, d_embed / 2

    # Apply the rotation matrix
    # [cos x   sin x] [x0] = [x0*cosx  + x1*sinx]
    # [-sin x  cos x] [x1]   [-x0*sinx + x1*cosx]

    cos = torch.cos(angles).unsqueeze(0) # 1, m, d/2
    sin = torch.sin(angles).unsqueeze(0) # 1, m, d/2

    rotated_matrix = torch.concat([
        cos * x[..., 0] + sin * x[..., 1],
        -sin * x[..., 0] + cos * x[..., 1],
    ], dim=-1) # b, m, d

    return rotated_matrix


In [24]:
m_rope = RoPE(d_embed)
out = m_rope(X)

assert out.shape == (batch_size, seq_len, d_embed)