In [13]:
import os
import re
import json
import csv
from collections import defaultdict, Counter

ROOT_DIR = "/scratch/doluk/Compact-Interference-PRAG/PRAG/warmup/output/freezeA/warmup_cot"
OUT_DIR = "metrics_by_rank_epoch"
REQUIRED_PREDICTIONS = 300

os.makedirs(OUT_DIR, exist_ok=True)

# Regex
RANK_ALPHA_RE = re.compile(r"rank=(\d+)_alpha=(\d+)")
LR_EPOCH_RE = re.compile(r"lr=([\d.]+)_epoch=(\d+)")

# -------------------------
# Helpers
# -------------------------

def parse_result_txt(path):
    metrics = {"EM": "", "F1": "", "PREC": "", "RECALL": ""}
    with open(path) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) < 2:
                continue
            key = parts[0].lower()
            if key == "em":
                metrics["EM"] = parts[1]
            elif key == "f1":
                metrics["F1"] = parts[1]
            elif key == "prec":
                metrics["PREC"] = parts[1]
            elif key == "recall":
                metrics["RECALL"] = parts[1]
    return metrics

def valid_predictions(path):
    try:
        with open(path) as f:
            data = json.load(f)
        if isinstance(data, list) and len(data) == REQUIRED_PREDICTIONS:
            return True
        else:
            print(f"[WARN] Unfinished prediction for {path}")
            return False
    except Exception:
        return False

def infer_subset(parts):
    """
    Dataset-agnostic subset inference.
    Handles 2Wiki compound subsets.
    """
    for p in reversed(parts):
        name = p.lower()
        if "bridge_comparison" in name:
            return "bridge_comp"
        if "compositional" in name:
            return "compositional"
        if "comparison" in name:
            return "comparison"
        if "inference" in name:
            return "inference"
        if name == "total":
            return "total"
        if "bridge" in name:
            return "bridge"
    return "total"

def infer_method(parts):
    for p in parts:
        if p in {"prag", "combine"}:
            return p
    return ""

# -------------------------
# Collect experiments
# -------------------------

groups = defaultdict(list)

for root, _, files in os.walk(ROOT_DIR):
    if "result.txt" not in files:
        continue

    result_path = os.path.join(root, "result.txt")
    predict_path = os.path.join(root, "predict.json")

    parts = result_path.split(os.sep)

    metrics = parse_result_txt(result_path)
    pred_ok = valid_predictions(predict_path)

    dataset = rank = epoch = None

    for i, p in enumerate(parts):
        m = RANK_ALPHA_RE.search(p)
        if m:
            rank = int(m.group(1))

        m = LR_EPOCH_RE.search(p)
        if m:
            epoch = int(m.group(2))

        if RANK_ALPHA_RE.match(p) and i + 1 < len(parts):
            dataset = parts[i + 1]

    if rank is None or epoch is None or dataset is None:
        continue

    subset = infer_subset(parts)
    method = infer_method(parts)

    rel_result_path = os.path.relpath(result_path, ROOT_DIR)

    groups[(rank, epoch)].append({
        "label": f"{dataset}/{subset}/{method}",
        "result_path": rel_result_path,
        "metrics": metrics,
        "pred_check": "" if pred_ok else "!",
    })

# -------------------------
# Write one CSV per group
# -------------------------

for (rank, epoch), experiments in sorted(groups.items()):

    # -------- ASSERTION: labels must be unique --------
    labels = [e["label"] for e in experiments]
    dupes = [lbl for lbl, c in Counter(labels).items() if c > 1]

    assert not dupes, (
        f"[ERROR] Duplicate labels in rank={rank}, epoch={epoch}: {dupes}\n"
        f"Check subset inference or directory structure."
    )
    # -------------------------------------------------

    out_csv = os.path.join(OUT_DIR, f"metrics_r{rank}_eps{epoch}.csv")

    with open(out_csv, "w", newline="") as f:
        writer = csv.writer(f)

        # Header row 1: logical labels
        writer.writerow([""] + [e["label"] for e in experiments])

        # Header row 2: exact result paths
        # writer.writerow(["result_path"] + [e["result_path"] for e in experiments])

        # Metric rows
        for metric in ["EM", "F1", "PREC", "RECALL"]:
            writer.writerow(
                [metric] +
                [
                    e["metrics"][metric] + e["pred_check"]
                    for e in experiments
                ]
            )

    print(f"Saved → {out_csv}")


[WARN] Unfinished prediction for /scratch/doluk/Compact-Interference-PRAG/PRAG/warmup/output/freezeA/warmup_cot/qwen2.5-1.5b-instruct/rank=2_alpha=32/popqa/lr=0.0003_epoch=1_direct/aug_model=qwen2.5-1.5b-instruct/prag/total/predict.json
Saved → metrics_by_rank_epoch/metrics_r2_eps1.csv
Saved → metrics_by_rank_epoch/metrics_r2_eps2.csv
Saved → metrics_by_rank_epoch/metrics_r2_eps4.csv
Saved → metrics_by_rank_epoch/metrics_r4_eps2.csv
Saved → metrics_by_rank_epoch/metrics_r4_eps4.csv
Saved → metrics_by_rank_epoch/metrics_r8_eps2.csv
Saved → metrics_by_rank_epoch/metrics_r8_eps4.csv
Saved → metrics_by_rank_epoch/metrics_r16_eps2.csv
Saved → metrics_by_rank_epoch/metrics_r16_eps4.csv


In [12]:
groups[(4,4)]

[{'label': 'hotpotqa/comparison/prag',
  'result_path': 'qwen2.5-1.5b-instruct/rank=4_alpha=32/hotpotqa/lr=0.0003_epoch=4_direct/aug_model=qwen2.5-1.5b-instruct/prag/comparison/result.txt',
  'metrics': {'EM': '0.4167',
   'F1': '0.5145',
   'PREC': '0.5367',
   'RECALL': '0.5219'},
  'pred_check': ''},
 {'label': 'hotpotqa/total/prag',
  'result_path': 'qwen2.5-1.5b-instruct/rank=4_alpha=32/hotpotqa/lr=0.0003_epoch=4_direct/aug_model=qwen2.5-1.5b-instruct/prag/bridge/result.txt',
  'metrics': {'EM': '0.0933',
   'F1': '0.1511',
   'PREC': '0.1602',
   'RECALL': '0.151'},
  'pred_check': ''},
 {'label': 'hotpotqa/comparison/combine',
  'result_path': 'qwen2.5-1.5b-instruct/rank=4_alpha=32/hotpotqa/lr=0.0003_epoch=4_direct/aug_model=qwen2.5-1.5b-instruct/combine/comparison/result.txt',
  'metrics': {'EM': '0.44',
   'F1': '0.5313',
   'PREC': '0.5616',
   'RECALL': '0.5266'},
  'pred_check': ''},
 {'label': 'hotpotqa/total/combine',
  'result_path': 'qwen2.5-1.5b-instruct/rank=4_alpha=3