1. Architecture

In [2]:
import jax 
import flax.linen as nn
import jax.numpy as jnp

In [3]:
class FeedForward(nn.Module):
    emb_dim: int
    hidden_dim: int
    dtype: any = jnp.float32

    @nn.compact
    def __call__(self, x):
        x_fc1 = nn.Dense(self.hidden_dim, use_bias=False, dtype=self.dtype)(x)
        x_fc2 = nn.Dense(self.hidden_dim, use_bias=False, dtype=self.dtype)(x)
        x = nn.gelu(x_fc1, approximate=True) * x_fc2
        x = nn.Dense(self.emb_dim, use_bias=False, dtype=self.dtype)(x)
        return x


In [4]:
class RMSNorm(nn.Module):
    emb_dim: int
    eps: float = 1e-6
    bias: bool = False

    @nn.compact
    def __call__(self, x):
        # Compute RMSNorm
        rms = jnp.sqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
        x_norm = x / rms

        scale = self.param('scale', nn.initializers.zeros, (self.emb_dim,))
        scale = 1.0 + scale  # Match Gemma3's (1 + weight) scaling
        x_norm = x_norm * scale

        if self.bias:
            shift = self.param('shift', nn.initializers.zeros, (self.emb_dim,))
            x_norm = x_norm + shift

        return x_norm


In [5]:
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=jnp.float32):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    inv_freq = 1.0 / ((theta_base) ** (jnp.arange(0, head_dim, 2, dtype=dtype) / head_dim))
    positions = jnp.arange(context_length)
    angles = positions[:, None] * inv_freq[None, :]
    angles = jnp.concatenate([angles, angles], axis=-1)
    cos = jnp.cos(angles)
    sin = jnp.sin(angles)
    return cos, sin

def apply_rope(x, cos, sin):
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"
    x1, x2 = jnp.split(x, 2, axis=4)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
    x1 = x1 * cos + x2 * sin
    x2 = x2 * cos - x1 * sin
    return jnp.concatenate([x1, x2], axis=4)


In [6]:
class GroupedQueryAttention(nn.Module):
    d_in: int
    num_heads: int
    num_kv_groups: int
    head_dim: int = None
    qk_norm: bool = False
    query_pre_attn_scalar: float = None
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x, cos, sin):
        assert self.num_heads % self.num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
        head_dim = self.head_dim or (self.d_in // self.num_heads)
        d_out = self.num_heads * head_dim

        W_query = nn.Dense(d_out, use_bias=False, dtype=self.dtype)
        W_key = nn.Dense(self.num_kv_groups * head_dim, use_bias=False, dtype=self.dtype)
        W_value = nn.Dense(self.num_kv_groups * head_dim, use_bias=False, dtype=self.dtype)
        out_proj = nn.Dense(self.d_in, use_bias=False, dtype=self.dtype)

        q = W_query(x)
        k = W_key(x)
        v = W_value(x)
        out = out_proj(q)  # This is just a placeholder

        if self.qk_norm:
            q = nn.normalize(q, dim=-1)
            k = nn.normalize(k, dim=-1)
        
        queries = apply_rope(q, cos, sin)
        keys = apply_rope(k, cos, sin)

        
        keys = jnp.repeat(keys, self.group_size, axis=1)
        values = jnp.repeat(values, self.group_size, axis=1)


        return x  # Replace with actual attention logic

In [7]:
class TransformerBlock(nn.Module):

    def __init__(self, cfg, attn_type):
        super().__init__()
        self.attn_type = attn_type 

        self.att = GroupedQueryAttention(
            d_in=cfg["emb_dim"],
            num_heads=cfg["n_heads"],
            num_kv_groups=cfg["n_kv_groups"],
            head_dim=cfg["head_dim"],
            qk_norm=cfg["qk_norm"],
            query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
            dtype=cfg["dtype"],
        )
        self.ff = FeedForward(cfg)
        self.input_layernorm = RMSNorm(cfg["emb_dim"], eps=1e-6)
        self.post_attention_layernorm = RMSNorm(cfg["emb_dim"], eps=1e-6)
        self.pre_feedforward_layernorm = RMSNorm(cfg["emb_dim"], eps=1e-6)
        self.post_feedforward_layernorm = RMSNorm(cfg["emb_dim"], eps=1e-6)

    @nn.compact
    def __call__(self,
        x,
        mask_global,
        mask_local,
        cos_global,
        sin_global,
        cos_local,
        sin_local,
    ):
        shortcut = x
        x = self.input_layernorm(x)

        if self.attn_type == "sliding_attention":
            attn_mask = mask_local
            cos = cos_local
            sin = sin_local
        else:
            attn_mask = mask_global
            cos = cos_global
            sin = sin_global
        
        x_attn = self.att(x, attn_mask, cos, sin)
        x_attn = self.post_attention_layernorm(x_attn)
        x = shortcut + x_attn

        # Shortcut connection for feed forward block
        shortcut = x
        x_ffn = self.pre_feedforward_layernorm(x)
        x_ffn = self.ff(x_ffn)
        x_ffn = self.post_feedforward_layernorm(x_ffn)
        x = shortcut + x_ffn
        return x

     

In [8]:
from typing import Any, Dict, Tuple

class Gemma3Model(nn.Module):
    cfg: Dict[str, Any]

    def setup(self):
        # --- Assertions and config checks ---
        assert self.cfg["layer_types"] is not None and len(self.cfg["layer_types"]) == self.cfg["n_layers"]

        # --- Token Embedding ---
        self.tok_emb = nn.Embed(
            num_embeddings=self.cfg["vocab_size"],
            features=self.cfg["emb_dim"],
            dtype=self.cfg["dtype"],
        )

        # --- Transformer blocks ---
        self.blocks = [
            TransformerBlock(self.cfg, attn_type)
            for attn_type in self.cfg["layer_types"]
        ]

        # --- Final normalization ---
        self.final_norm = RMSNorm(
            dim=self.cfg["emb_dim"], eps=1e-6
        )

        # --- Output projection (logits) ---
        self.out_head = nn.Dense(
            features=self.cfg["vocab_size"],
            use_bias=False,
            dtype=self.cfg["dtype"],
        )

        # --- RoPE params (cos, sin) ---
        cos_local, sin_local = compute_rope_params(
            head_dim=self.cfg["head_dim"],
            theta_base=self.cfg["rope_local_base"],
            context_length=self.cfg["context_length"],
            dtype=jnp.float32,
        )
        cos_global, sin_global = compute_rope_params(
            head_dim=self.cfg["head_dim"],
            theta_base=self.cfg["rope_base"],
            context_length=self.cfg["context_length"],
            dtype=jnp.float32,
        )

        # In Flax, store as "constants" (non-trainable variables)
        self.cos_local = self.variable("constants", "cos_local", lambda: cos_local)
        self.sin_local = self.variable("constants", "sin_local", lambda: sin_local)
        self.cos_global = self.variable("constants", "cos_global", lambda: cos_global)
        self.sin_global = self.variable("constants", "sin_global", lambda: sin_global)

    def _create_masks(self, seq_len: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
        ones = jnp.ones((seq_len, seq_len), dtype=bool)

        # Global causal mask (upper triangular, future is masked)
        mask_global = jnp.triu(ones, k=1)

        # Far-past mask (sliding window)
        far_past = jnp.triu(ones, k=self.cfg["sliding_window"]).T

        # Local mask = future OR far past
        mask_local = jnp.logical_or(mask_global, far_past)

        return mask_global, mask_local

    def __call__(self, input_ids: jnp.ndarray) -> jnp.ndarray:
        # Batch size, sequence length
        b, seq_len = input_ids.shape

        # Token embedding + scaling
        x = self.tok_emb(input_ids) * (self.cfg["emb_dim"] ** 0.5)

        # Build masks
        mask_global, mask_local = self._create_masks(seq_len)

        # Pass through transformer blocks
        for block in self.blocks:
            x = block(
                x,
                mask_global=mask_global,
                mask_local=mask_local,
                cos_global=self.cos_global.value,
                sin_global=self.sin_global.value,
                cos_local=self.cos_local.value,
                sin_local=self.sin_local.value,
            )

        # Normalize and project to vocab logits
        x = self.final_norm(x)
        logits = self.out_head(x.astype(self.cfg["dtype"]))
        return logits


In [12]:
GEMMA3_CONFIG_270M = {
    "vocab_size": 262_144,
    "context_length": 32_768,
    "emb_dim": 640,
    "n_heads": 4,
    "n_layers": 18,
    "hidden_dim": 2048,
    "head_dim": 256,
    "qk_norm": True,
    "n_kv_groups": 1,
    "rope_local_base": 10_000.0,
    "rope_base": 1_000_000.0,
    "sliding_window": 512,
      "layer_types": [
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention"
    ],
    "dtype": jnp.bfloat16,
    "query_pre_attn_scalar": 256,
}


In [13]:
model = Gemma3Model(GEMMA3_CONFIG_270M)

model

Gemma3Model(
    # attributes
    cfg = {'vocab_size': 262144, 'context_length': 32768, 'emb_dim': 640, 'n_heads': 4, 'n_layers': 18, 'hidden_dim': 2048, 'head_dim': 256, 'qk_norm': True, 'n_kv_groups': 1, 'rope_local_base': 10000.0, 'rope_base': 1000000.0, 'sliding_window': 512, 'layer_types': ['sliding_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'full_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'full_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'full_attention'], 'dtype': <class 'jax.numpy.bfloat16'>, 'query_pre_attn_scalar': 256}
)