In [2]:
import json

TRACE_FILES = {
    "only_grpo": "traces/only_grpo.jsonl",
    "only_sft": "traces/only_sft.jsonl",
    "sft_teacher": "traces/sft_teacher.jsonl",
    "sft_teacher_masked": "traces/sft_teacher_masked.jsonl",
    "grpo_on_sft_masked_teach": "traces/grpo_on_sft_masked_teach.jsonl",
}

def load_jsonl(path):
    games = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                games.append(json.loads(line))
    return games

traces = {name: load_jsonl(path) for name, path in TRACE_FILES.items()}
{k: len(v) for k, v in traces.items()}


{'only_grpo': 5,
 'only_sft': 5,
 'sft_teacher': 5,
 'sft_teacher_masked': 5,
 'grpo_on_sft_masked_teach': 5}

In [3]:
def format_game_trace(game) -> str:
    lines = []
    target = game.get("target", "?")
    solved = game.get("solved", False)

    # Header (kept minimal)
    lines.append(f"Target: {target}")
    lines.append("")

    for step in game.get("steps", []):
        turn = step.get("turn", "?")
        guess = step.get("guess") or "-----"
        status = (step.get("status") or "").lower()
        fb = step.get("feedback", "")

        suffix = ""
        if status == "oov":
            suffix = "   \u2190 OOV"
        elif status == "solved":
            suffix = "   \u2190 SOLVED"
        elif status == "format_fail":
            suffix = "   \u2190 PARSE FAIL"

        lines.append(f"Turn {turn}: {guess}{suffix}")
        if fb:
            lines.append(f"Feedback: {fb}")
        lines.append("")

    if not solved:
        lines.append("→ FAILED")

    return "\n".join(lines)


def show_trace(game, show_raw=False, raw_max_chars=400):
    print(format_game_trace(game))

    if show_raw:
        print("\n" + "-"*60)
        print("Raw model outputs (truncated):")
        for step in game.get("steps", []):
            raw = step.get("raw_model_output", "") or ""
            raw = raw[:raw_max_chars] + ("…" if len(raw) > raw_max_chars else "")
            print(f"\n[Turn {step.get('turn')} | status={step.get('status')} | guess={step.get('guess')}]")
            print(raw)


In [10]:
from IPython.display import clear_output

HELP = """Commands
  q                 quit
  m                 back to model menu
  n / p             next / previous trace
  r                 toggle raw output (on/off)
  g <idx>           go to trace index, e.g.  g 12
  h                 show this help
"""

def trace_browser(traces_dict):
    model_names = list(traces_dict.keys())
    selected_model = None
    idx = 0
    show_raw = False

    def print_model_menu():
        print("=== Wordle Trace Browser ===\n")
        print("Models:")
        for i, name in enumerate(model_names):
            print(f"  [{i}] {name}  (n={len(traces_dict[name])})")
        print("\nPick a model number, or type 'q' to quit.\n")

    def print_header():
        print("=== Wordle Trace Browser ===")
        print(f"Model: {selected_model} | Trace: {idx}/{len(traces_dict[selected_model])-1} | raw={'ON' if show_raw else 'OFF'}")
        print("-" * 60)
        print("Type 'h' for help. ('m' model menu, 'q' quit)\n")

    while True:
        # If no model selected, show model menu
        while selected_model is None:
            clear_output(wait=True)
            print_model_menu()
            s = input("Select model> ").strip().lower()
            if s == "q":
                return
            if s.isdigit() and int(s) in range(len(model_names)):
                selected_model = model_names[int(s)]
                idx = 0
                break
            input("Invalid selection. Press Enter to continue...")

        # Main trace view loop
        clear_output(wait=True)
        print_header()
        show_trace(traces_dict[selected_model][idx], show_raw=show_raw)

        cmd = input("\nCommand> ").strip()
        c = cmd.lower()

        if c in ("q", "quit", "exit"):
            return

        if c in ("m", "menu"):
            selected_model = None
            continue

        if c in ("h", "help", "?"):
            clear_output(wait=True)
            print_header()
            print(HELP)
            input("\nPress Enter to continue...")
            continue

        if c in ("n", "next"):
            idx = min(idx + 1, len(traces_dict[selected_model]) - 1)
            continue

        if c in ("p", "prev", "previous"):
            idx = max(idx - 1, 0)
            continue

        if c in ("r", "raw"):
            show_raw = not show_raw
            continue

        # "g 12" go to index
        if c.startswith("g "):
            parts = c.split()
            if len(parts) == 2 and parts[1].isdigit():
                new_idx = int(parts[1])
                new_idx = max(0, min(new_idx, len(traces_dict[selected_model]) - 1))
                idx = new_idx
                continue
            input("Usage: g <idx>   (example: g 12). Press Enter...")
            continue

        # Allow entering a number directly to jump
        if c.isdigit():
            new_idx = int(c)
            new_idx = max(0, min(new_idx, len(traces_dict[selected_model]) - 1))
            idx = new_idx
            continue

        input("Unknown command. Type 'h' for help. Press Enter...")

trace_browser(traces)


=== Wordle Trace Browser ===
Model: only_grpo | Trace: 4/4 | raw=OFF
------------------------------------------------------------
Type 'h' for help. ('m' model menu, 'q' quit)

Target: MANTA

Turn 1: CANDY
Feedback: C(x) A(✓) N(✓) D(x) Y(x)

Turn 2: GARDE   ← OOV
Feedback: G(x) A(✓) R(x) D(x) E(x)

Turn 3: GARDN   ← OOV
Feedback: G(x) A(✓) R(x) D(x) N(-)

Turn 4: GARND   ← OOV
Feedback: G(x) A(✓) R(x) N(-) D(x)

Turn 5: GARND   ← OOV
Feedback: G(x) A(✓) R(x) N(-) D(x)

Turn 6: GARND   ← OOV
Feedback: G(x) A(✓) R(x) N(-) D(x)

→ FAILED
