In [None]:
from unsloth import FastLanguageModel
import torch
import math
from datasets import load_dataset
from tqdm import tqdm
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import gc
from huggingface_hub import hf_hub_download
from msclap import CLAP
import re
import tiktoken
from torch.utils.data import Dataset, DataLoader
from IPython.display import display, HTML
from IPython.display import Audio, display, HTML, clear_output
import time
import pandas as pd
import os

In [None]:
class NormLayer(nn.Module):
    def __init__(self, emb_dim, eps=1e-5):
        super().__init__()
        # These are learnable scales and shifts so the model can
        # undo the normalization if it actually needs a different distribution.
        self.gamma = nn.Parameter(torch.ones(emb_dim))
        self.bias = nn.Parameter(torch.zeros(emb_dim))
        self.eps = eps

    def forward(self, x):
        # Traditional LayerNorm: we zero-center the data and scale it by variance.
        # This keeps gradients from exploding or dying during deep training.
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, unbiased=False, keepdim=True)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * norm_x + self.bias


class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-8):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(emb_dim))
        self.eps = eps

    def forward(self, x):
        # Llama-style normalization. It's faster because we skip the mean
        # centering and only focus on the Root Mean Square.
        # It's basically saying: "just keep the scale consistent."
        norm = x.norm(2, dim=-1, keepdim=True)  # L2 norm
        rms = norm * (1.0 / x.size(-1))**0.5
        return x / (rms + self.eps) * self.weight


class SwiGLU(nn.Module):
    def __init__(self, emb_dim: int, hidden_dim: int):
        """
        SwiGLU is basically a 'gated' linear unit.
        One side (W2) acts as a gate that decides what information
        from the other side (W1) gets to pass through.
        """
        super().__init__()
        self.W1 = nn.Linear(emb_dim, hidden_dim)
        self.W2 = nn.Linear(emb_dim, hidden_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # We split the input into two paths.
        # v goes through SiLU (Swish) to create a non-linear mask.
        u = self.W1(x)
        v = self.W2(x)
        # Element-wise multiplication: u is the signal, SiLU(v) is the gate.
        return u * F.silu(v)


class FFN(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            nn.GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)


class SwiGLU_FFN(nn.Module):
    def __init__(self, emb_dim: int, hidden_dim: int | None = None):
        """
        The Feed-Forward Network (FFN) is the 'knowledge' center of the model.
        We use SwiGLU here because it's more stable and expressive than standard ReLU,
        which is why modern models like LLaMA and PaLM use it.
        """
        super().__init__()

        # If we don't specify a hidden dimension, we follow the 'magic' 8/3 rule.
        if hidden_dim is None:
            hidden_dim = int(8 * emb_dim / 3)

        # SwiGLU handles the non-linear transformation
        self.swiGLU = SwiGLU(emb_dim, hidden_dim)

        # W_out maps the processed data back to our model's original embedding size
        self.W_out = nn.Linear(hidden_dim, emb_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # First, expand and activate the features using SwiGLU
        x = self.swiGLU(x)

        # Then, project them back down to the model dimension
        return self.W_out(x)

In [None]:
class GQA_SWA_Flash(nn.Module):
    """
    Grouped Query Attention (GQA) + Sliding Window Attention (SWA).

    The goal here is speed. GQA saves memory by sharing keys/values across
    multiple query heads, and SWA saves computation by making sure tokens
    don't look too far back into the past.
    """
    def __init__(self,
                 emb_dim,
                 model_dim,
                 max_context_len,
                 drop_rate,
                 heads_num,
                 kv_groups=None,
                 window_size=None,
                 swa_size=None,
                 qkv_bias=False):
        super().__init__()
        self.model_dim = model_dim
        self.heads_num = heads_num                     # total query heads H

        # Safety: ensure model_d is divisible by number of heads
        assert model_dim % heads_num == 0, "model_dim must be divisible by heads_num"

        self.head_dim = model_dim // heads_num           # dim per head
        self.kv_groups = kv_groups or 1                # number of KV groups G
        assert self.heads_num % self.kv_groups == 0, \
            "heads_num must be divisible by kv_groups"
        self.q_in_group = self.heads_num // self.kv_groups  # H/G heads per KV group

        # Linear projections for Query, Key, Value
        self.W_q = nn.Linear(emb_dim, model_dim, bias=qkv_bias)
        # In GQA, K/V have kv_groups groups of head_dim each: G * d_head
        self.W_k = nn.Linear(emb_dim, self.head_dim * self.kv_groups, bias=qkv_bias)
        self.W_v = nn.Linear(emb_dim, self.head_dim * self.kv_groups, bias=qkv_bias)

        self.window_size = window_size or max_context_len   # for KV cache / inference
        self.swa_size = swa_size or max_context_len         # for SWA during training
        self.drop_rate = drop_rate

        # --- KV cache (inference only) ---
        # Will hold: (b, kv_groups, window_size, head_dim) in GQA
        self.register_buffer("cache_k", None, persistent=False)
        self.register_buffer("cache_v", None, persistent=False)

        # Pointer into the ring buffer
        self.cache_ptr = 0
        self.cache_len = 0   # actual number of valid tokens in cache

        self.final_proj = nn.Linear(model_dim, model_dim)

    def reset_cache(self):
        """Call this before starting a new sequence."""
        self.cache_k = None
        self.cache_v = None
        self.cache_ptr = 0
        self.cache_len = 0

    def sliding_window_mask(self, seq_len, device):
        """
        Build SWA mask of shape (T, T):
        True  = masked
        False = allowed
        Token i can attend j if:
          j <= i and (i - j) < swa_size
        """
        i = torch.arange(seq_len, device=device).unsqueeze(1)  # (T, 1)
        j = torch.arange(seq_len, device=device).unsqueeze(0)  # (1, T)
        # Mask positions where j > i (future) or too far in the past (i - j >= swa_size)
        return (j > i) | (i - j >= self.swa_size)

    def forward(self, x, use_cache: bool = False, use_swa: bool = False):
        """
        x:
          - training / no-cache: (b, T, emb_d)
          - inference / cache:
                * first call after reset_cache (prefill): (b, T_prompt, emb_d)
                * subsequent incremental calls:        (b, 1, emb_d)

        Returns:
            (b, T, model_d)
        """
        b_size, seq_len, _ = x.shape

        # ---- Q, K, V projections ----
        Q = self.W_q(x)      # (b, T, H * d_head)
        K_new = self.W_k(x)  # (b, T, G * d_head)
        V_new = self.W_v(x)  # (b, T, G * d_head)

        # ---- reshape Q into (b, H, T, d_head) ----
        Q = Q.view(b_size, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
        # Q: (b, H, T_q, d_head)

        # ---- reshape K,V into (b, G, T, d_head) ----
        K_new = K_new.view(b_size, seq_len, self.kv_groups, self.head_dim).transpose(1, 2)
        V_new = V_new.view(b_size, seq_len, self.kv_groups, self.head_dim).transpose(1, 2)
        # K_new, V_new: (b, G, T_k_new, d_head)

        if use_cache:
            # ================= RING BUFFER KV-CACHE MODE =================

            # Sliding window attention (use_swa) is a training-only feature, not used in cache mode
            assert not use_swa, "use_swa is not supported when use_cache=True"

            # Initialize cache buffers for this batch if needed
            if self.cache_k is None or self.cache_k.size(0) != b_size:
                # Fixed-size ring buffers for KV groups
                # (b, G, window_size, d_head)
                self.cache_k = torch.zeros(
                    b_size, self.kv_groups, self.window_size, self.head_dim,
                    device=x.device, dtype=K_new.dtype
                )
                self.cache_v = torch.zeros_like(self.cache_k)
                self.cache_ptr = 0
                self.cache_len = 0

            if self.cache_len == 0:
                # ---------- PREFILL: FIRST CACHED CALL (FULL PROMPT) ----------
                # We process the entire prompt in one go and build the initial cache.
                # For simplicity and correctness, require prompt length <= window_size.
                assert seq_len <= self.window_size, \
                    "In prefill (first cache) call, seq_len must be <= window_size"

                # Store full prompt K/V into the beginning of the ring buffer
                # (no wrap-around since seq_len <= window_size and cache_ptr == 0)
                insert_len = seq_len
                end = insert_len
                # Use K_new, V_new directly: (b, G, T, d)
                self.cache_k[:, :, :end, :] = K_new
                self.cache_v[:, :, :end, :] = V_new

                self.cache_ptr = end % self.window_size
                self.cache_len = insert_len

                # For attention in prefill, just use K_new/V_new directly
                K = K_new  # (b, G, T_k, d_head)
                V = V_new  # (b, G, T_k, d_head)

                # Q keeps shape (b, H, T_q, d_head), with T_q == seq_len
                T_q = seq_len
                T_k = seq_len  # same as cache_len here

                # ---- expand K/V per query head within each group ----
                # Q: (b, H, T_q, d), we view it as (b, G, H/G, T_q, d)
                Qg = Q.reshape(b_size, self.kv_groups, self.q_in_group, T_q, self.head_dim)
                # K, V: (b, G, T_k, d) -> (b, G, 1, T_k, d) then broadcast over q_in_group
                Kg = K.unsqueeze(2)  # (b, G, 1, T_k, d)
                Vg = V.unsqueeze(2)  # (b, G, 1, T_k, d)

                # Broadcast to (b, G, q_in_group, T_k, d)
                Kg = Kg.expand(b_size, self.kv_groups, self.q_in_group, T_k, self.head_dim)
                Vg = Vg.expand(b_size, self.kv_groups, self.q_in_group, T_k, self.head_dim)

                # Merge groups and heads back: (b, H, T_q, d), (b, H, T_k, d)
                Q_flat = Qg.contiguous().view(b_size, self.heads_num, T_q, self.head_dim)
                K_flat = Kg.contiguous().view(b_size, self.heads_num, T_k, self.head_dim)
                V_flat = Vg.contiguous().view(b_size, self.heads_num, T_k, self.head_dim)

                # Prefill: full causal attention over the prompt
                att = F.scaled_dot_product_attention(
                    Q_flat,      # (b, H, T_q, d)
                    K_flat,      # (b, H, T_k, d)
                    V_flat,      # (b, H, T_k, d)
                    attn_mask=None,
                    dropout_p=0.0,
                    is_causal=True
                )

            else:
                # ---------- INCREMENTAL DECODE: SUBSEQUENT CACHED CALLS ----------
                # After prefill, we assume one-token-at-a-time decoding for strict causality.
                assert seq_len == 1, \
                    "After prefill, cache mode expects seq_len == 1 for incremental decoding"

                # Store new token(s) into ring buffer along time dim=2
                insert_len = min(seq_len, self.window_size)
                end = self.cache_ptr + insert_len
                # Use only the last insert_len timesteps from K_new/V_new
                K_slice = K_new[:, :, -insert_len:, :]  # (b, G, insert_len, d)
                V_slice = V_new[:, :, -insert_len:, :]

                if end <= self.window_size:
                    # Straight write
                    self.cache_k[:, :, self.cache_ptr:end, :] = K_slice
                    self.cache_v[:, :, self.cache_ptr:end, :] = V_slice
                else:
                    # Wrap-around case: split write
                    first = self.window_size - self.cache_ptr
                    self.cache_k[:, :, self.cache_ptr:, :] = K_slice[:, :, :first, :]
                    self.cache_k[:, :, :end - self.window_size, :] = K_slice[:, :, first:, :]

                    self.cache_v[:, :, self.cache_ptr:, :] = V_slice[:, :, :first, :]
                    self.cache_v[:, :, :end - self.window_size, :] = V_slice[:, :, first:, :]

                self.cache_ptr = (self.cache_ptr + insert_len) % self.window_size
                self.cache_len = min(self.cache_len + insert_len, self.window_size)

                # Reconstruct ordered K/V: (b, G, T_k, d_head)
                if self.cache_len < self.window_size:
                    K = self.cache_k[:, :, :self.cache_len, :]
                    V = self.cache_v[:, :, :self.cache_len, :]
                else:
                    K = torch.cat(
                        (self.cache_k[:, :, self.cache_ptr:, :],
                         self.cache_k[:, :, :self.cache_ptr, :]),
                        dim=2
                    )
                    V = torch.cat(
                        (self.cache_v[:, :, self.cache_ptr:, :],
                         self.cache_v[:, :, :self.cache_ptr, :]),
                        dim=2
                    )
                # K, V: (b, G, T_k, d_head)

                T_q = seq_len              # 1
                T_k = K.size(2)            # cache_len

                # ---- expand K/V per query head within each group ----
                # Q: (b, H, T_q, d), we view it as (b, G, H/G, T_q, d)
                # Use reshape here to be safe with non-contiguous Q
                Qg = Q.reshape(b_size, self.kv_groups, self.q_in_group, T_q, self.head_dim)
                # K, V: (b, G, T_k, d) -> (b, G, 1, T_k, d) then broadcast over q_in_group
                Kg = K.unsqueeze(2)  # (b, G, 1, T_k, d)
                Vg = V.unsqueeze(2)  # (b, G, 1, T_k, d)

                # Broadcast to (b, G, q_in_group, T_k, d)
                Kg = Kg.expand(b_size, self.kv_groups, self.q_in_group, T_k, self.head_dim)
                Vg = Vg.expand(b_size, self.kv_groups, self.q_in_group, T_k, self.head_dim)

                # Merge groups and heads back: (b, H, T_q, d), (b, H, T_k, d)
                Q_flat = Qg.contiguous().view(b_size, self.heads_num, T_q, self.head_dim)
                K_flat = Kg.contiguous().view(b_size, self.heads_num, T_k, self.head_dim)
                V_flat = Vg.contiguous().view(b_size, self.heads_num, T_k, self.head_dim)

                # FlashAttention path (no explicit causal mask; cache only holds past + current token)
                att = F.scaled_dot_product_attention(
                    Q_flat,      # (b, H, T_q, d)
                    K_flat,      # (b, H, T_k, d)
                    V_flat,      # (b, H, T_k, d)
                    attn_mask=None,
                    dropout_p=0.0,
                    is_causal=False
                )

        else:
            # ================= NORMAL TRAINING MODE =================
            # Full sequence attention with causal masking handled internally

            # Here we don't need to actually store cache; just use K_new, V_new.
            # K_new, V_new: (b, G, T, d)
            # Expand them per query head similarly as above:

            # Use reshape here as well to avoid view-on-transposed issues
            Qg = Q.reshape(b_size, self.kv_groups, self.q_in_group, seq_len, self.head_dim)
            Kg = K_new.unsqueeze(2)  # (b, G, 1, T, d)
            Vg = V_new.unsqueeze(2)  # (b, G, 1, T, d)

            Kg = Kg.expand(b_size, self.kv_groups, self.q_in_group, seq_len, self.head_dim)
            Vg = Vg.expand(b_size, self.kv_groups, self.q_in_group, seq_len, self.head_dim)

            Q_flat = Qg.contiguous().view(b_size, self.heads_num, seq_len, self.head_dim)
            K_flat = Kg.contiguous().view(b_size, self.heads_num, seq_len, self.head_dim)
            V_flat = Vg.contiguous().view(b_size, self.heads_num, seq_len, self.head_dim)

            attn_mask = self.sliding_window_mask(seq_len, device=Q_flat.device) if use_swa else None

            att = F.scaled_dot_product_attention(
                Q_flat,      # (b, H, T, d)
                K_flat,      # (b, H, T, d)
                V_flat,      # (b, H, T, d)
                attn_mask=attn_mask,
                dropout_p=self.drop_rate if self.training else 0.0,
                is_causal=False if use_swa else True
            )

        # Merge heads: (b, H, T, d) ‚Üí (b, T, H*d) = (b, T, model_d)
        out = att.transpose(1, 2).reshape(b_size, seq_len, self.model_dim)
        return self.final_proj(out)


class HDynMoF(nn.Module):
    """
    Adaptive hierarchical sparse MoE with:
      - Group-level top-p routing (adaptive #groups per token)
      - Expert-level top-p routing within each active group (adaptive #experts)
      - True sparse execution: each expert sees only its routed tokens.

    cfg keys:
      emb_dim              : int, embedding dim
      e_num                : int, total experts across all groups
      moe_groups           : int, number of groups
      moe_group_top_p      : float, group-level top-p threshold per token
      moe_max_groups       : int, max groups per token
      moe_top_p            : float, expert-level top-p threshold per token
      moe_max_k            : int, max experts per group per token
    """

    def __init__(self, cfg: dict):
        super().__init__()

        emb_dim = int(cfg["emb_dim"])

        # ---- Core MoE config ----
        self.e_num = int(cfg.get("e_num", 16))
        self.num_groups = int(cfg.get("moe_groups", 4))

        # Expert routing config
        self.top_p = float(cfg.get("moe_top_p", 0.9))
        self.max_k_per_grp = int(cfg.get("moe_max_k", 2))

        # Group routing config
        self.group_top_p = float(cfg.get("moe_group_top_p", 0.9))
        self.max_groups_per_token = int(cfg.get("moe_max_groups", self.num_groups))

        # ---- Sanity checks ----
        assert self.e_num >= 1, "e_num must be >= 1"
        assert self.num_groups >= 1, "moe_groups must be >= 1"
        assert 0.0 < self.top_p <= 1.0, "moe_top_p must be in (0, 1]"
        assert 0.0 < self.group_top_p <= 1.0, "moe_group_top_p must be in (0, 1]"
        assert self.e_num % self.num_groups == 0, "e_num must be divisible by moe_groups"

        self.exp_per_group = self.e_num // self.num_groups

        assert 1 <= self.max_k_per_grp <= self.exp_per_group, \
            f"moe_max_k must be in [1, {self.exp_per_group}], got {self.max_k_per_grp}"

        assert 1 <= self.max_groups_per_token <= self.num_groups, \
            f"moe_max_groups must be in [1, {self.num_groups}], got {self.max_groups_per_token}"

        # ---- Experts ----
        # experts[g][e] = expert e in group g
        self.experts = nn.ModuleList([
            nn.ModuleList([
                SwiGLU_FFN(emb_dim, 4 * emb_dim)
                for _ in range(self.exp_per_group)
            ])
            for _ in range(self.num_groups)
        ])

        # ---- Gating over experts (per group) ----
        self.gates = nn.ModuleList([
            nn.Linear(emb_dim, self.exp_per_group)
            for _ in range(self.num_groups)
        ])

        # ---- Group router ----
        self.group_router = nn.Linear(emb_dim, self.num_groups)

        # Variance scaling across groups
        self.group_scale = 1.0 / math.sqrt(self.num_groups)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, D)
        returns: (B, T, D)
        """
        B, T, D = x.shape
        device = x.device
        dtype = x.dtype

        # Flatten tokens: N = B * T
        x_flat = x.reshape(-1, D)              # (N, D)
        N = x_flat.size(0)

        # Output accumulator
        out_flat = x_flat.new_zeros((N, D))    # (N, D)

        # ============================================================
        # 1) Group-level routing (vectorized)
        # ============================================================
        # (B, T, G)
        group_logits = self.group_router(x)
        group_probs = F.softmax(group_logits, dim=-1)

        # sort along groups dim
        group_sorted_probs, group_sorted_idx = group_probs.sort(
            dim=-1, descending=True
        )  # (B, T, G)

        G = self.num_groups
        Kg = min(self.max_groups_per_token, G)

        # top-Kg candidates per token
        group_probs_k = group_sorted_probs[..., :Kg]        # (B, T, Kg)
        group_idx_k = group_sorted_idx[..., :Kg]            # (B, T, Kg)

        # cumulative prob for top-p
        group_cum = group_probs_k.cumsum(dim=-1)            # (B, T, Kg)
        group_active_mask = (group_cum <= self.group_top_p) # (B, T, Kg)
        group_active_mask[..., 0] = True                    # always keep best group

        # flatten to (N, Kg)
        group_probs_k_flat = group_probs_k.reshape(N, Kg)           # (N, Kg)
        group_idx_k_flat = group_idx_k.reshape(N, Kg)               # (N, Kg)
        group_active_flat = group_active_mask.reshape(N, Kg)        # (N, Kg)

        # masked + renormalize
        group_masked_probs = group_probs_k_flat * group_active_flat.to(dtype)
        group_sum_active = group_masked_probs.sum(dim=-1, keepdim=True) + 1e-9
        group_renorm_probs = group_masked_probs / group_sum_active             # (N, Kg)

        # ============================================================
        # 2) Group-wise expert routing (sparse, but GPU-friendly)
        # ============================================================
        # We keep the outer loop over groups (usually small: 4‚Äì8).
        for g in range(self.num_groups):
            gate_g = self.gates[g]
            experts_g = self.experts[g]

            # tokens & slots where group g is chosen
            # mask_group: (N, Kg)
            mask_group = (group_idx_k_flat == g) & group_active_flat
            if not mask_group.any():
                continue

            # indices of active (token, slot) for this group
            active_pos = mask_group.nonzero(as_tuple=False)    # (M_g, 2)
            if active_pos.numel() == 0:
                continue

            token_idx_g = active_pos[:, 0]          # (M_g,)
            group_slot_idx_g = active_pos[:, 1]     # (M_g,)

            # tokens that use this group (with repetition if same token picked
            # group g via multiple slots ‚Äî rare, but we handle it)
            x_g_flat = x_flat[token_idx_g]          # (M_g, D)
            group_weight_g = group_renorm_probs[token_idx_g, group_slot_idx_g]  # (M_g,)

            M_g = x_g_flat.size(0)
            if M_g == 0:
                continue

            # --------------------------------------------------------
            # Expert-level routing for group g (vectorized over tokens)
            # --------------------------------------------------------
            # (M_g, E_g)
            gate_logits_g = gate_g(x_g_flat)
            probs_g = F.softmax(gate_logits_g, dim=-1)

            # sort experts by prob
            sorted_probs, sorted_idx = probs_g.sort(dim=-1, descending=True)  # (M_g, E_g)

            K = min(self.max_k_per_grp, self.exp_per_group)
            sorted_probs_k = sorted_probs[:, :K]      # (M_g, K)
            sorted_idx_k = sorted_idx[:, :K]          # (M_g, K)

            # top-p within group
            cum_probs = sorted_probs_k.cumsum(dim=-1)         # (M_g, K)
            active_mask = (cum_probs <= self.top_p)           # (M_g, K)
            active_mask[:, 0] = True                          # always keep best

            # renormalize only over active experts
            masked_probs = sorted_probs_k * active_mask.to(dtype)     # (M_g, K)
            sum_active = masked_probs.sum(dim=-1, keepdim=True) + 1e-9
            renorm_probs = masked_probs / sum_active                  # (M_g, K)

            # --------------------------------------------------------
            # Combine group + expert routing into a single mapping:
            #   for each active (token_local, expert_local)
            #   we know: global_token_idx, expert_local_idx, weight
            # --------------------------------------------------------
            active_exp_pos = active_mask.nonzero(as_tuple=False)          # (M_ge, 2)
            if active_exp_pos.numel() == 0:
                continue

            token_local_all = active_exp_pos[:, 0]    # (M_ge,)
            slot_all = active_exp_pos[:, 1]           # (M_ge,)

            # Map local token indices back to global [0, N)
            global_token_all = token_idx_g[token_local_all]               # (M_ge,)

            # Which local expert each (token, slot) chose
            expert_local_all = sorted_idx_k[token_local_all, slot_all]    # (M_ge,)

            # expert-level probs p_{g,e}(x)
            p_all = renorm_probs[token_local_all, slot_all]               # (M_ge,)
            # group-level probs q_g(x)
            q_all = group_weight_g[token_local_all]                       # (M_ge,)

            # Total mixture weights
            total_weight_all = (q_all * p_all).to(dtype)                  # (M_ge,)

            # --------------------------------------------------------
            # Sparse expert execution:
            #   we still loop over experts in this group,
            #   but everything else is precomputed and vectorized.
            # --------------------------------------------------------
            # To avoid repeated comparisons, we can pre-sort by expert id
            # and then process experts by slices.
            # (Optional but nice: reduces number of boolean masks.)
            sort_by_expert = torch.argsort(expert_local_all)
            expert_local_sorted = expert_local_all[sort_by_expert]
            global_token_sorted = global_token_all[sort_by_expert]
            weight_sorted = total_weight_all[sort_by_expert]

            # Find boundaries where expert id changes
            # unique_experts: (U,)
            # counts: (U,)
            unique_experts, counts = torch.unique_consecutive(
                expert_local_sorted, return_counts=True
            )

            # prefix sums for slicing
            offsets = counts.cumsum(dim=0)
            starts = torch.cat([
                offsets.new_zeros((1,)),
                offsets[:-1]
            ], dim=0)  # (U,)

            # iterate only over actually used experts (unique_experts)
            for idx_u, e_local in enumerate(unique_experts.tolist()):
                start = int(starts[idx_u].item())
                end = int(offsets[idx_u].item())
                if end <= start:
                    continue

                # slice for this expert
                token_slice = global_token_sorted[start:end]   # (M_e,)
                w_slice = weight_sorted[start:end]             # (M_e,)

                x_e = x_flat[token_slice]                      # (M_e, D)
                y_e = experts_g[e_local](x_e)                  # (M_e, D)

                w_e = (w_slice * self.group_scale).unsqueeze(-1)  # (M_e, 1)
                out_flat[token_slice] += w_e * y_e

        # back to (B, T, D)
        out = out_flat.reshape(B, T, D)
        return out

In [None]:
class TransformerBlock_GQA_SWA(nn.Module):
    """
    The TransformerBlock is the fundamental computational unit of the model.
    It is designed to be highly flexible, supporting three distinct 'wiring' modes:

    1. Classic Mode: Sequential processing where attention is followed by the FFN.
    2. Parallel Mode: Attention and FFN run simultaneously for faster computation.
    3. Dual-Stream Mode: A custom experimental architecture that separates global
       context (Attention) from local expertise (MoE) into two parallel paths.
    """
    def __init__(self, cfg):
        super().__init__()

        self.use_parallel_att = cfg.get("use_parallel_att", False)
        self.use_RMSNorm      = cfg.get("use_RMSNorm", False)
        self.use_adaptive_moe = cfg.get("use_adaptive_moe", False)

        # NEW: dual global/local stream mode (push global/local all the way)
        # When True, we keep separate global (attn) and local (MoE) streams
        # through all layers and only combine at the very end of the model.
        self.use_dual_stream  = cfg.get("use_dual_stream", False)

        self.att = GQA_SWA_Flash(
            emb_dim=cfg["emb_dim"],
            model_dim=cfg["emb_dim"],
            max_context_len=cfg["context_length"],
            drop_rate=cfg["drop_rate"],
            heads_num=cfg["n_heads"],
            kv_groups=cfg["kv_groups"],
            swa_size=cfg["swa_size"],
            qkv_bias=cfg["qkv_bias"],
        )

        # --- MoE depend on mode ---
        MoE_FFN = HDynMoF if self.use_adaptive_moe else FFN
        self.ff = MoE_FFN(cfg)
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

        # --- Norms depend on mode ---
        Normalization = RMSNorm if self.use_RMSNorm else NormLayer

        if self.use_dual_stream:
            # Dual-stream: separate norms for global and local paths
            #  - global stream goes through attention only
            #  - local stream goes through MoE/FFN only
            self.ln_global = Normalization(cfg["emb_dim"])
            self.ln_local  = Normalization(cfg["emb_dim"])
            # We keep a ‚àö2 scale handy if we ever want to scale inside the block
            self.res_scale = 1.0 / math.sqrt(2.0)
            self.enhance_global_ctx = nn.GELU()

        elif self.use_parallel_att:
            # Single norm for both att and ff paths (parallel residual)
            self.ln = Normalization(cfg["emb_dim"])
            self.res_scale = 1.0 / math.sqrt(2.0)
        else:
            # Classic pre-norm: separate ln1/ln2
            self.ln1 = Normalization(cfg["emb_dim"])
            self.ln2 = Normalization(cfg["emb_dim"])

    def forward(self, x, use_cache: bool = False, use_swa: bool = False):
        if self.use_dual_stream:
            # Dual-stream mode:
            # x is a tuple: (x_global, x_local)
            x_global, x_local = x

            # Global stream: normalized then passed through attention (global / prefix semantics)
            y_global = self.ln_global(x_global)
            att_out  = self.att(y_global, use_cache=use_cache, use_swa=use_swa)

            # Local stream: normalized then passed through MoE/FFN (local token-wise expertise)
            y_local = self.ln_local(x_local)
            ff_out  = self.ff(y_local)

            # Residual updates are kept separate for the two streams
            enhanced_att = self.enhance_global_ctx(att_out)
            x_global = x_global + self.drop_shortcut(enhanced_att)
            x_local  = x_local  + self.drop_shortcut(ff_out)

            return (x_global, x_local)

        if self.use_parallel_att:
            # Parallel residual: x + (att(ln(x)) + ff(ln(x))) / ‚àö2
            y = self.ln(x)
            att_out = self.att(y, use_cache=use_cache, use_swa=use_swa)
            ff_out  = self.ff(y)
            z = att_out + ff_out
            z = self.res_scale * z
            x = x + self.drop_shortcut(z)
            return x
        else:
            # Classic: x + Att(ln1(x)) then x + FF(ln2(x))
            y = self.ln1(x)
            x = x + self.drop_shortcut(self.att(y, use_cache=use_cache, use_swa=use_swa))

            y2 = self.ln2(x)
            x = x + self.drop_shortcut(self.ff(y2))

            return x


class MyGPT_GQA_SWA(nn.Module):
    """
    Top-level Transformer orchestrator implementing a decoupled Dual-Stream
    architecture with Grouped-Query Attention (GQA) and Sliding Window Attention (SWA).

    Key Features:
    - Dual-Stream Logic: Separates global context from local processing.
    - Inference Optimization: Integrated KV-cache management for O(1) decoding.
    - MODERN Components: Support for RMSNorm, SwiGLU, and Flash-based attention.
    - Multimodal Injection: Supports CLAP audio features via a linear projection.
    """
    def __init__(self, cfg):
        super().__init__()

        self.use_RMSNorm = cfg.get("use_RMSNorm", False)
        # NEW: propagate dual-stream flag to model level
        self.use_dual_stream = cfg.get("use_dual_stream", False)

        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        # NEW: Audio Projection Layer
        # Maps CLAP 1024 features to the model's internal embedding dimension (e.g., 1024)
        # self.audio_proj = nn.Linear(1024, cfg["emb_dim"])
        self.audio_proj = SwiGLU_FFN(1024);
        # Use ModuleList so we can pass use_cache through each block
        self.trm_blocks = nn.ModuleList(
            [TransformerBlock_GQA_SWA(cfg) for _ in range(cfg["n_layers"])]
        )

        # --- Norms depend on mode ---
        Normalization = RMSNorm if self.use_RMSNorm else NormLayer
        self.final_norm = Normalization(cfg["emb_dim"])

        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

        # Track current position in cached sequence for inference
        # (how many tokens have been seen so far in cache mode)
        self.cache_pos = 0

        # For final combination of global/local streams (when use_dual_stream=True)
        self._dual_res_scale = 1.0 / math.sqrt(2.0)

    def reset_cache(self):
        """Resets the KV cache in all attention modules within the transformer blocks."""
        for block in self.trm_blocks:
            block.att.reset_cache()
        # Also reset our positional counter
        self.cache_pos = 0

    def forward(self, x, audio_features=None, use_cache: bool = False, use_swa: bool = False):
        batch_size, seq_len = x.shape
        device = x.device

        # 1. POSITIONAL LOGIC
        if use_cache:
            # In cache mode, we may be doing:
            #   - prefill: first call after reset_cache (seq_len = T_prompt)
            #   - incremental decode: subsequent calls (seq_len = 1)
            assert self.cache_pos + seq_len <= self.pos_emb.num_embeddings, (
                f"Total length {self.cache_pos + seq_len} exceeds context length "
                f"{self.pos_emb.num_embeddings}"
            )

            pos_start = self.cache_pos
            pos_ids = torch.arange(pos_start, pos_start + seq_len, device=device)
            self.cache_pos += seq_len
        else:
            # Training / no cache: positions always start at 0
            assert seq_len <= self.pos_emb.num_embeddings, \
                f"Sequence length {seq_len} exceeds context length {self.pos_emb.num_embeddings}"
            pos_ids = torch.arange(seq_len, device=device)

        # 2. EMBEDDING FUSION
        # Get standard text embeddings
        x = self.tok_emb(x) + self.pos_emb(pos_ids)

        # NEW: Multimodal Injection Logic
        if audio_features is not None:
            # Project CLAP features to embedding dimension [Batch, 512] -> [Batch, 1, Emb_dim]
            audio_x = self.audio_proj(audio_features).unsqueeze(1)

            # CONCAT: Audio becomes the 'prefix' token (Token 0)
            # This allows all text tokens to attend to the audio context via GQA
            x = torch.cat([audio_x, x], dim=1)

        x = self.drop_emb(x)

        # 3. DUAL-STREAM PROCESSING
        if self.use_dual_stream:
            # NEW: dual global/local streams
            #   - x_global: updated only by attention (global prefix semantics)
            #   - x_local:  updated only by MoE/FFN (local token-wise expertise)
            x_global = x
            x_local  = x

            state = (x_global, x_local)
            # Pass use_cache through every TransformerBlockGQA
            for block in self.trm_blocks:
                state = block(state, use_cache=use_cache, use_swa=use_swa)

            x_global, x_local = state

            # Only now combine global + local once, before final norm + classifier
            x = (x_global + x_local) * self._dual_res_scale
        else:
            # Pass use_cache through every TransformerBlockGQA
            for block in self.trm_blocks:
                x = block(x, use_cache=use_cache, use_swa=use_swa)

        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits


def load_standard_baseline(model, sd_hf):
    """
    Multimodal-aware surgical loader.
    Transplants GPT-2 intelligence while leaving the Audio Bridge for training.
    """
    print("ü©π Initializing Multimodal-Aware Baseline weight mapping...")

    with torch.no_grad():
        # 1. GLOBAL EMBEDDINGS
        # GPT-2 Medium uses 50257. If your model uses more (special tokens),
        # we only copy the overlapping pre-trained weights.
        hf_wte = sd_hf['wte.weight']
        model.tok_emb.weight[:hf_wte.size(0)].copy_(hf_wte)
        model.pos_emb.weight.copy_(sd_hf['wpe.weight'])

        # 2. TRANSFORMER BLOCKS (The "Brain")
        for i, block in enumerate(model.trm_blocks):
            prefix = f'h.{i}.'

            # Normalization (NormLayer: gamma/bias)
            block.ln1.gamma.copy_(sd_hf[f'{prefix}ln_1.weight'])
            block.ln1.bias.copy_(sd_hf[f'{prefix}ln_1.bias'])
            block.ln2.gamma.copy_(sd_hf[f'{prefix}ln_2.weight'])
            block.ln2.bias.copy_(sd_hf[f'{prefix}ln_2.bias'])

            # Attention (GQA/SWA mapping)
            qkv_w = sd_hf[f'{prefix}attn.c_attn.weight'].t()
            qkv_b = sd_hf[f'{prefix}attn.c_attn.bias']
            qw, kw, vw = qkv_w.chunk(3, dim=0)
            qb, kb, vb = qkv_b.chunk(3, dim=0)

            block.att.W_q.weight.copy_(qw); block.att.W_q.bias.copy_(qb)
            block.att.W_k.weight.copy_(kw); block.att.W_k.bias.copy_(kb)
            block.att.W_v.weight.copy_(vw); block.att.W_v.bias.copy_(vb)

            block.att.final_proj.weight.copy_(sd_hf[f'{prefix}attn.c_proj.weight'].t())
            block.att.final_proj.bias.copy_(sd_hf[f'{prefix}attn.c_proj.bias'])

            # Feed-Forward
            block.ff.layers[0].weight.copy_(sd_hf[f'{prefix}mlp.c_fc.weight'].t())
            block.ff.layers[0].bias.copy_(sd_hf[f'{prefix}mlp.c_fc.bias'])
            block.ff.layers[2].weight.copy_(sd_hf[f'{prefix}mlp.c_proj.weight'].t())
            block.ff.layers[2].bias.copy_(sd_hf[f'{prefix}mlp.c_proj.bias'])

        # 3. FINAL HEAD & NORM
        model.final_norm.gamma.copy_(sd_hf['ln_f.weight'])
        model.final_norm.bias.copy_(sd_hf['ln_f.bias'])

        # Tie weights for the head if using standard vocabulary
        # If model.out_head is bigger than hf_wte, we only copy the intersection
        model.out_head.weight[:hf_wte.size(0)].copy_(hf_wte)

        # 4. THE AUDIO BRIDGE (Crucial!)
        # We do NOT load weights for model.audio_proj here.
        # It remains randomly initialized so it can learn during your 14-hour marathon.
        print("‚ÑπÔ∏è  Note: 'audio_proj' remains randomly initialized for training.")

    print("‚úÖ Multimodal Baseline loaded. The 'Brain' is pre-trained, the 'Ears' are ready to learn.")

In [None]:
def text_to_token_ids(text, tokenizer):
    """Encodes text while explicitly allowing the <|endoftext|> stop signal."""
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    return torch.tensor(encoded).unsqueeze(0)

def top_k_top_p_filtering(logits, top_k=None, top_p=0.9):
    """
    Corrected filtering that avoids the 'Index tensor' dimensionality error.
    """
    B, V = logits.shape

    # 1. Top-K Filtering
    if top_k is not None and 0 < top_k < V:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits.masked_fill_(indices_to_remove, float('-inf'))

    # 2. Top-P (Nucleus) Filtering
    if top_p is not None and 0.0 < top_p < 1.0:
        # Sort logits in descending order
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)

        # Calculate cumulative probabilities
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Create the mask on the SORTED tensor (stays 2D)
        sorted_indices_to_remove = cumulative_probs > top_p

        # Shift mask to ensure we keep at least the first token above threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = False

        # Apply the mask to the sorted logits
        sorted_logits.masked_fill_(sorted_indices_to_remove, float('-inf'))

        # RE-SCATTER: Put the masked values back into their original vocabulary positions
        # This ensures the dimensionality [B, V] is preserved perfectly
        logits = torch.full_like(logits, float('-inf')).scatter_(dim=-1, index=sorted_indices, src=sorted_logits)

    return logits

@torch.no_grad()
def generate_multimodal(
  model, idx, audio_features, max_new_tokens, context_size,
  use_cache=True, temperature=0.2, top_k=None, top_p=0.9,
  repetition_penalty=1.1,
  eos_id=50256
):
  model.eval()
  if use_cache:
      model.reset_cache()

  generated_tokens = [] # Tracks history to apply penalty

  for step in range(max_new_tokens):
      # 1. KV-Cache Optimization:
      # Inject audio ONLY on Step 0. Afterwards, it's stored in memory.
      if use_cache:
          if step == 0:
              idx_cond = idx[:, -context_size:]
              audio_cond = audio_features
          else:
              idx_cond = idx[:, -1:] # Only process the very last word
              audio_cond = None      # Audio is already cached
      else:
          idx_cond = idx[:, -context_size:]
          audio_cond = audio_features

      # 2. Forward Pass
      logits = model(idx_cond, audio_features=audio_cond, use_cache=use_cache)[:, -1, :]

      # 3. Apply Repetition Penalty (The "Loop Killer")
      for token in set(generated_tokens):
          if logits[0, token] > 0:
              logits[0, token] /= repetition_penalty
          else:
              logits[0, token] *= repetition_penalty

      # 4. Sampling
      if temperature <= 0.0:
          next_token_id = torch.argmax(logits, dim=-1, keepdim=True)
      else:
          logits = logits / temperature
          logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
          probs = torch.softmax(logits, dim=-1)
          next_token_id = torch.multinomial(probs, num_samples=1)

      # 5. Append and Check Stop Signal
      generated_tokens.append(next_token_id.item())
      idx = torch.cat((idx, next_token_id), dim=-1)

      if next_token_id.item() == eos_id:
          break

  return idx

def generate_multimodal_sample(model, tokenizer, device, question, audio_key, features_dict):
    model.eval()
    context_size = model.pos_emb.weight.shape[0]

    # Format matches your JSON training data exactly
    prompt = (
        f"Below is an instruction describing an audio task. "
        f"Respond appropriately using the provided audio input.\n\n"
        f"### Instruction:\n{question}\n\n"
        f"### Response:\n"
    )

    encoded = text_to_token_ids(prompt, tokenizer).to(device)
    audio_vector = features_dict[audio_key].to(device)

    # Ensure we use the correct ID for Tiktoken
    tiktoken_eos_id = tokenizer.encode("<|endoftext|>", allowed_special={'<|endoftext|>'})[0]

    with torch.no_grad():
        token_ids = generate_multimodal(
            model=model,
            idx=encoded,
            audio_features=audio_vector,
            max_new_tokens=200,
            context_size=context_size,
            temperature=0.2,         # Adds variety
            top_p=0.9,               # Logical filtering
            repetition_penalty=1.1,  # Force the model to speak English
            eos_id=tiktoken_eos_id
        )

    decoded_text = tokenizer.decode(token_ids[0].tolist())
    print(f"\n--- INFERENCE RESULT ({audio_key}) ---\n")
    print(decoded_text)
    model.train()

In [None]:
class AudioAnalyzerAssistant:
    """
    The High-Level Inference Manager for the Auditory Reasoning System.

    Key Responsibilities:
    - Hardware Abstraction: Dynamically detects and orchestrates inference on available GPU accelerators (CUDA) or seamlessly falls back to CPU.
    - Multimodal Bridge: Manages the 'Ears' (CLAP) and 'Brain' (Custom GPT-2) to ensure synchronized feature injection.
    - Logic Enforcer: Applies the 'Auditory Analyst' system prompt to force rigorous Chain-of-Thought generation.
    - System Extensibility: Acts as the primary auditory reasoning module within the larger Visual Analyzer Assistant framework.
    """
    def __init__(self, model_path, tokenizer, CHOSEN_MODEL="gpt2-medium (355M)"):
        # 1. Device selection (Auto-detect GPU or CPU)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # 2. EXACT CLAP from your training (The "Ears")
        # Ensure 'version' matches what you used in the data script
        self.clap_model = CLAP(version='2023', use_cuda=torch.cuda.is_available())
        self.tokenizer = tokenizer

        # 3. Backbone Setup (The "Brain")
        gpt_cfg = BASE_CONFIG.copy()
        gpt_cfg.update(model_configs[CHOSEN_MODEL])
        # Force these to False to match the 'Surgical Loader' logic
        gpt_cfg.update({
            "use_dual_stream": False,
            "use_adaptive_moe": False,
            "use_parallel_att": False,
            "kv_groups": gpt_cfg["n_heads"],
            "drop_rate": 0.0 # CRITICAL: No dropout during inference
        })
        self.gpt_model = MyGPT_GQA_SWA(gpt_cfg).to(self.device)

        # 4. Load Milestone Weights
        print(f"üîÑ Loading weights from: {os.path.basename(model_path)}...")
        checkpoint = torch.load(model_path, map_location='cpu')
        self.gpt_model.load_state_dict(checkpoint['model_state_dict'])
        self.gpt_model.to(self.device)
        self.gpt_model.eval() # CRITICAL: Sets model to evaluation mode
        print(f"‚úÖ Assistant Ready on {self.device}")

    def analyze_audio(self, task_list):
        context_size = self.gpt_model.pos_emb.weight.shape[0]
        # Ensure we have the correct End-of-String ID for stopping
        tiktoken_eos_id = self.tokenizer.encode("<|endoftext|>", allowed_special={'<|endoftext|>'})[0]

        final_report = []

        for task in task_list:
            audio_path = task.get("path")
            questions = task.get("questions", [])

            if not audio_path or not os.path.exists(audio_path):
                print(f"‚ö†Ô∏è File not found: {audio_path}")
                continue

            # --- STEP 1: HEAR THE AUDIO (Feature Extraction) ---
            try:
                with torch.no_grad():
                    # Get the 1024-dim CLAP embedding
                    audio_emb = self.clap_model.get_audio_embeddings([audio_path]).to(self.device)
            except Exception as e:
                print(f"‚ùå CLAP Error on {audio_path}: {e}")
                continue

            print(f"\n--- üéß Analyzing: {os.path.basename(audio_path)} ---")

            # --- STEP 2: REASONING LOOP ---
            for q in questions:
                # OPTIMIZED PROMPT: Matches your training data structure exactly
                # We pre-fill "<|start_thought|>" to trigger the reasoning mode immediately.
                prompt = (
                    f"Below is an instruction describing an audio task. "
                    f"Respond appropriately using the provided audio input.\n\n"
                    f"### Instruction:\n{q}\n\n"
                    f"### Response:\n<|start_thought|>"
                )

                encoded = text_to_token_ids(prompt, self.tokenizer).to(self.device)

                with torch.no_grad():
                    output_ids = generate_multimodal(
                        model = self.gpt_model,
                        idx = encoded,
                        audio_features=audio_emb,
                        max_new_tokens=256,   # Enough room for the full thought process
                        temperature=0.1,      # Low temp = High precision (Fact-based)
                        top_p=0.9,
                        repetition_penalty=1.1, # Gentle penalty to keep it moving
                        eos_id=tiktoken_eos_id,
                        context_size = context_size
                    )

                # Decode the raw tokens
                response = self.tokenizer.decode(output_ids[0].tolist())

                # --- CLEANING: Extract just the reasoning and answer ---
                # 1. Remove the prompt
                raw_output = response.split("### Response:\n")[-1]

                # 2. Clean up the tags for display
                # We keep the text but remove the technical tags for the user report
                clean_res = raw_output.replace("<|start_thought|>", "").replace("<|end_thought|>", "\n‚û°Ô∏è Answer:").strip()

                # 3. Cut off at the end signal
                if "<|endoftext|>" in clean_res:
                    clean_res = clean_res.split("<|endoftext|>")[0].strip()

                print(f"‚ùì Q: {q}")
                print(f"ü§ñ Logic Trace: {clean_res[:150]}...") # Preview first 150 chars
                print(f"   (Full logic saved to report)\n")

                final_report.append({
                    "Audio": os.path.basename(audio_path),
                    "Question": q,
                    "AI Reasoning": clean_res # Saves the full explanation
                })

        return final_report

In [None]:
BASE_CONFIG = {
    # Text Processing
    "vocab_size": 50257,        # Total 'words' the model knows (Standard GPT-2)
    "context_length": 1024,     # Maximum tokens the model can "remember" at once

    # Layer Stability
    "qkv_bias": True,           # Adds learnable bias to the Attention projections
    "drop_rate": 0.0,           # Regularization: % of neurons to disable during training
    "use_RMSNorm": False,       # Toggle: Standard LayerNorm vs faster Llama-style Norm

    # Research Architecture Switches
    "use_dual_stream": False,   # Toggle: Separates Global context from Local expertise
    "use_adaptive_moe": False,  # Toggle: Enables the Hierarchical Mixture of Experts
    "use_parallel_att": False,  # Toggle: Runs Attention and FFN at the same time (PaLM style)

    # GQA (Grouped Query Attention)
    #
    "kv_groups": 16,            # Number of KV heads (saves VRAM compared to standard MHA)

    # MoE (Mixture of Experts) Settings
    #
    "e_num": 1,                 # Total number of available expert brains
    "moe_groups": 1,            # How we cluster experts into 'knowledge neighborhoods'
    "moe_group_top_p": 0.7,     # Adaptive threshold for selecting active groups
    "moe_top_p": 0.9,           # Adaptive threshold for selecting experts within groups
    "moe_max_groups": 1,        # Limit on how many groups a token can visit
    "moe_max_k": 1,             # Limit on how many experts per group a token can use

    # SWA (Sliding Window Attention)
    #
    "window_size": 1024,        # Cache size for inference (ring buffer limit)
    "swa_size": 1024,           # Look-back limit during training (saves computation)
}

# Standard OpenAI Scaling blue-prints
model_configs = {
    "gpt2-small (124M)":  {"emb_dim": 768,  "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)":  {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)":    {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}

# Linking our readable names to the official Hugging Face Hub identifiers
mapping = {
    "gpt2-small (124M)":  "gpt2",
    "gpt2-medium (355M)": "gpt2-medium",
    "gpt2-large (774M)":  "gpt2-large",
    "gpt2-xl (1558M)":    "gpt2-xl"
}