In [None]:
# !pip install --quiet matplotlib
# !pip install --quiet seaborn
# !pip install --quiet tqdm
# ! pip install --quiet ipywidgets

In [1]:
# https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial15/Vision_Transformer.html

In [48]:
## Standard libraries
import os
import numpy as np
import math
import json
from functools import partial



## Imports for plotting

import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf')  # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## tqdm for loading bars
from tqdm.auto import tqdm

## To run JAX on TPU in Google Colab, uncomment the two lines below
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()

## JAX
import jax
import jax.numpy as jnp
from jax import random
# Seeding for random operations
main_rng = random.PRNGKey(42)

## Flax (NN in JAX)
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax
from flax import linen as nn
from flax.training import train_state, checkpoints

## Optax (Optimizers in JAX)
try:
    import optax
except ModuleNotFoundError: # Install optax if missing
    !pip install --quiet optax
    import optax

## PyTorch
import torch
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR100


<Figure size 640x480 with 0 Axes>

In [2]:

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial6_jax"


In [49]:

print("Device:", jax.devices()[0])

Device: TFRT_CPU_0


#### Multihead Attention

In [6]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.shape[-1]
    attn_logits = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = jnp.where(mask == 0, -9e15, attn_logits)
    attention = nn.softmax(attn_logits, axis=-1)
    values = jnp.matmul(attention, v)
    return values, attention

In [None]:
seq_len, d_k = 3, 2
main_rng, rand1 = random.split(main_rng)
qkv = random.normal(rand1, (3, L, d_k))
q, k, v = qkv[0], qkv[1], qkv[2]
values, attention = scaled_dot_product(q, k, v)
print("Q\n", q)
print("K\n", k)
print("V\n", v)
print("Values\n", values)
print("Attention\n", attention)

Q
 [[ 0.60576403  0.7990441 ]
 [-0.908927   -0.63525754]
 [-1.2226585  -0.83226097]]
K
 [[-0.47417238 -1.2504351 ]
 [-0.17678244 -0.04917514]
 [-0.41177532 -0.39363015]]
V
 [[ 1.3116323   0.21555556]
 [ 0.41164538 -0.28955024]
 [-0.96516913  0.4492738 ]]
Values
 [[0.12734914 0.06441191]
 [0.4115729  0.17320421]
 [0.46902645 0.1854193 ]]
Attention
 [[0.20383833 0.4564296  0.33973208]
 [0.46830934 0.2255167  0.30617398]
 [0.51187545 0.19520193 0.29292265]]


#### Step 1: Understand Multi-head Latent Attention (MLA)

MLA modifies the standard Multi-head Attention (MHA) by:

 1. **Compressing Keys and Values:** Instead of storing full KV matrices, MLA projects them into a low-dimensional latent space (e.g., from $d_{model}$ to $latent_dim$, where $latent_dim << d_{model}$) using a down-projection matrix.
 2. **Reconstructing Keys and Values:** During attention computation, latent vectors are up-projected back to the original dimension using separate matrices for keys and values.
 3. **Decoupled RoPE:** Positional information (via RoPE) is handled separately to maintain compatibility with KV compression.
 4. **Query Handling:** Queries can also be compressed into a latent space and then up-projected, though this is optional depending on the design.

 The goal is to reduce the KV cache size (e.g., from $O(num_heads * d_h)$ to $O(latent_dim)$ per token, where $num_heads$ is the **number of heads** and $d_h$ is **the head dimension**) while preserving attention quality.

 $X\in \mathbb{R}^{T x D }$, where $T$ represents the sequence length and $D$ is the hidden dimension. 

#### Step 2: Define the Architecture

- Down-projection: A linear layer to compress KV into a latent vector.
- Up-projection: Two linear layers to reconstruct keys and values from the latent vector.
- Query projection: Optionally compress queries into a latent space, though typically queries remain in full dimension for flexibility.
- RoPE: A custom implementation for positional embeddings, decoupled from KV compression.
- Attention computation: Standard scaled dot-product attention using the reconstructed keys and values.



In [50]:
# Define rotary Positional Embedding
def rotary_embedding(x, max_seq_len, dim):
    """Apply RoPE to input tensor x."""
    positions = jnp.arange(max_seq_len, dtype=jnp.float32)
    freqs = 1.0 / (10000 ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim))
    angles = positions[:, None] * freqs[None, :]
    sin, cos = jnp.sin(angles),jnp.cos(angles)
    """Apply rotary embeddings to input tensor x."""
    x1, x2 = x[..., ::2], x[..., 1::2]
    x_rotated = jnp.concatenate([x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1)
    return x_rotated

In [51]:
# def scaled_dot_product(q, k, v, mask=None):
#     d_k = q.shape[-1]
#     attn_logits = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
#     attn_logits = attn_logits / math.sqrt(d_k)
#     if mask is not None:
#         attn_logits = jnp.where(mask == 0, -9e15, attn_logits)
#     attention = nn.softmax(attn_logits, axis=-1)
#     values = jnp.matmul(attention, v)
#     return values, attention

In [52]:
# Scaled Dot-Product Attention
def scaled_dot_product(q, k, v, mask=None):
    denominator = q.shape[-1]  # d_h + d_h_r
    scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(denominator)  # Normalize by sqrt(d_h + d_h^R) transpose[B, n_h, d_h + d_h_r, L]
    if mask is not None:
        scores = scores + mask
    attention = nn.softmax(scores, axis=-1)
    values = jnp.matmul(attention, v)
    return values, attention


In [53]:
# Helper function to support different mask shapes.
# Output shape supports (B, number of heads, seq length, seq length)
# If 2D: broadcasted over batch size and number of heads
# If 3D: broadcasted over number of heads
# If 4D: leave as is
def expand_mask(mask):
    assert mask.ndim >= 2, "Mask must be at least 2-dimensional with L x L"
    if mask.ndim == 3:
        mask = mask.unsqueeze(1)
    while mask.ndim < 4:
        mask = mask.unsqueeze(0)
    return mask

### Original MHA with ROPE 

In [54]:
# Multi-head Latent Attention Module
class MultiHeadLatentAttention(nn.Module):
    d_model: int  # Output dimension (d_model)
    n_h: int  # Number of parallel heads (h)
    d_c : int # Latent dimension for compression (d_c) kv
    d_c_ : int # Latent dimension for compression (d_c') queries
    d_h_R: int  # Rotated dimension for RoPE (d_h^R)

    def setup(self):
        self.d_h = self.d_model // self.n_h

        # Down-projection for KV
        self.dkv_proj = nn.Dense(
            features=self.d_c, 
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )  # W^DKV
        
        # Up-projection for keys (compressed part)
        self.uk_proj = nn.Dense(
            features=self.d_model,   # d_h * n_h = d_model
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )  # W^UK

        # Up-projection for values
        self.uv_proj = nn.Dense(
            features=self.d_model, # d_h * n_h = d_model
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )  # W^UV

         # Down-projection for queries
        self.dq_proj = nn.Dense(
            features=self.d_c_,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )  # W^DQ


        # Up-projection for queries (compressed part)
        self.uq_proj = nn.Dense(
            features=self.d_model,  # W^{UQ} (d_h n_h x d_c' (d_c_))  c^Q (d_c')
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )  # W^UQ


        # Rotated projection for queries and keys (shared)
        self.qr_proj  = nn.Dense(
            features=self.d_h_R * self.n_h, # d_h^R * n_h
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )  # W^QR 

        # Rotated projection for keys (shared)
        self.kr_proj = nn.Dense(
            features=self.d_h_R, # d_h^R * d x d = d_h^R
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )  # W^KR

        # Output projection
        self.o_proj = nn.Dense(
            features=self.d_model,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )

    def __call__(self, x, mask=None):
        B, L, _ = x.shape  # Batch, Sequence Length, Dimension
    

        if mask is not None:
            mask = expand_mask(mask)

      
        # Down-project KV
        c_kv = self.dkv_proj(x)  # c_t^KV = W^DKV h_t [B, L, latent_dim (d_c)]
        # Up-project keys (compressed part)
        k_c = self.uk_proj(c_kv)  # k_t^C = W^UK c_t^KV [B, L, dim]
        # Up-project values
        v_c = self.uv_proj(c_kv)  # v_t^C = W^UV c_t^KV [B, L, dim]
   
        # Down-project queries
        c_q = self.dq_proj(x)  # c_t^Q = W^DQ h_t [B, L, latent_dim(d_c')]
        # Up-project queries (compressed part)
        q_c = self.uq_proj(c_q)  # q_t^C = W^UQ c_t^Q [B, L, dim]


        # Rotated part for queries
        q_r = self.qr_proj(c_q)  # W^QR c_t^Q [B, L, d_h^R * n_h]
        q_r = rotary_embedding(q_r, L, self.d_h_R * self.n_h)  # Apply RoPE  [B, L, d_h^R * n_h]
        q_r = q_r.reshape(B, L, self.n_h, self.d_h_R)  # Reshape to [B, L, n_h, d_h^R]
        q_r = q_r.transpose(0, 2, 1, 3)  # [B, n_h, L, d_h^R]

        q_c = q_c.reshape(B, L, self.n_h, self.d_h)
        q_c = q_c.transpose(0,2,1,3) # [B, n_h, L, d_h]
     
        q = jnp.concatenate([q_c, q_r], axis=-1)  # [B, n_h, L, d_h + d_h^R]
        
        # Rotated part for keys (shared)
        k_r = self.kr_proj(x) # k_t^R = W^KR h_t [B, L, d_h^R]
        k_r = rotary_embedding(k_r, L, self.d_h_R)  # Apply RoPE  [B, L, d_h^R]
        k_r = k_r.reshape(B, L, 1, self.d_h_R)# [B, L, n_h = 1, d_h^R]
        k_r = jnp.repeat(k_r, self.n_h, axis=2)  # Repeat along head dimension to [B, L, n_h, d_h^R]
        k_r = k_r.transpose(0, 2, 1, 3) # [B, n_h,  L, d_h^R]
     
        k_c = k_c.reshape(B, L, self.n_h, self.d_model // self.n_h) # n_h x d_h (d // n_h) [B, L, n_h, d_h]
        k_c  = k_c.transpose(0, 2, 1, 3)
        k = jnp.concatenate([k_c, k_r], axis=-1)  # [B, n_h, L, d_h^R + d_h]
  
        
        v_c = v_c.reshape(B, L, self.n_h, self.d_model // self.n_h) #[B, L, n_h, d_h]
        v = v_c.transpose(0, 2, 1, 3)  # [B, n_h, L, d_h]

        # Attention computation
        values, attention = scaled_dot_product(q, k, v, mask=mask) #[B, n_h, L , d_h]
       
        # Reshape and project output
        values = values.transpose(0, 2, 1, 3)  #  [B, L, n_h, d_h]
        values = values.reshape(B, L, self.n_h * self.d_model // self.n_h) #[B, L, n_h * d_h]
        o = self.o_proj(values)  # [B, L, d_model]

        return o, attention


In [56]:
# Example usage
def main():
    # Hyperparameters
    dim = 512  # Output dimension
    num_heads = 8
    latent_dim = 128  # Latent dimension (d_c)
    rotary_dim = 32  # Rotated dimension (d_h^R)
    seq_len = 64
    batch_size = 2
    
    # d_model: int  # Output dimension (d_model)
    # n_h: int  # Number of parallel heads (h)
    # d_c : int # Latent dimension for compression (d_c) kv
    # d_c_ : int # Latent dimension for compression (d_c') queries
    # d_h_R: int  # Rotated dimension for RoPE (d_h^R)
    # Initialize model
    rng = jax.random.PRNGKey(0)
    model = MultiHeadLatentAttention(
        d_model=dim, n_h=num_heads, d_c=latent_dim,d_c_ = latent_dim, d_h_R=rotary_dim
    )
    x = jax.random.normal(rng, (batch_size, seq_len, dim))
    params = model.init(rng, x)

    # Forward pass
    output, attention = model.apply(params, x)
    print("Output shape:", output.shape)  # Should be [batch_size, seq_len, dim]
    print("Attention shape:", attention.shape)  # Should be [batch_size, num_heads, seq_len, seq_len]



In [57]:
%%timeit
main()

Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 6

#### MLA

In [59]:
# # Scaled Dot-Product Attention
# def scaled_dot_product(q, k, v, mask=None):
#     denominator = q.shape[-1]  # d_h 
#     scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(denominator)  # Normalize by sqrt(d_h + d_h^R) transpose[B, n_h, d_h + d_h_r, L]
#     if mask is not None:
#         scores = scores + mask
#     attention = nn.softmax(scores, axis=-1)
#     values = jnp.matmul(attention, v)
#     return values, attention

In [60]:
# Multi-head Latent Attention Module
class MultiHeadLatentAttention(nn.Module):
    d_model: int  # Output dimension (d_model)
    n_h: int  # Number of parallel heads (h)
    d_c : int # Latent dimension for compression (d_c) kv

    def setup(self):
        self.d_h = self.d_model // self.n_h

        # Down-projection for KV
        self.dkv_proj = nn.Dense(
            features=self.d_c, 
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )  # W^DKV
        
        # Up-projection for keys (compressed part)
        self.uk_proj = nn.Dense(
            features=self.d_model,   # d_h * n_h = d_model
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )  # W^UK

        # Up-projection for values
        self.uv_proj = nn.Dense(
            features=self.d_model, # d_h * n_h = d_model
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )  # W^UV

        # projection for queries (compressed part)
        self.q_proj = nn.Dense(
            features=self.d_model,  # W^{UQ} (d_h n_h x d_c' (d_c_))  c^Q (d_c')
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )  # W^Q

        # Output projection
        self.o_proj = nn.Dense(
            features=self.d_model,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )

    def __call__(self, x, mask=None):
        B, L, _ = x.shape  # Batch, Sequence Length, Dimension
    

        if mask is not None:
            mask = expand_mask(mask)

      
        # Down-project KV
        c_kv = self.dkv_proj(x)  # c_t^KV = W^DKV h_t [B, L, d_c]
        # Up-project keys (compressed part)
        k = self.uk_proj(c_kv)  # k_t^C = W^UK c_t^KV [B, L, dim]
        k = k.reshape(B, L, self.n_h, self.d_h)
        k = k.transpose(0, 2, 1, 3)
        # Up-project values
        v = self.uv_proj(c_kv)  # v_t^C = W^UV c_t^KV [B, L, dim]
        v = v.reshape(B, L, self.n_h, self.d_h)
        v = v.transpose(0, 2, 1, 3)
  
        # project queries
        q = self.q_proj(x)  # c_t^Q = W^Q h_t [B, L, dim]
        q = q.reshape(B, L, self.n_h, self.d_h)
        q = q.transpose(0, 2, 1, 3)
      
          
        # Attention computation
        values, attention = scaled_dot_product(q, k, v, mask=mask) #[B, n_h, L , d_h]

        values = values.transpose(0, 2, 1, 3)  #  [B, L, n_h, d_h]
        values = values.reshape(B, L, self.n_h * self.d_model // self.n_h) #[B, L, n_h * d_h]


       
        o = self.o_proj(values)  # [B, L, d_model]

        return o, attention

In [61]:
# Example usage
def main():
    # Hyperparameters
    dim = 512  # Output dimension
    num_heads = 8
    latent_dim = 128  # Latent dimension (d_c)
    seq_len = 64
    batch_size = 2
    
    # d_model: int  # Output dimension (d_model)
    # n_h: int  # Number of parallel heads (h)
    # d_c : int # Latent dimension for compression (d_c) kv
   
    # Initialize model
    rng = jax.random.PRNGKey(0)
    model = MultiHeadLatentAttention(d_model=dim, n_h=num_heads, d_c=latent_dim )
    x = jax.random.normal(rng, (batch_size, seq_len, dim))
    params = model.init(rng, x)

    # Forward pass
    output, attention = model.apply(params, x)
    print("Output shape:", output.shape)  # Should be [batch_size, seq_len, dim]
    print("Attention shape:", attention.shape)  # Should be [batch_size, num_heads, seq_len, seq_len]



In [62]:
%%timeit
main()

Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 64)
Output shape: (2, 64, 512)
Attention shape: (2, 8, 64, 6