In [None]:
import os
import re
import pandas as pd
import shutil

# === Helper functions ===

def parse_test_envs(folder_name, file_text):
    """Extract test environment indices from folder name (e.g. '_T12' -> [1, 2])"""
    match = re.search(r"_T(\d+)", folder_name)
    if match:
        digits = list(match.group(1))
        return [int(d) for d in digits]
    # Fallback if missing in name: try reading from out.txt header
    m2 = re.search(r"test_envs:\s*\[([0-9,\s]+)\]", file_text)
    if m2:
        return [int(x.strip()) for x in m2.group(1).split(",")]
    return []

def parse_table(file_text):
    """Parse the main results table from out.txt, robust to DANN spacing and header variants."""
    # Find all candidate headers
    matches = list(re.finditer(r"^(disc_loss|env0_in_acc)\s+", file_text, re.MULTILINE))
    if not matches:
        raise ValueError("Could not locate results table start (env0_in_acc or disc_loss)")

    # Use the last one (main training section)
    start_idx = matches[-1].start()
    table_text = file_text[start_idx:].strip()
    lines = [ln.strip() for ln in table_text.splitlines() if ln.strip()]

    # Handle case where the first header is env0_in_acc but next is disc_loss
    if lines[0].startswith("env0_in_acc") and len(lines) > 1 and lines[1].startswith("disc_loss"):
        lines = lines[1:]

    header = re.split(r"\s{2,}|\t+", lines[0].strip())
    data_lines = []
    for line in lines[1:]:
        if line.startswith("Updated state") or line.startswith("Restarting"):
            break
        parts = re.split(r"\s{2,}|\t+", line.strip())
        if len(parts) == len(header):
            data_lines.append(parts)

    if not data_lines:
        raise ValueError("No valid rows found in results table")

    df = pd.DataFrame(data_lines, columns=header)
    # Convert numeric columns
    df = df.apply(pd.to_numeric, errors="ignore")
    return df

def get_best_checkpoint(root_dir):
    """Traverse experiment folders, read out.txt, and find best model per test env."""
    results = []

    for folder in sorted(os.listdir(root_dir)):
        folder_path = os.path.join(root_dir, folder)
        out_file = os.path.join(folder_path, "out.txt")
        if not os.path.isfile(out_file):
            continue

        try:
            with open(out_file, "r", encoding="utf-8", errors="ignore") as f:
                text = f.read()
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to read {folder}: {e}")
            continue

        try:
            df = parse_table(text)
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to parse table in {folder}: {e}")
            continue

        test_envs = parse_test_envs(folder, text)
        if not test_envs:
            print(f"‚ö†Ô∏è No test envs found in {folder}")
            continue

        env_accs = [f"env{t}_out_acc" for t in test_envs if f"env{t}_out_acc" in df.columns]
        if not env_accs:
            print(f"‚ö†Ô∏è No env*_out_acc columns found in {folder}")
            continue

        # Ensure numeric
        df[env_accs] = df[env_accs].apply(pd.to_numeric, errors="coerce")
        df["weighted_acc"] = df[env_accs].mean(axis=1)

        best_idx = df["weighted_acc"].idxmax()
        best_row = df.loc[best_idx]

        best_step = str(int(best_row["step"])) if "step" in best_row else "N/A"

        results.append({
            "folder": folder,
            "folder_path": folder_path,
            "best_checkpoint": best_step,
            "weighted_acc": best_row["weighted_acc"],
            **{env_acc: best_row[env_acc] for env_acc in env_accs},
        })

    return results


def copy_best_checkpoints(results, root_dir):
    """Copy best checkpoint files into Oracle Selection folder."""
    oracle_dir = os.path.join(root_dir, "Oracle Selection")
    os.makedirs(oracle_dir, exist_ok=True)

    for r in results:
        src_folder = r["folder_path"]
        dst_folder = os.path.join(oracle_dir, r["folder"])
        os.makedirs(dst_folder, exist_ok=True)

        step = r["best_checkpoint"]
        if step == "N/A":
            print(f"‚ö†Ô∏è Skipping {r['folder']} (no valid checkpoint step)")
            continue

        checkpoint_name = f"model_step{step}.pkl"
        src = os.path.join(src_folder, checkpoint_name)
        dst = os.path.join(dst_folder, checkpoint_name)

        if os.path.exists(src):
            shutil.copy2(src, dst)
            print(f"‚úÖ Copied {checkpoint_name} ‚Üí {dst_folder}")
        else:
            print(f"‚ö†Ô∏è No checkpoint found for step {step} in {r['folder']} (expected {checkpoint_name})")


if __name__ == "__main__":
    root_dir = r"D:\SPROJ_ABLATE_PACS"  # üîÅ change this to your experiments root
    results = get_best_checkpoint(root_dir)

    print("\n=== Oracle Model Selection Results ===\n")
    for r in results:
        print(f"üìÅ {r['folder']}")
        print(f"  ‚Üí Best checkpoint: {r['best_checkpoint']}")
        print(f"  ‚Üí Weighted Acc: {r['weighted_acc']:.4f}")
        for k, v in r.items():
            if k.startswith("env") and k.endswith("_out_acc"):
                print(f"     {k}: {v:.4f}")
        print()

    print("üì¶ Copying best checkpoints to 'Oracle Selection'...\n")
    copy_best_checkpoints(results, root_dir)
    print("\n‚úÖ Done! All best checkpoints collected under 'Oracle Selection'.")
