In [1]:
# --- Cell 1: Local setup (no Colab) ---
from pathlib import Path
import os, sys

# If you run this notebook from inside the repo, this will just work.
# Otherwise set:  export REPO_ROOT=/path/to/pure_ssm
REPO_ROOT = Path(os.environ.get("REPO_ROOT", Path.cwd())).resolve()

# If the notebook lives in a subdirectory, walk upwards until we find a repo marker.
_repo_markers = ["pyproject.toml", "setup.py", "runner_ssm.py"]
for parent in [REPO_ROOT, *REPO_ROOT.parents]:
    if any((parent / m).exists() for m in _repo_markers):
        REPO_ROOT = parent
        break

if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

os.chdir(REPO_ROOT)

# Where to write logs/results (no Drive mount).
OUTPUT_ROOT = Path(os.environ.get("OUTPUT_ROOT", REPO_ROOT / "outputs")).resolve()
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)

print("Working directory:", REPO_ROOT)
print("Output root      :", OUTPUT_ROOT)


Working directory: /insomnia001/home/dwz2107/SSM_experiment/test_notebooks/pure_ssm
Output root      : /insomnia001/home/dwz2107/SSM_experiment/test_notebooks/pure_ssm/outputs


In [2]:
# --- Cell 2: Imports & global config ---
from runner_ssm import (
    RunConfig,
    load_llm,
    DEFAULT_SAMPLING,
    LB_ALEVAL_SAMPLING,
    profile_run,
    log_result,
    run_and_log,
    score_mc_options,
    choose_mc_option,
    # smoke_test_pure_ssm,  # optional; not needed for main pipeline
)

from config.pure_ssm_config import PURE_SSM_CONTEXTS, PURE_SSM_MODELS, DECODE_CONFIG

from data import longbench_v2_utils as lbv2
from data import ada_leval_utils as ada
from data import pg19_utils as pg

from transformers import AutoTokenizer
from pathlib import Path
import json
import pandas as pd
import subprocess

PROMPT_ROOT = REPO_ROOT / "data" / "prompt_sets"
PROMPT_ROOT.mkdir(parents=True, exist_ok=True)
print("Prompt root:", PROMPT_ROOT)
print(PURE_SSM_CONTEXTS)

  from .autonotebook import tqdm as notebook_tqdm


INFO 12-15 09:18:12 [__init__.py:216] Automatically detected platform cuda.
Prompt root: /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/data/prompt_sets
[131072]


In [3]:
# --- Cell 3: vLLM + W&B instrumentation (LongBench/LEval compatible) ---
import os, re, time, json, statistics, threading, subprocess
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple

# Optional W&B (set USE_WANDB=1 to enable)
try:
    import wandb  # type: ignore
except Exception:
    wandb = None

# OpenAI-compatible client (vLLM exposes this API)
try:
    from openai import OpenAI  # type: ignore
except Exception:
    OpenAI = None
    print("⚠️  openai not installed. Install with: pip install openai>=1.0.0")


def percentile(xs: List[float], q: float) -> float:
    """Nearest-rank percentile, deterministic."""
    if not xs:
        return float("nan")
    ys = sorted(xs)
    k = int(round((len(ys) - 1) * q))
    k = max(0, min(k, len(ys) - 1))
    return float(ys[k])


def extract_answer(text: str) -> Optional[str]:
    """Extract A/B/C/D from a model response."""
    if not text:
        return None
    m = re.search(r"Final\s*answer\s*:\s*([ABCD])\b", text, flags=re.IGNORECASE)
    if m:
        return m.group(1).upper()
    # fallback: last standalone A/B/C/D in the string
    m = re.search(r"\b([ABCD])\b(?!.*\b[ABCD]\b)", text, flags=re.IGNORECASE)
    if m:
        return m.group(1).upper()
    return None


def _encode(tok, s: str) -> List[int]:
    # transformers tokenizers
    if hasattr(tok, "encode"):
        return tok.encode(s, add_special_tokens=False)
    return tok(s)["input_ids"]


def truncate_prompt(prompt: str, tok, max_prompt_tokens: int) -> str:
    """Keep the *tail* of the prompt within a token budget."""
    if max_prompt_tokens <= 0:
        return ""
    try:
        ids = _encode(tok, prompt)
    except Exception:
        return prompt

    if len(ids) <= max_prompt_tokens:
        return prompt

    ids = ids[-max_prompt_tokens:]
    try:
        return tok.decode(ids)
    except Exception:
        # crude fallback
        return prompt[-max_prompt_tokens * 4 :]


class VramSampler:
    """
    Lightweight VRAM sampler using `nvidia-smi`.
    Exposes `peak_mib_per_gpu_max` to mirror longbench_vllm/leval-vllm scripts.
    """

    def __init__(self, poll_s: float = 0.1):
        self.poll_s = poll_s
        self._stop = threading.Event()
        self._t: Optional[threading.Thread] = None
        self.peak_mib_per_gpu_max: float = float("nan")

    def _query_used_mib(self) -> Optional[List[float]]:
        try:
            out = subprocess.check_output(
                ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader,nounits"],
                text=True,
            )
            vals = [float(x.strip()) for x in out.strip().splitlines() if x.strip()]
            return vals if vals else None
        except Exception:
            return None

    def _loop(self):
        peak: List[float] = []
        while not self._stop.is_set():
            vals = self._query_used_mib()
            if vals is not None:
                if not peak:
                    peak = vals
                else:
                    peak = [max(a, b) for a, b in zip(peak, vals)]
                self.peak_mib_per_gpu_max = float(max(peak)) if peak else float("nan")
            time.sleep(self.poll_s)

    def __enter__(self):
        self._stop.clear()
        self._t = threading.Thread(target=self._loop, daemon=True)
        self._t.start()
        return self

    def __exit__(self, exc_type, exc, tb):
        self._stop.set()
        if self._t is not None:
            self._t.join(timeout=1.0)
        return False


def _openai_client(vllm_url: str):
    if OpenAI is None:
        raise RuntimeError("openai is not installed; pip install openai>=1.0.0")
    api_key = os.environ.get("OPENAI_API_KEY", "EMPTY")
    return OpenAI(base_url=vllm_url, api_key=api_key)


def stream_completion(
    client,
    served_model_name: str,
    prompt: str,
    temperature: float,
    top_p: float,
    max_new_tokens: int,
    seed: int,
    use_chat: bool = True,
) -> Tuple[str, float, float]:
    """Return (response_text, ttft_seconds, e2e_seconds)."""
    start = time.time()
    first: Optional[float] = None
    chunks: List[str] = []

    try:
        if use_chat:
            stream = client.chat.completions.create(
                model=served_model_name,
                messages=[{"role": "user", "content": prompt}],
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_new_tokens,
                seed=seed,
                stream=True,
            )
            for ev in stream:
                if first is None:
                    first = time.time()
                delta = ev.choices[0].delta.content
                if delta:
                    chunks.append(delta)
        else:
            stream = client.completions.create(
                model=served_model_name,
                prompt=prompt,
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_new_tokens,
                seed=seed,
                stream=True,
            )
            for ev in stream:
                if first is None:
                    first = time.time()
                delta = ev.choices[0].text
                if delta:
                    chunks.append(delta)
    except Exception:
        end = time.time()
        return "", float((first - start) if first else (end - start)), float(end - start)

    end = time.time()
    resp = "".join(chunks).strip()
    ttft_s = (first - start) if first else (end - start)
    e2e_s = (end - start)
    return resp, float(ttft_s), float(e2e_s)


def build_mc_prompt_from_record(item: Dict[str, Any]) -> Tuple[str, str]:
    """Returns (prompt, context)."""
    if isinstance(item.get("prompt"), str):
        prompt = item["prompt"]
        context = item.get("context", item.get("document", ""))
        return prompt, str(context) if context is not None else ""

    context = item.get("context") or item.get("document") or item.get("passage") or ""
    q = item.get("question") or item.get("query") or item.get("Q") or ""

    A = item.get("choice_A") or item.get("A") or item.get("option_A") or ""
    B = item.get("choice_B") or item.get("B") or item.get("option_B") or ""
    C = item.get("choice_C") or item.get("C") or item.get("option_C") or ""
    D = item.get("choice_D") or item.get("D") or item.get("option_D") or ""

    prompt = (
        f"{context}\n\n"
        f"Question: {q}\n"
        f"A. {A}\n"
        f"B. {B}\n"
        f"C. {C}\n"
        f"D. {D}\n"
    )
    return prompt, str(context)

In [10]:
# --- Cell 4: Build prompt sets (8k/16k/32k) if missing ---

tokenizer_model_id = PURE_SSM_MODELS["mamba-codestral-7b"]["hf_id"]

# LongBench v2
lbv2_32k_path = PROMPT_ROOT / "longbench_v2" / "lbv2_128k.jsonl"

if not lbv2_32k_path.exists():
    print("Building LongBench v2 prompt sets (8k/16k/32k)...")
    lbv2.build_lb2_prompt_sets(
        prompt_root=PROMPT_ROOT,
        tokenizer_model_id=tokenizer_model_id,
        pure_ssm_contexts=PURE_SSM_CONTEXTS,
        max_new_tokens=DECODE_CONFIG["max_new_tokens"],
        split="train",
        max_examples_per_ctx=500,
        tol=512,
    )
else:
    print("LongBench v2 prompt sets already exist, skipping build.")

# Ada‑LEval BestAnswer
ada_32k_path = PROMPT_ROOT / "ada_leval" / "ada_bestanswer_128k.jsonl"

if not ada_32k_path.exists():
    print("Building Ada‑LEval BestAnswer prompt sets (8k/16k/32k)...")
    ada.build_ada_bestanswer_prompt_sets(
        prompt_root=PROMPT_ROOT,
        tokenizer_model_id=tokenizer_model_id,
        pure_ssm_contexts=PURE_SSM_CONTEXTS,
        max_new_tokens=DECODE_CONFIG["max_new_tokens"],
        setting="8k",
        max_examples_per_ctx=500,
    )
else:
    print("Ada‑LEval BestAnswer prompt sets already exist, skipping build.")

# PG‑19
pg19_32k_path = PROMPT_ROOT / "pg19" / "pg19_32k.jsonl"

if not pg19_32k_path.exists():
    print("Building PG‑19 prompt sets (8k/16k/32k)...")
    pg.build_pg19_prompt_sets(
        prompt_root=PROMPT_ROOT,
        tokenizer_model_id=tokenizer_model_id,
        pure_ssm_contexts=PURE_SSM_CONTEXTS,
        max_new_tokens=DECODE_CONFIG["max_new_tokens"],
        split="test",
        max_examples_per_ctx=500,
        min_fraction=0.75,
    )
else:
    print("PG‑19 prompt sets already exist, skipping build.")

LongBench v2 prompt sets already exist, skipping build.
Building Ada‑LEval BestAnswer prompt sets (8k/16k/32k)...
[Ada BestAnswer] Loaded 1000 samples from Ada-LEval/data/stackselect_8k.json
Building Ada-LEval BestAnswer prompt sets.
Tokenizer: mistralai/Mamba-Codestral-7B-v0.1
Contexts: [131072] max_new_tokens: 128
[Ada BestAnswer ctx=131072] collected=500, min_len=6695, max_len=10474, mean_len=9152.5
Saved 500 records to /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/data/prompt_sets/ada_leval/ada_bestanswer_128k.jsonl
Building PG‑19 prompt sets (8k/16k/32k)...
Dataset({
    features: ['short_book_title', 'publication_date', 'url', 'text'],
    num_rows: 100
})
Building PG-19 prompt sets with tokenizer: mistralai/Mamba-Codestral-7B-v0.1
Contexts: [131072] max_new_tokens: 128
[PG19 ctx=131072] collected=34, min_len=99379, max_len=130912, mean_len=125686.9
Saved 34 records to /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/data/prompt_sets/pg19/pg19_128k.jsonl


In [11]:
# --- Cell 5: Run LongBench / LEval over vLLM with longbench_vllm-style metrics ---
from transformers import AutoTokenizer

@dataclass
class VLLMBenchArgs:
    benchmark: str = "longbench_v2"  # or: "ada_bestanswer" (LEval)
    tag: str = "8k"                 # "8k" / "16k" / "32k" prompt set tag

    vllm_url: str = os.environ.get("VLLM_URL", "http://localhost:8000/v1")
    served_model_name: str = os.environ.get("SERVED_MODEL_NAME", "model")
    tokenizer_id: str = os.environ.get("TOKENIZER_ID", "gpt2")

    max_len: int = int(os.environ.get("CTX_LEN", "8192"))
    seed: int = int(os.environ.get("SEED", "42"))
    cot: bool = bool(int(os.environ.get("COT", "0")))

    # logging
    log_every: int = int(os.environ.get("LOG_EVERY", "50"))
    use_wandb: bool = bool(int(os.environ.get("USE_WANDB", "0")))

    # optional bookkeeping (mirrors your script config)
    no_context: bool = bool(int(os.environ.get("NO_CONTEXT", "0")))
    rag_k: int = int(os.environ.get("RAG_K", "0"))
    rep: str = os.environ.get("RUN_REP", "")


def _default_wandb_project(benchmark: str) -> str:
    b = (benchmark or "").lower()
    if "leval" in b or "ada" in b:
        return "leval-vllm"
    return "longbench-vllm"


def _wandb_init_if_enabled(args: VLLMBenchArgs):
    if not args.use_wandb:
        return None
    if wandb is None:
        raise RuntimeError("USE_WANDB=1 but wandb is not installed. pip install wandb")

    project = os.environ.get("WANDB_PROJECT", _default_wandb_project(args.benchmark))
    run = wandb.init(
        project=project,
        entity=os.environ.get("WANDB_ENTITY"),
        name=os.environ.get("WANDB_NAME"),
        group=os.environ.get("WANDB_GROUP"),
        config={
            "benchmark": args.benchmark,
            "prompt_tag": args.tag,
            "model/served_name": args.served_model_name,
            "tokenizer_id": args.tokenizer_id,
            "vllm/url": args.vllm_url,
            "protocol/seed": args.seed,
            "protocol/temperature": 0.0,
            "protocol/top_p": 1.0,
            "cot": bool(args.cot),
            "no_context": bool(args.no_context),
            "rag_k": int(args.rag_k),
            "ctx_len": int(args.max_len),
            "rep": args.rep,
        },
    )
    return run


def _load_records_for_benchmark(benchmark: str, tag: str) -> List[Dict[str, Any]]:
    """Uses the existing prompt-set loaders if available."""
    b = (benchmark or "").lower()
    if "longbench" in b:
        recs, _prompts = lbv2.load_lb2_prompts_for_tag(PROMPT_ROOT, tag)
        return list(recs)
    if "leval" in b or "ada" in b:
        recs, _prompts = ada.load_ada_bestanswer_prompts_for_tag(PROMPT_ROOT, tag)
        return list(recs)
    raise ValueError(f"Unknown benchmark={benchmark!r}")


def run_mc_over_vllm(args: VLLMBenchArgs) -> Dict[str, Any]:
    # Load records
    data = _load_records_for_benchmark(args.benchmark, args.tag)

    # Tokenizer for output-token counting + prompt truncation
    tok = AutoTokenizer.from_pretrained(args.tokenizer_id, trust_remote_code=True)

    client = _openai_client(args.vllm_url)

    # Output file
    out_dir = OUTPUT_ROOT / "vllm_runs"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / f"{args.benchmark}_{args.tag}_{args.served_model_name.replace('/', '_')}.jsonl"

    run = _wandb_init_if_enabled(args)

    # metrics accumulators
    lat_e2e: List[float] = []
    lat_ttft: List[float] = []
    out_tokens: List[int] = []
    reprompt_used = 0

    # accuracy accumulators (match longbench_vllm schema)
    counts = {"easy": 0, "hard": 0, "short": 0, "medium": 0, "long": 0, "total": 0}
    correct = {"easy": 0, "hard": 0, "short": 0, "medium": 0, "long": 0, "total": 0}

    def _gold_letter(item: Dict[str, Any]) -> Optional[str]:
        for k in ["answer", "gold", "label", "target"]:
            v = item.get(k)
            if isinstance(v, str) and v.strip():
                if v.strip().upper() in {"A","B","C","D"}:
                    return v.strip().upper()
                m = re.search(r"\b([ABCD])\b", v.strip(), flags=re.IGNORECASE)
                if m:
                    return m.group(1).upper()
        return None

    def _mc_fields(item: Dict[str, Any]) -> Tuple[str, str, str, str, str, str]:
        context = item.get("context") or item.get("document") or item.get("passage") or ""
        q = item.get("question") or item.get("query") or item.get("Q") or ""
        A = item.get("choice_A") or item.get("A") or item.get("option_A") or ""
        B = item.get("choice_B") or item.get("B") or item.get("option_B") or ""
        C = item.get("choice_C") or item.get("C") or item.get("option_C") or ""
        D = item.get("choice_D") or item.get("D") or item.get("option_D") or ""
        return str(context), str(q), str(A), str(B), str(C), str(D)

    # Start run
    with out_file.open("w", encoding="utf-8") as fout, VramSampler(poll_s=0.1) as vs:
        for item in data:
            prompt, context = build_mc_prompt_from_record(item)

            # Max tokens: mirror your script behavior
            max_new = 1024 if args.cot else 128
            prompt_budget = max(256, args.max_len - max_new - 64)
            prompt = truncate_prompt(prompt, tok, prompt_budget)

            # Encourage parsable answer without forcing CoT
            prompt += "\n\nYou MUST end with: Final answer: X (X is A, B, C, or D)."

            resp, ttft_s, e2e_s = stream_completion(
                client=client,
                served_model_name=args.served_model_name,
                prompt=prompt,
                temperature=0.0,
                top_p=1.0,
                max_new_tokens=max_new,
                seed=args.seed,
                use_chat=True,
            )

            response_cot: Optional[str] = None

            # Optional explicit CoT mode: ask for reasoning then answer extraction
            if args.cot:
                response_cot = (resp or "").strip()

                # Build a compact answer-only prompt (keeps the same doc/Q/choices)
                ctx, q, A, B, C, D = _mc_fields(item)
                prompt2 = (
                    f"{ctx}\n\n"
                    f"Question: {q}\n"
                    f"A. {A}\n"
                    f"B. {B}\n"
                    f"C. {C}\n"
                    f"D. {D}\n\n"
                    f"Reasoning:\n{response_cot}\n\n"
                    f"Final answer:"
                )
                prompt2 = truncate_prompt(prompt2, tok, prompt_budget)

                resp2, ttft2, e2e2 = stream_completion(
                    client=client,
                    served_model_name=args.served_model_name,
                    prompt=prompt2,
                    temperature=0.0,
                    top_p=1.0,
                    max_new_tokens=16,
                    seed=args.seed,
                    use_chat=True,
                )
                resp = (resp2 or "").strip()
                ttft_s += ttft2
                e2e_s += e2e2

            response = (resp or "").strip()
            pred = extract_answer(response)

            # Reprompt if missing
            if pred is None:
                reprompt = "Answer with ONLY one letter: A, B, C, or D.\nFinal answer:"
                resp2, ttft2, e2e2 = stream_completion(
                    client=client,
                    served_model_name=args.served_model_name,
                    prompt=reprompt,
                    temperature=0.0,
                    top_p=1.0,
                    max_new_tokens=5,
                    seed=args.seed,
                    use_chat=True,
                )
                reprompt_used += 1
                ttft_s += ttft2
                e2e_s += e2e2
                pred = extract_answer(resp2 or "")
                response = response + "\n\n[REPROMPT]\n" + (resp2 or "")

            gold = _gold_letter(item)
            judge = (pred == gold) if (pred is not None and gold is not None) else False

            # token counting for throughput
            try:
                n_out = len(_encode(tok, response))
            except Exception:
                n_out = 0

            lat_e2e.append(float(e2e_s))
            lat_ttft.append(float(ttft_s))
            out_tokens.append(int(n_out))

            # accuracy breakdown
            counts["total"] += 1
            correct["total"] += int(judge)

            d = item.get("difficulty", "hard")
            if d not in ("easy", "hard"):
                d = "hard"
            counts[d] += 1
            correct[d] += int(judge)

            L = item.get("length", "long")
            if L not in ("short", "medium", "long"):
                L = "long"
            counts[L] += 1
            correct[L] += int(judge)

            # write line (keep original item keys, add tracking)
            item_out = dict(item)
            if response_cot is not None:
                item_out["response_cot"] = response_cot

            item_out["response"] = response
            item_out["pred"] = pred
            item_out["gold"] = gold
            item_out["judge"] = judge
            item_out["context"] = context[:1000] if isinstance(context, str) else str(context)[:1000]
            fout.write(json.dumps(item_out, ensure_ascii=False) + "\n")
            fout.flush()

            if run is not None and (counts["total"] % args.log_every == 0):
                wandb.log(
                    {
                        "progress/seen": counts["total"],
                        "eff/latency_e2e_s_mean_sofar": float(statistics.mean(lat_e2e)) if lat_e2e else float("nan"),
                        "eff/ttft_s_mean_sofar": float(statistics.mean(lat_ttft)) if lat_ttft else float("nan"),
                        "debug/reprompt_used_sofar": reprompt_used,
                    }
                )

    # finalize metrics
    def acc(k: str) -> float:
        return (correct[k] / counts[k]) if counts[k] > 0 else float("nan")

    overall = acc("total")
    easy = acc("easy")
    hard = acc("hard")
    short = acc("short")
    medium = acc("medium")
    long = acc("long")

    e2e_mean = float(statistics.mean(lat_e2e)) if lat_e2e else float("nan")
    e2e_p50 = float(percentile(lat_e2e, 0.50))
    e2e_p95 = float(percentile(lat_e2e, 0.95))

    ttft_mean = float(statistics.mean(lat_ttft)) if lat_ttft else float("nan")
    ttft_p50 = float(percentile(lat_ttft, 0.50))
    ttft_p95 = float(percentile(lat_ttft, 0.95))

    total_out = float(sum(out_tokens))
    total_time = float(sum(lat_e2e))
    toks_per_s = (total_out / total_time) if total_time > 0 else float("nan")

    peak_vram = float(getattr(vs, "peak_mib_per_gpu_max", float("nan")))

    summary = {
        "acc_overall": overall,
        "acc_easy": easy,
        "acc_hard": hard,
        "acc_short": short,
        "acc_medium": medium,
        "acc_long": long,
        "e2e_mean_s": e2e_mean,
        "e2e_p50_s": e2e_p50,
        "e2e_p95_s": e2e_p95,
        "ttft_mean_s": ttft_mean,
        "ttft_p50_s": ttft_p50,
        "ttft_p95_s": ttft_p95,
        "tokens_per_s": toks_per_s,
        "peak_vram_mib": peak_vram,
        "reprompt_used": reprompt_used,
        "out_file": str(out_file),
    }

    print(json.dumps(summary, indent=2))

    if run is not None:
        wandb.log(
            {
                "acc/overall": overall,
                "acc/easy": easy,
                "acc/hard": hard,
                "acc/short": short,
                "acc/medium": medium,
                "acc/long": long,
                "eff/latency_e2e_s_mean": e2e_mean,
                "eff/latency_e2e_s_p50": e2e_p50,
                "eff/latency_e2e_s_p95": e2e_p95,
                "eff/ttft_s_mean": ttft_mean,
                "eff/ttft_s_p50": ttft_p50,
                "eff/ttft_s_p95": ttft_p95,
                "eff/output_tokens_total": total_out,
                "eff/tokens_per_s": toks_per_s,
                "eff/peak_vram_mib_per_gpu_max": peak_vram,
                "debug/reprompt_used_total": reprompt_used,
            }
        )
        run.finish()

    return summary


# ---- Example: run LongBench v2 @ 8k ----
# args = VLLMBenchArgs(
#     benchmark="longbench_v2",
#     tag="8k",
#     served_model_name="your-served-name",
#     tokenizer_id="your-tokenizer-id",
# )
# summary = run_mc_over_vllm(args)


In [14]:
# --- Cell 6: Load datasets by context tag ---

def load_datasets_for_ctx(tag: str):
    """
    tag: "8k", "16k", or "32k"
    Returns a dict: dataset_name -> list of prompts
    """
    lb_recs, lb_prompts = lbv2.load_lb2_prompts_for_tag(PROMPT_ROOT, tag)
    ada_recs, ada_prompts = ada.load_ada_bestanswer_prompts_for_tag(PROMPT_ROOT, tag)
    pg_recs, pg_prompts   = pg.load_pg19_prompts_for_tag(PROMPT_ROOT, tag)

    print(f"[load_datasets_for_ctx] tag={tag}")
    print("  LongBench v2 prompts   :", len(lb_prompts))
    print("  Ada BestAnswer prompts :", len(ada_prompts))
    print("  PG-19 chunks           :", len(pg_prompts))

    datasets = {
        "longbench_v2": lb_prompts,
        "ada_bestanswer": ada_prompts,
        "pg19": pg_prompts
    }
    return datasets

# Optional, cheap: eager load so you can inspect counts
datasets_8k  = load_datasets_for_ctx("128k")
#datasets_16k = load_datasets_for_ctx("16k")
#datasets_32k = load_datasets_for_ctx("32k")

Loaded 500 records from /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/data/prompt_sets/longbench_v2/lbv2_128k.jsonl
Loaded 500 records from /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/data/prompt_sets/ada_leval/ada_bestanswer_128k.jsonl
Loaded 34 records from /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/data/prompt_sets/pg19/pg19_128k.jsonl
[load_datasets_for_ctx] tag=128k
  LongBench v2 prompts   : 500
  Ada BestAnswer prompts : 500
  PG-19 chunks           : 34


In [15]:
# --- Cell 7: Tokenizer cache + prompt filtering ---

_TOKENIZER_CACHE = {}

def get_tokenizer_for_model_key(model_key: str):
    if model_key in _TOKENIZER_CACHE:
        return _TOKENIZER_CACHE[model_key]
    info = PURE_SSM_MODELS[model_key]
    tok = AutoTokenizer.from_pretrained(
        info["hf_id"],
        revision=info.get("revision", "main"),
        trust_remote_code=True,
    )
    _TOKENIZER_CACHE[model_key] = tok
    return tok


def filter_prompts_to_ctx(prompts, cfg: RunConfig, safety_margin: int = 32):
    """
    Drop prompts that would exceed cfg.context_len - max_new_tokens - safety_margin
    when tokenized with THIS model's tokenizer.
    """
    tok = get_tokenizer_for_model_key(cfg.model_key)
    hard_max = cfg.context_len - cfg.max_new_tokens - safety_margin

    kept = []
    dropped = 0
    for p in prompts:
        n_tokens = len(tok.encode(p))
        if n_tokens <= hard_max:
            kept.append(p)
        else:
            dropped += 1

    print(
        f"[filter_prompts_to_ctx] {cfg.model_key} @ {cfg.context_len} "
        f"=> kept {len(kept)} prompts, dropped {dropped} (hard_max={hard_max})"
    )
    return kept

In [17]:
# --- Cell 8: Run pure SSM sweeps with log-file checks ---

LOG_DIR_ROOT = OUTPUT_ROOT / "pure_ssm_logs"
LOG_DIR_ROOT.mkdir(parents=True, exist_ok=True)
print("Log dir root:", LOG_DIR_ROOT)

def make_configs_for_ctx(ctx_len: int, seed: int = 42):
    """Two pure-SSM configs (Mamba-2.8B, Mamba-Codestral-7B) for a given context."""
    return [
        RunConfig(
            model_key="mamba-2.8b",
            context_len=ctx_len,
            max_new_tokens=DECODE_CONFIG["max_new_tokens"],
            precision="bfloat16",
            gpu_mem_util=0.90,
            seed=seed,
        ),
        RunConfig(
            model_key="mamba-codestral-7b",
            context_len=ctx_len,
            max_new_tokens=DECODE_CONFIG["max_new_tokens"],
            precision="bfloat16",
            gpu_mem_util=0.90,
            seed=seed,
        ),
    ]


def run_phase_for_ctx(ctx_len: int, ctx_tag: str):
    """
    Run 8k/16k/32k baselines for all pure SSM models on all datasets.

    ctx_len: 8192, 16384, or 32768
    ctx_tag: "8k", "16k", or "32k"

    This function is *idempotent*: if a log file already exists for a given
    (dataset, model, ctx_tag), that run is skipped.
    """
    configs = make_configs_for_ctx(ctx_len)
    datasets = load_datasets_for_ctx(ctx_tag)

    all_stats = []

    for cfg in configs:
        print("\n" + "=" * 80)
        print(f"PHASE | pure SSM | ctx={ctx_len} ({ctx_tag}) | model={cfg.model_key}")
        print("=" * 80)

        for dataset_name, prompts in datasets.items():
            log_name = f"{dataset_name}_{ctx_tag}_{cfg.model_key}"
            log_path = LOG_DIR_ROOT / f"{log_name}.jsonl"

            # NEW: skip heavy run if the log already exists
            if log_path.exists():
                print(
                    f"[SKIP] Found existing log for "
                    f"dataset={dataset_name}, model={cfg.model_key}, ctx={ctx_tag}:"
                    f"\n       {log_path}"
                )
                continue

            prompts_filtered = filter_prompts_to_ctx(prompts, cfg, safety_margin=32)

            if not prompts_filtered:
                print(
                    f"[WARN] No prompts left after filtering for "
                    f"{dataset_name} @ {ctx_tag} ({cfg.model_key}); skipping."
                )
                continue

            tag = f"{dataset_name}_{ctx_tag}.{cfg.model_key}"

            if dataset_name in {"longbench_v2", "ada_bestanswer"}:
                sampling = LB_ALEVAL_SAMPLING
            else:
                sampling = DEFAULT_SAMPLING

            print(
                f"\nRunning dataset={dataset_name} | "
                f"original_prompts={len(prompts)} | "
                f"filtered_prompts={len(prompts_filtered)} | "
                f"tag={tag}"
            )

            stats, outputs, log_path_returned = run_and_log(
                cfg=cfg,
                prompts=prompts_filtered,
                tag=tag,
                log_name=log_name,
                log_dir_root=LOG_DIR_ROOT,
                sampling_params=sampling,
                batch_size=1,
            )

            stats_record = {
                "model_key": cfg.model_key,
                "dataset": dataset_name,
                "ctx_len": cfg.context_len,
                "seed": cfg.seed,
                "tokens_per_s": stats["tokens_per_s"],
                "peak_vram_gb": stats["peak_vram_gb"],
                "total_time_s": stats["total_time_s"],
                "num_prompts": stats["num_prompts"],
                "log_path": str(log_path_returned),
            }
            all_stats.append(stats_record)

            print("  tokens/s      :", stats["tokens_per_s"])
            print("  peak VRAM (GB):", stats["peak_vram_gb"])
            print("  total time (s):", stats["total_time_s"])
            print("  log file      :", log_path_returned)

    return all_stats

# Heavy but reproducible: will only run missing combinations.
phase_stats_128k  = run_phase_for_ctx(8192,  "128k")


Log dir root: /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/outputs/pure_ssm_logs
Loaded 500 records from /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/data/prompt_sets/longbench_v2/lbv2_128k.jsonl
Loaded 500 records from /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/data/prompt_sets/ada_leval/ada_bestanswer_128k.jsonl
Loaded 34 records from /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/data/prompt_sets/pg19/pg19_128k.jsonl
[load_datasets_for_ctx] tag=128k
  LongBench v2 prompts   : 500
  Ada BestAnswer prompts : 500
  PG-19 chunks           : 34

PHASE | pure SSM | ctx=8192 (128k) | model=mamba-2.8b
[filter_prompts_to_ctx] mamba-2.8b @ 8192 => kept 0 prompts, dropped 500 (hard_max=8032)
[WARN] No prompts left after filtering for longbench_v2 @ 128k (mamba-2.8b); skipping.
[filter_prompts_to_ctx] mamba-2.8b @ 8192 => kept 57 prompts, dropped 443 (hard_max=8032)

Running dataset=ada_bestanswer | original_prompts=500 | filtered_prompts=57 | tag=ada_bestanswer_128k.mamba-2

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


INFO 12-15 09:24:29 [model.py:547] Resolved architecture: MambaForCausalLM


`torch_dtype` is deprecated! Use `dtype` instead!


INFO 12-15 09:24:29 [model.py:1730] Downcasting torch.float32 to torch.bfloat16.
INFO 12-15 09:24:29 [model.py:1510] Using max model len 8192


2025-12-15 09:24:32,674	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


INFO 12-15 09:24:32 [scheduler.py:205] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 12-15 09:24:32 [config.py:297] Hybrid or mamba-based model detected: disabling prefix caching since it is not yet supported.
INFO 12-15 09:24:32 [config.py:308] Hybrid or mamba-based model detected: setting cudagraph mode to FULL_AND_PIECEWISE in order to optimize performance.
INFO 12-15 09:24:32 [__init__.py:381] Cudagraph is disabled under eager mode
[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:32 [core.py:644] Waiting for init message from front-end.
[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:32 [core.py:77] Initializing a V1 LLM engine (v0.11.0) with config: model='state-spaces/mamba-2.8b-hf', speculative_config=None, tokenizer='state-spaces/mamba-2.8b-hf', skip_tokenizer_init=False, tokenizer_mode=auto, revision=main, tokenizer_revision=main, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_p

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:36 [parallel_state.py:1208] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:37 [topk_topp_sampler.py:55] Using FlashInfer for top-p & top-k sampling.
[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:37 [gpu_model_runner.py:2602] Starting to load model state-spaces/mamba-2.8b-hf...
[1;36m(Engine

Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:03<00:07,  3.94s/it]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:04<00:02,  2.19s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:08<00:00,  2.98s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:08<00:00,  2.95s/it]
[1;36m(EngineCore_DP0 pid=841149)[0;0m 


[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:46 [default_loader.py:267] Loading weights took 8.93 seconds
[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:47 [gpu_model_runner.py:2653] Model loading took 5.2347 GiB and 9.222777 seconds
[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:50 [gpu_worker.py:298] Available KV cache memory: 36.99 GiB
[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:50 [kv_cache_utils.py:1087] GPU KV cache size: 26,124,288 tokens
[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:50 [kv_cache_utils.py:1091] Maximum concurrency for 8,192 tokens per request: 3189.00x
[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:50 [core.py:210] init engine (profile, create kv cache, warmup model) took 3.01 seconds
[1;36m(EngineCore_DP0 pid=841149)[0;0m INFO 12-15 09:24:50 [__init__.py:381] Cudagraph is disabled under eager mode
INFO 12-15 09:24:51 [llm.py:306] Supported_tasks: ['generate']


Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 45.85it/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.18s/it, est. speed input: 1891.96 toks/s, output: 7.66 toks/s]
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 64.75it/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.85s/it, est. speed input: 4138.00 toks/s, output: 17.30 toks/s]
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 70.96it/s]
Processed promp

Logged 57 samples to /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/outputs/pure_ssm_logs/pure_ssm_logs/ada_bestanswer_128k_mamba-2.8b.jsonl
  tokens/s      : 4196.745971714533
  peak VRAM (GB): 42.8212890625
  total time (s): 106.49631953239441
  log file      : /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/outputs/pure_ssm_logs/pure_ssm_logs/ada_bestanswer_128k_mamba-2.8b.jsonl
[filter_prompts_to_ctx] mamba-2.8b @ 8192 => kept 0 prompts, dropped 34 (hard_max=8032)
[WARN] No prompts left after filtering for pg19 @ 128k (mamba-2.8b); skipping.

PHASE | pure SSM | ctx=8192 (128k) | model=mamba-codestral-7b
[filter_prompts_to_ctx] mamba-codestral-7b @ 8192 => kept 0 prompts, dropped 500 (hard_max=8032)
[WARN] No prompts left after filtering for longbench_v2 @ 128k (mamba-codestral-7b); skipping.
[filter_prompts_to_ctx] mamba-codestral-7b @ 8192 => kept 1 prompts, dropped 499 (hard_max=8032)

Running dataset=ada_bestanswer | original_prompts=500 | filtered_prompts=1 | tag=ada_besta

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


INFO 12-15 09:27:56 [model.py:547] Resolved architecture: Mamba2ForCausalLM
INFO 12-15 09:27:56 [model.py:1510] Using max model len 8192
INFO 12-15 09:27:56 [scheduler.py:205] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 12-15 09:27:56 [config.py:297] Hybrid or mamba-based model detected: disabling prefix caching since it is not yet supported.
INFO 12-15 09:27:56 [config.py:308] Hybrid or mamba-based model detected: setting cudagraph mode to FULL_AND_PIECEWISE in order to optimize performance.
INFO 12-15 09:27:56 [__init__.py:381] Cudagraph is disabled under eager mode


  return get_tokenizer(






INFO 12-15 09:27:59 [__init__.py:216] Automatically detected platform cuda.
[1;36m(EngineCore_DP0 pid=842617)[0;0m INFO 12-15 09:28:01 [core.py:644] Waiting for init message from front-end.
[1;36m(EngineCore_DP0 pid=842617)[0;0m INFO 12-15 09:28:01 [core.py:77] Initializing a V1 LLM engine (v0.11.0) with config: model='mistralai/Mamba-Codestral-7B-v0.1', speculative_config=None, tokenizer='mistralai/Mamba-Codestral-7B-v0.1', skip_tokenizer_init=False, tokenizer_mode=auto, revision=main, tokenizer_revision=main, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser=''), observability_

Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:04<00:08,  4.11s/it]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:08<00:04,  4.14s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:12<00:00,  4.24s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:12<00:00,  4.21s/it]
[1;36m(EngineCore_DP0 pid=842617)[0;0m 


[1;36m(EngineCore_DP0 pid=842617)[0;0m INFO 12-15 09:30:04 [default_loader.py:267] Loading weights took 12.71 seconds
[1;36m(EngineCore_DP0 pid=842617)[0;0m INFO 12-15 09:30:04 [gpu_model_runner.py:2653] Model loading took 13.6327 GiB and 120.600976 seconds
[1;36m(EngineCore_DP0 pid=842617)[0;0m INFO 12-15 09:30:07 [gpu_worker.py:298] Available KV cache memory: 27.86 GiB
[1;36m(EngineCore_DP0 pid=842617)[0;0m INFO 12-15 09:30:08 [kv_cache_utils.py:1087] GPU KV cache size: 1,769,472 tokens
[1;36m(EngineCore_DP0 pid=842617)[0;0m INFO 12-15 09:30:08 [kv_cache_utils.py:1091] Maximum concurrency for 8,192 tokens per request: 216.00x
[1;36m(EngineCore_DP0 pid=842617)[0;0m INFO 12-15 09:30:08 [core.py:210] init engine (profile, create kv cache, warmup model) took 3.37 seconds


[1;36m(EngineCore_DP0 pid=842617)[0;0m   return get_tokenizer(


[1;36m(EngineCore_DP0 pid=842617)[0;0m INFO 12-15 09:30:08 [__init__.py:381] Cudagraph is disabled under eager mode
INFO 12-15 09:30:08 [llm.py:306] Supported_tasks: ['generate']


Adding requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 100.13it/s]
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:31<00:00, 31.41s/it, est. speed input: 213.21 toks/s, output: 0.16 toks/s]


Logged 1 samples to /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/outputs/pure_ssm_logs/pure_ssm_logs/ada_bestanswer_128k_mamba-codestral-7b.jsonl
  tokens/s      : 213.29208147318863
  peak VRAM (GB): 42.0654296875
  total time (s): 31.421700954437256
  log file      : /insomnia001/home/dwz2107/SSM_experiment/pure_ssm/outputs/pure_ssm_logs/pure_ssm_logs/ada_bestanswer_128k_mamba-codestral-7b.jsonl
[filter_prompts_to_ctx] mamba-codestral-7b @ 8192 => kept 0 prompts, dropped 34 (hard_max=8032)
[WARN] No prompts left after filtering for pg19 @ 128k (mamba-codestral-7b); skipping.


In [18]:
# --- Cell 9: LongBench MC evaluation with file checks ---

LB_MC_ROOT = REPO_ROOT / "longbench_mc_preds"
LB_MC_ROOT.mkdir(exist_ok=True)

LB_MC_COMBOS = [
    ("mamba-2.8b",         8192),
    ("mamba-2.8b",        16384),
    ("mamba-2.8b",        32768),
    ("mamba-codestral-7b", 8192),
    ("mamba-codestral-7b",16384),
    ("mamba-codestral-7b",32768),
]

for model_key, ctx_len in LB_MC_COMBOS:
    pred_path = LB_MC_ROOT / f"lbv2_mc_{model_key}_{ctx_len}.jsonl"
    if pred_path.exists():
        print(f"[LB-MC] Using existing predictions for {model_key} @ {ctx_len}: {pred_path}")
        continue

    print(f"[LB-MC] Running eval_longbench_mc.py for {model_key} @ {ctx_len}...")
    subprocess.run(
        [
            "python",
            "eval_longbench_mc.py",
            "--model-key", model_key,
            "--ctx", str(ctx_len),
        ],
        cwd=REPO_ROOT,
        check=True,
    )

[LB-MC] Running eval_longbench_mc.py for mamba-2.8b @ 8192...




INFO 12-15 09:35:43 [__init__.py:216] Automatically detected platform cuda.

=== LongBench MC eval | model_key=mamba-2.8b @ ctx_len=8192 (num_raw_prompts=499) ===


Traceback (most recent call last):
  File "/insomnia001/home/dwz2107/SSM_experiment/pure_ssm/eval_longbench_mc.py", line 205, in <module>
    main()
  File "/insomnia001/home/dwz2107/SSM_experiment/pure_ssm/eval_longbench_mc.py", line 201, in main
    run_eval_for_ctx(args.model_key, ctx_len)
  File "/insomnia001/home/dwz2107/SSM_experiment/pure_ssm/eval_longbench_mc.py", line 119, in run_eval_for_ctx
    cfg = find_run_config(model_key, ctx_len)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/insomnia001/home/dwz2107/SSM_experiment/pure_ssm/eval_longbench_mc.py", line 85, in find_run_config
    raise ValueError(
ValueError: No RunConfig found for model_key=mamba-2.8b, ctx_len=8192.


CalledProcessError: Command '['python', 'eval_longbench_mc.py', '--model-key', 'mamba-2.8b', '--ctx', '8192']' returned non-zero exit status 1.

In [None]:
# --- Cell 10: Aggregate 8k accuracy + efficiency into results/ssm_8k.csv ---

LOG_ROOT    = OUTPUT_ROOT / "pure_ssm_logs"          # pure_ssm_logs/*.jsonl
RESULTS_DIR = REPO_ROOT / "results"
RESULTS_DIR.mkdir(exist_ok=True)

print("REPO_ROOT :", REPO_ROOT)
print("LOG_ROOT  :", LOG_ROOT)
print("RESULTS   :", RESULTS_DIR)


def load_lb_accuracy(pred_path: Path) -> float:
    """
    Compute accuracy from a LongBench MC prediction JSONL file.

    Each line is expected to have keys:
      ["model_key", "context_len", "prompt_idx", "gold", "pred"]
    """
    correct = 0
    total = 0
    with pred_path.open() as f:
        for line in f:
            row = json.loads(line)
            total += 1
            if row["pred"] == row["gold"]:
                correct += 1
    return correct / total if total > 0 else 0.0


def load_efficiency_from_log(path: Path) -> dict:
    """
    Read the FIRST line of a pure_ssm_logs JSONL file and return
    a small dict with model+ctx+dataset + throughput / VRAM / num_prompts.

    The log row is produced by runner_ssm.log_result / profile_run, e.g.:

        {
          "tag": "longbench_v2_8k.mamba-2.8b",
          "model_key": "mamba-2.8b",
          "context_len": 8192,
          "tokens_per_s": ...,
          "peak_vram_gb": ...,
          "num_prompts": ...,
          ...
        }
    """
    with path.open() as f:
        first = json.loads(next(f))

    tag = first["tag"]             # e.g. "longbench_v2_8k.mamba-2.8b"
    head = tag.split(".")[0]       # "longbench_v2_8k"
    dataset_name = head.rsplit("_", 1)[0]  # "longbench_v2"

    return {
        "model":        first["model_key"],
        "ctx":          first["context_len"],
        "dataset":      dataset_name,
        "tokens_per_s": first["tokens_per_s"],
        "peak_vram_gb": first["peak_vram_gb"],
        "num_prompts":  first["num_prompts"],
    }


# 1) Accuracy from MC prediction files @ 8k
mamba_mc_path     = LB_MC_ROOT / "lbv2_mc_mamba-2.8b_8192.jsonl"
codestral_mc_path = LB_MC_ROOT / "lbv2_mc_mamba-codestral-7b_8192.jsonl"

acc_mamba_lb_8k     = load_lb_accuracy(mamba_mc_path)
acc_codestral_lb_8k = load_lb_accuracy(codestral_mc_path)

print("Mamba-2.8B    @ 8k LongBench MC acc:", acc_mamba_lb_8k)
print("Codestral-7B  @ 8k LongBench MC acc:", acc_codestral_lb_8k)

# 2) Efficiency stats from pure_ssm_logs (official 8k runs)
eff_mamba_lb_8k = load_efficiency_from_log(
    LOG_ROOT / "longbench_v2_8k_mamba-2.8b.jsonl"
)
eff_codestral_lb_8k = load_efficiency_from_log(
    LOG_ROOT / "longbench_v2_8k_mamba-codestral-7b.jsonl"
)

print("\nMamba-2.8B   @ 8k efficiency:", eff_mamba_lb_8k)
print("Codestral-7B @ 8k efficiency:", eff_codestral_lb_8k)


# 3) Build ssm_8k.csv with quality + efficiency metrics
rows = []

# ---- LongBench MC accuracy ----
rows.append({
    "model":   "mamba-2.8b",
    "ctx":     8192,
    "dataset": "longbench_v2_mc",
    "metric":  "accuracy",
    "value":   acc_mamba_lb_8k,
})

rows.append({
    "model":   "mamba-codestral-7b",
    "ctx":     8192,
    "dataset": "longbench_v2_mc",
    "metric":  "accuracy",
    "value":   acc_codestral_lb_8k,
})


# ---- LongBench 8k efficiency (tokens/s, VRAM, num_prompts) ----
def add_eff_rows(eff: dict):
    rows.append({
        "model":   eff["model"],
        "ctx":     eff["ctx"],
        "dataset": eff["dataset"],
        "metric":  "tokens_per_s",
        "value":   eff["tokens_per_s"],
    })
    rows.append({
        "model":   eff["model"],
        "ctx":     eff["ctx"],
        "dataset": eff["dataset"],
        "metric":  "peak_vram_gb",
        "value":   eff["peak_vram_gb"],
    })
    rows.append({
        "model":   eff["model"],
        "ctx":     eff["ctx"],
        "dataset": eff["dataset"],
        "metric":  "num_prompts",
        "value":   eff["num_prompts"],
    })

add_eff_rows(eff_mamba_lb_8k)
add_eff_rows(eff_codestral_lb_8k)

df_8k = pd.DataFrame(rows)

out_path = RESULTS_DIR / "ssm_8k.csv"
if out_path.exists():
    print("results/ssm_8k.csv already exists; overwriting with freshly computed values.")
df_8k.to_csv(out_path, index=False)

print("Wrote:", out_path)
df_8k.head()