In [1]:
import os
os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1"
os.environ["TPU_VISIBLE_CHIPS"] = "0"
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import jax
import jax.numpy as jnp
import jax.nn as jnn
import jax.lax as lax
from jax import jit
from functools import partial
from flax import struct
from typing import Optional
from dataclasses import dataclass
import math

# Try to import torch_xla for TPU support
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    TORCH_XLA_AVAILABLE = True
except ImportError:
    TORCH_XLA_AVAILABLE = False

# Device selection: TPU > CUDA > CPU
if TORCH_XLA_AVAILABLE:
    device = torch_xla.device()  # TPU device
    print(f"Using PyTorch XLA device: {device}")
elif torch.cuda.is_available():
    device = "cuda"
    print("Using CUDA device")
else:
    device = "cpu"
    print("Using CPU device")

jax_dtype = jnp.bfloat16
torch_dtype = torch.bfloat16



Using PyTorch XLA device: xla:0


In [2]:
# JAX KVCache implementation
@struct.dataclass
class KVCache:
    k: jax.Array
    v: jax.Array
    positions: jax.Array  # [bsz] - tracks current filled position for each sequence

    @classmethod
    def new(
        cls,
        n_layers: int,
        bsz: int,
        max_seqlen: int,
        kv_heads: int,
        head_dim: int,
        dtype=jnp.bfloat16,
    ) -> "KVCache":
        return cls(
            k=jnp.zeros((n_layers, bsz, max_seqlen, kv_heads, head_dim), dtype=dtype),
            v=jnp.zeros((n_layers, bsz, max_seqlen, kv_heads, head_dim), dtype=dtype),
            positions=jnp.zeros(bsz, dtype=jnp.int32),
        )

    def update(self, xk: jax.Array, xv: jax.Array, layer_idx: int):
        """Updates the Key and Value cache for all sequences, each at their own position."""
        bsz, seqlen, n_kv_heads, head_dim = xk.shape

        # Ensure xk/xv have the same dtype as cache
        xk = xk.astype(self.k.dtype)
        xv = xv.astype(self.v.dtype)

        # Start with current cache
        new_k = self.k
        new_v = self.v

        # Update each sequence at its own position
        for i in range(bsz):
            start_pos = self.positions[i]
            xk_i = xk[i]
            xv_i = xv[i]
            xk_update = xk_i[None, None, :, :, :]
            xv_update = xv_i[None, None, :, :, :]

            new_k = jax.lax.dynamic_update_slice(
                new_k,
                xk_update,
                (layer_idx, i, start_pos, 0, 0)
            )
            new_v = jax.lax.dynamic_update_slice(
                new_v,
                xv_update,
                (layer_idx, i, start_pos, 0, 0)
            )

        new_positions = self.positions + seqlen
        return KVCache(k=new_k, v=new_v, positions=new_positions)

    def get_layer(self, layer_idx: int):
        """Retrieves K/V for a specific layer."""
        keys = self.k[layer_idx]
        values = self.v[layer_idx]
        return keys, values


In [3]:
# JAX RoPE and helper functions
@partial(
    jit,
    static_argnames=[
        "scale_factor",
        "low_freq_factor",
        "high_freq_factor",
        "old_context_len",
    ],
    donate_argnums=[0],
)
def apply_scaling_jax(
    freqs: jax.Array,
    scale_factor: float = 8.0,
    low_freq_factor: float = 1.0,
    high_freq_factor: float = 4.0,
    old_context_len: float = 8192.0,
) -> jax.Array:
    """Apply RoPE scaling to frequencies based on Llama 3 implementation."""
    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor

    wavelen = 2 * jnp.pi / freqs

    smooth = (old_context_len / wavelen - low_freq_factor) / (
        high_freq_factor - low_freq_factor
    )
    freqs_for_mid_range = (1 - smooth) * freqs / scale_factor + smooth * freqs

    new_freqs = jnp.where(
        wavelen > low_freq_wavelen,
        freqs / scale_factor,
        jnp.where(wavelen < high_freq_wavelen, freqs, freqs_for_mid_range),
    )
    return new_freqs


@partial(jit, static_argnames=["head_dim", "end", "use_scaled", "dtype"])
def precompute_freqs_cis_jax(
    head_dim: int,
    end: int,
    theta: float = 500000.0,
    use_scaled: bool = False,
    dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
    """Precompute the rotational frequency embeddings."""
    freqs = 1.0 / (
        theta
        ** (jnp.arange(0, head_dim, 2)[: (head_dim // 2)].astype(dtype) / head_dim)
    )
    if use_scaled:
        freqs = apply_scaling_jax(freqs)
    t = jnp.arange(end, dtype=dtype)
    freqs = jnp.outer(t, freqs)

    freqs_cos = jnp.cos(freqs)
    freqs_sin = jnp.sin(freqs)

    freqs_cis = jnp.stack([freqs_cos, freqs_sin], axis=-1)
    return freqs_cis


@partial(jit, donate_argnums=[0])
def apply_rotary_emb_batch(x: jax.Array, freqs_cis: jax.Array) -> jax.Array:
    """Apply Rotary Positional Embeddings (RoPE) to a tensor with per-batch-item frequencies."""
    x_shaped = x.reshape(*x.shape[:-1], -1, 2)
    x_r, x_i = x_shaped[..., 0], x_shaped[..., 1]

    freqs_cis = freqs_cis[:, :, None, :, :]
    freqs_cos, freqs_sin = freqs_cis[..., 0], freqs_cis[..., 1]

    x_out_r = x_r * freqs_cos - x_i * freqs_sin
    x_out_i = x_r * freqs_sin + x_i * freqs_cos

    x_out = jnp.stack([x_out_r, x_out_i], axis=-1).reshape(x.shape)
    return x_out


@partial(jit, static_argnames=["n_rep"])
def repeat_kv(x: jax.Array, n_rep: int) -> jax.Array:
    """Repeat Key/Value heads for Grouped Query Attention."""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return jnp.broadcast_to(
        x[:, :, :, None, :], (bs, slen, n_kv_heads, n_rep, head_dim)
    ).reshape(bs, slen, n_kv_heads * n_rep, head_dim)


In [4]:
# JAX Attention implementation
@struct.dataclass
class AttentionParams:
    wq: jax.Array
    wk: jax.Array
    wv: jax.Array
    wo: jax.Array


@partial(jit, static_argnames=["layer_idx"], donate_argnums=[0, 1, 3])
def grouped_query_attention(
    x: jax.Array,
    freqs_cis: jax.Array,
    params: AttentionParams,
    kv_cache: KVCache,
    layer_idx: int,
    seq_lengths: jax.Array,
) -> tuple[jax.Array, KVCache]:
    """Compute Grouped Query Attention with variable-length sequences."""
    bsz, seqlen, dim = x.shape

    # Get per-sequence cache positions
    start_positions = kv_cache.positions

    # Project inputs to queries, keys, values
    xq = jnp.einsum("bsd,dhc->bshc", x, params.wq)
    xk = jnp.einsum("bsd,dkc->bskc", x, params.wk)
    xv = jnp.einsum("bsd,dvc->bsvc", x, params.wv)

    # Apply RoPE at absolute positions
    position_offsets = jnp.arange(seqlen)[None, :]
    absolute_positions = start_positions[:, None] + position_offsets

    # Get frequencies for these absolute positions
    batch_freqs_cis = freqs_cis[absolute_positions]

    # Apply rotary embeddings
    xq = apply_rotary_emb_batch(xq, batch_freqs_cis)
    xk = apply_rotary_emb_batch(xk, batch_freqs_cis)

    # Update cache
    updated_cache = kv_cache.update(xk, xv, layer_idx)
    keys, values = updated_cache.get_layer(layer_idx)

    max_seqlen = keys.shape[1]

    # Build per-sequence attention masks
    def build_mask_for_sequence(true_len, cache_pos):
        query_offsets = jnp.arange(seqlen)
        key_positions = jnp.arange(max_seqlen)
        query_positions = (cache_pos - seqlen) + query_offsets

        causal_mask = query_positions[:, None] >= key_positions[None, :]
        valid_query_mask = query_offsets < true_len

        mask = causal_mask & valid_query_mask[:, None]
        return mask

    # Apply to all sequences
    mask = jax.vmap(build_mask_for_sequence)(seq_lengths, updated_cache.positions)
    mask = mask[:, None, :, :]

    # Perform attention
    attn_output = jnn.dot_product_attention(
        query=xq,
        key=keys,
        value=values,
        mask=mask,
    )

    # Output projection
    attn_output = attn_output.reshape(bsz, seqlen, -1)
    output = jnp.einsum("bsd,do->bso", attn_output, params.wo)

    return output, updated_cache


In [5]:
# PyTorch ModelArgs and helper functions
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    flash: bool = False

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            if hasattr(self, k):
                setattr(self, k, v)
        if self.n_kv_heads is None:
            self.n_kv_heads = self.n_heads


def apply_scaling_torch(freqs: torch.Tensor):
    """Apply RoPE scaling based on Llama 3 implementation."""
    scale_factor = 8
    low_freq_factor = 1
    high_freq_factor = 4
    old_context_len = 8192
    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor
    new_freqs = []
    for freq in freqs:
        wavelen = 2 * math.pi / freq
        if wavelen < high_freq_wavelen:
            new_freqs.append(freq)
        elif wavelen > low_freq_wavelen:
            new_freqs.append(freq / scale_factor)
        else:
            assert low_freq_wavelen != high_freq_wavelen
            smooth = (old_context_len / wavelen - low_freq_factor) / (
                high_freq_factor - low_freq_factor
            )
            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis_torch(
    dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    if use_scaled:
        
        freqs = apply_scaling_torch(freqs)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    freqs_cis_real = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
    return freqs_cis_real


def apply_rotary_emb_torch(x, freqs_cis):
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)


def repeat_kv_torch(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )


In [6]:
# PyTorch KVCache and Attention implementation
class KVCache_torch(nn.Module):
    def __init__(self, batch_size, seq_length, n_kv_heads, head_dim, dtype, device):
        super().__init__()
        cache_shape = (batch_size, seq_length, n_kv_heads, head_dim)
        self.register_buffer(
            "cache_k", torch.zeros(cache_shape, dtype=dtype, device=device)
        )
        self.register_buffer(
            "cache_v", torch.zeros(cache_shape, dtype=dtype, device=device)
        )

    def update(self, start_pos, xk, xv):
        seqlen = xk.size(1)
        self.cache_k[:, start_pos : start_pos + seqlen] = xk
        self.cache_v[:, start_pos : start_pos + seqlen] = xv
        xk = self.cache_k[:, : start_pos + seqlen]
        xv = self.cache_v[:, : start_pos + seqlen]
        return xk, xv


class Attention_torch(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.flash = args.flash
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = 1
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        self.cache = None

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        
        xq = apply_rotary_emb_torch(xq, freqs_cis)
        xk = apply_rotary_emb_torch(xk, freqs_cis)
        
        if self.cache is not None:
            xk, xv = self.cache.update(start_pos, xk, xv)
        
        xk = repeat_kv_torch(xk, self.n_rep)
        xv = repeat_kv_torch(xv, self.n_rep)
        
        xq, xk, xv = (x.transpose(1, 2) for x in (xq, xk, xv))
        
        if self.flash:
            output = F.scaled_dot_product_attention(xq, xk, xv, mask)
        else:
            scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
            if mask is not None:
                scores = scores + mask
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, xv)
        
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        proj = self.wo(output)
        return proj


In [7]:
# 1. Setup Parameters
bsz = 4
seqlen = 128
dim = 2048
n_heads = 16
n_kv_heads = 4
head_dim = dim // n_heads
cache_seq_len = 64
start_pos = cache_seq_len
dtype = np.float32
max_seq_len = 1024
n_layers = 1

print(f"Batch size: {bsz}")
print(f"Sequence length: {seqlen}")
print(f"Model dimension: {dim}")
print(f"Number of heads: {n_heads}")
print(f"Number of KV heads: {n_kv_heads}")
print(f"Head dimension: {head_dim}")
print(f"Cache sequence length: {cache_seq_len}")
print(f"Start position: {start_pos}")


Batch size: 4
Sequence length: 128
Model dimension: 2048
Number of heads: 16
Number of KV heads: 4
Head dimension: 128
Cache sequence length: 64
Start position: 64


In [8]:
# 2. Create inputs and parameters
use_scaled_rope = True  # Enable scaled RoPE

np.random.seed(42)
x_np = np.random.randn(bsz, seqlen, dim).astype(dtype)

# Create weight matrices
wq_np = np.random.normal(0, 0.02, (dim, n_heads * head_dim)).astype(dtype)
wk_np = np.random.normal(0, 0.02, (dim, n_kv_heads * head_dim)).astype(dtype)
wv_np = np.random.normal(0, 0.02, (dim, n_kv_heads * head_dim)).astype(dtype)
wo_np = np.random.normal(0, 0.02, (n_heads * head_dim, dim)).astype(dtype)

# Precompute frequencies with scaled RoPE
print(f"Using scaled RoPE: {use_scaled_rope}")
freqs_cis_torch = precompute_freqs_cis_torch(head_dim, max_seq_len, use_scaled=use_scaled_rope).to(dtype=torch_dtype)
freqs_cis_jax = jnp.array(freqs_cis_torch.float().detach().cpu().numpy(), dtype=jax_dtype)

print(f"Input shape: {x_np.shape}")
print(f"Weight shapes: wq={wq_np.shape}, wk={wk_np.shape}, wv={wv_np.shape}, wo={wo_np.shape}")
print(f"Frequencies shape: {freqs_cis_jax.shape}")


Using scaled RoPE: True
Input shape: (4, 128, 2048)
Weight shapes: wq=(2048, 2048), wk=(2048, 512), wv=(2048, 512), wo=(2048, 2048)
Frequencies shape: (1024, 64, 2)


In [9]:
# 3. JAX setup
x_jax = jnp.array(x_np, dtype=jax_dtype)
jax_params = AttentionParams(
    wq=jnp.array(wq_np, dtype=jax_dtype).reshape(dim, n_heads, head_dim),
    wk=jnp.array(wk_np, dtype=jax_dtype).reshape(dim, n_kv_heads, head_dim),
    wv=jnp.array(wv_np, dtype=jax_dtype).reshape(dim, n_kv_heads, head_dim),
    wo=jnp.array(wo_np, dtype=jax_dtype),
)

# Initialize KV cache with randomly initialized values up to cache_seq_len
kv_cache_jax = KVCache.new(
    n_layers, bsz, max_seq_len, n_kv_heads, head_dim, dtype=jax_dtype
)

# Pre-fill cache with random values
prefill_k = np.random.randn(n_layers, bsz, cache_seq_len, n_kv_heads, head_dim).astype(dtype)
prefill_v = np.random.randn(n_layers, bsz, cache_seq_len, n_kv_heads, head_dim).astype(dtype)
k_init_jax = jnp.zeros((n_layers, bsz, max_seq_len, n_kv_heads, head_dim), dtype=jax_dtype)
v_init_jax = jnp.zeros((n_layers, bsz, max_seq_len, n_kv_heads, head_dim), dtype=jax_dtype)
k_updated_jax = k_init_jax.at[:, :, :cache_seq_len, :, :].set(jnp.array(prefill_k, dtype=jax_dtype))
v_updated_jax = v_init_jax.at[:, :, :cache_seq_len, :, :].set(jnp.array(prefill_v, dtype=jax_dtype))
kv_cache_jax = KVCache(
    k=k_updated_jax,
    v=v_updated_jax,
    positions=jnp.full((bsz,), cache_seq_len, dtype=jnp.int32)
)

print(f"JAX KV cache initialized with shape k={kv_cache_jax.k.shape}, v={kv_cache_jax.v.shape}")


JAX KV cache initialized with shape k=(1, 4, 1024, 4, 128), v=(1, 4, 1024, 4, 128)


In [10]:
# 4. PyTorch setup
model_args = ModelArgs(dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, flash=False)
x_torch = torch.tensor(x_np, device=device, dtype=torch_dtype)
torch_attention = Attention_torch(model_args)
torch_attention.wq.weight = nn.Parameter(torch.tensor(wq_np.T, device=device, dtype=torch_dtype))
torch_attention.wk.weight = nn.Parameter(torch.tensor(wk_np.T, device=device, dtype=torch_dtype))
torch_attention.wv.weight = nn.Parameter(torch.tensor(wv_np.T, device=device, dtype=torch_dtype))
torch_attention.wo.weight = nn.Parameter(torch.tensor(wo_np.T, device=device, dtype=torch_dtype))

kv_cache_torch = KVCache_torch(
    bsz, max_seq_len, n_kv_heads, head_dim, dtype=torch_dtype, device=device
)

# Pre-fill PyTorch cache with same random values
kv_cache_torch.cache_k[:, :cache_seq_len, :, :] = torch.tensor(prefill_k[0], device=device, dtype=torch_dtype)
kv_cache_torch.cache_v[:, :cache_seq_len, :, :] = torch.tensor(prefill_v[0], device=device, dtype=torch_dtype)
torch_attention.cache = kv_cache_torch

print(f"PyTorch KV cache initialized")


PyTorch KV cache initialized


In [11]:
# 5. Execute JAX attention
seq_lengths = jnp.full((bsz,), seqlen, dtype=jnp.int32)
output_jax, updated_kv_cache_jax = grouped_query_attention(
    x_jax, freqs_cis_jax, jax_params, kv_cache_jax, 0, seq_lengths
)

print(f"JAX output shape: {output_jax.shape}")
print(f"JAX output dtype: {output_jax.dtype}")


See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation.


JAX output shape: (4, 128, 2048)
JAX output dtype: bfloat16


In [12]:
# 6. Execute PyTorch attention
freqs_cis_torch_sliced = freqs_cis_torch[start_pos : start_pos + seqlen].to(device=device)

# Create causal mask for PyTorch
mask = None
if seqlen > 1:
    mask = torch.full((seqlen, seqlen), float("-inf"), device=device)
    mask = torch.triu(mask, diagonal=1)
    mask = torch.hstack(
        [torch.zeros((seqlen, start_pos), device=device), mask]
    ).type_as(x_torch)

output_torch = torch_attention.forward(
    x_torch, start_pos, freqs_cis_torch_sliced, mask
)

print(f"PyTorch output shape: {output_torch.shape}")
print(f"PyTorch output dtype: {output_torch.dtype}")


PyTorch output shape: torch.Size([4, 128, 2048])
PyTorch output dtype: torch.bfloat16


In [20]:
# 7. Compare outputs
output_jax_np = np.array(output_jax)
output_torch_np = output_torch.float().detach().cpu().numpy()

# Check shapes match
assert output_jax_np.shape == output_torch_np.shape, f"Shape mismatch: JAX {output_jax_np.shape} vs PyTorch {output_torch_np.shape}"

# Check that output is not all zeros or NaN
assert not np.all(output_jax_np == 0), "JAX output should not be all zeros"
assert not np.any(np.isnan(output_jax_np)), "JAX output should not contain NaN values"
assert not np.all(output_torch_np == 0), "PyTorch output should not be all zeros"
assert not np.any(np.isnan(output_torch_np)), "PyTorch output should not contain NaN values"

# Compare with same tolerances as test_ops.py
np.testing.assert_allclose(
    output_jax_np, output_torch_np, rtol=5e-3, atol=5e-3
)

print("✓ Attention output test passed!")
print(f"Max absolute difference: {np.max(np.abs(output_jax_np - output_torch_np))}")
print(f"Mean absolute difference: {np.mean(np.abs(output_jax_np - output_torch_np))}")


✓ Attention output test passed!
Max absolute difference: 0.0039016902446746826
Mean absolute difference: 0.0004866586241405457


In [18]:
# 8. Compare KV caches
updated_k_jax = updated_kv_cache_jax.k[0]
updated_v_jax = updated_kv_cache_jax.v[0]
updated_k_torch = torch_attention.cache.cache_k
updated_v_torch = torch_attention.cache.cache_v

# Check KV cache was updated correctly
assert not np.all(np.array(updated_k_jax[:, :cache_seq_len, :, :]) == 0), "JAX pre-filled keys should remain in cache"
assert not np.all(np.array(updated_v_jax[:, :cache_seq_len, :, :]) == 0), "JAX pre-filled values should remain in cache"
assert not np.all(np.array(updated_k_jax[:, cache_seq_len:cache_seq_len+seqlen, :, :]) == 0), "JAX new keys should be cached"
assert not np.all(np.array(updated_v_jax[:, cache_seq_len:cache_seq_len+seqlen, :, :]) == 0), "JAX new values should be cached"

# Compare cache values
np.testing.assert_allclose(
    np.array(updated_k_jax),
    updated_k_torch.float().detach().cpu().numpy(),
    rtol=1e-5,
    atol=1e-5,
)
np.testing.assert_allclose(
    np.array(updated_v_jax),
    updated_v_torch.float().detach().cpu().numpy(),
    rtol=1e-5,
    atol=1e-4,
)

print("✓ KV cache test passed!")
print(f"Max key difference: {np.max(np.abs(np.array(updated_k_jax) - updated_k_torch.float().detach().cpu().numpy()))}")
print(f"Max value difference: {np.max(np.abs(np.array(updated_v_jax) - updated_v_torch.float().detach().cpu().numpy()))}")


AssertionError: 
Not equal to tolerance rtol=1e-05, atol=1e-05

Mismatched elements: 80122 / 2097152 (3.82%)
Max absolute difference among violations: 0.03125
Max relative difference among violations: 227.83545
 ACTUAL: array([[[[-1.09375, -0.796875, 0.746094, ..., -0.275391, -1.01562,
          -0.9375],
         [-1.01562, 0.292969, 1.08594, ..., 0.388672, 0.0966797,...
 DESIRED: array([[[[-1.09375 , -0.796875,  0.746094, ..., -0.275391, -1.015625,
          -0.9375  ],
         [-1.015625,  0.292969,  1.085938, ...,  0.388672,  0.09668 ,...