<a href="https://colab.research.google.com/github/amanzoni1/MoE_Adapter_Routing_Analysis/blob/main/fp_hellora.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q transformers datasets peft accelerate bitsandbytes wandb

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m38.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os, json
import torch
import torch.nn.functional as F
from transformers import OlmoeForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm.notebook import tqdm
from typing import Any, Callable, Dict, List, Optional
from google.colab import drive, files



In [3]:
drive.mount("/content/drive")

SAVE_ROOT_DRIVE = "/content/drive/MyDrive/HELLoRA_Experiments"
STAB_ROOT = os.path.join(SAVE_ROOT_DRIVE, "telemetry")
os.makedirs(STAB_ROOT, exist_ok=True)

Mounted at /content/drive


In [4]:
model_name = "allenai/OLMoE-1B-7B-0924"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = OlmoeForCausalLM.from_pretrained(
    model_name,
    dtype=torch.bfloat16,
    device_map="auto",
)

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/759 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.84G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/120 [00:00<?, ?B/s]

In [5]:
# --- formatters ---
def fmt_gsm8k(ex):  # math
    return f"Question: {ex['question']}\nAnswer: {ex['answer']}"

def fmt_alpaca(ex):  # code
    inp = ex.get("input", "") or ""
    return f"{ex['instruction']}\n{inp}\n{ex['output']}"

def fmt_wikitext(ex):  # general text
    return ex["text"]

# --- unified registry ---
DATASETS: Dict[str, Dict[str, Any]] = {
    "gsm8k": {
        "path": "openai/gsm8k",
        "name": "main",
        "splits": {"train": "train"},
        "text_fn": fmt_gsm8k,
        "full_n": 7473,
        "train_defaults": {"lr": None, "epochs": 3},
    },
    "alpaca": {
        "path": "sahil2801/CodeAlpaca-20k",
        "name": None,
        "splits": {"train": "train"},
        "text_fn": fmt_alpaca,
        "full_n": 20000,
        "train_defaults": {"lr": None, "epochs": 2},
    },
    "wikitext": {
        "path": "Salesforce/wikitext",
        "name": "wikitext-2-raw-v1",
        "splits": {"train": "train"},
        "text_fn": fmt_wikitext,
        "full_n": 36718,
        "train_defaults": {"lr": None, "epochs": 1},
    },
}

def load_hf_split(dataset_key: str, split: str, seed: int, n: Optional[int] = None):
    """
    Generic loader for ANY phase (profiling/training/eval).
    n: optional number of samples to select after shuffle.
    """
    cfg = DATASETS[dataset_key]
    split_name = cfg["splits"][split]
    ds = load_dataset(cfg["path"], cfg["name"], split=split_name)
    ds = ds.shuffle(seed=seed)
    if n is not None:
        ds = ds.select(range(min(n, len(ds))))
    return ds

def format_text(dataset_key: str, ex: Dict[str, Any]) -> str:
    return DATASETS[dataset_key]["text_fn"](ex)


In [6]:
import os
import torch
import torch.nn.functional as F  # NEW
from datasets import load_dataset
from tqdm.notebook import tqdm
from typing import Any, Callable, Dict, List, Optional


class ProfilerEngine:
    """
    Profiles MoE routing from gate logits (layer.mlp.gate):
      - counts: expert hit counts from top-k routing per token
      - mass:   summed routing probability mass assigned to each expert (top-k only)
               (useful for coverage/entropy stats later)

    Saves per-layer:
      - global:
          counts: [E] (long)
          mass:   [E] (float32)
          total:  scalar (#token*top_k) (long)
      - bucketed:
          counts: [B,E]
          mass:   [B,E]
          total:  [B]

    Notes:
      - We compute probs = softmax(logits) in float32, then topk(probs).
      - Top-k indices match topk(logits), but probs enable “40% coverage” etc.
    """

    def __init__(
        self,
        model,
        tokenizer,
        output_dir: str,
        seq_len: int = 2048,
        bucket_edges: Optional[List[int]] = None,
        gate_getter: Optional[Callable[[Any], Any]] = None,
        num_experts: Optional[int] = None,
        top_k: Optional[int] = None,
        store_mass: bool = True,                 # keep probability-mass telemetry
        prob_dtype: torch.dtype = torch.float32, # stable softmax dtype
        renorm_topk_prob: Optional[bool] = None, # NEW: match HF MoE scaling for mass if enabled
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.output_dir = output_dir
        self.seq_len = int(seq_len)

        # Layers
        if hasattr(model, "model") and hasattr(model.model, "layers"):
            self.layers = model.model.layers
        elif hasattr(model, "layers"):
            self.layers = model.layers
        else:
            raise ValueError("Could not automatically identify model layers.")

        self.num_layers = len(self.layers)
        self.num_experts = int(num_experts or getattr(model.config, "num_experts", 64))
        self.top_k = int(top_k or getattr(model.config, "num_experts_per_tok", 8))

        self.store_mass = bool(store_mass)
        self.prob_dtype = prob_dtype

        # Buckets
        self.bucket_edges = bucket_edges
        if self.bucket_edges is not None:
            if self.bucket_edges[0] != 0:
                raise ValueError("bucket_edges must start at 0.")
            if self.bucket_edges[-1] < self.seq_len:
                raise ValueError("bucket_edges[-1] must be >= seq_len.")
            self.num_buckets = len(self.bucket_edges) - 1

            # NEW: warn on left padding (bucket positions become meaningless)
            if getattr(self.tokenizer, "padding_side", "right") == "left":
                print("⚠️ [Profiler] tokenizer.padding_side='left' detected.")
                print("   Positional bucketing will be skewed; recommended: tokenizer.padding_side='right'.")
        else:
            self.num_buckets = 0

        self._current_attn_mask: Optional[torch.Tensor] = None
        self.hooks = []
        self.data_buffer: Dict[int, Dict[str, Any]] = {}

        # Default OLMoE gate path
        self.gate_getter = gate_getter or (lambda layer: layer.mlp.gate)

        # NEW: auto-detect renorm_topk_prob (affects mass only, not indices)
        if renorm_topk_prob is None:
            try:
                renorm_topk_prob = bool(getattr(self.layers[0].mlp, "norm_topk_prob", False))
            except Exception:
                renorm_topk_prob = False
        self.renorm_topk_prob = bool(renorm_topk_prob)

        os.makedirs(self.output_dir, exist_ok=True)

    def _init_buffer(self):
        if self.bucket_edges is None:
            self.data_buffer = {
                i: {
                    "counts": torch.zeros(self.num_experts, dtype=torch.long),
                    "mass": torch.zeros(self.num_experts, dtype=torch.float32) if self.store_mass else None,
                    "total": 0,
                }
                for i in range(self.num_layers)
            }
        else:
            self.data_buffer = {
                i: {
                    "counts": torch.zeros(self.num_buckets, self.num_experts, dtype=torch.long),
                    "mass": torch.zeros(self.num_buckets, self.num_experts, dtype=torch.float32) if self.store_mass else None,
                    "total": [0 for _ in range(self.num_buckets)],
                }
                for i in range(self.num_layers)
            }

    def _get_hook(self, layer_idx: int):
        def hook(module, input, output):
            logits = output[0] if isinstance(output, tuple) else output  # gate logits

            if self._current_attn_mask is None:
                raise RuntimeError("attention_mask not set; cannot mask padding.")

            B, S = self._current_attn_mask.shape
            m = self._current_attn_mask.to(torch.bool)  # [B,S]

            # NEW: safer universal flattening (still supports 2D or 3D logits)
            if logits.dim() == 3:
                logits_flat = logits.reshape(-1, logits.shape[-1])  # [B*S, E]
            elif logits.dim() == 2:
                logits_flat = logits  # [B*S, E]
            else:
                raise RuntimeError(f"Unexpected gate output shape: {tuple(logits.shape)}")

            if logits_flat.shape[1] != self.num_experts:
                raise RuntimeError(f"Gate last-dim {logits_flat.shape[1]} != num_experts {self.num_experts}")
            if logits_flat.shape[0] != B * S:
                raise RuntimeError(f"Gate tokens {logits_flat.shape[0]} != B*S {B*S} (cannot align)")

            flat_mask = m.reshape(-1)          # [B*S]
            logits2d = logits_flat[flat_mask]  # [N,E] (non-pad tokens)
            if logits2d.numel() == 0:
                return

            # Exact HF routing rule: softmax -> topk
            probs2d = F.softmax(logits2d.to(self.prob_dtype), dim=-1)  # [N,E] float32
            top_p, idx = torch.topk(probs2d, k=self.top_k, dim=-1)      # [N,K], [N,K]

            # match HF norm_topk_prob for mass (indices unchanged)
            if self.renorm_topk_prob:
                top_p = top_p / top_p.sum(dim=-1, keepdim=True)

            # counts (hit frequency)
            flat_idx_cpu = idx.reshape(-1).to(torch.long).cpu()
            counts = torch.bincount(flat_idx_cpu, minlength=self.num_experts)

            if self.bucket_edges is None:
                self.data_buffer[layer_idx]["counts"] += counts
                self.data_buffer[layer_idx]["total"] += int(idx.shape[0]) * self.top_k

                # mass (probability mass, top-k only)
                if self.store_mass:
                    mass = torch.zeros(self.num_experts, dtype=torch.float32)
                    mass.index_add_(0, flat_idx_cpu, top_p.reshape(-1).to(torch.float32).cpu())
                    self.data_buffer[layer_idx]["mass"] += mass

            else:
                # bucket by token position
                pos = torch.arange(S, device=self._current_attn_mask.device).view(1, S).expand(B, S)  # [B,S]
                pos2d = pos[m]  # [N]

                for bi in range(self.num_buckets):
                    lo, hi = self.bucket_edges[bi], self.bucket_edges[bi + 1]
                    bm = (pos2d >= lo) & (pos2d < hi)
                    if not bm.any():
                        continue

                    idx_b_cpu = idx[bm].reshape(-1).to(torch.long).cpu()
                    counts_b = torch.bincount(idx_b_cpu, minlength=self.num_experts)

                    self.data_buffer[layer_idx]["counts"][bi] += counts_b
                    self.data_buffer[layer_idx]["total"][bi] += int(bm.sum().item()) * self.top_k

                    if self.store_mass:
                        mass_b = torch.zeros(self.num_experts, dtype=torch.float32)
                        mass_b.index_add_(0, idx_b_cpu, top_p[bm].reshape(-1).to(torch.float32).cpu())
                        self.data_buffer[layer_idx]["mass"][bi] += mass_b

        return hook

    def attach_hooks(self):
        self.detach_hooks()
        self._init_buffer()
        print(f"[Profiler] Attaching gate hooks to {self.num_layers} layers...")
        for i, layer in enumerate(self.layers):
            gate = self.gate_getter(layer)  # e.g., layer.mlp.gate
            self.hooks.append(gate.register_forward_hook(self._get_hook(i)))

    def detach_hooks(self):
        for h in self.hooks:
            try:
                h.remove()
            except Exception:
                pass
        self.hooks = []

    def _process_batch(self, batch_text: List[str]):
        inputs = self.tokenizer(
            batch_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.seq_len,
        ).to(self.model.device)

        self._current_attn_mask = inputs.get("attention_mask", None)
        if self._current_attn_mask is None:
            raise RuntimeError("Tokenizer output missing attention_mask")

        with torch.no_grad():
            self.model(**inputs)

        self._current_attn_mask = None

    def run_dataset(
        self,
        dataset_cfg: Dict[str, Any],
        dataset_key: str,
        seed: int = 123,
        batch_size: int = 32,
        num_samples: int = 1024,
    ) -> str:
        print(f"\n[Profiler] {dataset_key} | samples={num_samples} | seed={seed}")
        ds = load_dataset(dataset_cfg["path"], dataset_cfg.get("name"), split=dataset_cfg["split"])
        ds = ds.shuffle(seed=seed).select(range(num_samples))

        self.attach_hooks()
        try:
            batch_text = []
            pbar = tqdm(total=num_samples, desc="Profiling")
            for item in ds:
                text = dataset_cfg["text_fn"](item)
                if not isinstance(text, str) or len(text) < 10:
                    continue
                batch_text.append(text)
                if len(batch_text) == batch_size:
                    self._process_batch(batch_text)
                    pbar.update(len(batch_text))
                    batch_text = []
            if batch_text:
                self._process_batch(batch_text)
                pbar.update(len(batch_text))
            pbar.close()
        finally:
            self.detach_hooks()

        return self._save(dataset_key)

    def _save(self, dataset_key: str) -> str:
        suffix = "bucketed" if self.bucket_edges is not None else "global"
        filename = os.path.join(self.output_dir, f"telemetry_{dataset_key}_{suffix}.pt")

        final = {
            "meta": {
                "num_layers": self.num_layers,
                "num_experts": self.num_experts,
                "top_k": self.top_k,
                "seq_len": self.seq_len,
                "bucket_edges": self.bucket_edges,
                "store_mass": self.store_mass,
                "renorm_topk_prob": self.renorm_topk_prob,  # NEW
            }
        }

        for layer_idx in range(self.num_layers):
            buf = self.data_buffer[layer_idx]
            obj = {
                "counts": buf["counts"].clone(),
                "total": torch.tensor(buf["total"], dtype=torch.long),
            }
            if self.store_mass:
                obj["mass"] = buf["mass"].clone()
            final[layer_idx] = obj

        torch.save(final, filename)
        print(f"[IO] Saved: {filename}")
        return filename


In [7]:
def make_profile(
    dataset_key: str,
    model,
    tokenizer,
    output_dir: str,
    *,
    split: str = "train",
    seed: int = 123,
    n: int = 2000,
    bs: int = 16,
    seq_len: int = 2048,
    bucket_edges: Optional[List[int]] = None,
    gate_getter: Optional[Callable[[Any], Any]] = None,
    store_mass: bool = True,
) -> str:
    """
    Profiles routing and saves a telemetry .pt file.
    Returns telemetry .pt path.
    """
    os.makedirs(output_dir, exist_ok=True)

    # your existing helpers
    ds = load_hf_split(dataset_key, split=split, seed=seed, n=n)

    eng = ProfilerEngine(
        model=model,
        tokenizer=tokenizer,
        output_dir=output_dir,
        seq_len=seq_len,
        bucket_edges=bucket_edges,
        gate_getter=gate_getter,   # default works for OLMoE: layer.mlp.gate
        store_mass=store_mass,
    )

    eng.attach_hooks()
    try:
        buf = []
        pbar = tqdm(total=len(ds), desc=f"profiling {dataset_key}/{split}")
        for ex in ds:
            text = format_text(dataset_key, ex)
            if not isinstance(text, str) or len(text) < 10:
                continue
            buf.append(text)
            if len(buf) == bs:
                eng._process_batch(buf)
                pbar.update(len(buf))
                buf = []
        if buf:
            eng._process_batch(buf)
            pbar.update(len(buf))
        pbar.close()
    finally:
        eng.detach_hooks()

    # ---- save (custom filename + meta you want) ----
    suffix = "bucketed" if bucket_edges is not None else "global"
    pt_path = os.path.join(output_dir, f"telemetry_{dataset_key}_{split}_n{n}_seed{seed}_{suffix}.pt")

    payload: Dict[Any, Any] = {
        "meta": {
            "dataset_key": dataset_key,
            "split": split,
            "seed": seed,
            "n": n,
            "bs": bs,
            "seq_len": seq_len,
            "bucket_edges": bucket_edges,
            "num_layers": eng.num_layers,
            "num_experts": eng.num_experts,
            "top_k": eng.top_k,
            "store_mass": store_mass,
        }
    }

    for li in range(eng.num_layers):
        layer_buf = eng.data_buffer[li]
        payload[li] = {
            "counts": layer_buf["counts"].clone(),
            "total": torch.tensor(layer_buf["total"], dtype=torch.long),
        }
        if store_mass:
            payload[li]["mass"] = layer_buf["mass"].clone()

    torch.save(payload, pt_path)
    print("[IO] saved:", pt_path)
    return pt_path


In [4]:
def build_hotmap(
    pt_path: str,
    k: int = 8,
    out_json: Optional[str] = None,
    mode: str = "counts",   # "counts" or "mass"
) -> str:
    """
    Build hotmap json from telemetry pt.
    - mode="counts": top-k by hit frequency (what you used so far)
    - mode="mass":   top-k by summed prob mass (useful for “40% coverage” logic)
    """
    d = torch.load(pt_path, map_location="cpu")
    layers = {kk: vv for kk, vv in d.items() if isinstance(kk, int)}

    if mode not in ("counts", "mass"):
        raise ValueError("mode must be 'counts' or 'mass'")

    hm: Dict[int, List[int]] = {}
    for layer_idx, obj in layers.items():
        x = obj[mode]
        if x.dim() == 2:
            x = x.sum(dim=0)  # aggregate buckets -> global

        top_idx = torch.topk(x.to(torch.float32), k=k).indices.tolist()
        hm[int(layer_idx)] = top_idx

    if out_json is None:
        base = os.path.splitext(pt_path)[0]
        out_json = f"{base}_hotmap_{mode}_k{k}.json"

    with open(out_json, "w") as f:
        json.dump({str(l): exps for l, exps in hm.items()}, f, indent=2)

    print("[IO] saved:", out_json)
    return out_json


In [9]:
ROOT = "/content/drive/MyDrive/HELLoRA_Experiments/telemetry/gsm8k"

gsm8k_telemetry = make_profile("gsm8k", model, tokenizer, ROOT, n=7473)

README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

[Profiler] Attaching gate hooks to 16 layers...


profiling gsm8k/train:   0%|          | 0/7473 [00:00<?, ?it/s]

[IO] saved: /content/drive/MyDrive/HELLoRA_Experiments/telemetry/gsm8k/telemetry_gsm8k_train_n7473_seed123_global.pt


In [10]:
GSM8K_HOTMAP_K8 = build_hotmap(gsm8k_telemetry, k=8, mode="counts")

[IO] saved: /content/drive/MyDrive/HELLoRA_Experiments/telemetry/gsm8k/telemetry_gsm8k_train_n7473_seed123_global_hotmap_counts_k8.json


In [11]:
GSM8K_HOTMAP_MASS_K8 = build_hotmap(gsm8k_telemetry, k=8, mode="mass")

[IO] saved: /content/drive/MyDrive/HELLoRA_Experiments/telemetry/gsm8k/telemetry_gsm8k_train_n7473_seed123_global_hotmap_mass_k8.json


In [5]:
gsm8k_telemetry = "/content/drive/MyDrive/HELLoRA_Experiments/telemetry/gsm8k/telemetry_gsm8k_train_n7473_seed123_global.pt"

GSM8K_HOTMAP_MASS_K4 = build_hotmap(gsm8k_telemetry, k=4)
GSM8K_HOTMAP_MASS_K16 = build_hotmap(gsm8k_telemetry, k=16)

[IO] saved: /content/drive/MyDrive/HELLoRA_Experiments/telemetry/gsm8k/telemetry_gsm8k_train_n7473_seed123_global_hotmap_counts_k4.json
[IO] saved: /content/drive/MyDrive/HELLoRA_Experiments/telemetry/gsm8k/telemetry_gsm8k_train_n7473_seed123_global_hotmap_counts_k16.json


In [16]:
import inspect
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock

print(inspect.getsource(OlmoeSparseMoeBlock.forward)[:2000])  # print first part


    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        # router_logits: (batch * sequence_length, n_experts)
        router_logits = self.gate(hidden_states)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        if self.norm_topk_prob:
            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be 