1. Architecture

In [6]:
import jax 
import flax.linen as nn
import jax.numpy as jnp

In [7]:
class FeedForward(nn.Module):
    emb_dim: int
    hidden_dim: int
    dtype: any = jnp.float32

    @nn.compact
    def __call__(self, x):
        x_fc1 = nn.Dense(self.hidden_dim, use_bias=False, dtype=self.dtype)(x)
        x_fc2 = nn.Dense(self.hidden_dim, use_bias=False, dtype=self.dtype)(x)
        x = nn.gelu(x_fc1, approximate=True) * x_fc2
        x = nn.Dense(self.emb_dim, use_bias=False, dtype=self.dtype)(x)
        return x


In [8]:
class RMNSNorm(nn.Module):
    emb_dim: int
    eps: float = 1e-6
    bias: bool = False

    @nn.compact
    def __call__(self, x):
        # Compute RMSNorm
        rms = jnp.sqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
        x_norm = x / rms

        scale = self.param('scale', nn.initializers.zeros, (self.emb_dim,))
        scale = 1.0 + scale  # Match Gemma3's (1 + weight) scaling
        x_norm = x_norm * scale

        if self.bias:
            shift = self.param('shift', nn.initializers.zeros, (self.emb_dim,))
            x_norm = x_norm + shift

        return x_norm


In [9]:
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=jnp.float32):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    inv_freq = 1.0 / ((theta_base) ** (jnp.arange(0, head_dim, 2, dtype=dtype) / head_dim))
    positions = jnp.arange(context_length)
    angles = positions[:, None] * inv_freq[None, :]
    angles = jnp.concatenate([angles, angles], axis=-1)
    cos = jnp.cos(angles)
    sin = jnp.sin(angles)
    return cos, sin

def apply_rope(x, cos, sin):
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"
    x1, x2 = jnp.split(x, 2, axis=4)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    x1 = x1 * cos + x2 * sin
    x2 = x2 * cos - x1 * sin
    return jnp.concatenate([x1, x2], axis=4)


In [None]:
class GroupedQueryAttention(nn.Module):
    d_in: int
    num_heads: int
    num_kv_groups: int
    head_dim: int = None
    qk_norm: bool = False
    query_pre_attn_scalar: float = None
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x, cos, sin):
        assert self.num_heads % self.num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
        head_dim = self.head_dim or (self.d_in // self.num_heads)
        d_out = self.num_heads * head_dim

        W_query = nn.Dense(d_out, use_bias=False, dtype=self.dtype)
        W_key = nn.Dense(self.num_kv_groups * head_dim, use_bias=False, dtype=self.dtype)
        W_value = nn.Dense(self.num_kv_groups * head_dim, use_bias=False, dtype=self.dtype)
        out_proj = nn.Dense(self.d_in, use_bias=False, dtype=self.dtype)

        q = W_query(x)
        k = W_key(x)
        v = W_value(x)
        out = out_proj(q)  # This is just a placeholder

        if self.qk_norm:
            q = nn.normalize(q, dim=-1)
            k = nn.normalize(k, dim=-1)
        
        queries = apply_rope(q, cos, sin)
        keys = apply_rope(k, cos, sin)

        
        keys = jnp.repeat(keys, self.group_size, axis=1)
        values = jnp.repeat(values, self.group_size, axis=1)


        return x  # Replace with actual attention logic