# Plot Insights and Cross-Model Comparisons

This notebook loads experiment result CSVs and produces publication-ready figures:
- Single-model figures for attack and mitigation analyses
- Cross-model comparisons between two result CSVs

Figures are also saved to disk under `figures/{MODEL_NAME}` and `figures/compare/{A}_vs_{B}`.


In [1]:
import os
from pathlib import Path
from typing import List

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_context('talk')
plt.rcParams['figure.dpi'] = 120


def prepare_output_dir(output_dir: Path) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)


def load_results(csv_path: Path) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    numeric_cols = [
        'overall_mae', 'mae_gt_zero', 'mae_gt_positive',
        'overall_jer', 'jer_gt_zero', 'jer_gt_positive',
        'delta_overall_mae', 'delta_mae_gt_zero', 'delta_mae_gt_positive',
        'delta_overall_jer', 'delta_jer_gt_zero', 'delta_jer_gt_positive',
    ]
    for c in numeric_cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors='coerce')
    return df


def savefig(fig: plt.Figure, outpath: Path) -> None:
    fig.tight_layout()
    fig.savefig(outpath, dpi=200, bbox_inches='tight')
    plt.close(fig)


def bar_by_mitigation(df: pd.DataFrame, metric: str, output_dir: Path, title_suffix: str = "") -> None:
    fig, ax = plt.subplots(figsize=(8, 5))
    sns.barplot(
        data=df,
        x='mitigation_type', y=metric,
        estimator=np.mean, ci=95, capsize=0.1, errwidth=1.2,
        order=['none', 'user_prompt_hardening', 'system_prompt_hardening', 'few_shot'],
        ax=ax,
    )
    ax.set_title(f"{metric} by Mitigation {title_suffix}")
    ax.set_xlabel("Mitigation")
    ax.set_ylabel(metric)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=20, ha='right')
    savefig(fig, output_dir / f"bar_mitigation_{metric}.png")


def bar_by_attack(df: pd.DataFrame, metric: str, output_dir: Path, title_suffix: str = "") -> None:
    fig, ax = plt.subplots(figsize=(7, 5))
    sns.barplot(
        data=df,
        x='attack_type', y=metric,
        estimator=np.mean, ci=95, capsize=0.1, errwidth=1.2,
        order=['none', 'prepend', 'append', 'scatter'],
        ax=ax,
    )
    ax.set_title(f"{metric} by Attack {title_suffix}")
    ax.set_xlabel("Attack")
    ax.set_ylabel(metric)
    savefig(fig, output_dir / f"bar_attack_{metric}.png")


def bar_attack_by_mitigation(df: pd.DataFrame, metric: str, output_dir: Path, title_suffix: str = "") -> None:
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.barplot(
        data=df,
        x='attack_type', y=metric, hue='mitigation_type',
        estimator=np.mean, ci=95, capsize=0.08, errwidth=1.0,
        order=['none', 'prepend', 'append', 'scatter'],
        hue_order=['none', 'user_prompt_hardening', 'system_prompt_hardening', 'few_shot'],
        ax=ax,
    )
    ax.set_title(f"{metric} by Attack × Mitigation {title_suffix}")
    ax.set_xlabel("Attack")
    ax.set_ylabel(metric)
    ax.legend(title="Mitigation", bbox_to_anchor=(1.02, 1), loc='upper left')
    savefig(fig, output_dir / f"bar_attack_by_mitigation_{metric}.png")


def heatmap_attack_mitigation(df: pd.DataFrame, metric: str, output_dir: Path, title_suffix: str = "") -> None:
    pivot = df.pivot_table(
        index='attack_type', columns='mitigation_type', values=metric,
        aggfunc='mean'
    ).reindex(index=['none', 'prepend', 'append', 'scatter'],
              columns=['none', 'user_prompt_hardening', 'system_prompt_hardening', 'few_shot'])
    fig, ax = plt.subplots(figsize=(7.5, 5))
    sns.heatmap(pivot, annot=True, fmt='.3g', cmap='viridis', ax=ax)
    ax.set_title(f"{metric} Heatmap (Attack × Mitigation) {title_suffix}")
    ax.set_xlabel("Mitigation")
    ax.set_ylabel("Attack")
    savefig(fig, output_dir / f"heatmap_attack_mitigation_{metric}.png")


def facet_by_prompt(df: pd.DataFrame, metric: str, output_dir: Path, title_suffix: str = "") -> None:
    g = sns.catplot(
        data=df,
        x='attack_type', y=metric, hue='mitigation_type', col='prompt_type',
        estimator=np.mean, ci=95, capsize=0.05, errwidth=1.0,
        kind='bar', height=4, aspect=1,
        order=['none', 'prepend', 'append', 'scatter'],
        hue_order=['none', 'user_prompt_hardening', 'system_prompt_hardening', 'few_shot'],
    )
    g.set_titles("{col_name}")
    g.fig.suptitle(f"{metric} by Attack × Mitigation × Prompt {title_suffix}", y=1.05)
    for ax in g.axes.flat:
        ax.set_xlabel("Attack")
        ax.set_ylabel(metric)
    out = output_dir / f"facet_prompt_{metric}.png"
    g.savefig(out, dpi=200, bbox_inches='tight')
    plt.close(g.fig)


def distribution_by_mitigation(df: pd.DataFrame, metric: str, output_dir: Path, title_suffix: str = "") -> None:
    fig, ax = plt.subplots(figsize=(9, 5))
    sns.violinplot(data=df, x='mitigation_type', y=metric,
                   order=['none', 'user_prompt_hardening', 'system_prompt_hardening', 'few_shot'], ax=ax, cut=0)
    sns.swarmplot(data=df, x='mitigation_type', y=metric,
                  order=['none', 'user_prompt_hardening', 'system_prompt_hardening', 'few_shot'],
                  ax=ax, color='k', size=2, alpha=0.4)
    ax.set_title(f"Distribution of {metric} by Mitigation {title_suffix}")
    ax.set_xlabel("Mitigation")
    ax.set_ylabel(metric)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=20, ha='right')
    savefig(fig, output_dir / f"distribution_mitigation_{metric}.png")


def baseline_vs_configs(df: pd.DataFrame, metric: str, output_dir: Path, title_suffix: str = "") -> None:
    base_mask = (df['prompt_type'] == 'BASIC') & (df['attack_type'] == 'none') & (df['mitigation_type'] == 'none')
    baseline_value = df.loc[base_mask, metric].mean()
    temp = df.copy()
    temp['delta_vs_baseline'] = temp[metric] - baseline_value

    fig, ax = plt.subplots(figsize=(10, 6))
    sns.barplot(
        data=temp,
        x='attack_type', y='delta_vs_baseline', hue='mitigation_type',
        estimator=np.mean, ci=95, capsize=0.08, errwidth=1.0,
        order=['none', 'prepend', 'append', 'scatter'],
        hue_order=['none', 'user_prompt_hardening', 'system_prompt_hardening', 'few_shot'],
        ax=ax,
    )
    ax.axhline(0, color='gray', linewidth=1)
    ax.set_title(f"Δ {metric} vs Baseline (BASIC/none/none) {title_suffix}")
    ax.set_xlabel("Attack")
    ax.set_ylabel(f"Δ {metric}")
    ax.legend(title="Mitigation", bbox_to_anchor=(1.02, 1), loc='upper left')
    savefig(fig, output_dir / f"delta_vs_baseline_{metric}.png")


In [2]:
# Configuration
MODEL_A = os.getenv('MODEL_NAME', 'qwen3_0.6b')
MODEL_B = os.getenv('MODEL_NAME_COMPARE', 'gemma3_1b')

CSV_A = Path(f'{MODEL_A}_results.csv')
CSV_B = Path(f'{MODEL_B}_results.csv')

OUT_DIR_A = Path('figures') / MODEL_A
OUT_DIR_B = Path('figures') / MODEL_B
OUT_DIR_COMPARE = Path('figures') / 'compare' / f'{MODEL_A}_vs_{MODEL_B}'

prepare_output_dir(OUT_DIR_A)
prepare_output_dir(OUT_DIR_B)
prepare_output_dir(OUT_DIR_COMPARE)

print('Primary model:', MODEL_A)
print('Compare model:', MODEL_B)
print('CSV A:', CSV_A)
print('CSV B:', CSV_B)
print('Output A:', OUT_DIR_A)
print('Output B:', OUT_DIR_B)
print('Compare out:', OUT_DIR_COMPARE)


Primary model: qwen3_0.6b
Compare model: gemma3_1b
CSV A: qwen3_0.6b_results.csv
CSV B: gemma3_1b_results.csv
Output A: figures\qwen3_0.6b
Output B: figures\gemma3_1b
Compare out: figures\compare\qwen3_0.6b_vs_gemma3_1b


In [3]:
# Load data
assert CSV_A.exists(), f"Missing {CSV_A}"
assert CSV_B.exists(), f"Missing {CSV_B}"

df_a = load_results(CSV_A)
df_b = load_results(CSV_B)

print(f"Loaded {len(df_a)} rows for {MODEL_A}")
print(f"Loaded {len(df_b)} rows for {MODEL_B}")


Loaded 48 rows for qwen3_0.6b
Loaded 48 rows for gemma3_1b


In [4]:
# Single-model figures for MODEL_A
metrics = ['overall_mae', 'overall_jer', 'mae_gt_zero', 'mae_gt_positive', 'jer_gt_zero', 'jer_gt_positive']

for metric in metrics:
    bar_by_mitigation(df_a, metric, OUT_DIR_A, f'[{MODEL_A}]')
    bar_by_attack(df_a, metric, OUT_DIR_A, f'[{MODEL_A}]')
    bar_attack_by_mitigation(df_a, metric, OUT_DIR_A, f'[{MODEL_A}]')
    heatmap_attack_mitigation(df_a, metric, OUT_DIR_A, f'[{MODEL_A}]')
    facet_by_prompt(df_a, metric, OUT_DIR_A, f'[{MODEL_A}]')
    distribution_by_mitigation(df_a, metric, OUT_DIR_A, f'[{MODEL_A}]')
    baseline_vs_configs(df_a, metric, OUT_DIR_A, f'[{MODEL_A}]')

print('Saved single-model figures to', OUT_DIR_A)



The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  sns.barplot(

The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 1.2}` instead.

  sns.barplot(
  ax.set_xticklabels(ax.get_xticklabels(), rotation=20, ha='right')

The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  sns.barplot(

The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 1.2}` instead.

  sns.barplot(

The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  sns.barplot(

The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 1.0}` instead.

  sns.barplot(

The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  g = sns.catplot(

The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 1.0}` instead.

  g = sns.catplot(
  ax

Saved single-model figures to figures\qwen3_0.6b


In [5]:
# Single-model figures for MODEL_B
for metric in metrics:
    bar_by_mitigation(df_b, metric, OUT_DIR_B, f'[{MODEL_B}]')
    bar_by_attack(df_b, metric, OUT_DIR_B, f'[{MODEL_B}]')
    bar_attack_by_mitigation(df_b, metric, OUT_DIR_B, f'[{MODEL_B}]')
    heatmap_attack_mitigation(df_b, metric, OUT_DIR_B, f'[{MODEL_B}]')
    facet_by_prompt(df_b, metric, OUT_DIR_B, f'[{MODEL_B}]')
    distribution_by_mitigation(df_b, metric, OUT_DIR_B, f'[{MODEL_B}]')
    baseline_vs_configs(df_b, metric, OUT_DIR_B, f'[{MODEL_B}]')

print('Saved single-model figures to', OUT_DIR_B)



The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  sns.barplot(

The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 1.2}` instead.

  sns.barplot(
  ax.set_xticklabels(ax.get_xticklabels(), rotation=20, ha='right')

The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  sns.barplot(

The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 1.2}` instead.

  sns.barplot(

The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  sns.barplot(

The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 1.0}` instead.

  sns.barplot(

The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  g = sns.catplot(

The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 1.0}` instead.

  g = sns.catplot(
  ax

Saved single-model figures to figures\gemma3_1b


In [6]:
# Cross-model comparison helpers

def compare_bar_by_mitigation(df_a: pd.DataFrame, df_b: pd.DataFrame, metric: str, out_dir: Path) -> None:
    xorder = ['none', 'user_prompt_hardening', 'system_prompt_hardening', 'few_shot']
    a = df_a.copy(); a['model'] = MODEL_A
    b = df_b.copy(); b['model'] = MODEL_B
    both = pd.concat([a, b], ignore_index=True)
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.barplot(data=both, x='mitigation_type', y=metric, hue='model', estimator=np.mean,
                ci=95, capsize=0.08, errwidth=1.0, order=xorder, ax=ax)
    ax.set_title(f"{metric}: {MODEL_A} vs {MODEL_B} by Mitigation")
    ax.set_xlabel("Mitigation")
    ax.set_ylabel(metric)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=20, ha='right')
    ax.legend(title='Model')
    savefig(fig, out_dir / f"compare_bar_mitigation_{metric}.png")


def compare_bar_by_attack(df_a: pd.DataFrame, df_b: pd.DataFrame, metric: str, out_dir: Path) -> None:
    xorder = ['none', 'prepend', 'append', 'scatter']
    a = df_a.copy(); a['model'] = MODEL_A
    b = df_b.copy(); b['model'] = MODEL_B
    both = pd.concat([a, b], ignore_index=True)
    fig, ax = plt.subplots(figsize=(9, 6))
    sns.barplot(data=both, x='attack_type', y=metric, hue='model', estimator=np.mean,
                ci=95, capsize=0.08, errwidth=1.0, order=xorder, ax=ax)
    ax.set_title(f"{metric}: {MODEL_A} vs {MODEL_B} by Attack")
    ax.set_xlabel("Attack")
    ax.set_ylabel(metric)
    ax.legend(title='Model')
    savefig(fig, out_dir / f"compare_bar_attack_{metric}.png")


def compare_heatmap_diff(df_a: pd.DataFrame, df_b: pd.DataFrame, metric: str, out_dir: Path) -> None:
    order_rows = ['none', 'prepend', 'append', 'scatter']
    order_cols = ['none', 'user_prompt_hardening', 'system_prompt_hardening', 'few_shot']
    pvt_a = df_a.pivot_table(index='attack_type', columns='mitigation_type', values=metric, aggfunc='mean').reindex(index=order_rows, columns=order_cols)
    pvt_b = df_b.pivot_table(index='attack_type', columns='mitigation_type', values=metric, aggfunc='mean').reindex(index=order_rows, columns=order_cols)
    diff = pvt_a - pvt_b
    fig, ax = plt.subplots(figsize=(8, 5))
    sns.heatmap(diff, annot=True, fmt='.3g', cmap='coolwarm', center=0, ax=ax)
    ax.set_title(f"{metric} difference ({MODEL_A} - {MODEL_B})")
    ax.set_xlabel("Mitigation")
    ax.set_ylabel("Attack")
    savefig(fig, out_dir / f"compare_heatmap_diff_{metric}.png")


def compare_delta_vs_baseline(df_a: pd.DataFrame, df_b: pd.DataFrame, metric: str, out_dir: Path) -> None:
    base_mask = (df_a['prompt_type'] == 'BASIC') & (df_a['attack_type'] == 'none') & (df_a['mitigation_type'] == 'none')
    base_a = df_a.loc[base_mask, metric].mean()
    base_mask_b = (df_b['prompt_type'] == 'BASIC') & (df_b['attack_type'] == 'none') & (df_b['mitigation_type'] == 'none')
    base_b = df_b.loc[base_mask_b, metric].mean()

    ta = df_a.copy(); ta['model'] = MODEL_A; ta['delta_vs_own_baseline'] = ta[metric] - base_a
    tb = df_b.copy(); tb['model'] = MODEL_B; tb['delta_vs_own_baseline'] = tb[metric] - base_b
    both = pd.concat([ta, tb], ignore_index=True)

    fig, ax = plt.subplots(figsize=(10, 6))
    sns.barplot(
        data=both,
        x='attack_type', y='delta_vs_own_baseline', hue='model',
        estimator=np.mean, ci=95, capsize=0.08, errwidth=1.0,
        order=['none', 'prepend', 'append', 'scatter'], ax=ax
    )
    ax.axhline(0, color='gray', linewidth=1)
    ax.set_title(f"Δ {metric} vs Each Model's Baseline: {MODEL_A} vs {MODEL_B}")
    ax.set_xlabel("Attack")
    ax.set_ylabel(f"Δ {metric}")
    ax.legend(title='Model')
    savefig(fig, out_dir / f"compare_delta_vs_baseline_{metric}.png")


In [7]:
# Generate comparison figures
for metric in metrics:
    compare_bar_by_mitigation(df_a, df_b, metric, OUT_DIR_COMPARE)
    compare_bar_by_attack(df_a, df_b, metric, OUT_DIR_COMPARE)
    compare_heatmap_diff(df_a, df_b, metric, OUT_DIR_COMPARE)
    compare_delta_vs_baseline(df_a, df_b, metric, OUT_DIR_COMPARE)

print('Saved comparison figures to', OUT_DIR_COMPARE)



The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  sns.barplot(data=both, x='mitigation_type', y=metric, hue='model', estimator=np.mean,

The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 1.0}` instead.

  sns.barplot(data=both, x='mitigation_type', y=metric, hue='model', estimator=np.mean,
  ax.set_xticklabels(ax.get_xticklabels(), rotation=20, ha='right')

The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  sns.barplot(data=both, x='attack_type', y=metric, hue='model', estimator=np.mean,

The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'linewidth': 1.0}` instead.

  sns.barplot(data=both, x='attack_type', y=metric, hue='model', estimator=np.mean,

The `ci` parameter is deprecated. Use `errorbar=('ci', 95)` for the same effect.

  sns.barplot(

The `errwidth` parameter is deprecated. And will be removed in v0.15.0. Pass `err_kws={'l

Saved comparison figures to figures\compare\qwen3_0.6b_vs_gemma3_1b
