In [6]:
import os
from pathlib import Path
import json
import numpy as np
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt


def find_run_summaries(base_dir):
    return list(Path(base_dir).rglob("**/run_summary.json"))


def extract_metrics(summary_path, metrics_mode="standard", additional_config_keys=[]):
    with open(summary_path, "r") as f:
        d = json.load(f)
    res = {}
    subset_size = None
    # old: at root; new: sometimes nested under "config"
    config = d.get("config", d)
    subset_size = config.get("subset_size", None)
    if subset_size is not None:
        subset_size = float(subset_size)
    # Collect metrics per main.py convention
    res["subset_size"] = subset_size
    if metrics_mode == "test_model":
        tmr = d.get("test_model_result", {}) or {}
        res["acc"] = tmr.get("test_acc", None)
        res["fpr"] = tmr.get("test_fpr", None)
        res["fnr"] = tmr.get("test_fnr", None)
        res["f1"] = tmr.get("test_f1", None)
    else:
        res["acc"] = d.get("test_acc", None)
        res["fpr"] = d.get("test_fpr", None)
        res["fnr"] = d.get("test_fnr", None)
        res["f1"] = d.get("test_f1", None)
    res["category_distribution"] = None
    res["label_distribution"] = None
    for key in additional_config_keys:
        keys = key.split("->")
        temp = config
        for k in keys:
            v = temp.get(k, None)
            if v is None:
                break
            temp = v
        res[key] = temp
    dataset_summary = d.get("dataset_summary", None)
    if dataset_summary:
        cat_dist = dataset_summary.get("category_distribution", None)
        if cat_dist:
            res["category_distribution"] = cat_dist
        label_dist = dataset_summary.get("label_distribution", None)
        if label_dist:
            res["label_distribution"] = label_dist
    return res


def merge_nested_dicts(list_of_dicts):
    # Merge a list of dicts with same keys, summing values for each key
    merged = {}
    for d in list_of_dicts:
        for k, v in d.items():
            merged.setdefault(k, [])
            merged[k].append(v)
    # now, for each key, make sure all lists are the same length, pad with zeros if needed
    all_keys = list(merged)
    maxlen = max(len(v) for v in merged.values())
    for k in all_keys:
        n = len(merged[k])
        if n < maxlen:
            merged[k] += [0] * (maxlen - n)
    return merged


def plot_metrics(metrics, outdir, metrics_mode="standard"):
    subset_sizes = sorted(
        set(m["subset_size"] for m in metrics if m["subset_size"] is not None)
    )
    # Group metrics by subset size
    metrics_by_subset = {sz: [] for sz in subset_sizes}
    for m in metrics:
        if m["subset_size"] is not None:
            metrics_by_subset[m["subset_size"]].append(m)
    avgs = {k: [] for k in ["acc", "fpr", "fnr", "f1"]}
    stds = {k: [] for k in ["acc", "fpr", "fnr", "f1"]}
    for sz in subset_sizes:
        ms = metrics_by_subset[sz]
        for k in avgs:
            vals = [m[k] for m in ms if m[k] is not None]
            avgs[k].append(np.nanmean(vals) if vals else np.nan)
            stds[k].append(np.nanstd(vals, ddof=1) if vals else np.nan)
    plt.figure(figsize=(8, 5), dpi=150)
    for k, label, marker in zip(
        ["acc", "fpr", "fnr", "f1"],
        ["Accuracy", "FPR", "FNR", "F1"],
        ["o", "s", "^", "d"],
    ):
        plt.errorbar(
            subset_sizes, avgs[k], yerr=stds[k], label=label, marker=marker, capsize=3
        )
    plt.xlabel("Subset size")
    plt.ylabel("Metric value")
    mode_label = "Standard" if metrics_mode == "standard" else "Test-Model"
    plt.title(f"Average evaluation metrics vs subset size ({mode_label})")
    plt.xlim(min(subset_sizes), max(subset_sizes))
    plt.xscale("log")
    plt.ylim(0.0, 1.0)
    plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)
    plt.legend()
    plt.tight_layout()
    os.makedirs(outdir, exist_ok=True)
    suffix = "standard" if metrics_mode == "standard" else "test_model"
    outpath = os.path.join(outdir, f"metrics_vs_subset_{suffix}.png")
    plt.show()
    plt.savefig(outpath)
    plt.close()
    print(f"Saved metrics plot to {outpath}")


def plot_metric_by_config(metrics, outdir, config_key="num_phi"):
    """
    Plot F1 score vs subset size, one line per config_key value (e.g., num_phi).
    Error bars denote std over seeds for each (config_key, subset_size).
    Args:
        metrics: List of dicts representing each run's metrics and config
        outdir: Path to save plots
        config_key: The config key to group by (e.g., "num_phi")
    """
    # Group: (config_val) -> {subset_size: list of f1s over seeds}
    groups = dict()
    for m in metrics:
        # Get config value: must look inside config if available, fallback to top-level.
        config_val = None
        if config_key not in m:
            print(f"Config key {config_key} not in {m}")
            continue
        config_val = m[config_key]
        subset_size = m.get("subset_size", None)
        f1 = m.get("f1", None)
        if subset_size is None or f1 is None:
            continue
        if config_val not in groups:
            groups[config_val] = dict()
        if subset_size not in groups[config_val]:
            groups[config_val][subset_size] = []
        groups[config_val][subset_size].append(f1)

    # Sort config values numerically if possible
    try:
        config_vals = sorted(groups.keys(), key=lambda v: float(v))
    except Exception:
        config_vals = sorted(groups.keys(), key=lambda v: str(v))

    # Collect all subset sizes and sort
    all_subset_sizes = set()
    for group in groups.values():
        all_subset_sizes.update(group.keys())
    all_subset_sizes = sorted(all_subset_sizes)

    print(f"All subset sizes: {all_subset_sizes}")
    print(f"Config values: {config_vals}")

    # Prepare for plotting: For each config_val, collect mean/std f1 for each subset_size
    plt.figure(figsize=(8, 6), dpi=150)
    colors = plt.cm.viridis(np.linspace(0, 1, len(config_vals)))
    for color, conf_val in zip(colors, config_vals):
        means, stds = [], []
        ss_ordered = []
        for sz in all_subset_sizes:
            vals = groups[conf_val].get(sz, [])
            if vals:
                means.append(np.nanmean(vals))
                stds.append(np.nanstd(vals, ddof=1) if len(vals) > 1 else 0.0)
            else:
                means.append(np.nan)
                stds.append(np.nan)
            ss_ordered.append(sz)
        plt.errorbar(
            ss_ordered,
            means,
            yerr=stds,
            label=f"{config_key}={conf_val}",
            marker="o",
            capsize=3,
            color=color,
        )
    plt.xlabel("Subset size")
    plt.ylabel("F1 score")
    plt.title(f"F1 vs Subset Size (per {config_key})")
    plt.xlim(min(all_subset_sizes), max(all_subset_sizes))
    plt.xscale("log")
    plt.ylim(0.0, 1.0)
    plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)
    plt.legend(title=config_key)
    plt.tight_layout()
    os.makedirs(outdir, exist_ok=True)
    outpath = os.path.join(outdir, f"f1_vs_subset_by_{config_key}.png")
    plt.savefig(outpath)
    plt.close()
    print(f"Saved F1 vs subset size plot by {config_key} to {outpath}")


def plot_category_label_distribution(metrics, outdir):
    # We want to show category and label counts for each subset size, averaged and error-barred,
    # Group per subset size, then over random seeds, average.
    subset_sizes = sorted(
        set(m["subset_size"] for m in metrics if m["subset_size"] is not None)
    )
    metrics_by_subset = {sz: [] for sz in subset_sizes}
    for m in metrics:
        if m["subset_size"] is not None:
            metrics_by_subset[m["subset_size"]].append(m)

    # Get all possible categories and labels
    all_categories = set()
    all_labels = set()
    for m in metrics:
        cat_dist = m.get("category_distribution", None)
        if cat_dist:
            all_categories.update(cat_dist.keys())
        label_dist = m.get("label_distribution", None)
        if label_dist:
            all_labels.update(label_dist.keys())
    all_categories = sorted(list(all_categories))
    all_labels = sorted(list(all_labels))

    # Category distribution
    cat_mean = {sz: [] for sz in subset_sizes}
    cat_std = {sz: [] for sz in subset_sizes}
    for sz in subset_sizes:
        per_seed_counts = []
        for m in metrics_by_subset[sz]:
            cat_dist = m.get("category_distribution", {})
            # Ensure all categories present
            counts = [cat_dist.get(cat, 0) for cat in all_categories]
            per_seed_counts.append(counts)
        arr = np.array(per_seed_counts, dtype=float)
        cat_mean[sz] = (
            arr.mean(axis=0) if len(arr) > 0 else np.zeros(len(all_categories))
        )
        cat_std[sz] = (
            arr.std(axis=0, ddof=1) if len(arr) > 1 else np.zeros(len(all_categories))
        )
    # Plot as grouped barplot with error bars, one panel per subset size (or one panel w/ grouping)
    fig, ax = plt.subplots(figsize=(max(10, len(all_categories) * 0.7), 6))
    bar_width = 0.8 / len(subset_sizes)
    indices = np.arange(len(all_categories))
    colors = plt.cm.viridis(np.linspace(0, 1, len(subset_sizes)))
    for i, sz in enumerate(subset_sizes):
        offset = (i - len(subset_sizes) / 2) * bar_width + bar_width / 2
        ax.bar(
            indices + offset,
            cat_mean[sz],
            bar_width,
            yerr=cat_std[sz],
            label=f"{sz:.1f}",
            capsize=3,
            color=colors[i],
        )
    ax.set_xticks(indices)
    ax.set_xticklabels(all_categories, rotation=45, ha="right")
    ax.set_ylabel("Avg count per subset (± std)")
    ax.set_title("Category distribution per subset size")
    ax.legend(title="Subset size")
    plt.tight_layout()
    outpath = os.path.join(outdir, "category_distribution_vs_subset.png")
    plt.savefig(outpath)
    plt.close()
    print(f"Saved category distribution plot to {outpath}")

    # Label distribution
    label_mean = {sz: [] for sz in subset_sizes}
    label_std = {sz: [] for sz in subset_sizes}
    for sz in subset_sizes:
        per_seed_counts = []
        for m in metrics_by_subset[sz]:
            label_dist = m.get("label_distribution", {})
            counts = [label_dist.get(label, 0) for label in all_labels]
            per_seed_counts.append(counts)
        arr = np.array(per_seed_counts, dtype=float)
        label_mean[sz] = arr.mean(axis=0) if len(arr) > 0 else np.zeros(len(all_labels))
        label_std[sz] = (
            arr.std(axis=0, ddof=1) if len(arr) > 1 else np.zeros(len(all_labels))
        )
    fig, ax = plt.subplots(figsize=(max(7, len(all_labels) * 1.5), 4))
    bar_width = 0.8 / len(subset_sizes)
    indices = np.arange(len(all_labels))
    for i, sz in enumerate(subset_sizes):
        offset = (i - len(subset_sizes) / 2) * bar_width + bar_width / 2
        ax.bar(
            indices + offset,
            label_mean[sz],
            bar_width,
            yerr=label_std[sz],
            label=f"{sz:.1f}",
            capsize=3,
            color=colors[i],
        )
    ax.set_xticks(indices)
    ax.set_xticklabels(all_labels, rotation=0, ha="center")
    ax.set_ylabel("Avg label count (± std)")
    ax.set_title("Label distribution per subset size")
    ax.legend(title="Subset size")
    plt.tight_layout()
    outpath = os.path.join(outdir, "label_distribution_vs_subset.png")
    plt.savefig(outpath)
    plt.close()
    print(f"Saved label distribution plot to {outpath}")

In [28]:
base_dir = [
    "/scratch/gpfs/KOROLOVA/cl6486/SafetyPolytope/multirun/beaver_tails/2025-10-21/polytope_hb-16-08-57",
    "/scratch/gpfs/KOROLOVA/cl6486/SafetyPolytope/multirun/beaver_tails/2025-10-21/polytope_hb-17-30-46"
]
metrics_mode = "standard"
outdir = "robustness_plots"
run_summaries = []
additional_config_keys = ["dataset->num_phi"]
for bd in base_dir:
    run_summaries.extend(find_run_summaries(bd))
print(f"Found {len(run_summaries)} run summary files.")
metrics = [
    extract_metrics(path, metrics_mode, additional_config_keys)
    for path in run_summaries
]
print(metrics)
# plot_metrics(metrics, outdir, metrics_mode)
plot_metric_by_config(metrics, outdir, config_key="dataset->num_phi")

Found 100 run summary files.
[{'subset_size': 0.05, 'acc': 0.7916516948137502, 'fpr': 0.25480229011718125, 'fnr': 0.14931665193445298, 'f1': 0.7824265165728579, 'category_distribution': {'animal_abuse': 185, 'child_abuse': 88, 'controversial_topics,politics': 452, 'discrimination,stereotype,injustice': 1153, 'drug_abuse,weapons,banned_substance': 806, 'financial_crime,property_crime,theft': 1424, 'hate_speech,offensive_language': 624, 'misinformation_regarding_ethics,laws_and_safety': 149, 'non_violent_unethical_behavior': 1204, 'privacy_violation': 569, 'safety,_ethics,_and_legality': 6693, 'self_harm': 77, 'sexually_explicit,adult_content': 226, 'terrorism,organized_crime': 52, 'violence,aiding_and_abetting,incitement': 1326}, 'label_distribution': {'0': 8335, '1': 6693}, 'dataset->num_phi': 50}, {'subset_size': 0.005, 'acc': 0.6866990058689664, 'fpr': 0.07844186419819145, 'fnr': 0.6117495070374651, 'f1': 0.5218662889000594, 'category_distribution': {'animal_abuse': 16, 'child_abuse'

In [4]:
# ministral-8b increasing subset size
base_dir = [
    "/scratch/gpfs/KOROLOVA/cl6486/SafetyPolytope/multirun/beaver_tails/2025-10-22/polytope_hb-18-25-05"
]

metrics_mode = "standard"
out_dir = "robustness_plots"
run_summaries = find_run_summaries(base_dir[0])
metrics = [extract_metrics(path, metrics_mode) for path in run_summaries]
plot_metrics(metrics, out_dir, metrics_mode)

Saved metrics plot to robustness_plots/metrics_vs_subset_standard.png


In [7]:
# qwen3-4b increasing subset size
base_dir = [
    "/scratch/gpfs/KOROLOVA/cl6486/SafetyPolytope/multirun/beaver_tails/2025-10-22/polytope_hb-19-00-27"
]
metrics_mode = "standard"
out_dir = "robustness_plots"
run_summaries = find_run_summaries(base_dir[0])
metrics = [extract_metrics(path, metrics_mode) for path in run_summaries]
plot_metrics(metrics, out_dir, metrics_mode)

Saved metrics plot to robustness_plots/metrics_vs_subset_standard.png
