# Progressive Pretrain → Classification: Results Analysis

训练完成后，从磁盘加载各 run/stage 的结果文件，汇总并绘制：
1. **Accuracy vs Pretrain Stages** — 分类准确率随 pretrain source task 数量的变化
2. **Per-class Metrics** (ALL test) — 5 classes × 3 metrics (precision, recall, f1)
3. **Per-class Metrics** (n_elements ≥ 4) — 仅含组分元素数 ≥ 4 的 test 样本

支持多台 machine 上分别训练各 run，最终将 `runXX/` 目录复制到同一 `output_dir` 下后运行本 notebook。

In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

## 1. 设置路径

将 `OUTPUT_DIR` 指向训练脚本产出的根目录（包含 `run01/`, `run02/`, ... 子目录）。

In [None]:
# === 修改此处指向你的训练输出目录 ===
OUTPUT_DIR = Path("../logs/multi_task_suite/0211_0355")

# 自动从 experiment_records.json 获取 class_names（如果存在）
# 否则手动指定
CLASS_NAMES = ["DAC", "DQC", "IAC", "IQC", "others"]

print(f"Output dir: {OUTPUT_DIR}")
print(f"Exists: {OUTPUT_DIR.exists()}")
if OUTPUT_DIR.exists():
    runs = sorted([d.name for d in OUTPUT_DIR.iterdir() if d.is_dir() and d.name.startswith("run")])
    print(f"Found {len(runs)} run(s): {runs}")

## 2. 加载工具函数

In [None]:
def load_clf_reports_from_disk(root_dir: Path, report_name: str = "clf_report_all.json"):
    """
    扫描 root_dir 下所有 run/stage 的 clf report JSON。
    返回 dict: {run_label: {stage_idx: report_dict}}
    目录结构: root_dir / runXX / pretrain_stageNN_xxx / finetune / material_type / <report_name>
    """
    results = {}
    for p in sorted(root_dir.glob(f"run*/pretrain_stage*/finetune/material_type/{report_name}")):
        parts = p.relative_to(root_dir).parts
        run_label = parts[0]
        stage_str = parts[1]  # e.g. 'pretrain_stage01_efermi'
        stage_idx = int(stage_str.split("_")[1].replace("stage", ""))
        with open(p) as f:
            report = json.load(f)
        results.setdefault(run_label, {})[stage_idx] = report
    return results


def load_metrics_from_disk(root_dir: Path):
    """
    扫描 root_dir 下所有 run/stage 的 metrics.json。
    返回 dict: {run_label: {stage_idx: metrics_dict}}
    """
    results = {}
    for p in sorted(root_dir.glob("run*/pretrain_stage*/finetune/material_type/metrics.json")):
        parts = p.relative_to(root_dir).parts
        run_label = parts[0]
        stage_str = parts[1]
        stage_idx = int(stage_str.split("_")[1].replace("stage", ""))
        with open(p) as f:
            metrics = json.load(f)
        results.setdefault(run_label, {})[stage_idx] = metrics
    return results


def load_experiment_records(root_dir: Path):
    """Load experiment_records.json if available."""
    path = root_dir / "experiment_records.json"
    if path.exists():
        with open(path) as f:
            return json.load(f)
    return None


print("Utility functions defined.")

## 3. 加载数据

In [None]:
# Load metrics
metrics_by_run = load_metrics_from_disk(OUTPUT_DIR)
print(f"Found {len(metrics_by_run)} run(s) with metrics")

# Determine n_stages from data
n_stages = 0
for run_label, stages in metrics_by_run.items():
    n_stages = max(n_stages, max(stages.keys()))
    print(f"  {run_label}: stages {sorted(stages.keys())}")

print(f"\nTotal stages: {n_stages}")

# Build accuracy matrix
accuracy_matrix = []
run_labels = sorted(metrics_by_run.keys())
for run_label in run_labels:
    stages = metrics_by_run[run_label]
    accs = []
    for s in range(1, n_stages + 1):
        if s in stages:
            accs.append(stages[s].get("test_accuracy"))
        else:
            accs.append(None)
    accuracy_matrix.append(accs)

# Load experiment records for task sequence info
experiment_records = load_experiment_records(OUTPUT_DIR)
if experiment_records:
    print(f"\nExperiment records: {len(experiment_records)} run(s)")
else:
    print("\nNo experiment_records.json found (optional).")

## 4. Accuracy vs Pretrain Stages

In [None]:
fig, ax = plt.subplots(figsize=(8, 5), dpi=150)

for i, (run_label, run_accs) in enumerate(zip(run_labels, accuracy_matrix)):
    stages_x = list(range(1, len(run_accs) + 1))
    ax.plot(stages_x, run_accs, "o-", alpha=0.5, label=run_label)

# Mean ± Std
acc_array = np.array([
    [a if a is not None else np.nan for a in row]
    for row in accuracy_matrix
])
mean_acc = np.nanmean(acc_array, axis=0)
std_acc = np.nanstd(acc_array, axis=0)
stages_x = np.arange(1, n_stages + 1)

ax.errorbar(stages_x, mean_acc, yerr=std_acc, fmt="s-", color="black",
            linewidth=2, markersize=8, capsize=4, label="Mean ± Std")

ax.set_xlabel("Number of Pretrain Source Tasks", fontsize=13)
ax.set_ylabel("Test Classification Accuracy", fontsize=13)
ax.set_title("Material Type Classification vs. Progressive Pretraining", fontsize=14)
ax.set_xticks(stages_x)
ax.legend(fontsize=9, loc="best", ncol=max(1, len(run_labels) // 5))
ax.grid(True, alpha=0.3)

fig.savefig(OUTPUT_DIR / "accuracy_vs_stages.png", bbox_inches="tight")
plt.show()

print("\n=== Accuracy Summary ===")
for s in range(n_stages):
    print(f"  Stage {s+1}: mean={mean_acc[s]:.4f} ± {std_acc[s]:.4f}")

## 5. Per-Run Summary

In [None]:
if experiment_records:
    print("=== Per-Run Summary ===")
    for record in experiment_records:
        run = record["run"]
        seq = record["task_sequence"]
        # Get final accuracy from finetune records
        ft_records = record.get("finetune", [])
        final_acc = ft_records[-1].get("test_accuracy") if ft_records else None
        if final_acc is not None:
            print(f"  {run}: order={seq}, final_acc={final_acc:.4f}")
        else:
            print(f"  {run}: order={seq}, no result")
else:
    print("=== Per-Run Summary (from metrics.json) ===")
    for run_label, accs in zip(run_labels, accuracy_matrix):
        final_acc = accs[-1] if accs else None
        if final_acc is not None:
            print(f"  {run_label}: final_acc={final_acc:.4f}")
        else:
            print(f"  {run_label}: no result")

## 6. Per-class Classification Metrics vs Pretrain Stages

从磁盘加载各 run/stage 的 `clf_report_all.json` 和 `clf_report_ge4.json`，  
汇总后绘制 5 classes × 3 metrics 的 subplot（支持不同 machine 上的 run 结果合并）。

In [None]:
def plot_per_class_metrics(
    reports_by_run,
    class_names,
    n_stages,
    title_suffix="",
    save_path=None,
):
    """
    绘制 n_classes rows × 3 cols (precision/recall/f1) 的 subplot。
    reports_by_run: {run_label: {stage_idx: clf_report_dict}}
    """
    metrics_of_interest = ["precision", "recall", "f1-score"]
    n_classes = len(class_names)
    n_metrics = len(metrics_of_interest)
    n_runs = len(reports_by_run)

    # shape [n_runs, n_stages, n_classes, n_metrics]
    arr = np.full((n_runs, n_stages, n_classes, n_metrics), np.nan)

    for r_idx, (run_label, stages_dict) in enumerate(sorted(reports_by_run.items())):
        for s_idx_1based, report in stages_dict.items():
            s_idx = s_idx_1based - 1  # 0-based
            if not report:
                continue
            for c_idx, cname in enumerate(class_names):
                if cname in report:
                    for m_idx, metric in enumerate(metrics_of_interest):
                        arr[r_idx, s_idx, c_idx, m_idx] = report[cname].get(metric, np.nan)

    mean_vals = np.nanmean(arr, axis=0)
    std_vals = np.nanstd(arr, axis=0)
    stages_x = np.arange(1, n_stages + 1)

    fig, axes = plt.subplots(
        n_classes, n_metrics,
        figsize=(4 * n_metrics, 3.2 * n_classes),
        sharex=True, squeeze=False,
    )
    colors = plt.cm.tab10.colors

    for c_idx, cname in enumerate(class_names):
        for m_idx, metric in enumerate(metrics_of_interest):
            ax = axes[c_idx, m_idx]
            mu = mean_vals[:, c_idx, m_idx]
            sd = std_vals[:, c_idx, m_idx]

            ax.plot(stages_x, mu, "o-", color=colors[c_idx], linewidth=1.5, markersize=4)
            ax.fill_between(stages_x, mu - sd, mu + sd, alpha=0.2, color=colors[c_idx])

            ax.set_ylim(-0.05, 1.05)
            ax.set_xticks(stages_x)

            if c_idx == 0:
                ax.set_title(metric.capitalize(), fontsize=12, fontweight="bold")
            if m_idx == 0:
                ax.set_ylabel(cname, fontsize=11, fontweight="bold")
            if c_idx == n_classes - 1:
                ax.set_xlabel("# pretrain source tasks")
            ax.grid(True, alpha=0.3)

    fig.suptitle(
        f"Per-class Metrics vs Progressive Pretrain Stages{title_suffix}",
        fontsize=14, fontweight="bold", y=1.01,
    )
    fig.tight_layout()
    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
        print(f"Saved to {save_path}")
    plt.show()
    return arr


print("plot_per_class_metrics() defined.")

In [None]:
print(f"Loading reports from: {OUTPUT_DIR}\n")

# 1) ALL test
reports_all = load_clf_reports_from_disk(OUTPUT_DIR, "clf_report_all.json")
print(f"Found {len(reports_all)} run(s) with clf_report_all")
arr_all = plot_per_class_metrics(
    reports_all, CLASS_NAMES, n_stages,
    title_suffix=" (ALL test)",
    save_path=OUTPUT_DIR / "per_class_metrics_all_test.png",
)

In [None]:
# 2) n_elements >= 4
reports_ge4 = load_clf_reports_from_disk(OUTPUT_DIR, "clf_report_ge4.json")
print(f"Found {len(reports_ge4)} run(s) with clf_report_ge4")
arr_ge4 = plot_per_class_metrics(
    reports_ge4, CLASS_NAMES, n_stages,
    title_suffix=" (n_elements ≥ 4)",
    save_path=OUTPUT_DIR / "per_class_metrics_ge4_test.png",
)

## 7. 数值总结

In [None]:
metrics_of_interest = ["precision", "recall", "f1-score"]

print("=== Per-class Metrics Summary (ALL test) ===")
for c_idx, cname in enumerate(CLASS_NAMES):
    print(f"\n  {cname}:")
    for m_idx, metric in enumerate(metrics_of_interest):
        vals = arr_all[:, :, c_idx, m_idx]  # [n_runs, n_stages]
        mean_v = np.nanmean(vals, axis=0)
        std_v = np.nanstd(vals, axis=0)
        line = "    " + metric.ljust(12)
        for s in range(n_stages):
            line += f"  S{s+1}: {mean_v[s]:.3f}±{std_v[s]:.3f}"
        print(line)

print("\n\n=== Per-class Metrics Summary (n_elements ≥ 4) ===")
for c_idx, cname in enumerate(CLASS_NAMES):
    print(f"\n  {cname}:")
    for m_idx, metric in enumerate(metrics_of_interest):
        vals = arr_ge4[:, :, c_idx, m_idx]
        mean_v = np.nanmean(vals, axis=0)
        std_v = np.nanstd(vals, axis=0)
        line = "    " + metric.ljust(12)
        for s in range(n_stages):
            line += f"  S{s+1}: {mean_v[s]:.3f}±{std_v[s]:.3f}"
        print(line)