# Drossel-Schwab Forest Fire Model

Four experiments:
1. **f/p Ratio Sweep** - Effect of fire-to-growth ratio
2. **p Parameter Sweep** - Effect of growth probability
3. **Grid Size Effects (RQ1)** - Scaling with system size
4. **f Parameter Sweep** - Effect of fire/lightning probability

In [None]:
import results
from utils import (
    create_experiment_dir, get_latest_experiment_dir,
    run_parallel_simulations, save_summary,
    load_experiment_data, load_summary_map,
    plot_fire_size_distribution, plot_density_timeseries, plot_cluster_size_distribution,
)
from simulations.drosselschwab import *
import matplotlib.pyplot as plt

## Test

In [None]:
EXP1_NAME = "suppression_test"
L, steps, runs_per_param = 256, 2000, 3
p = 0.01
f = 1e-4

suppresions = [0, 1, 5, 10, 50, 100, 500, 1000, 5000, 10000]


exp1_param_list = []
param_idx = 0
for sup in suppresions:
        param_idx += 1
        for run_idx in range(runs_per_param):
            exp1_param_list.append({'L': L, 'p': p, 'f': f, 'steps': steps,
                                    'param_id': param_idx, 'run_id': run_idx, 'suppress': sup})

print(f"Experiment 1: {len(exp1_param_list)} simulations, {param_idx} parameter sets")

In [None]:
# Run (uncomment to execute)
exp1_outdir = create_experiment_dir(EXP1_NAME)
exp1_results = run_parallel_simulations(exp1_param_list, exp1_outdir)
save_summary(exp1_results, exp1_outdir)

In [None]:
# Analyze
try:
    exp1_dir = get_latest_experiment_dir(EXP1_NAME)
    exp1_data = load_experiment_data(exp1_dir)
    exp1_summary = load_summary_map(exp1_dir)
    plot_fire_size_distribution(exp1_data, exp1_summary, "Exp 1: Fire Size by f/p Ratio",
                                 results.path("exp1_fire_size_dist.png"))
    plot_density_timeseries(exp1_data, exp1_summary, "Exp 1: Tree Density Over Time",
                            results.path("exp1_density_timeseries.png"))
except FileNotFoundError as e:
    print(f"No data: {e}")

In [None]:
from notebooks.utils import _make_label


def plot_fire_size_distribution(runs_by_param: dict, summary_map: dict = None,
                                 title: str = "Fire Size Distribution", save_path = None):
    """Plot fire size distribution (log-log)."""
    plot_size_distribution(runs_by_param, 'fires_all', summary_map, title, 'Fire size', save_path)

def plot_size_distribution(runs_by_param: dict, data_key: str = 'fires_all',
                           summary_map: dict = None, title: str = "Size Distribution",
                           xlabel: str = "Size", save_path= None):
    """Plot log-log size distribution. data_key is 'fires_all' or 'clusters_all'."""
    import numpy as np
    import matplotlib.pyplot as plt

    # Collect all data for global bins
    all_data = []
    for runs in runs_by_param.values():
        for r in runs:
            all_data.extend(r[data_key])

    if not all_data:
        print(f"No {data_key} recorded")
        return

    all_data = np.array(all_data)
    bins = np.logspace(np.log10(max(1, all_data.min())), np.log10(all_data.max()), num=25)

    plt.figure(figsize=(10, 6))
    for pid in sorted(runs_by_param.keys()):
        agg = np.concatenate([np.array(r[data_key]) for r in runs_by_param[pid] if r[data_key]])
        if agg.size == 0:
            continue

        hist, edges = np.histogram(agg, bins=bins, density=True)
        centers = np.sqrt(edges[:-1] * edges[1:])
        mask = hist > 0
        plt.loglog(centers[mask]/256**2, hist[mask], 'o-', label=_make_label(pid, summary_map))

    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel('Probability density')
    plt.legend(fontsize='small')
    plt.grid(alpha=0.3)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150)
        print(f"Saved plot to {save_path}")
    plt.show()


In [None]:
    plot_fire_size_distribution(exp1_data, exp1_summary, "Exp 1: Fire Size by f/p Ratio",
                                 results.path("exp1_fire_size_dist.png"))

In [None]:
# python
"""
Analyze the suppression paradox:
- Group raw fire sizes by suppress value (from debug jsons)
- Define megafire as absolute threshold or top quantile
- Compute CCDFs, tail probs, bootstrapped CIs, KS tests vs baseline
- Run logistic regression predicting whether a run had >=1 megafire
- Plot CCDFs and megafire probability by suppress
"""
import os
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
import statsmodels.api as sm

def load_experiment_fires(exp_dir):
    # returns dict: (param_id, run_id) -> dict(params, fires_array)
    dbg_files = sorted(glob.glob(os.path.join(exp_dir, "debug_param*.json")))
    fires_files = sorted(glob.glob(os.path.join(exp_dir, "fires_param*.csv")))
    # map by param/run from filename if possible
    fires_map = {}
    for f in fires_files:
        base = os.path.basename(f)
        # try to infer param/run
        pid = None; rid = None
        for tok in base.replace(".", "_").split("_"):
            if tok.startswith("param"):
                try: pid = int(tok.replace("param", ""));
                except: pass
            if tok.startswith("run"):
                try: rid = int(tok.replace("run", ""));
                except: pass
        try:
            df = pd.read_csv(f)
            if "fire_size" in df.columns:
                arr = df["fire_size"].to_numpy(dtype=float)
            else:
                # fallback first numeric column
                arr = df.select_dtypes(include=[np.number]).iloc[:,0].to_numpy(dtype=float)
        except Exception:
            arr = np.array([], dtype=float)
        fires_map[(pid, rid, base)] = {"path": f, "fires": arr}
    runs = []
    for dbg in dbg_files:
        try:
            info = json.load(open(dbg))
            params = info.get("params", {})
            pid = params.get("param_id")
            rid = params.get("run_id")
            # try to find matching fires file
            match = None
            for (ppid, prid, base), v in fires_map.items():
                if ppid == pid and (rid is None or prid == rid):
                    match = v
                    break
            if match is None:
                # fallback by substring
                for (ppid, prid, base), v in fires_map.items():
                    if pid is not None and f"param{pid}" in base:
                        match = v; break
            fires = match["fires"] if match is not None else np.array([], dtype=float)
            runs.append({"params": params, "fires": np.asarray(fires, dtype=float)})
        except Exception:
            continue
    return runs

def bootstrap_ci(data, statfunc=np.mean, nboot=2000, ci=0.95, rng=None):
    rng = np.random.default_rng(rng)
    n = len(data)
    if n == 0:
        return (np.nan, np.nan, np.nan)
    boots = []
    for _ in range(nboot):
        sample = rng.choice(data, size=n, replace=True)
        boots.append(statfunc(sample))
    boots = np.array(boots)
    lower = np.percentile(boots, (1-ci)/2*100)
    upper = np.percentile(boots, (1+ci)/2*100)
    return statfunc(data), lower, upper

def analyze_suppression_paradox(exp_dir, megafire_threshold=None, megafire_quantile=0.99):
    runs = load_experiment_fires(exp_dir)
    # group by suppress value
    grouped = {}
    per_run_summary = []
    for r in runs:
        params = r["params"]
        sup = params.get("suppress", None)
        fires = r["fires"]
        grouped.setdefault(sup, []).append(fires)
        per_run_summary.append({"suppress": sup, "fires": fires})
    # decide threshold
    if megafire_threshold is None:
        # global quantile across all fires
        all_fires = np.concatenate([f for f in (r["fires"] for r in runs) if len(f)>0]) if runs else np.array([])
        if len(all_fires) == 0:
            raise RuntimeError("No fire data found")
        megafire_threshold = np.quantile(all_fires[all_fires>0], megafire_quantile)  # ignore zeros if present

    # compute statistics per suppress
    stats_table = []
    baseline_key = 0 if 0 in grouped else sorted(grouped.keys())[0]
    baseline_all = np.concatenate(grouped[baseline_key]) if grouped.get(baseline_key) else np.array([])
    for sup, list_runs in sorted(grouped.items(), key=lambda x: (float('inf') if x[0] is None else x[0])):
        all_vec = np.concatenate(list_runs) if len(list_runs)>0 else np.array([])
        # ignore non-positive if interpreting fire size >0
        positive = all_vec[all_vec>0]
        total_fires = len(all_vec)
        megafire_count = np.sum(all_vec >= megafire_threshold)
        megafire_prob = megafire_count / total_fires if total_fires>0 else np.nan
        mean, mean_lo, mean_hi = bootstrap_ci(positive, np.mean)
        median, med_lo, med_hi = bootstrap_ci(positive, np.median)
        prob, prob_lo, prob_hi = bootstrap_ci((all_vec>=megafire_threshold).astype(int), np.mean)
        # KS against baseline
        if len(positive)>0 and len(baseline_all[baseline_all>0])>0:
            ks_stat, ks_p = stats.ks_2samp(positive, baseline_all[baseline_all>0])
        else:
            ks_stat, ks_p = np.nan, np.nan
        stats_table.append({
            "suppress": sup,
            "n_runs": len(list_runs),
            "total_fires": int(total_fires),
            "megafire_prob": prob,
            "megafire_prob_lo": prob_lo,
            "megafire_prob_hi": prob_hi,
            "mean_fire": mean, "mean_lo": mean_lo, "mean_hi": mean_hi,
            "median_fire": median, "median_lo": med_lo, "median_hi": med_hi,
            "ks_stat_vs_baseline": ks_stat, "ks_p_vs_baseline": ks_p
        })

    df_stats = pd.DataFrame(stats_table)
    print("Megafire threshold =", megafire_threshold)
    print(df_stats[["suppress","n_runs","total_fires","megafire_prob","mean_fire","median_fire","ks_p_vs_baseline"]])

    # run-level analysis: did a run have >=1 megafire?
    run_rows = []
    for r in per_run_summary:
        fires = r["fires"]
        run_rows.append({
            "suppress": r["suppress"],
            "had_megafire": int(np.any(fires >= megafire_threshold)),
            "num_megafires": int(np.sum(fires >= megafire_threshold)),
            "total_fires": int(len(fires))
        })
    df_runs = pd.DataFrame(run_rows).dropna(subset=["suppress"])
    # logistic regression (had_megafire ~ suppress)
    # encode suppress numeric if possible
    try:
        df_runs["suppress_num"] = df_runs["suppress"].astype(float)
        X = sm.add_constant(df_runs[["suppress_num"]])
        y = df_runs["had_megafire"]
        logit = sm.Logit(y, X).fit(disp=False)
        print("Logistic regression summary (had_megafire ~ suppress):")
        print(logit.summary())
    except Exception:
        print("Skipping logistic regression (non-numeric suppress or insufficient data)")

    # Plots
    plt.figure(figsize=(8,6))
    # CCDF plot per suppress
    for sup, list_runs in sorted(grouped.items(), key=lambda x: (float('inf') if x[0] is None else x[0])):
        all_vec = np.concatenate(list_runs) if len(list_runs)>0 else np.array([])
        pos = all_vec[all_vec>0]
        if len(pos)==0:
            continue
        sorted_x = np.sort(pos)
        ccdf = 1.0 - np.arange(1, len(sorted_x)+1)/len(sorted_x)
        plt.loglog(sorted_x, ccdf, marker='.', linestyle='none', label=f"suppress={sup}")
    plt.xlabel("Fire size")
    plt.ylabel("CCDF")
    plt.title("CCDF of fire sizes by suppress")
    plt.legend()
    plt.grid(True, which='both', ls='--', alpha=0.4)
    plt.show()

    # megafire probability bar with CI
    plt.figure(figsize=(8,4))
    xs = np.arange(len(df_stats))
    plt.errorbar(xs, df_stats["megafire_prob"], yerr=[df_stats["megafire_prob"]-df_stats["megafire_prob_lo"], df_stats["megafire_prob_hi"]-df_stats["megafire_prob"]], fmt='o', capsize=4)
    plt.xticks(xs, [str(x) for x in df_stats["suppress"]], rotation=45)
    plt.xlabel("suppress")
    plt.ylabel(f"P(fire >= {megafire_threshold:.3g})")
    plt.title("Megafire probability by suppress")
    plt.grid(True, ls='--', alpha=0.3)
    plt.show()

    # chi-square / contingency: counts of runs with/without megafire by suppress
    contingency = []
    labels = []
    for sup, grp in df_runs.groupby("suppress"):
        n_with = grp["had_megafire"].sum()
        n_without = len(grp) - n_with
        contingency.append([n_with, n_without])
        labels.append(sup)
    if len(contingency) >= 2:
        chi2, pval, dof, expected = stats.chi2_contingency(np.array(contingency))
        print("Chi-square across suppress groups (run-level had_megafire): p =", pval)
    else:
        print("Not enough groups for chi-square test")

    return {"df_stats": df_stats, "df_runs": df_runs, "megafire_threshold": megafire_threshold}

# Example usage:
res = analyze_suppression_paradox(exp_dir=exp1_dir, megafire_quantile=0.995)

In [None]:
# Analysis: suppression paradox diagnostics (log-log density, means/medians, maxima, megafire rates)
import os
import glob
import json
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --- CONFIG --- edit these as you like ---
EXP_DIR = get_latest_experiment_dir("suppression_test")  # or set explicit path: Path("data/.../experiment_1")
NBINS = 30                # bins for log-histogram
AREA_NORMALIZE = True      # when True x-axis is size normalized by area (L**2)
INCLUDE_ZEROS_IN_DENOM = False  # for density normalization: if True divide by total events (including zeros)
MEGAFIRE_THRESHOLD = None  # absolute threshold (None => use MEGAFIRE_FRACTION)
MEGAFIRE_FRACTION = 0.01   # if MEGAFIRE_THRESHOLD is None, threshold = fraction * L**2
SHOW_BOXPLOT = True
VERBOSE = True
# --- end CONFIG ---

def discover_runs(exp_dir):
    """Return list of runs: dict with keys params (dict) and fires (np.array) and L (int)."""
    ed = Path(exp_dir)
    debug_files = sorted(ed.glob("debug_param*.json"))
    fire_files = sorted(ed.glob("fires_param*.csv"))

    # index fire files by tokens to match param/run
    fire_index = {}
    for f in fire_files:
        name = f.name
        # try to parse param and id tokens (filename patterns produced by worker)
        pid = None; rid = None
        toks = name.replace(".", "_").split("_")
        for t in toks:
            if t.startswith("param"):
                try: pid = int(t.replace("param", ""));
                except: pass
            if t.startswith("id"):
                # id token in your pattern is 'id{run}' before timestamp, so handle both 'id7' or 'id7' in tokens
                try: rid = int(t.replace("id", ""));
                except: pass
        fire_index.setdefault((pid, rid), []).append(f)

    runs = []
    for dbg in debug_files:
        try:
            info = json.load(open(dbg, 'r'))
        except Exception:
            if VERBOSE: print("failed to read", dbg)
            continue
        params = info.get("params", {})
        pid = params.get("param_id")
        rid = params.get("run_id")
        # pick matching fire file (prefer exact pid,rid)
        candidates = fire_index.get((pid, rid), []) + fire_index.get((pid, None), [])
        fire_arr = np.array([], dtype=float)
        if candidates:
            # pick first candidate (should be unique)
            try:
                df = pd.read_csv(candidates[0])
                if 'fire_size' in df.columns:
                    fire_arr = df['fire_size'].to_numpy(dtype=float)
                else:
                    # fallback: first numeric column
                    nums = df.select_dtypes(include=[np.number])
                    if not nums.empty:
                        fire_arr = nums.iloc[:,0].to_numpy(dtype=float)
                    else:
                        fire_arr = np.array([], dtype=float)
            except Exception as e:
                if VERBOSE: print("failed reading fires file", candidates[0], e)
                fire_arr = np.array([], dtype=float)
        else:
            # fallback: scan any fire file that contains "param{pid}" in name
            fallback = None
            if pid is not None:
                for f in fire_files:
                    if f"param{pid}" in f.name:
                        fallback = f; break
            if fallback:
                try:
                    df = pd.read_csv(fallback)
                    if 'fire_size' in df.columns:
                        fire_arr = df['fire_size'].to_numpy(dtype=float)
                    else:
                        nums = df.select_dtypes(include=[np.number])
                        if not nums.empty:
                            fire_arr = nums.iloc[:,0].to_numpy(dtype=float)
                        else:
                            fire_arr = np.array([], dtype=float)
                except Exception as e:
                    if VERBOSE: print("fallback read failed", fallback, e)
                    fire_arr = np.array([], dtype=float)

        runs.append({'params': params, 'fires': np.asarray(fire_arr, dtype=float)})
    return runs

def group_by_suppress(runs):
    """Return dict: suppress_value -> list of np.arrays (fires for each run) and store L values seen."""
    by_sup = {}
    Ls = []
    for r in runs:
        sup = r['params'].get('suppress', None)
        L = r['params'].get('L', None)
        if L is not None:
            try: Ls.append(int(L))
            except: pass
        by_sup.setdefault(sup, []).append(r['fires'])
    return by_sup, Ls

def log_density_plot(grouped, area=None, area_norm=True, include_zeros_in_denom=False, nbins=40, ax=None):
    """Plot log-log density vs size (optionally normalized by area).

    include_zeros_in_denom: if True divide counts by total events (including zeros);
                           if False divide by positive-only events (so density integrates to 1 over >0).
    area: if provided, divides x-axis (sizes) by this area for normalization (e.g. L**2)
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(8,6))
    else:
        fig = ax.figure

    # pick global limits across groups (only positive sizes)
    all_pos = []
    for runs in grouped.values():
        for arr in runs:
            if arr is None:
                continue
            all_pos.extend(list(np.asarray(arr)[np.asarray(arr) > 0]))
    if len(all_pos) == 0:
        print("No positive fire sizes found.")
        return ax
    all_pos = np.array(all_pos)
    xmin = np.min(all_pos)
    xmax = np.max(all_pos)
    # protect against degenerate ranges
    if xmin <= 0 or xmax <= 0 or xmin == xmax:
        bins = np.logspace(np.log10(max(1, xmin)), np.log10(max(1, xmax if xmax>0 else xmin+1)), nbins)
    else:
        bins = np.logspace(np.log10(xmin), np.log10(xmax), nbins)

    for sup, runs in sorted(grouped.items(), key=lambda x: (float('inf') if x[0] is None else x[0])):
        if not runs:
            continue
        all_f = np.concatenate([np.asarray(r) for r in runs]) if len(runs)>1 else np.asarray(runs[0])
        positive = all_f[all_f > 0]
        if positive.size == 0:
            continue
        counts, edges = np.histogram(positive, bins=bins)
        widths = edges[1:] - edges[:-1]
        if include_zeros_in_denom:
            denom = len(all_f)  # including zeros
        else:
            denom = positive.size  # only positive events
        denom = denom if denom > 0 else 1
        density = counts / (denom * widths)   # density per size-unit
        centers = np.sqrt(edges[:-1] * edges[1:])
        # mask zero-density bins to avoid plotting zeros on log-scale and connecting across them
        mask = density > 0
        if not np.any(mask):
            continue
        x = centers[mask].astype(float)
        y = density[mask].astype(float)
        if area_norm and area is not None and area > 0:
            x = x / float(area)
        # plot markers only (no lines) to avoid vertical connectors across gaps
        ax.loglog(x, y, marker='o', label=f"suppress={sup}")

    ax.set_xlabel("Fire size (normalized by area if requested)" if area_norm and area is not None else "Fire size")
    ax.set_ylabel("Probability density (log-log)")
    ax.grid(True, which='both', ls='--', alpha=0.4)
    ax.legend(fontsize='small')
    return ax

def mean_median_by_suppress(grouped, include_zeros=True):
    """Return DataFrame with columns: suppress, run_mean_mean, run_mean_std, run_median_mean, run_median_std, n_runs"""
    rows = []
    for sup, runs in sorted(grouped.items(), key=lambda x: (float('inf') if x[0] is None else x[0])):
        per_run_means = []
        per_run_medians = []
        for arr in runs:
            if include_zeros:
                if arr.size==0:
                    per_run_means.append(np.nan)
                    per_run_medians.append(np.nan)
                else:
                    per_run_means.append(np.mean(arr))
                    per_run_medians.append(np.median(arr))
            else:
                pos = arr[arr>0]
                if pos.size==0:
                    per_run_means.append(np.nan)
                    per_run_medians.append(np.nan)
                else:
                    per_run_means.append(np.mean(pos))
                    per_run_medians.append(np.median(pos))
        per_run_means = np.array(per_run_means, dtype=float)
        per_run_medians = np.array(per_run_medians, dtype=float)
        rows.append({
            'suppress': sup,
            'n_runs': int(np.sum(~np.isnan(per_run_means))),
            'mean_of_runs_mean': np.nanmean(per_run_means),
            'std_of_runs_mean': np.nanstd(per_run_means),
            'mean_of_runs_median': np.nanmean(per_run_medians),
            'std_of_runs_median': np.nanstd(per_run_medians),
        })
    return pd.DataFrame(rows)

def per_run_maxima_boxplot(grouped, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8,5))
    else:
        fig = ax.figure
    labels = []
    data = []
    for sup, runs in sorted(grouped.items(), key=lambda x: (float('inf') if x[0] is None else x[0])):
        perrun_max = [np.nan if arr.size==0 else float(np.max(arr)) for arr in runs]
        data.append([v for v in perrun_max if not np.isnan(v)])
        labels.append(str(sup))
    ax.boxplot(data, labels=labels, showfliers=True)
    ax.set_xlabel("suppress")
    ax.set_ylabel("Per-run max fire size")
    ax.set_title("Per-run maximum fire size distribution by suppression")
    ax.grid(alpha=0.2)
    return ax

def megafire_stats(grouped, Ls, threshold=None, fraction=None):
    """Return DataFrame with columns suppress, per_fire_prob, per_fire_count, per_run_prob, n_runs.
       If threshold is None, use fraction * L**2 (requires Ls non-empty)."""
    rows = []
    # compute representative L
    L_rep = int(np.median(Ls)) if Ls else None
    if threshold is None:
        if fraction is None:
            raise ValueError("Either threshold or fraction must be provided")
        if L_rep is None:
            raise ValueError("No L information to compute absolute threshold from fraction")
        threshold = fraction * (L_rep**2)
    for sup, runs in sorted(grouped.items(), key=lambda x: (float('inf') if x[0] is None else x[0])):
        all_f = np.concatenate(runs) if runs else np.array([], dtype=float)
        total_fires = len(all_f)
        if total_fires == 0:
            per_fire_prob = np.nan
        else:
            per_fire_prob = np.sum(all_f >= threshold) / total_fires
        # per-run: fraction of runs that had >=1 megafire
        run_has = [int(np.any(arr >= threshold)) for arr in runs]
        run_level_prob = np.mean(run_has) if len(run_has)>0 else np.nan
        rows.append({'suppress': sup, 'per_fire_prob': per_fire_prob, 'per_run_prob': run_level_prob, 'n_runs': len(runs), 'threshold': threshold})
    return pd.DataFrame(rows)

# ---------------------- perform analysis ----------------------
runs = discover_runs(EXP_DIR)
if VERBOSE:
    print(f"Discovered {len(runs)} runs in {EXP_DIR}")

grouped, Ls = group_by_suppress(runs)
if VERBOSE:
    print("Suppression keys found:", sorted(grouped.keys(), key=lambda x: (float('inf') if x is None else x)))

# choose threshold
abs_threshold = MEGAFIRE_THRESHOLD
if abs_threshold is None:
    if MEGAFIRE_FRACTION is not None:
        if len(Ls)==0:
            raise RuntimeError("No L found in run params to compute fraction-based threshold.")
        L_rep = int(np.median(Ls))
        abs_threshold = MEGAFIRE_FRACTION * (L_rep**2)
    else:
        raise RuntimeError("No megafire threshold/fraction provided")

# 1) log-log density plot (normalized by area optionally outside function)
fig, ax = plt.subplots(figsize=(9,6))
# For x-axis normalization we'll compute area here and transform x ticks later
area = (int(np.median(Ls))**2) if Ls else 1
log_density_plot(grouped, area_norm=False, include_zeros_in_denom=INCLUDE_ZEROS_IN_DENOM, nbins=NBINS, ax=ax)
if AREA_NORMALIZE:
    ax.set_xlabel("Fire size (fraction of grid area)")
    # rescale x-axis ticks by dividing by area (we plotted raw sizes), update tick labels
    xticks = ax.get_xticks()
    ax.set_xticklabels([f"{xt/area:.2e}" if xt>0 else "0" for xt in xticks])
plt.title("Log-log fire-size density by suppression (bins log-spaced)")
plt.show()

# 2) Mean & median per suppression, with and without zeros
df_incl = mean_median_by_suppress(grouped, include_zeros=True)
df_excl = mean_median_by_suppress(grouped, include_zeros=False)

fig, axes = plt.subplots(1,2, figsize=(12,4))
# means
# axes[0].errorbar(df_incl['suppress'].astype(str), df_incl['mean_of_runs_mean'], yerr=df_incl['std_of_runs_mean'], fmt='o-', label='include zeros')
axes[0].errorbar(df_excl['suppress'].astype(str), df_excl['mean_of_runs_mean'], yerr=df_excl['std_of_runs_mean'], fmt='s--')
axes[0].set_xlabel("suppression limit"); axes[0].set_ylabel("Mean fire size (per-run average)"); axes[0].set_title("Mean fire size by suppression limit"); axes[0].grid(alpha=0.3)


# medians
# axes[1].errorbar(df_incl['suppress'].astype(str), df_incl['mean_of_runs_median'], yerr=df_incl['std_of_runs_median'], fmt='o-', label='include zeros')
axes[1].errorbar(df_excl['suppress'].astype(str), df_excl['mean_of_runs_median'], yerr=df_excl['std_of_runs_median'], fmt='s--')
axes[1].set_xlabel("suppression limit"); axes[1].set_ylabel("Median fire size (per-run median)"); axes[1].set_title("Median fire size by suppression limit"); axes[1].grid(alpha=0.3)

plt.tight_layout(); plt.show()

# 3) per-run maxima boxplot (and print overall maxima)
fig, ax = plt.subplots(figsize=(10,5))
per_run_maxima_boxplot(grouped, ax=ax)
plt.show()

overall_max = {sup: (np.max(np.concatenate(runs)) if len(runs)>0 and np.concatenate(runs).size>0 else np.nan) for sup, runs in grouped.items()}
print("Overall max per suppress:")
for k,v in sorted(overall_max.items(), key=lambda x: (float('inf') if x[0] is None else x[0])):
    print("suppress", k, "max", v)

# 4) megafire percentages
df_meg = megafire_stats(grouped, Ls, threshold=abs_threshold)
print("\nMegafire threshold (absolute):", abs_threshold)
print(df_meg[['suppress','per_fire_prob','per_run_prob','n_runs']])

fig, axes = plt.subplots(1,2, figsize=(12,4))
axes[0].bar(df_meg['suppress'].astype(str), df_meg['per_fire_prob']*100)
axes[0].set_ylabel("% of fires >= threshold")
axes[0].set_title("Per-fire megafire percentage (%) by suppress")
axes[0].grid(alpha=0.3)
axes[1].bar(df_meg['suppress'].astype(str), df_meg['per_run_prob']*100)
axes[1].set_ylabel("% of runs with >=1 megafire")
axes[1].set_title("Per-run megafire probability (%) by suppress")
axes[1].grid(alpha=0.3)
plt.tight_layout(); plt.show()