# Consist End-to-End (Monte Carlo Predator–Prey Sweep)

This notebook is a **happy-path** walkthrough of Consist in a scientific simulation workflow:

1. Run a single simulation to preview the time series
2. Define a sweep (parameters + seeds)
3. Run one simulation per Consist step (each sweep member has provenance)
4. Export a SQLModel stub for the series artifact
5. Aggregate per-sim outputs with typed queries (no raw SQL)
6. Visualize summary metrics
7. Inspect what landed in the DuckDB file


## 0) Setup

This notebook writes outputs under `examples/runs/` and a DuckDB file under `examples/`.


In [None]:
from __future__ import annotations

import sys
from pathlib import Path
import uuid

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.display import display
from tqdm import tqdm

from consist import Tracker, pivot_facets, run_query


def _find_repo_root(start: Path) -> Path:
    for candidate in (start, *start.parents):
        if (candidate / "pyproject.toml").exists():
            return candidate
    raise RuntimeError("Could not locate repo root (missing pyproject.toml)")


REPO_ROOT = _find_repo_root(Path.cwd())
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from examples.pipeline_steps import (
    build_sweep_registry,
    make_run_id,
    run_one_simulation,
    write_parquet,
)
from examples.synth_simulation import PredatorPreyConfig

sns.set_theme(style="whitegrid")

EXAMPLES_DIR = REPO_ROOT / "examples"
RUN_DIR = EXAMPLES_DIR / "runs" / "predator_prey_demo"
SESSION_ID = uuid.uuid4().hex[:8]
DB_PATH = EXAMPLES_DIR / f"predator_prey_demo_{SESSION_ID}.duckdb"

tracker = Tracker(run_dir=RUN_DIR, db_path=DB_PATH)

## 1) Single run preview

Run a single simulation to see the raw time-series shape before we sweep.


In [None]:
base_config = PredatorPreyConfig(
    steps=250,
    sample_every=1,
    prey_init=80,
    predator_init=25,
    predator_birth_efficiency=0.20,
)

preview_row = {
    "seed": 7,
    "prey_birth_rate": 0.70,
    "predation_rate": 0.011,
    "predator_birth_efficiency": 0.20,
    "predator_death_rate": 0.18,
}

preview_series = run_one_simulation(
    base_config=base_config,
    registry_row=preview_row,
)
display(preview_series.head())

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(preview_series["t"], preview_series["prey"], label="prey")
ax.plot(preview_series["t"], preview_series["predator"], label="predator")
ax.set_title("Single-run trajectory (preview)")
ax.set_xlabel("t")
ax.set_ylabel("population")
ax.legend()
plt.show()

## 2) Define the sweep

We sweep a small grid of parameters (and optionally replicates per setting). Each row gets a stable `sim_id` and per-run seed.


In [None]:
registry = build_sweep_registry(
    # Tuned grid to highlight a transition region (coexistence ↔ extinction).
    # (The stochastic dynamics can be sensitive; this set usually shows mixed regimes.)
    prey_birth_rates=[0.60, 0.7, 0.75],
    predation_rates=[0.008, 0.011, 0.014],
    predator_death_rates=[0.15],
    predator_birth_efficiency=0.20,
    replicates_per_setting=10,
    seed=7,
)
print(f"sims={len(registry)} settings={registry['setting_id'].nunique()}")
registry.head()

## 3) Cold run: one simulation per Consist step

Each step writes a sampled time-series table (`sim_*/series.parquet`).


In [None]:
from dataclasses import asdict


scenario_id = "predator_prey_sweep"
artifact_root = RUN_DIR / "artifacts" / scenario_id

sample_series_artifact = None

with tracker.scenario(
    scenario_id,
    config={
        "base": asdict(base_config),
        "sweep": {
            "prey_birth_rates": sorted(registry["prey_birth_rate"].unique().tolist()),
            "predation_rates": sorted(registry["predation_rate"].unique().tolist()),
            "predator_death_rates": sorted(
                registry["predator_death_rate"].unique().tolist()
            ),
            "replicates_per_setting": int(registry["replicate_id"].max() + 1),
        },
    },
    tags=["examples", "simulation", "predator_prey"],
) as scenario:
    for row in tqdm(registry.to_dict(orient="records"), total=len(registry)):
        sim_cfg = dict(row)
        sim_id = int(sim_cfg["sim_id"])
        setting_id = int(sim_cfg["setting_id"])
        replicate_id = int(sim_cfg["replicate_id"])

        sim_dir = artifact_root / f"sim_{sim_id:04d}"
        series_path = sim_dir / "series.parquet"

        run_id = make_run_id(scenario_id=scenario_id, sim_id=sim_id)
        with scenario.step(
            name=f"sim_{sim_id:04d}",
            run_id=run_id,
            model="simulate",
            config={
                **{
                    "steps": int(base_config.steps),
                    "dt": float(base_config.dt),
                    "prey_init": int(base_config.prey_init),
                    "predator_init": int(base_config.predator_init),
                },
                **sim_cfg,
            },
            facet={
                "sim_id": sim_id,
                "setting_id": setting_id,
                "replicate_id": replicate_id,
                "prey_birth_rate": float(sim_cfg["prey_birth_rate"]),
                "predation_rate": float(sim_cfg["predation_rate"]),
                "predator_death_rate": float(sim_cfg["predator_death_rate"]),
            },
            tags=["sim"],
        ) as t:
            series_df = run_one_simulation(
                base_config=base_config,
                registry_row=sim_cfg,
            )
            write_parquet(series_df, series_path)

            series_artifact = t.log_artifact(
                series_path,
                key="sim_series",
                direction="output",
                rows=int(len(series_df)),
            )

            if sample_series_artifact is None:
                sample_series_artifact = series_artifact

assert sample_series_artifact is not None

## 4) Schema export (generated)

Export a SQLModel stub for the series artifact and compare it to the checked-in
version with richer constraints and annotations.


In [None]:
from examples.checked_models import PredatorPreySeriesChecked

tracker.ingest(sample_series_artifact)

tracker.export_schema_sqlmodel(
    artifact_id=sample_series_artifact.id,
    # out_path=EXAMPLES_DIR / "generated" / "predator_prey_series_preview.py", # You can also save the model definition to a file path
    class_name="PredatorPreySeriesGenerated",
    table_name="sim_series",
)

print(series_stub_path.read_text())
print("Checked-in model:", PredatorPreySeriesChecked.__name__)

## 5) Hybrid views (hot + cold)

We ingest one series artifact to create a **hot** table in DuckDB, then use Consist’s
typed hybrid views to query across that table and the remaining **cold** Parquet files.
Summary statistics are derived from the run config facets (not embedded in the artifacts).


In [None]:
from sqlalchemy import case, func
from sqlmodel import select

from examples.checked_models import PredatorPreySeriesChecked, PredatorPreySeriesHot

tracker.views.register(PredatorPreySeriesChecked, key="sim_series")

SeriesView = tracker.views.PredatorPreySeriesChecked

hot_count = run_query(
    select(func.count()).select_from(PredatorPreySeriesHot), tracker=tracker
)[0]
view_count = run_query(select(func.count()).select_from(SeriesView), tracker=tracker)[0]

display(pd.DataFrame({"rows": [hot_count]}, index=["hot_table"]))
display(pd.DataFrame({"rows": [view_count]}, index=["hybrid_view"]))

params_stmt = pivot_facets(
    namespace="simulate",
    keys=["prey_birth_rate", "predation_rate", "predator_death_rate"],
)

series_rollup = (
    select(
        SeriesView.consist_run_id.label("run_id"),
        func.max(SeriesView.prey).label("prey_peak"),
        func.max(SeriesView.predator).label("predator_peak"),
        func.arg_max(SeriesView.prey, SeriesView.t).label("prey_final"),
        func.arg_max(SeriesView.predator, SeriesView.t).label("predator_final"),
        func.min(SeriesView.prey).label("prey_min"),
        func.min(SeriesView.predator).label("predator_min"),
    )
    .group_by(SeriesView.consist_run_id)
    .subquery()
)

summary_stmt = (
    select(
        params_stmt.c.prey_birth_rate,
        params_stmt.c.predation_rate,
        params_stmt.c.predator_death_rate,
        func.count().label("sims"),
        func.avg(case((series_rollup.c.prey_min == 0, 1), else_=0)).label(
            "prey_extinct_rate"
        ),
        func.avg(case((series_rollup.c.predator_min == 0, 1), else_=0)).label(
            "predator_extinct_rate"
        ),
        func.avg(series_rollup.c.prey_final).label("mean_prey_final"),
        func.avg(series_rollup.c.predator_final).label("mean_predator_final"),
        func.avg(series_rollup.c.prey_peak).label("mean_prey_peak"),
        func.avg(series_rollup.c.predator_peak).label("mean_predator_peak"),
    )
    .join(series_rollup, series_rollup.c.run_id == params_stmt.c.run_id)
    .group_by(
        params_stmt.c.prey_birth_rate,
        params_stmt.c.predation_rate,
        params_stmt.c.predator_death_rate,
    )
)

summary_rows = run_query(summary_stmt, tracker=tracker)
summary_df = pd.DataFrame([row._mapping for row in summary_rows])

summary_path = artifact_root / "sweep_summary.parquet"
write_parquet(summary_df, summary_path)

summary_df.sort_values(
    ["predator_extinct_rate", "prey_extinct_rate"], ascending=False
).head(8)

## 6) Sweep summary

A single summary table makes it easy to see “regimes” (stable coexistence vs extinction) without looking at every trajectory.


In [None]:
summary_df = pd.read_parquet(artifact_root / "sweep_summary.parquet")

# Example: visualize extinction risk as a heatmap (fix death rate to one slice)
death_rate = float(sorted(summary_df["predator_death_rate"].unique())[0])
slice_df = summary_df[summary_df["predator_death_rate"] == death_rate]
pivot = slice_df.pivot_table(
    index="prey_birth_rate",
    columns="predation_rate",
    values="predator_extinct_rate",
)

ax = sns.heatmap(pivot.sort_index(), annot=True, fmt=".2f", cmap="mako", vmin=0, vmax=1)
ax.set_title(f"Predator extinction rate (predator_death_rate={death_rate})")
ax.set_xlabel("predation_rate")
ax.set_ylabel("prey_birth_rate")
plt.show()

## 7) Inspect the database

Consist persists run metadata in DuckDB; ingested artifacts live in 
`global_tables`, while views like `v_sim_series` stitch hot and cold data together.


In [None]:
from sqlalchemy import inspect
from sqlmodel import select

from consist.models import Artifact, ConfigFacet, Run, RunConfigKV

inspector = inspect(tracker.engine)
main_tables = inspector.get_table_names()
global_tables = inspector.get_table_names(schema="global_tables")
views = inspector.get_view_names()

display(pd.DataFrame({"main_tables": main_tables}))
display(pd.DataFrame({"global_tables": global_tables}))
display(pd.DataFrame({"views": views}))

table_counts = {
    "run": run_query(select(func.count()).select_from(Run), tracker=tracker)[0],
    "artifact": run_query(select(func.count()).select_from(Artifact), tracker=tracker)[
        0
    ],
    "config_facet": run_query(
        select(func.count()).select_from(ConfigFacet), tracker=tracker
    )[0],
    "run_config_kv": run_query(
        select(func.count()).select_from(RunConfigKV), tracker=tracker
    )[0],
}
display(pd.DataFrame({"rows": table_counts}))

print(f"DuckDB file: {DB_PATH}")
print(f"Open UI in a terminal: duckdb -ui {DB_PATH}")