# Fail Rate From Running Logs

This notebook scans your training stdout/stderr logs (or W&B `wandb-history.jsonl`) and computes fail rate from the `risk/fail` metric.

By default, it treats `risk/fail == 1` as a failure (you can change `FAIL_VALUE`).


## Plot Recommendation

- Rolling fail rate over time: line plot of a moving average of `risk/fail` (best for seeing stability/regressions during training).
- Per-run/per-file comparison: bar chart of fail rate grouped by log file (best for comparing datasets/experiments).


In [None]:
from __future__ import annotations

import json
import os
import re
from dataclasses import dataclass
from glob import glob
from pathlib import Path
from typing import Iterable, Iterator

import matplotlib.pyplot as plt


# Treat this value of `risk/fail` as a "failure".
# You asked for failure based on `risk/fail: 0`, so the default is 0.
# If in your runs `risk/fail: 1` means failure, set FAIL_VALUE = 1.
FAIL_VALUE = 0


@dataclass(frozen=True)
class FailRecord:
    path: str
    line_no: int
    value: int
    line: str


_RISK_FAIL_RE = re.compile(r"['\"]risk/fail['\"]\s*[:=]\s*([-+]?\d+(?:\.\d+)?)")


def _iter_text_files(paths: Iterable[str]) -> Iterator[tuple[str, Iterator[str]]]:
    """Yield (path, line_iterator) for existing files."""
    for p in paths:
        if not p:
            continue
        if os.path.isdir(p):
            continue
        if not os.path.exists(p):
            continue
        yield p, Path(p).open("r", encoding="utf-8", errors="replace")


def _extract_risk_fail_value(line: str) -> int | None:
    """Extract risk/fail from one line.

    Supports:
    - JSONL like: {"risk/fail": 0, ...}
    - Python-ish dict printing like: {'risk/fail': 0, ...}
    - Other text like: risk/fail: 0
    """
    # Fast path: regex
    m = _RISK_FAIL_RE.search(line)
    if m:
        try:
            return int(float(m.group(1)))
        except Exception:
            return None

    # JSON path (W&B history lines often are valid JSON)
    s = line.strip()
    if s.startswith("{") and s.endswith("}"):
        try:
            obj = json.loads(s)
        except Exception:
            return None
        if isinstance(obj, dict) and "risk/fail" in obj:
            try:
                return int(float(obj["risk/fail"]))
            except Exception:
                return None

    return None


def parse_risk_fail_from_logs(paths: Iterable[str]) -> list[FailRecord]:
    records: list[FailRecord] = []
    for path, f in _iter_text_files(paths):
        with f:
            for i, line in enumerate(f, start=1):
                v = _extract_risk_fail_value(line)
                if v is None:
                    continue
                records.append(FailRecord(path=path, line_no=i, value=v, line=line.rstrip("\n")))
    return records


In [None]:
# Point this at your captured stdout/stderr log(s).
# You can either list paths explicitly OR use globs.

# Option A: explicit paths
EXPLICIT_LOG_PATHS: list[str] = [
    # "path/to/your/run.log",
]

# Option B: auto-discover common log locations
LOG_GLOBS: list[str] = [
    "**/*.log",
    "**/*.out",
    "**/output.log",
    "**/wandb-history.jsonl",
]

# Paths to ignore (add your big artifact dirs here if needed)
IGNORE_SUBSTRINGS = [
    os.sep + ".git" + os.sep,
    os.sep + ".venv" + os.sep,
]


def discover_log_paths() -> list[str]:
    found: list[str] = []
    for g in LOG_GLOBS:
        found.extend(glob(g, recursive=True))
    # De-dupe while preserving order
    uniq: list[str] = []
    seen = set()
    for p in found:
        ap = str(Path(p))
        if ap in seen:
            continue
        if any(s in ap for s in IGNORE_SUBSTRINGS):
            continue
        if os.path.isdir(ap):
            continue
        seen.add(ap)
        uniq.append(ap)
    return uniq


log_paths = EXPLICIT_LOG_PATHS or discover_log_paths()
print(f"Found {len(log_paths)} candidate log file(s)")
for p in log_paths[:30]:
    print("-", p)
if len(log_paths) > 30:
    print(f"... and {len(log_paths) - 30} more")


In [None]:
records = parse_risk_fail_from_logs(log_paths)
total = len(records)

if total == 0:
    raise RuntimeError(
        "No 'risk/fail' entries found. Set EXPLICIT_LOG_PATHS to your run log, or adjust LOG_GLOBS."
    )

count0 = sum(1 for r in records if r.value == 0)
count1 = sum(1 for r in records if r.value == 1)
count_other = total - count0 - count1

failed = sum(1 for r in records if r.value == FAIL_VALUE)
fail_rate = failed / total

print(f"Total risk/fail entries: {total}")
print(f"risk/fail == 0: {count0} ({count0/total:.2%})")
print(f"risk/fail == 1: {count1} ({count1/total:.2%})")
if count_other:
    print(f"risk/fail other: {count_other} ({count_other/total:.2%})")
print(f"Fail rate (treating risk/fail == {FAIL_VALUE} as failure): {fail_rate:.2%}")

print("\nMost recent matches:")
for r in records[-5:]:
    print(f"- {r.path}:{r.line_no}  risk/fail={r.value}")


In [None]:
# Optional: per-file breakdown (handy for multi-dataset runs saved to separate logs)
from collections import defaultdict

by_path: dict[str, list[int]] = defaultdict(list)
for r in records:
    by_path[r.path].append(r.value)

rows = []
for p, vals in by_path.items():
    t = len(vals)
    f = sum(1 for v in vals if v == FAIL_VALUE)
    rows.append((f / t, f, t, p))

rows.sort(reverse=True)
print(f"Per-file fail rate (FAIL_VALUE={FAIL_VALUE}):")
for rate, f, t, p in rows[:50]:
    print(f"- {rate:.2%}  ({f}/{t})  {p}")
if len(rows) > 50:
    print(f"... and {len(rows) - 50} more")


In [None]:
# Visualization

def rolling_mean(xs: list[float], window: int) -> list[float]:
    if window <= 1:
        return xs
    out: list[float] = []
    s = 0.0
    for i, x in enumerate(xs):
        s += x
        if i >= window:
            s -= xs[i - window]
        denom = min(i + 1, window)
        out.append(s / denom)
    return out


# 1) Rolling fail rate over the sequence of log matches
y = [1.0 if r.value == FAIL_VALUE else 0.0 for r in records]
x = list(range(1, len(y) + 1))
WINDOW = 50  # adjust for smoothing
y_roll = rolling_mean(y, WINDOW)

plt.figure(figsize=(10, 4))
plt.plot(x, y_roll, linewidth=2)
plt.ylim(-0.05, 1.05)
plt.title(f"Rolling fail rate (window={WINDOW}, FAIL_VALUE={FAIL_VALUE})")
plt.xlabel("risk/fail match index (as encountered in logs)")
plt.ylabel("fail rate")
plt.grid(True, alpha=0.3)
plt.show()


# 2) Per-file fail rate bar chart
paths = [p for (_, _, _, p) in rows]
rates = [rate for (rate, _, _, _) in rows]

TOP_N = 20
paths_top = paths[:TOP_N][::-1]
rates_top = rates[:TOP_N][::-1]

plt.figure(figsize=(10, max(3, 0.35 * len(paths_top))))
plt.barh(paths_top, rates_top)
plt.xlim(0, 1)
plt.title(f"Fail rate by log file (top {min(TOP_N, len(rows))}, FAIL_VALUE={FAIL_VALUE})")
plt.xlabel("fail rate")
plt.tight_layout()
plt.show()
