# Consist End-to-End (Monte Carlo Predator–Prey Sweep)

This notebook is a **happy-path** walkthrough of Consist in a scientific simulation workflow:

1. Create a sweep registry (parameters + seeds)
2. Run one simulation per Consist step (so each sweep member has provenance)
3. Aggregate per-sim outputs into canonical analysis tables
4. Ingest into DuckDB and persist schema profiles
5. Export SQLModel classes from the persisted schema
6. Import the generated module and do typed joins/analysis
7. Re-run to demonstrate caching / minimal recomputation


## 0) Setup

This notebook writes outputs under `examples/runs/`, a DuckDB file under `examples/`, and generated SQLModel code under `examples/generated/`.

In [None]:
from __future__ import annotations

import sys
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.display import display
from sqlmodel import select
import uuid

import consist
from consist import Tracker
from consist.models import Artifact, Run

def _find_repo_root(start: Path) -> Path:
    for candidate in (start, *start.parents):
        if (candidate / "pyproject.toml").exists():
            return candidate
    raise RuntimeError("Could not locate repo root (missing pyproject.toml)")


REPO_ROOT = _find_repo_root(Path.cwd())
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from examples.pipeline_steps import (
    aggregate_parquet,
    build_sweep_registry,
    make_run_id,
    run_one_simulation_with_raw,
    write_npz,
    write_parquet,
)
from examples.synth_simulation import PredatorPreyConfig

sns.set_theme(style="whitegrid")

EXAMPLES_DIR = REPO_ROOT / "examples"
RUN_DIR = EXAMPLES_DIR / "runs" / "predator_prey_demo"
SESSION_ID = uuid.uuid4().hex[:8]
DB_PATH = EXAMPLES_DIR / f"predator_prey_demo_{SESSION_ID}.duckdb"
GENERATED_DIR = EXAMPLES_DIR / "generated"
GENERATED_DIR.mkdir(parents=True, exist_ok=True)

tracker = Tracker(run_dir=RUN_DIR, db_path=DB_PATH)


**Caching note:** Consist’s cache signature includes a code identity (Git SHA). If your repo has uncommitted *tracked* changes, Consist uses a stable hash of the diff so cache hits still work during iteration (it ignores untracked run outputs/DB files).

## 1) Define the sweep

We sweep a small grid of parameters (and optionally replicates per setting). Each row gets a stable `sim_id` and per-run seed.

In [None]:
registry = build_sweep_registry(
    # Tuned grid to highlight a transition region (coexistence ↔ extinction).
    # (The stochastic dynamics can be sensitive; this set usually shows mixed regimes.)
    prey_birth_rates=[0.60, 0.75],
    predation_rates=[0.008, 0.011, 0.014],
    predator_death_rates=[0.15, 0.20],
    predator_birth_efficiency=0.20,
    replicates_per_setting=10,
    seed=7,
)
print(f"sims={len(registry)} settings={registry['setting_id'].nunique()}")
registry.head()

## 2) Cold run: one simulation per Consist step

We create a scenario header run, then execute each `sim_id` as a child step. Each step writes:

- a one-row metrics table (`sim_*_metrics.parquet`)
- a sampled time-series table (`sim_*_series.parquet`)
- optionally raw arrays (`sim_*_raw.npz`)


In [None]:
from dataclasses import asdict
from tqdm import tqdm


def _py_scalar(v):
    return v.item() if hasattr(v, "item") else v


base_config = PredatorPreyConfig(
    steps=250,
    sample_every=1,
    prey_init=80,
    predator_init=25,
    predator_birth_efficiency=0.20,
)

scenario_id = "predator_prey_sweep"
artifact_root = RUN_DIR / "artifacts" / scenario_id

registry_path = artifact_root / "sweep_registry.parquet"

sample_metrics_artifact: Artifact | None = None
sample_series_artifact: Artifact | None = None
registry_artifact: Artifact | None = None
summary_artifact: Artifact | None = None

cache_hits = 0
cache_misses = 0

with tracker.scenario(
    scenario_id,
    config={
        "base": asdict(base_config),
        "sweep": {
            "prey_birth_rates": sorted(registry["prey_birth_rate"].unique().tolist()),
            "predation_rates": sorted(registry["predation_rate"].unique().tolist()),
            "predator_death_rates": sorted(registry["predator_death_rate"].unique().tolist()),
            "replicates_per_setting": int(registry["replicate_id"].max() + 1),
        },
    },
    tags=["examples", "simulation", "predator_prey"],
) as scenario:
    # Step: write a sweep registry artifact (parameters + seeds)
    with scenario.step(
        name="registry",
        model="registry",
        cache_mode="overwrite",
        facet={"rows": int(len(registry)), "settings": int(registry["setting_id"].nunique())},
    ) as t:
        write_parquet(registry, registry_path)
        registry_artifact = t.log_artifact(
            registry_path,
            key="sweep_registry",
            direction="output",
            rows=int(len(registry)),
        )

    assert registry_artifact is not None

    # Steps: run one simulation per sweep row
    for row in tqdm(registry.to_dict(orient="records"), total=len(registry)):
        sim_cfg = {k: _py_scalar(v) for k, v in dict(row).items()}
        sim_id = int(sim_cfg["sim_id"])
        setting_id = int(sim_cfg["setting_id"])
        replicate_id = int(sim_cfg["replicate_id"])

        sim_dir = artifact_root / f"sim_{sim_id:04d}"
        metrics_path = sim_dir / "metrics.parquet"
        series_path = sim_dir / "series.parquet"
        raw_path = sim_dir / "raw.npz"

        run_id = make_run_id(scenario_id=scenario_id, sim_id=sim_id)
        with scenario.step(
            name=f"sim_{sim_id:04d}",
            run_id=run_id,
            model="simulate",
            cache_mode="reuse",
            inputs=[registry_artifact],
            config={
                **{
                    "steps": int(base_config.steps),
                    "dt": float(base_config.dt),
                    "prey_init": int(base_config.prey_init),
                    "predator_init": int(base_config.predator_init),
                },
                **sim_cfg,
            },
            facet={
                "sim_id": sim_id,
                "setting_id": setting_id,
                "replicate_id": replicate_id,
                "prey_birth_rate": float(sim_cfg["prey_birth_rate"]),
                "predation_rate": float(sim_cfg["predation_rate"]),
                "predator_death_rate": float(sim_cfg["predator_death_rate"]),
            },
            tags=["sim"],
        ) as t:
            if t.is_cached:
                cache_hits += 1
                continue

            cache_misses += 1
            raw_arrays, series_df, metrics_df = run_one_simulation_with_raw(
                base_config=base_config,
                registry_row=sim_cfg,
            )
            write_parquet(metrics_df, metrics_path)
            write_parquet(series_df, series_path)
            write_npz(raw_arrays, raw_path)

            metrics_art = t.log_artifact(
                metrics_path, key="sim_metrics", direction="output", rows=int(len(metrics_df))
            )
            series_art = t.log_artifact(
                series_path, key="sim_series", direction="output", rows=int(len(series_df))
            )
            t.log_artifact(raw_path, key="sim_raw", direction="output")

            if sample_metrics_artifact is None:
                sample_metrics_artifact = metrics_art
            if sample_series_artifact is None:
                sample_series_artifact = series_art

    # Step: summarize the sweep into a single results table
    with scenario.step(
        name="summarize",
        model="summarize",
        cache_mode="overwrite",
        inputs=[registry_artifact],
    ) as t:
        metrics_all = aggregate_parquet(paths=sorted(artifact_root.glob("sim_*/metrics.parquet")))
        summary = (
            metrics_all.groupby(
                ["prey_birth_rate", "predation_rate", "predator_death_rate"], dropna=False
            )
            .agg(
                sims=("sim_id", "count"),
                prey_extinct_rate=("prey_extinct", "mean"),
                predator_extinct_rate=("predator_extinct", "mean"),
                mean_prey_final=("prey_final", "mean"),
                mean_predator_final=("predator_final", "mean"),
                mean_prey_std=("prey_std", "mean"),
                mean_predator_std=("predator_std", "mean"),
            )
            .reset_index()
        )
        summary["oscillation_score"] = summary["mean_prey_std"] + summary["mean_predator_std"]
        summary_path = artifact_root / "sweep_summary.parquet"
        write_parquet(summary, summary_path)
        summary_artifact = t.log_artifact(
            summary_path,
            key="sweep_summary",
            direction="output",
            rows=int(len(summary)),
        )

print(f"cache_misses={cache_misses} cache_hits={cache_hits}")
assert registry_artifact is not None
assert sample_metrics_artifact is not None
assert sample_series_artifact is not None
assert summary_artifact is not None

pd.read_parquet(artifact_root / "sweep_summary.parquet").sort_values(
    ["oscillation_score"], ascending=False
).head(8)


## 3) Sweep summary

A single summary artifact makes it easy to see “regimes” (stable coexistence vs extinction) without looking at every trajectory.

In [None]:
summary_df = pd.read_parquet(artifact_root / "sweep_summary.parquet")

# Example: visualize extinction risk as a heatmap (fix death rate to one slice)
death_rate = float(sorted(summary_df["predator_death_rate"].unique())[0])
slice_df = summary_df[summary_df["predator_death_rate"] == death_rate]
pivot = slice_df.pivot_table(
    index="prey_birth_rate",
    columns="predation_rate",
    values="predator_extinct_rate",
)

ax = sns.heatmap(pivot.sort_index(), annot=True, fmt=".2f", cmap="mako", vmin=0, vmax=1)
ax.set_title(f"Predator extinction rate (predator_death_rate={death_rate})")
ax.set_xlabel("predation_rate")
ax.set_ylabel("prey_birth_rate")
plt.show()


## 4) Ingest + schema export

Ingest `all_metrics` into DuckDB, verify schema tracking in the DB, then export SQLModel code into `examples/generated/`.

In [None]:
# Ingest one representative sim output to persist a schema profile for export.
# (We do NOT ingest every sim; we keep most results “cold” and query them via views.)
tracker.ingest(registry_artifact)
tracker.ingest(sample_metrics_artifact)
tracker.ingest(sample_series_artifact)

metrics_stub_path = GENERATED_DIR / "predator_prey_metrics.py"
registry_stub_path = GENERATED_DIR / "predator_prey_registry.py"
series_stub_path = GENERATED_DIR / "predator_prey_series.py"

tracker.export_schema_sqlmodel(
    artifact_id=sample_metrics_artifact.id,
    out_path=metrics_stub_path,
    class_name="PredatorPreyMetrics",
    table_name="sim_metrics",
)
tracker.export_schema_sqlmodel(
    artifact_id=registry_artifact.id,
    out_path=registry_stub_path,
    class_name="PredatorPreySweepRegistry",
    table_name="sweep_registry",
)
tracker.export_schema_sqlmodel(
    artifact_id=sample_series_artifact.id,
    out_path=series_stub_path,
    class_name="PredatorPreySeries",
    table_name="sim_series",
)

print("Wrote:")
print("-", metrics_stub_path)
print("-", registry_stub_path)
print("-", series_stub_path)


## 5) Warm analysis (typed)

Import the generated SQLModel class and run typed joins/analysis (e.g., extinction rates by parameter setting, representative trajectories).

In [None]:
import importlib

# Compare: generated stubs vs checked-in “contract” models.
# - Generated stubs are conservative and meant for editing.
# - Checked-in models below demonstrate reviewed PK/FK/index choices.
predator_prey_metrics = importlib.import_module("examples.generated.predator_prey_metrics")
predator_prey_registry = importlib.import_module("examples.generated.predator_prey_registry")
predator_prey_series = importlib.import_module("examples.generated.predator_prey_series")

importlib.reload(predator_prey_metrics)
importlib.reload(predator_prey_registry)
importlib.reload(predator_prey_series)

GeneratedPredatorPreyMetrics = predator_prey_metrics.PredatorPreyMetrics
GeneratedPredatorPreySweepRegistry = predator_prey_registry.PredatorPreySweepRegistry
GeneratedPredatorPreySeries = predator_prey_series.PredatorPreySeries

from examples.checked_models import (
    PredatorPreyMetricsChecked,
    PredatorPreySeriesChecked,
    PredatorPreySweepRegistryChecked,
)

print("Generated (example):", GeneratedPredatorPreyMetrics.__name__, "abstract=", getattr(GeneratedPredatorPreyMetrics, "__abstract__", False))
print("Checked-in:", PredatorPreyMetricsChecked.__name__)

# Register typed hybrid views.
# These views unify:
# - “hot” rows (ingested tables)
# - “cold” rows (Parquet artifacts across many runs)
tracker.views.register(PredatorPreyMetricsChecked, key="sim_metrics")
tracker.views.register(PredatorPreySweepRegistryChecked, key="sweep_registry")
tracker.views.register(PredatorPreySeriesChecked, key="sim_series")

MetricsView = tracker.views.PredatorPreyMetricsChecked
RegistryView = tracker.views.PredatorPreySweepRegistryChecked
SeriesView = tracker.views.PredatorPreySeriesChecked

with consist.db_session(tracker=tracker) as session:
    # Typed join: per-sim metrics + sweep registry
    stmt = (
        select(MetricsView, RegistryView)
        .join(RegistryView, MetricsView.sim_id == RegistryView.sim_id)
        .where(MetricsView.consist_scenario_id == scenario_id)
    )
    joined = session.exec(stmt).all()

metrics_rows = [m.model_dump() for m, _ in joined]
registry_rows = [r.model_dump() for _, r in joined]
metrics_df = pd.DataFrame(metrics_rows)
registry_df = pd.DataFrame(registry_rows)

display(metrics_df.head(3))
display(registry_df.head(3))

# A small analysis: extinction rate by parameter setting (from the typed view)
ext_summary = (
    metrics_df.groupby(["prey_birth_rate", "predation_rate", "predator_death_rate"], dropna=False)
    .agg(
        sims=("sim_id", "count"),
        predator_extinct_rate=("predator_extinct", "mean"),
        prey_extinct_rate=("prey_extinct", "mean"),
        mean_oscillation=("predator_std", "mean"),
    )
    .reset_index()
)
ext_summary.sort_values(["predator_extinct_rate", "mean_oscillation"], ascending=[False, False]).head(8)


## 6) Caching + recomputation demo

We re-run the sweep under a new scenario id with `cache_mode="reuse"`.

- Most sims should be cache hits.
- We intentionally delete one cached output file to force exactly one recomputation.
- On cache hits, we also request **materialization** of cached outputs into a new folder.

In [None]:
warm_scenario_id = f"{scenario_id}_warm"
warm_root = RUN_DIR / "artifacts" / warm_scenario_id

from tqdm import tqdm

# Force one cache miss by deleting a cached output file from the original run.
# We delete `raw.npz` (not ingested) so Consist’s cache validation can detect it.
victim_sim_id = int(registry["sim_id"].max())
victim_raw = artifact_root / f"sim_{victim_sim_id:04d}" / "raw.npz"
if victim_raw.exists():
    victim_raw.unlink()
    print("Deleted cached output to force recomputation:", victim_raw)
else:
    print("Victim file already missing:", victim_raw)

DEMO_N = 12
MATERIALIZE_CACHED_OUTPUTS = False

warm_hits = 0
warm_misses = 0

warm_registry = registry.sort_values("sim_id").head(DEMO_N)
print(f"warm demo sims={len(warm_registry)} of total={len(registry)}")

with tracker.scenario(
    warm_scenario_id,
    config={"reuse_from": scenario_id},
    tags=["examples", "cache_demo"],
) as scenario:
    assert registry_artifact is not None
    for row in tqdm(warm_registry.to_dict(orient="records"), total=len(warm_registry)):
        sim_cfg = {k: _py_scalar(v) for k, v in dict(row).items()}
        sim_id = int(sim_cfg["sim_id"])
        
        sim_dir = warm_root / f"sim_{sim_id:04d}"
        metrics_path = sim_dir / "metrics.parquet"
        series_path = sim_dir / "series.parquet"
        raw_path = sim_dir / "raw.npz"

        run_id = make_run_id(scenario_id=warm_scenario_id, sim_id=sim_id)

        step_kwargs = {}
        if MATERIALIZE_CACHED_OUTPUTS:
            step_kwargs = {
                "materialize_cached_outputs": "requested",
                "materialize_cached_output_paths": {
                    "sim_metrics": metrics_path,
                    "sim_series": series_path,
                    "sim_raw": raw_path,
                },
            }

        with scenario.step(
            name=f"sim_{sim_id:04d}",
            run_id=run_id,
            model="simulate",
            cache_mode="reuse",
            inputs=[registry_artifact],
            config={
                **{
                    "steps": int(base_config.steps),
                    "dt": float(base_config.dt),
                    "prey_init": int(base_config.prey_init),
                    "predator_init": int(base_config.predator_init),
                },
                **sim_cfg,
            },
            **step_kwargs,
        ) as t:
            if t.is_cached:
                warm_hits += 1
                continue

            warm_misses += 1
            raw_arrays, series_df, metrics_df = run_one_simulation_with_raw(
                base_config=base_config,
                registry_row=sim_cfg,
            )
            write_parquet(metrics_df, metrics_path)
            write_parquet(series_df, series_path)
            write_npz(raw_arrays, raw_path)
            t.log_artifact(metrics_path, key="sim_metrics", direction="output")
            t.log_artifact(series_path, key="sim_series", direction="output")
            t.log_artifact(raw_path, key="sim_raw", direction="output")

print(f"warm cache hits={warm_hits} misses={warm_misses}")

# Inspect one cached run’s materialized outputs
with consist.db_session(tracker=tracker) as session:
    warm_runs = (
        session.exec(
            select(Run)
            .where(Run.parent_run_id == warm_scenario_id)
            .where(Run.model_name == "simulate")
        )
        .all()
    )
    cached = [r for r in warm_runs if (r.meta or {}).get("cache_hit")]
    print(f"warm simulate runs={len(warm_runs)} cached={len(cached)}")
    if cached:
        print("example cached run meta keys:", sorted((cached[0].meta or {}).keys()))
