# Density × speed-transfer decay sweep

Sweeps density (controls agents per group) and speed-transfer decay to see how much rhythmicity leaks into the nominally non-rhythmic group.

- Agents per group = `500 * density_factor` (so 0.40 ≈ 200 agents).
- Day duration: 200 steps; simulation length: 800 steps (4 days).
- Defaults otherwise unchanged; `.mat` saving disabled for speed.

In [28]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from simulation import (
    AblationConfig,
    OutputOptions,
    SimulationParameters,
    default_simulation_config,
    run_ablation_study,
)

plt.style.use("seaborn-v0_8")

In [29]:
# Sweep definition
density_factors = [0.20, 0.40, 0.60, 0.80, 1.00, 1.20, 1.40]
speed_transfer_decays = [0.0, 0.4, 0.8, 0.95]
num_runs = 5
rhythmicity_permutations = 100

cfg = default_simulation_config()
cfg.sim = SimulationParameters(day_duration=300, sim_duration=1200)
cfg.output = OutputOptions(output_dir=Path("visualizations/output"))


In [30]:
cache_dir = Path("visualizations/sweep_cache")
cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / "density_decay_sweep.csv"

sweep_meta = {
    "day_duration": cfg.sim.day_duration,
    "sim_duration": cfg.sim.sim_duration,
    "rhythmicity_permutations": rhythmicity_permutations,
    "num_runs": num_runs,
}

df = None

if cache_file.exists():
    cached = pd.read_csv(cache_file)
    required_cols = {
        "density_factor",
        "speed_transfer_decay",
        "day_duration",
        "sim_duration",
        "rhythmicity_permutations",
        "num_runs",
        "run_id",
    }
    if required_cols.issubset(cached.columns):
        cached_filtered = cached[
            cached["density_factor"].isin(density_factors)
            & cached["speed_transfer_decay"].isin(speed_transfer_decays)
            & (cached["day_duration"] == sweep_meta["day_duration"])
            & (cached["sim_duration"] == sweep_meta["sim_duration"])
            & (cached["rhythmicity_permutations"] == sweep_meta["rhythmicity_permutations"])
            & (cached["num_runs"] == sweep_meta["num_runs"])
        ].copy()
        counts = cached_filtered.groupby(["density_factor", "speed_transfer_decay"]).run_id.nunique()
        if counts.reindex(pd.MultiIndex.from_product([density_factors, speed_transfer_decays], names=["density_factor", "speed_transfer_decay"])).fillna(0).ge(num_runs).all():
            df = cached_filtered
            print(
                f"Loaded cached sweep results from {cache_file}. Delete this file to force recompute after model changes."
            )
        else:
            print(f"Cached sweep at {cache_file} is incomplete for this config; recomputing and overwriting.")
    else:
        print(f"Cached sweep at {cache_file} missing metadata columns; recomputing and overwriting.")

if df is None:
    records = []
    seed_gen = np.random.default_rng(42)

    for density in density_factors:
        for decay in speed_transfer_decays:
            seeds = seed_gen.integers(0, 1_000_000_000, size=num_runs)
            for run_id, seed in enumerate(seeds):
                label = f"d{density:.2f}_decay{decay:.2f}_r{run_id}"
                ablations = {label: AblationConfig(speed_transfer_decay=decay)}
                rng = np.random.default_rng(int(seed))
                results = run_ablation_study(
                    density_factor=density,
                    ablations=ablations,
                    config=cfg,
                    rhythmicity_permutations=rhythmicity_permutations,
                    rng=rng,
                )
                summary = results[label]
                records.append(
                    {
                        "label": label,
                        "density_factor": density,
                        "agents_per_group": 500 * density,
                        "speed_transfer_decay": decay,
                        "amplitude_group1": summary.amplitude_group1,
                        "amplitude_group2": summary.amplitude_group2,
                        "phase_group1": summary.phase_group1,
                        "phase_group2": summary.phase_group2,
                        "phase_shift": summary.phase_shift_g2_minus_g1,
                        "p_value_group1": summary.rhythmicity_p_value_group1,
                        "p_value_group2": summary.rhythmicity_p_value_group2,
                        "day_duration": sweep_meta["day_duration"],
                        "sim_duration": sweep_meta["sim_duration"],
                        "rhythmicity_permutations": sweep_meta["rhythmicity_permutations"],
                        "num_runs": sweep_meta["num_runs"],
                        "run_id": run_id,
                        "seed": int(seed),
                    }
                )

    df = pd.DataFrame(records)
    df.sort_values(["density_factor", "speed_transfer_decay", "run_id"], inplace=True)
    df.to_csv(cache_file, index=False)
    print(f"Saved sweep results to {cache_file} for reuse. Delete this file to recompute.")

# Aggregate across runs for plotting with confidence intervals
agg_df = (
    df.groupby(["density_factor", "speed_transfer_decay"])
    .agg(
        agents_per_group=("agents_per_group", "first"),
        amplitude_group2_mean=("amplitude_group2", "mean"),
        amplitude_group2_std=("amplitude_group2", "std"),
        amplitude_group2_count=("amplitude_group2", "count"),
        amplitude_group1_mean=("amplitude_group1", "mean"),
        amplitude_group1_std=("amplitude_group1", "std"),
        phase_shift_mean=("phase_shift", "mean"),
        phase_shift_std=("phase_shift", "std"),
    )
    .reset_index()
)
agg_df["amplitude_group2_sem"] = agg_df["amplitude_group2_std"] / np.sqrt(agg_df["amplitude_group2_count"].clip(lower=1))
agg_df["amplitude_group2_ci95"] = 1.96 * agg_df["amplitude_group2_sem"]
agg_df



=== Ablation: d0.20_decay0.00_r0 ===
density factor: 0.200
Sim density 0.200 [##############################] 100.0%
group 1 (blue): speeds mu=3.7596, std=1.0030
group 2 (green): speeds mu=1.0368, std=0.5746
Summary density 0.2: amp1=1.247, amp2=0.014, phase1=359.9°, phase2=1.3°, phase_shift=1.4°, p1=0.0099, p2=0.0099

=== Ablation: d0.20_decay0.00_r1 ===
density factor: 0.200
Sim density 0.200 [##############################] 100.0%
group 1 (blue): speeds mu=3.7632, std=1.0042
group 2 (green): speeds mu=1.0418, std=0.5835
Summary density 0.2: amp1=1.255, amp2=0.015, phase1=0.4°, phase2=17.0°, phase_shift=16.5°, p1=0.0099, p2=0.0099

=== Ablation: d0.20_decay0.00_r2 ===
density factor: 0.200
Sim density 0.200 [##############################] 100.0%
group 1 (blue): speeds mu=3.7647, std=1.0030
group 2 (green): speeds mu=1.0365, std=0.5795
Summary density 0.2: amp1=1.252, amp2=0.019, phase1=0.0°, phase2=358.3°, phase_shift=-1.8°, p1=0.0099, p2=0.0099

=== Ablation: d0.20_decay0.00_r3 ==

Unnamed: 0,density_factor,speed_transfer_decay,agents_per_group,amplitude_group2_mean,amplitude_group2_std,amplitude_group2_count,amplitude_group1_mean,amplitude_group1_std,phase_shift_mean,phase_shift_std,amplitude_group2_sem,amplitude_group2_ci95
0,0.2,0.0,100.0,0.015291,0.002189,5,1.250802,0.003217,4.111357,7.525904,0.000979,0.001919
1,0.2,0.4,100.0,0.01801,0.00225,5,1.252753,0.003095,11.484133,16.994736,0.001006,0.001972
2,0.2,0.8,100.0,0.064346,0.005779,5,1.252567,0.001887,17.880986,9.714101,0.002585,0.005066
3,0.2,0.95,100.0,0.190228,0.016312,5,1.245406,0.004609,20.786926,6.132461,0.007295,0.014298
4,0.4,0.0,200.0,0.027222,0.004345,5,1.248716,0.003156,2.405929,8.776775,0.001943,0.003809
5,0.4,0.4,200.0,0.043739,0.002504,5,1.249946,0.00305,7.738854,6.101238,0.00112,0.002194
6,0.4,0.8,200.0,0.113039,0.006556,5,1.249952,0.006138,7.097428,4.489869,0.002932,0.005747
7,0.4,0.95,200.0,0.281959,0.0137,5,1.241706,0.006027,16.728592,6.244572,0.006127,0.012009
8,0.6,0.0,300.0,0.036809,0.002995,5,1.249419,0.002682,7.461744,1.630266,0.001339,0.002625
9,0.6,0.4,300.0,0.061083,0.003802,5,1.249428,0.001799,7.22004,3.782369,0.0017,0.003332


In [None]:
# Heatmaps for non-rhythmic group metrics
pivot_amp = agg_df.pivot(
    index="density_factor", columns="speed_transfer_decay", values="amplitude_group2_mean"
).sort_index()
pivot_phase = agg_df.pivot(
    index="density_factor", columns="speed_transfer_decay", values="phase_shift_mean"
).sort_index()

fig, axes = plt.subplots(1, 2, figsize=(12, 4), constrained_layout=True)

im0 = axes[0].imshow(pivot_amp.values, origin="lower", aspect="auto", cmap="magma")
axes[0].set_xticks(range(len(pivot_amp.columns)))
axes[0].set_xticklabels([f"{c:.2f}" for c in pivot_amp.columns])
axes[0].set_yticks(range(len(pivot_amp.index)))
axes[0].set_yticklabels([f"{r:.2f}" for r in pivot_amp.index])
axes[0].set_xlabel("speed_transfer_decay")
axes[0].set_ylabel("density_factor")
axes[0].set_title("Amplitude (Group 2, mean across runs)")
fig.colorbar(im0, ax=axes[0], label="speed amplitude")

im1 = axes[1].imshow(pivot_phase.values, origin="lower", aspect="auto", cmap="twilight")
axes[1].set_xticks(range(len(pivot_phase.columns)))
axes[1].set_xticklabels([f"{c:.2f}" for c in pivot_phase.columns])
axes[1].set_yticks(range(len(pivot_phase.index)))
axes[1].set_yticklabels([f"{r:.2f}" for r in pivot_phase.index])
axes[1].set_xlabel("speed_transfer_decay")
axes[1].set_title("Phase shift (Group2 - Group1, mean degrees)")
fig.colorbar(im1, ax=axes[1], label="degrees")

plt.show()


In [None]:
# Line view: amplitude of non-rhythmic group across decay for each density (mean ± 95% CI)
fig, ax = plt.subplots(figsize=(8, 4))
for density in density_factors:
    subset = agg_df[agg_df["density_factor"] == density]
    ax.errorbar(
        subset["speed_transfer_decay"],
        subset["amplitude_group2_mean"],
        yerr=subset["amplitude_group2_ci95"],
        marker="o",
        capsize=4,
        label=f"density {density:.2f} (agents≈{int(500*density)})",
    )
ax.set_xlabel("speed_transfer_decay")
ax.set_ylabel("Amplitude (Group 2)")
ax.set_title("Non-rhythmic group entrainment vs decay (mean ±95% CI)")
ax.legend()
plt.tight_layout()
plt.show()
