In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
DTYPE = torch.float32

In [3]:
class Flatten(nn.Module):
  def __init__(super):
    super().__init__()

  def forward(self, x):
    return x.view(x.size(0), -1)

In [4]:
class BatchNorm1d(nn.Module):
  def __init__(self, in_channels, eps=1e-05, momentum=0.1):
    super().__init__()
    self.gamma = torch.ones(in_channels)
    self.beta = torch.zeros(in_channels)
    self.training = True
    self.momentum = momentum
    self.eps = eps
    self.running_mean = torch.zeros(in_channels)
    self.running_var = torch.ones(in_channels)

  def forward(self, x):
    # INPUT SHAPE (BATCH_SIZE=M, IN_CHANNELS=C, LENGTH=L)
    if self.training:
      mean = torch.mean(x, dim=(0, 2), dtype=DTYPE).unsqueeze(0).unsqueeze(2) # (C)
      var = torch.var(x, dim=(0,2)).unsqueeze(0).unsqueeze(2) # (C)
      self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
      self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
    else:
      mean = self.running_mean
      var = self.running_var

    return ((x - mean) / torch.sqrt(var + self.eps)) * self.gamma.unsqueeze(0).unsqueeze(2) + self.beta.unsqueeze(0).unsqueeze(2)

  def parameters(self):
    return [self.gamma, self.beta]

In [5]:
SEQ_LEN = 1

In [6]:
class Attention(nn.Module):
  def __init__(self, in_channels, head_size, dv):
    super().__init__()
    self.head_size = head_size
    self.Wk = nn.Linear(in_channels, head_size) # (C, H)
    self.Wq = nn.Linear(in_channels, head_size) # (C, H)
    self.Wv = nn.Linear(in_channels, dv) # (C, H)

    # Register the triangular mask as a buffer
    self.register_buffer('tril', torch.tril(torch.ones(SEQ_LEN, SEQ_LEN)))

  def forward(self, x):
    # x - (M, C, L)
    x = x.transpose(-2,-1)
    K = self.Wk(x) # (M, L, H)
    Q = self.Wq(x) # (M, L, H)
    V = self.Wv(x) # (M, L, H)
    scores = Q @ K.transpose(-2, -1) / torch.sqrt(torch.tensor(self.head_size, dtype=DTYPE)) # (M, L, L)

    # Create a triangular mask
    mask = self.tril[:scores.size(-2), :scores.size(-1)]
    scores = scores.masked_fill(mask == 0, float('-inf'))

    attn_weights = torch.softmax(scores, dim=-1)
    return attn_weights @ V

  def parameters(self):
    return [self.Wk, self.Wq, self.Wv]

In [7]:
class MultiHeadAttention(nn.Module):
  def __init__(self, in_channels, n_heads, head_size, dv):
    super().__init__()
    self.head_size = head_size
    self.heads = [Attention(in_channels, head_size, dv) for i in range(n_heads)]
    self.Wo = nn.Linear(dv * n_heads, in_channels)

  def forward(self, x):
    # x - (M, C, L)
    self.heads_output = [head(x) for head in self.heads]
    return self.Wo(torch.cat(self.heads_output, dim=-1))

In [8]:
MAX_BATCH_SIZE, MAX_SEQ_LEN, N_KV_HEADS, HEAD_DIM = 1,1,1,1

In [9]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, in_channels, head_size, n_heads, n_kv_heads):
      self.n_heads = n_heads
      self.head_size = head_size
      self.n_kv_heads = n_kv_heads
      super().__init__()
      self.Wq = nn.Linear(in_channels, n_heads * head_size)
      self.Wk = nn.Linear(in_channels, n_kv_heads * head_size)
      self.Wv = nn.Linear(in_channels, n_kv_heads * head_size)
      self.Wo = nn.Linear(n_heads * head_size, in_channels)

      # Create empty caches for keys and values.
      self.cache_k = torch.zeros((MAX_BATCH_SIZE,MAX_SEQ_LEN,n_kv_heads,head_size))
      self.cache_v = torch.zeros((MAX_BATCH_SIZE,MAX_SEQ_LEN,n_kv_heads,head_size))
      # Register the triangular mask as a buffer
      self.register_buffer('tril', torch.tril(torch.ones(MAX_SEQ_LEN, MAX_SEQ_LEN)))

    def forward(self, x, start_pos, mask):
        M, C, L = x.shape
        x = x.tranpose(-2, -1)
        Q, K, V = self.Wq(x), self.Wk(x), self.Wv(x)

        Q = Q.view(M, L, self.n_heads, self.head_size)
        K = K.view(M, L, self.n_kv_heads, self.head_size)
        V = V.view(M, L, self.n_kv_heads, self.head_size)

        # Simple Caching
        self.cache_k = self.cache_k.to(Q.device)
        self.cache_v = self.cache_v.to(Q.device)
        self.cache_k[:M, start_pos : start_pos + L] = K
        self.cache_v[:M, start_pos : start_pos + L] = V
        K = self.cache_k[:M, : start_pos + L]
        V = self.cache_v[:M, : start_pos + L]

        n_duplicates = self.n_heads / self.n_kv_heads

        K = torch.repeat_interleave(K, dim=2, repeats=n_duplicates)
        V = torch.repeat_interleave(V, dim=2, repeats=n_duplicates)

        Q = Q.transpose(-3 -2) # (M, L, n_heads, head_size) -> (M, n_heads, L, head_size)
        K = K.transpose(-3 -2) # (M, L, n_heads, head_size) -> (M, n_heads, L, head_size)
        V = V.transpose(-3, -2) # (M, L, n_heads, head_size) -> (M, n_heads, L, head_size)

        out = F.scaled_dot_product_attention(Q, K, V, attn_mask=self.tril[:L, :L])
        return self.Wo(out.transpose(-3, -2).contiguous().view(M, L, -1))