### Code from teacher

In [2]:
"""
Minimal decoder-only Transformer blocks in Flax/JAX, commented for learning.

The model mirrors a GPT-style architecture:
- Token embeddings + learned positional embeddings
- Stack of Pre-LayerNorm decoder blocks with causal self-attention
- Final LayerNorm
- Weight tying between input embeddings and output logits projection

Tensor shape conventions used below:
- B: batch size
- T: sequence length (time/positions)
- D: hidden size / embedding dimension (d_model)
- V: vocabulary size
"""

import jax.numpy as jnp
import flax.linen as nn
from flax.linen import attention as attn

class MLP(nn.Module):
        """Transformer feed-forward network (a.k.a. MLP block).

        Structure: Dense(D -> 4D), GELU, Dense(4D -> D) by default.
        The expansion factor can be adjusted with `mlp_ratio`.

        Args:
            d_model: Hidden size D.
            mlp_ratio: Expansion factor for the intermediate hidden size.

        Input shape:  (B, T, D)
        Output shape: (B, T, D)
        """

        d_model: int
        mlp_ratio: int = 4

        @nn.compact
        def __call__(self, x):
                # Expand channel dimension (D -> hidden), apply non-linearity, project back to D.
                hidden = int(self.d_model * self.mlp_ratio)
                x = nn.Dense(hidden)(x)
                x = nn.gelu(x)
                x = nn.Dense(self.d_model)(x)
                return x

class DecoderBlock(nn.Module):
    """A single decoder block (Pre-LayerNorm + Self-Attn + MLP + residuals).

    Pre-LayerNorm improves training stability. Residual connections are used after
    attention and MLP sublayers. The attention is causal when a causal mask is passed
    (so each position can only attend to previous or current positions).

    Args:
      d_model: Hidden size D.
      n_heads: Number of attention heads.

    Input/Output shape: (B, T, D)
    """

    d_model: int
    n_heads: int
    mlp_ratio: int = 4

    @nn.compact
    def __call__(self, x, *, mask=None):
        # Attention sublayer: Pre-LayerNorm -> Self-Attention -> Residual add
        h = nn.LayerNorm()(x)
        h = nn.SelfAttention(
            num_heads=self.n_heads,
            use_bias=False,
        )(h, mask=mask)
        x = x + h  # residual connection

        # MLP sublayer: Pre-LayerNorm -> MLP -> Residual add
        h = nn.LayerNorm()(x)
        h = MLP(self.d_model, mlp_ratio=self.mlp_ratio)(h)
        x = x + h  # residual connection
        return x

class DecoderOnlyTransformer(nn.Module):
    """GPT-style decoder-only Transformer for language modeling.

    Components:
      - Token embeddings: maps token ids to D-dim vectors
      - Learned positional embeddings: adds position information (0..T-1)
      - N stacked decoder blocks with causal self-attention
      - Final LayerNorm
      - Output projection:
          * If tie_weights=True (default), reuse token embedding matrix E to
            compute logits via x @ E^T (implemented via einsum).
          * Else, use a separate linear head to project to V logits.

    Args:
      vocab_size: Vocabulary size V.
      d_model: Hidden size D.
      n_layers: Number of decoder blocks.
      n_heads: Attention heads per block.
      max_len: Maximum supported sequence length for positional embeddings.
    """

    vocab_size: int
    d_model: int
    n_layers: int
    n_heads: int
    max_len: int
    mlp_ratio: int = 4

    def setup(self):
        # Token embedding table E with shape (V, D)
        self.tok_embed = nn.Embed(self.vocab_size, self.d_model)

        # Learned positional embeddings P with shape (max_len, D)
        # We'll slice P[:T] each forward pass and add to token embeddings.
        self.positional_embed = self.param(
            "positional_embed",
            nn.initializers.normal(stddev=0.02),
            (self.max_len, self.d_model)
        )

        # Stack of decoder blocks
        self.blocks = [DecoderBlock(d_model=self.d_model, n_heads=self.n_heads, mlp_ratio=self.mlp_ratio) for _ in range(self.n_layers)]

        # Final LayerNorm before projecting to logits
        self.layerNorm_final = nn.LayerNorm()

        # Optional separate output head if not weight-tying
        self.project_to_vocab = nn.Dense(self.vocab_size, use_bias=False)

    def __call__(self, idx):
        """Forward pass (causal-only).

        Args:
          idx: Token ids of shape (B, T), dtype int32/int64.

        Returns:
          logits: (B, T, V) unnormalized vocabulary scores for next-token prediction.
        """
        B, T = idx.shape

        # Token + positional embeddings -> (B, T, D)
        x = self.tok_embed(idx) + self.positional_embed[:T]

        # Build attention mask: strictly causal (lower-triangular), no padding mask.
        causal = attn.make_causal_mask(jnp.ones((B, T), dtype=bool))
        mask = causal

        # Run the stack of decoder blocks
        for blk in self.blocks:
            x = blk(x, mask=mask)

        # Final LayerNorm before output projection
        x = self.layerNorm_final(x)

        # Output projection to logits over V tokens.
        logits = self.project_to_vocab(x)
        
        return logits

### RoPE + multi-query attention

In [None]:
class RMSNorm(nn.Module):
    dim: int
    eps: float = 1e-8
    @nn.compact
    def __call__(self, x):
        scale = self.param('scale', nn.initializers.ones, (self.dim,))
        rms = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + self.eps)
        return x * (scale / rms)

class SwiGLU(nn.Module):
    d_model: int
    mult: float = 2.667
    dropout: float = 0.1

    @nn.compact
    def __call__(self, x, *, deterministic: bool):
        hidden = int(self.d_model * self.mult)
        u = nn.Dense(hidden, use_bias=False)(x)
        v = nn.Dense(hidden, use_bias=False)(x)
        x = nn.silu(u) * v
        x = nn.Dropout(self.dropout)(x, deterministic=deterministic)
        x = nn.Dense(self.d_model, use_bias=False)(x)
        return x

def rotary_emb(d: int, T: int, base: float = 10000.0, dtype=jnp.float32):
    inv_freq = 1.0 / (base ** (jnp.arange(0, d, 2, dtype=dtype) / d))
    t = jnp.arange(T, dtype=dtype)
    freqs = jnp.einsum('t,f->tf', t, inv_freq)
    cos, sin = jnp.cos(freqs), jnp.sin(freqs)
    return cos[None, None, ...], sin[None, None, ...]

def apply_rope(x, cos, sin):
    x1, x2 = jnp.split(x, 2, axis=-1)
    xc = x1 * cos - x2 * sin
    xs = x1 * sin + x2 * cos
    return jnp.concatenate([xc, xs], axis=-1)


class MQSelfAttention(nn.Module):
    d_model: int
    n_heads: int
    n_kv_heads: int = 1
    dropout: float = 0.1

    @nn.compact
    def __call__(self, x, *, mask=None, deterministic: bool = True):
        H = self.n_heads
        H_kv = self.n_kv_heads
        Dh = self.d_model // H

        q = nn.Dense(self.d_model, use_bias=False)(x)
        k = nn.Dense(H_kv * Dh, use_bias=False)(x)
        v = nn.Dense(H_kv * Dh, use_bias=False)(x)

        B, T, _ = q.shape
        q = q.reshape(B, T, H, Dh).transpose(0, 2, 1, 3)
        k = k.reshape(B, T, H_kv, Dh).transpose(0, 2, 1, 3)
        v = v.reshape(B, T, H_kv, Dh).transpose(0, 2, 1, 3)

        cos, sin = rotary_emb(Dh, T)
        q = apply_rope(q, cos, sin)
        k = apply_rope(k, cos, sin)

        if H_kv != H:
            factor = H // H_kv
            k = jnp.repeat(k, repeats=factor, axis=1)
            v = jnp.repeat(v, repeats=factor, axis=1)

        scale = (1.0 / jnp.sqrt(jnp.array(Dh, dtype=x.dtype)))
        att = jnp.einsum('bhtd,bhTd->bhtT', q, k) * scale

        if mask is not None:
            att = jnp.where(mask, att, jnp.finfo(att.dtype).min)
        att = nn.softmax(att, axis=-1)
        att = nn.Dropout(self.dropout)(att, deterministic=deterministic)

        y = jnp.einsum('bhtT,bhTd->bhtd', att, v)
        y = y.transpose(0, 2, 1, 3).reshape(B, T, self.d_model)
        y = nn.Dense(self.d_model, use_bias=False)(y)
        y = nn.Dropout(self.dropout)(y, deterministic=deterministic)
        return y

class DecoderBlock(nn.Module):
    d_model: int; n_heads: int; n_kv_heads: int = 1
    attn_dropout: float = 0.1
    mlp_dropout: float = 0.1
    resid_scale_init: float = 1e-2

    @nn.compact
    def __call__(self, x, *, mask=None, deterministic: bool = True):
        scale = self.param('res_scale', nn.initializers.constant(self.resid_scale_init), ())

        h = RMSNorm(self.d_model)(x)
        h = MQSelfAttention(self.d_model, self.n_heads, self.n_kv_heads, dropout=self.attn_dropout)(
            h, mask=mask, deterministic=deterministic)
        x = x + scale * h

        h = RMSNorm(self.d_model)(x)
        h = SwiGLU(self.d_model, mult=2.667, dropout=self.mlp_dropout)(h, deterministic=deterministic)
        x = x + scale * h
        return x

class DecoderOnlyTransformer(nn.Module):
    vocab_size: int
    d_model: int
    n_layers: int
    n_heads: int
    n_kv_heads: int = 1
    max_len: int = 2048
    dropout: float = 0.1

    def setup(self):
        self.tok_embed = nn.Embed(self.vocab_size, self.d_model)
        self.blocks = [DecoderBlock(self.d_model, self.n_heads, self.n_kv_heads,
                                    attn_dropout=self.dropout, mlp_dropout=self.dropout)
                       for _ in range(self.n_layers)]
        self.final_norm = RMSNorm(self.d_model)

    def __call__(self, idx, *, deterministic: bool = True, pad_mask: jnp.ndarray | None = None):
        B, T = idx.shape
        x = self.tok_embed(idx)
        x = nn.Dropout(self.dropout)(x, deterministic=deterministic)

        causal = attn.make_causal_mask(jnp.ones((B, T), dtype=bool))
        mask = causal if pad_mask is None else attn.combine_masks(
            causal, attn.make_attention_mask(pad_mask, pad_mask, dtype=bool))

        for blk in self.blocks:
            x = blk(x, mask=mask, deterministic=deterministic)

        x = self.final_norm(x)

        logits = jnp.einsum('btd,vd->btv', x, self.tok_embed.embedding)
        return logits

### Transformer-XL with segment memory

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen import attention as attn

class RMSNorm(nn.Module):
    dim: int
    eps: float = 1e-8
    @nn.compact
    def __call__(self, x):
        g = self.param('scale', nn.initializers.ones, (self.dim,))
        rms = jnp.sqrt(jnp.mean(x * x, axis=-1, keepdims=True) + self.eps)
        return x * (g / rms)

class SwiGLU(nn.Module):
    d_model: int
    mult: float = 2.667
    dropout: float = 0.1
    @nn.compact
    def __call__(self, x, *, deterministic: bool):
        h = int(self.d_model * self.mult)
        u = nn.Dense(h, use_bias=False)(x)
        v = nn.Dense(h, use_bias=False)(x)
        y = nn.silu(u) * v
        y = nn.Dropout(self.dropout)(y, deterministic=deterministic)
        return nn.Dense(self.d_model, use_bias=False)(y)

class XLBlock(nn.Module):
    d_model: int
    n_heads: int
    mlp_mult: float = 2.667
    attn_dropout: float = 0.1
    mlp_dropout: float = 0.1
    resid_scale_init: float = 1e-2
    mem_len: int = 512

    @nn.compact
    def __call__(self, x, *, mem=None, deterministic: bool = True):
        B, T, D = x.shape
        mem = jnp.zeros((B, 0, D), x.dtype) if mem is None else mem
        cat = jnp.concatenate([mem, x], axis=1)
        M = mem.shape[1]

        scale = self.param('res_scale', nn.initializers.constant(self.resid_scale_init), ())

        h = RMSNorm(self.d_model)(x)
        q = nn.Dense(self.d_model, use_bias=False)(h)
        kv = nn.Dense(2 * self.d_model, use_bias=False)(RMSNorm(self.d_model)(cat))
        k, v = jnp.split(kv, 2, axis=-1)

        H = self.n_heads
        Dh = self.d_model // H
        q = q.reshape(B, T, H, Dh).transpose(0, 2, 1, 3)
        k = k.reshape(B, M + T, H, Dh).transpose(0, 2, 1, 3)
        v = v.reshape(B, M + T, H, Dh).transpose(0, 2, 1, 3)

        base = jnp.ones((1, 1, T, M + T), dtype=bool)
        tri = jnp.tril(jnp.ones((T, T), dtype=bool))
        mask = base.at[:, :, :, M:].set(tri[None, None, :, :])

        att_logits = jnp.einsum('bhtd,bhTd->bhtT', q, k) / jnp.sqrt(Dh)
        att_logits = jnp.where(mask, att_logits, jnp.finfo(att_logits.dtype).min)
        att_probs = nn.softmax(att_logits, axis=-1)
        att_probs = nn.Dropout(self.attn_dropout)(att_probs, deterministic=deterministic)
        y = jnp.einsum('bhtT,bhTd->bhtd', att_probs, v)
        y = y.transpose(0, 2, 1, 3).reshape(B, T, D)
        y = nn.Dense(self.d_model, use_bias=False)(y)
        x = x + scale * y

        h = RMSNorm(self.d_model)(x)
        y = SwiGLU(self.d_model, mult=self.mlp_mult, dropout=self.mlp_dropout)(h, deterministic=deterministic)
        x = x + scale * y

        new_mem = jax.lax.stop_gradient(jnp.concatenate([mem, x], axis=1))[:, -(self.mem_len):, :]
        return x, new_mem

class XLMemoryTransformer(nn.Module):
    vocab_size: int
    d_model: int
    n_layers: int
    n_heads: int
    mem_len: int = 512
    dropout: float = 0.1
    mlp_mult: float = 2.667

    def setup(self):
        self.tok_embed = nn.Embed(self.vocab_size, self.d_model)
        self.blocks = [XLBlock(self.d_model, self.n_heads, self.mlp_mult,
                               attn_dropout=self.dropout, mlp_dropout=self.dropout,
                               mem_len=self.mem_len) for _ in range(self.n_layers)]
        self.final_norm = RMSNorm(self.d_model)

    def __call__(self, idx, *, mems=None, deterministic: bool = True):
        B, T = idx.shape
        x = self.tok_embed(idx)
        x = nn.Dropout(self.dropout)(x, deterministic=deterministic)

        if mems is None:
            mems = [None] * self.n_layers
        new_mems = []
        for blk, mem in zip(self.blocks, mems):
            x, mem_out = blk(x, mem=mem, deterministic=deterministic)
            new_mems.append(mem_out)

        x = self.final_norm(x)
        logits = jnp.einsum('btd,vd->btv', x, self.tok_embed.embedding)
        return logits, new_mems


### Transformer with a Switch/Top-1 MoE MLP

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen import attention as attn

class RMSNorm(nn.Module):
    dim: int
    eps: float = 1e-8
    @nn.compact
    def __call__(self, x):
        g = self.param('scale', nn.initializers.ones, (self.dim,))
        rms = jnp.sqrt(jnp.mean(x * x, axis=-1, keepdims=True) + self.eps)
        return x * (g / rms)

class SwitchMoE(nn.Module):
    d_model: int
    mult: float = 2.667
    n_experts: int = 8
    dropout: float = 0.1

    @nn.compact
    def __call__(self, x, *, deterministic: bool):
        B, T, D = x.shape
        H = int(D * self.mult)
        router_w = self.param('router_w', nn.initializers.normal(0.02), (D, self.n_experts))
        logits = jnp.einsum('btd,de->bte', x, router_w)
        gates = nn.softmax(logits, axis=-1)
        top1 = jnp.argmax(logits, axis=-1)

        def expert():
            return nn.Sequential([
                nn.Dense(H, use_bias=False),
                nn.silu,
                nn.Dense(D, use_bias=False),
            ])
        experts = [expert() for _ in range(self.n_experts)]

        out = jnp.zeros_like(x)
        for e, ff in enumerate(experts):
            mask = (top1 == e)[..., None]
            xe = jnp.where(mask, x, 0.0)
            ye = ff(xe)
            ye = nn.Dropout(self.dropout)(ye, deterministic=deterministic)
            out = out + jnp.where(mask, ye, 0.0)

        expert_usage = jnp.mean(gates, axis=(0, 1))         
        aux_loss = jnp.sum((expert_usage - 1.0 / self.n_experts) ** 2)
        return out, aux_loss

class DecoderBlockMoE(nn.Module):
    d_model: int; n_heads: int
    attn_dropout: float = 0.1
    moe_dropout: float = 0.1
    moe_mult: float = 2.667
    n_experts: int = 8
    resid_scale_init: float = 1e-2

    @nn.compact
    def __call__(self, x, *, mask=None, deterministic: bool = True):
        scale = self.param('res_scale', nn.initializers.constant(self.resid_scale_init), ())

        h = RMSNorm(self.d_model)(x)
        h = nn.SelfAttention(
            num_heads=self.n_heads,
            use_bias=False,
            dropout_rate=self.attn_dropout,
        )(h, mask=mask, deterministic=deterministic)
        x = x + scale * h

        h = RMSNorm(self.d_model)(x)
        y, aux = SwitchMoE(
            self.d_model, mult=self.moe_mult,
            n_experts=self.n_experts, dropout=self.moe_dropout
        )(h, deterministic=deterministic)
        x = x + scale * y
        return x, aux

class MoETransformer(nn.Module):
    vocab_size: int
    d_model: int
    n_layers: int
    n_heads: int
    n_experts: int = 8
    dropout: float = 0.1
    moe_mult: float = 2.667

    def setup(self):
        self.tok_embed = nn.Embed(self.vocab_size, self.d_model)
        self.blocks = [
            DecoderBlockMoE(
                self.d_model, self.n_heads,
                attn_dropout=self.dropout,
                moe_dropout=self.dropout,
                moe_mult=self.moe_mult,
                n_experts=self.n_experts
            )
            for _ in range(self.n_layers)
        ]
        self.final_norm = RMSNorm(self.d_model)

    def __call__(self, idx, *, deterministic: bool = True, pad_mask: jnp.ndarray | None = None):
        B, T = idx.shape
        x = self.tok_embed(idx)
        x = nn.Dropout(self.dropout)(x, deterministic=deterministic)

        causal = attn.make_causal_mask(jnp.ones((B, T), dtype=bool))
        mask = causal if pad_mask is None else attn.combine_masks(
            causal, attn.make_attention_mask(pad_mask, pad_mask, dtype=bool)
        )

        aux_total = jnp.array(0.0, dtype=x.dtype)
        for blk in self.blocks:
            x, aux = blk(x, mask=mask, deterministic=deterministic)
            aux_total = aux_total + aux

        x = self.final_norm(x)

        logits = jnp.einsum('btd,vd->btv', x, self.tok_embed.embedding)
        return logits, aux_total

### Relative Position Bias + Local-Global

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen import attention as attn

class RMSNorm(nn.Module):
    dim: int
    eps: float = 1e-8
    @nn.compact
    def __call__(self, x):
        g = self.param('scale', nn.initializers.ones, (self.dim,))
        rms = jnp.sqrt(jnp.mean(x * x, axis=-1, keepdims=True) + self.eps)
        return x * (g / rms)

class SwiGLU(nn.Module):
    d_model: int
    mult: float = 2.667
    dropout: float = 0.1
    @nn.compact
    def __call__(self, x, *, deterministic: bool):
        h = int(self.d_model * self.mult)
        u = nn.Dense(h, use_bias=False)(x)
        v = nn.Dense(h, use_bias=False)(x)
        y = nn.silu(u) * v
        y = nn.Dropout(self.dropout)(y, deterministic=deterministic)
        return nn.Dense(self.d_model, use_bias=False)(y)

def relative_position_bucket(rel_pos, bidirectional: bool, num_buckets=32, max_distance=128):
    n = num_buckets // (2 if bidirectional else 1)
    sign = (rel_pos < 0).astype(jnp.int32) if bidirectional else 0
    rp = jnp.abs(rel_pos)
    max_exact = n // 2
    is_small = rp < max_exact
    val_large = max_exact + (jnp.log(rp / max_exact + 1e-6) / jnp.log(max_distance / max_exact)) * (n - max_exact)
    val_large = val_large.astype(jnp.int32)
    val_large = jnp.minimum(val_large, n - 1)
    buckets = jnp.where(is_small, rp, val_large)
    return buckets + sign * n

class RelPosBias(nn.Module):
    num_buckets: int
    num_heads: int
    max_distance: int
    bidirectional: bool = False
    @nn.compact
    def __call__(self, qlen, klen):
        cp = jnp.arange(qlen)[:, None]
        mp = jnp.arange(klen)[None, :]
        rel = mp - cp
        buckets = relative_position_bucket(rel, self.bidirectional, self.num_buckets, self.max_distance)
        table = self.param('bias', nn.initializers.normal(0.02),
                           (self.num_buckets * (2 if self.bidirectional else 1), self.num_heads))
        bias = table[buckets]
        return bias.transpose(2, 0, 1)[None, ...]

class LocalGlobalSelfAttention(nn.Module):
    d_model: int
    n_heads: int
    window: int = 512
    n_global: int = 8
    attn_dropout: float = 0.1
    num_buckets: int = 32
    max_distance: int = 512

    @nn.compact
    def __call__(self, x, *, deterministic: bool):
        B, T, D = x.shape
        H = self.n_heads
        Dh = D // H

        g_tokens = self.param('globals', nn.initializers.normal(0.02), (self.n_global, D))
        g = jnp.broadcast_to(g_tokens[None, :, :], (B, self.n_global, D))
        cat = jnp.concatenate([g, x], axis=1)  # (B, G+T, D)
        G = self.n_global

        q = nn.Dense(D, use_bias=False)(x).reshape(B, T, H, Dh).transpose(0, 2, 1, 3)
        k = nn.Dense(D, use_bias=False)(cat).reshape(B, G + T, H, Dh).transpose(0, 2, 1, 3)
        v = nn.Dense(D, use_bias=False)(cat).reshape(B, G + T, H, Dh).transpose(0, 2, 1, 3)

        rp_bias = RelPosBias(num_buckets=self.num_buckets, num_heads=H, max_distance=self.max_distance)(T, T)
        zero_g = jnp.zeros((1, H, T, G), dtype=q.dtype)
        rel_bias = jnp.concatenate([zero_g, rp_bias], axis=-1)

        ar = jnp.arange(T)
        dist = ar[None, :] - ar[:, None]
        local = (dist >= -self.window) & (dist <= 0)
        mask_seq = local[None, None, :, :]
        mask = jnp.concatenate([jnp.ones((1,1,T,G), dtype=bool), mask_seq], axis=-1)

        logits = jnp.einsum('bhtd,bhkd->bhtk', q, k) / jnp.sqrt(Dh) + rel_bias
        logits = jnp.where(mask, logits, jnp.finfo(logits.dtype).min)
        att = nn.softmax(logits, axis=-1)
        att = nn.Dropout(self.attn_dropout)(att, deterministic=deterministic)
        y = jnp.einsum('bhtk,bhkd->bhtd', att, v)
        y = y.transpose(0, 2, 1, 3).reshape(B, T, D)
        y = nn.Dense(D, use_bias=False)(y)
        return y

class DecoderBlockLocalGlobal(nn.Module):
    d_model: int; n_heads: int
    window: int = 512; n_global: int = 8
    attn_dropout: float = 0.1; mlp_dropout: float = 0.1
    resid_scale_init: float = 1e-2
    @nn.compact
    def __call__(self, x, *, deterministic: bool = True):
        scale = self.param('res_scale', nn.initializers.constant(self.resid_scale_init), ())
        h = RMSNorm(self.d_model)(x)
        h = LocalGlobalSelfAttention(self.d_model, self.n_heads, self.window, self.n_global,
                                     attn_dropout=self.attn_dropout)(h, deterministic=deterministic)
        x = x + scale * h
        h = RMSNorm(self.d_model)(x)
        h = SwiGLU(self.d_model, mult=2.667, dropout=self.mlp_dropout)(h, deterministic=deterministic)
        x = x + scale * h
        return x

class RelPosLocalGlobalTransformer(nn.Module):
    vocab_size: int
    d_model: int
    n_layers: int
    n_heads: int
    window: int = 512
    n_global: int = 8
    dropout: float = 0.1

    def setup(self):
        self.tok_embed = nn.Embed(self.vocab_size, self.d_model)
        self.blocks = [DecoderBlockLocalGlobal(self.d_model, self.n_heads, self.window, self.n_global,
                                               attn_dropout=self.dropout, mlp_dropout=self.dropout)
                       for _ in range(self.n_layers)]
        self.final_norm = RMSNorm(self.d_model)

    def __call__(self, idx, *, deterministic: bool = True):
        x = self.tok_embed(idx)
        x = nn.Dropout(self.dropout)(x, deterministic=deterministic)
        for blk in self.blocks:
            x = blk(x, deterministic=deterministic)
        x = self.final_norm(x)
        logits = jnp.einsum('btd,vd->btv', x, self.tok_embed.embedding)
        return logits

### standard causal decoder (not so different)

In [None]:
import jax.numpy as jnp
import flax.linen as nn
from flax.linen import attention as attn

class RMSNorm(nn.Module):
    dim: int
    eps: float = 1e-8
    @nn.compact
    def __call__(self, x):
        g = self.param('scale', nn.initializers.ones, (self.dim,))
        rms = jnp.sqrt(jnp.mean(x * x, axis=-1, keepdims=True) + self.eps)
        return x * (g / rms)

class SwiGLU(nn.Module):
    d_model: int
    mult: float = 2.667
    dropout: float = 0.1
    @nn.compact
    def __call__(self, x, *, deterministic: bool):
        h = int(self.d_model * self.mult)
        u = nn.Dense(h, use_bias=False)(x)
        v = nn.Dense(h, use_bias=False)(x)
        y = nn.silu(u) * v
        y = nn.Dropout(self.dropout)(y, deterministic=deterministic)
        return nn.Dense(self.d_model, use_bias=False)(y)

class DecoderBlock(nn.Module):
    d_model: int; n_heads: int
    attn_dropout: float = 0.1; mlp_dropout: float = 0.1
    resid_scale_init: float = 1e-2
    @nn.compact
    def __call__(self, x, *, mask=None, deterministic: bool = True):
        scale = self.param('res_scale', nn.initializers.constant(self.resid_scale_init), ())
        h = RMSNorm(self.d_model)(x)
        h = nn.SelfAttention(num_heads=self.n_heads, use_bias=False, dropout_rate=self.attn_dropout)(
            h, mask=mask, deterministic=deterministic)
        x = x + scale * h
        h = RMSNorm(self.d_model)(x)
        h = SwiGLU(self.d_model, dropout=self.mlp_dropout)(h, deterministic=deterministic)
        x = x + scale * h
        return x

class NgramAuxTransformer(nn.Module):
    vocab_size: int
    d_model: int
    n_layers: int
    n_heads: int
    dropout: float = 0.1

    def setup(self):
        self.tok_embed = nn.Embed(self.vocab_size, self.d_model)
        self.blocks = [DecoderBlock(self.d_model, self.n_heads, self.dropout, self.dropout)
                       for _ in range(self.n_layers)]
        self.final_norm = RMSNorm(self.d_model)

    def __call__(self, idx, *, deterministic: bool = True, pad_mask: jnp.ndarray | None = None):
        B, T = idx.shape
        x = self.tok_embed(idx)
        x = nn.Dropout(self.dropout)(x, deterministic=deterministic)

        causal = attn.make_causal_mask(jnp.ones((B, T), dtype=bool))
        mask = causal if pad_mask is None else attn.combine_masks(
            causal, attn.make_attention_mask(pad_mask, pad_mask, dtype=bool))

        for blk in self.blocks:
            x = blk(x, mask=mask, deterministic=deterministic)

        x = self.final_norm(x)
        E = self.tok_embed.embedding
        logits1 = jnp.einsum('btd,vd->btv', x, E)
        logits2 = logits1
        logits3 = logits1
        return {'logits1': logits1, 'logits2': logits2, 'logits3': logits3}