## Display results

In [6]:
# =========================================
# Full-tasks table (display-only, no saving)
# =========================================
from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, List, Optional
import math
import re

import pandas as pd
from IPython.display import display

pd.options.display.float_format = "{:.2f}".format

# -------------------
# Config / Constants
# -------------------

# Tasks to keep (rows will be task/metric pairs)
TASKS = [
    "hellaswag",
    "lambada_openai",
    "mmlu",
    "piqa",
    "wikitext",
    "winogrande",
    "commonsense_qa",
    "pubmedqa",
    "race",
    "sciq",
    "wsc273",
    "xnli",
]

# Recognized model tokens to normalize model names (order here defines column order)
MODEL_TOKENS = [
    "gpt2",
    "mha",
    "mla192-96-0",
    "mla192-96-192",
    "mla0-96-192",
    "mla0-0-192",
    "mla0-128-0",
    "mla192-0-0",
    "mla0-0-0",
]


# -------------------
# Helpers
# -------------------

def load_wandb_export(csv_path: str | Path) -> pd.DataFrame:
    """Load the W&B CSV export."""
    return pd.read_csv(csv_path)

def find_model_column(df: pd.DataFrame) -> Optional[str]:
    """
    Heuristically find the column that contains the model/run name including 'full-tasks'.
    If several columns match, pick the one with the most matches.
    """
    candidates = [c for c in df.columns if df[c].astype(str).str.contains("full-tasks", case=False, na=False).any()]
    if not candidates:
        return None
    counts = {c: df[c].astype(str).str.contains("full-tasks", case=False, na=False).sum() for c in candidates}
    return max(counts, key=counts.get)

def normalize_model_name(text: Any) -> Optional[str]:
    """Map a raw model string to our canonical short token (longest match wins)."""
    s = str(text).lower()
    matches = [tok for tok in MODEL_TOKENS if tok in s]
    if not matches:
        return None
    matches.sort(key=len, reverse=True)
    return matches[0]

def select_task_metric_columns(df: pd.DataFrame, tasks: List[str]) -> List[str]:
    """Return columns of the form 'task/metric' for the tasks we care about."""
    cols: List[str] = []
    for col in df.columns:
        if "/" not in col:
            continue
        task, metric = col.split("/", 1)
        if task in tasks:
            cols.append(col)
    return cols

def pretty_row_label(col_name: str) -> str:
    """Format 'task/metric' -> 'task\\n(metric)'."""
    task, metric = col_name.split("/", 1)
    return f"{task} ({metric})"

import re

def _parse_label(label: str):
    """
    From 'task (metric)' return ('task', 'metric').
    Falls back to ('label', '') if pattern not found.
    """
    m = re.match(r"^(.*?)\s*\(([^()]*)\)\s*$", str(label))
    if m:
        return m.group(1), m.group(2)
    return str(label), ""

def scale_acc_like(series: pd.Series) -> pd.Series:
    """
    Multiply acc/acc_norm rows by 100 and round to 2 d.p.
    """
    task, metric = _parse_label(series.name)
    if metric in {"acc", "acc_norm"}:
        return pd.to_numeric(series, errors="coerce").mul(100).round(2)
    return series



# -------------------
# Core builder
# -------------------

def build_full_tasks_dataframe(csv_path: str | Path) -> pd.DataFrame:
    """
    Build the display DataFrame:
      - Columns: model tokens (only those present among 'full-tasks' runs)
      - Rows: 'task\\n(metric)' for selected TASKS
      - 'acc' and 'acc_norm' are scaled to percentages (2 d.p.)
    """
    df = load_wandb_export(csv_path)

    # Identify column that contains 'full-tasks'
    model_col = find_model_column(df)
    if model_col is None:
        raise RuntimeError("Could not find any column containing 'full-tasks' to identify model rows.")

    # Keep only rows with 'full-tasks'
    ft_mask = df[model_col].astype(str).str.contains("full-tasks", case=False, na=False)
    df_ft = df[ft_mask].copy()
    if df_ft.empty:
        raise RuntimeError("No rows with 'full-tasks' found in the CSV.")

    # Normalize model names
    df_ft["__model__"] = df_ft[model_col].apply(normalize_model_name)
    df_ft = df_ft.dropna(subset=["__model__"])
    if df_ft.empty:
        raise RuntimeError("No recognizable model tokens among 'full-tasks' rows.")

    # Select 'task/metric' columns for tasks of interest
    metric_cols = select_task_metric_columns(df_ft, TASKS)
    if not metric_cols:
        raise RuntimeError("No columns of the form 'task/metric' for the requested TASKS were found.")

    # Build per-model series (take the last occurrence if multiple rows per model)
    per_model: dict[str, pd.Series] = {}
    for model, g in df_ft.groupby("__model__", sort=False):
        row = g.iloc[-1]
        ser = row[metric_cols]
        ser.index = [pretty_row_label(c) for c in ser.index]  # task\n(metric)
        per_model[model] = ser

    # Combine
    out = pd.DataFrame(per_model)

    # Scale acc/acc_norm rows to percentages
    out = out.apply(scale_acc_like, axis=1)

    # ... later, where you sort rows:
    order_map = {t: i for i, t in enumerate(TASKS)}

    def row_key(lbl: str):
        task, metric = _parse_label(lbl)
        return (order_map.get(task, 999), metric)
    out = out.reindex(sorted(out.index, key=row_key))

    # Order columns by MODEL_TOKENS, then any others
    cols = [m for m in MODEL_TOKENS if m in out.columns] + [c for c in out.columns if c not in MODEL_TOKENS]
    out = out[cols]

    return out


# -------------------
# Example usage
# -------------------
# Set your CSV path (W&B export)
csv_path = "./wandb_export_2025-10-03T09_56_38.127-05_00.csv"  # <- change if needed

df_results = build_full_tasks_dataframe(csv_path)

# Nicely display in the notebook (no saving to disk)
display(df_results)


Unnamed: 0,gpt2,mha,mla192-96-0,mla192-96-192,mla0-96-192,mla0-0-192,mla0-128-0,mla192-0-0,mla0-0-0
hellaswag (acc),28.92,27.11,27.18,27.01,27.21,27.14,27.05,27.14,27.11
lambada_openai (acc),32.56,17.25,13.06,11.08,11.62,15.02,13.89,15.82,17.25
mmlu (acc),22.92,22.95,22.95,22.94,22.92,22.92,22.94,22.95,22.95
piqa (acc),62.89,61.32,61.43,59.63,60.01,61.37,60.5,61.21,61.32
wikitext (word_perplexity),37.37,80.63,89.5,96.61,92.87,82.44,85.11,81.84,80.63
winogrande (acc),51.62,53.28,50.75,51.38,52.57,51.46,50.59,49.96,53.28
commonsense_qa (acc),19.57,19.57,19.57,19.49,19.57,19.57,19.57,19.57,19.57
pubmedqa (acc),45.4,38.2,38.6,36.6,34.4,39.2,35.0,37.6,38.2
race (acc),29.47,25.36,23.83,24.4,24.11,24.02,25.74,26.32,25.36
sciq (acc),75.2,56.2,56.9,54.1,51.9,57.5,55.3,57.9,56.2
