# Iterative Workflows and Provenance Chains

Iterative equilibrium workflows show up all over scientific computing: transportation, economics,
agent-based simulation, MCMC, and more. The pattern is simple: outputs from one step become inputs
to the next until the system converges.

This tutorial focuses on Consist mechanics, not transportation modeling. You can treat the model
functions as black boxes and still learn the key patterns.

**Prerequisites:** Quickstart and Concepts Overview.

## What you'll learn

- How Consist tracks deep, multi-step iteration chains
- How incremental runs skip work you already computed
- How to query lineage to understand impact and reproducibility

## The model in 30 seconds

People choose how to travel, car trips create congestion, and congestion feeds back into the next
round of choices. We run 10 iterations and watch the system stabilize.


## Setup

We initialize the tracker and load the model functions. The details of the model live in
`examples/src/travel_demand_functions.py`, but we treat them as implementation details here.


In [None]:
from __future__ import annotations

import sys
from pathlib import Path


def _find_repo_root(start: Path) -> Path:
    for candidate in (start, *start.parents):
        if (candidate / "pyproject.toml").exists():
            return candidate
    raise RuntimeError("Could not locate repo root (missing pyproject.toml)")


REPO_ROOT = _find_repo_root(Path.cwd())
EXAMPLES_DIR = REPO_ROOT / "examples"
EXAMPLES_SRC = EXAMPLES_DIR / "src"

for path in (REPO_ROOT, EXAMPLES_SRC):
    if str(path) not in sys.path:
        sys.path.insert(0, str(path))

In [None]:
from dataclasses import asdict, replace

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os
from tqdm import tqdm

import consist
from consist import Tracker

from travel_demand_functions import (
    AssignmentParams,
    DestinationChoiceParams,
    ModeChoiceParams,
    TravelDemandScenarioConfig,
    ZoneParams,
    apply_congestion,
    apply_mode_choice,
    compute_mode_shares,
    compute_mode_utilities,
    compute_od_logsums,
    compute_od_volumes,
    create_skims_dataset,
    distribute_trips,
    generate_distances,
    generate_population,
    generate_zones,
    save_skims,
    summarize_iteration,
)

sns.set_theme(style="whitegrid")

In [None]:
EXAMPLES_DIR = REPO_ROOT / "examples"
RUN_DIR = EXAMPLES_DIR / "runs" / "travel_demand_demo"
SESSION_ID = os.getenv("CONSIST_SESSION_ID", "demo")
DB_PATH = RUN_DIR / f"travel_demand_demo_{SESSION_ID}.duckdb"
if DB_PATH.exists():
    DB_PATH.unlink()

tracker = Tracker(
    run_dir=RUN_DIR,
    db_path=DB_PATH,
    hashing_strategy="fast",
    project_root=str(RUN_DIR),
)

## Model Configuration

We set up a small, fixed configuration for a 5-zone linear city. The parameters aren't important
for understanding Consist—what matters is that the model is iterative and produces artifacts each step.


In [None]:
zone_params = ZoneParams()
mode_params = ModeChoiceParams()
dest_params = DestinationChoiceParams()
assignment_params = AssignmentParams()

DEFAULT_SEED = 0

base_config = TravelDemandScenarioConfig(
    n_iterations=10,
    seed=DEFAULT_SEED,
    zone_params=zone_params,
    mode_params=mode_params,
    dest_params=dest_params,
    assignment_params=assignment_params,
)

SCENARIO_NAME = "travel_demand_demo"

SKIM_PERTURBATION = 0.1
DESTINATION_UPDATE_SHARE = 0.2

## Workflow Structure (Iteration Loop)

The loop below is the conceptual anchor. Each iteration consumes the previous iteration's
`skims` (travel times), produces new trips, assigns congestion, and updates `skims` again.

```
┌─────────────┐
│    init     │ (iteration 0 only)
└──────┬──────┘
       │ skims, zones, population
       ▼
┌─────────────┐
│   logsums   │◄─── skims, zones
└──────┬──────┘
       │ logsums
       ▼
┌─────────────┐
│ trip_dist   │◄─── logsums, zones, population
└──────┬──────┘
       │ trips
       ▼
┌─────────────┐
│ utilities   │◄─── trips, skims, zones
└──────┬──────┘
       │ utilities
       ▼
┌─────────────┐
│ mode_choice │◄─── utilities, (prev trips_with_modes)
└──────┬──────┘
       │ trips_with_modes
       ▼
┌─────────────┐
│ assignment  │◄─── trips_with_modes
└──────┬──────┘
       │ volumes
       ▼
┌─────────────┐
│ traffic_sim │◄─── volumes, skims
└──────┬──────┘
       │ updated skims ──► next iteration
       ▼
```

For a quick explanation of `run(...)` vs `trace(...)`, see the
[Concepts Overview](../docs/concepts.md#when-to-use-each-pattern).


### Simplified run loop (shape only)

Below is a compact outline of the iterative loop. The full implementation (with logging and plotting) follows.

```python
def run_scenario(scenario_config, scenario_run_id):
    with tracker.scenario(scenario_run_id, config=..., tags=[...]) as scenario:
        scenario.run(name="init", fn=initialize_scenario, ...)
        for i in range(scenario_config.n_iterations):
            with scenario.trace(name="logsums", ...):
                ...
            scenario.run(name="trip_distribution", ...)
            scenario.run(name="calculate_utilities", ...)
            scenario.run(name="mode_choice", ...)
            scenario.run(name="assignment", ...)
            scenario.run(name="traffic_simulation", ...)
        scenario.run(name="summaries", ...)
    return results
```


## Workflow Functions (black box)

The cell below contains the full workflow definition. It's long, but you don't need to read it
line-by-line to follow the Consist concepts. The main takeaway is that each step logs artifacts
that the next step can reuse.


In [None]:
@tracker.define_step(outputs=["skims", "persons", "zones"])
def initialize_scenario(*, zone_params, mode_params, skim_perturbation, _consist_ctx):
    zones = generate_zones(zone_params)
    population = generate_population(zones)
    distances = generate_distances(zone_params)

    skims = create_skims_dataset(zones, distances, mode_params)
    skims["time_car_mins"] *= skim_perturbation

    output_dir = _consist_ctx.run_dir
    output_dir.mkdir(parents=True, exist_ok=True)
    input_skims_path = output_dir / "skims_init.zarr"

    save_skims(skims, input_skims_path)

    consist.log_artifact(input_skims_path, key="skims", direction="output")
    consist.log_dataframe(
        population,
        key="persons",
        direction="output",
    )
    consist.log_dataframe(zones, key="zones", direction="output")


@tracker.define_step(outputs=["trips"])
def distribute_trips_step(*, dest_params, seed, update_share, _consist_ctx):
    zones = _consist_ctx.load("zones")
    population = _consist_ctx.load("persons")
    prev_trips_art = _consist_ctx.inputs.get("trips")
    prev_trips = _consist_ctx.load(prev_trips_art) if prev_trips_art else None
    logsums_df = _consist_ctx.load("logsums")
    if "origin" not in logsums_df.columns or "destination" not in logsums_df.columns:
        logsums_df = logsums_df.reset_index()
    logsums = logsums_df.set_index(["origin", "destination"])["logsum"].to_xarray()
    trips = distribute_trips(
        population,
        zones,
        logsums,
        dest_params,
        seed=seed,
        prev_trips=prev_trips,
        update_share=update_share,
    )
    consist.log_dataframe(trips, key="trips")


@tracker.define_step(
    outputs=[
        "mode_shares",
        "iteration_summaries",
        "pmt_totals",
        "mode_shares_plot",
        "iteration_totals_plot",
    ]
)
def summarize_results_step(*, mode_shares, summaries, pmt_totals, _consist_ctx):
    output_dir = _consist_ctx.run_dir
    summary_dir = output_dir / "summary"
    summary_dir.mkdir(parents=True, exist_ok=True)

    mode_shares_df = (
        pd.DataFrame.from_dict(mode_shares, orient="index")
        .sort_index()
        .rename_axis("iteration")
        .reset_index()
    )
    summaries_df = (
        pd.DataFrame.from_dict(summaries, orient="index")
        .sort_values("iteration")
        .reset_index(drop=True)
    )
    pmt_totals_df = (
        pd.DataFrame.from_dict(pmt_totals, orient="index")
        .sort_index()
        .rename_axis("iteration")
        .reset_index()
    )

    mode_shares_path = summary_dir / "mode_shares.csv"
    summaries_path = summary_dir / "iteration_summaries.csv"
    pmt_totals_path = summary_dir / "pmt_totals.csv"

    shares_long = mode_shares_df.melt(
        id_vars="iteration", var_name="mode", value_name="share"
    )
    plt.figure(figsize=(8, 4))
    sns.lineplot(data=shares_long, x="iteration", y="share", hue="mode", marker="o")
    plt.title("Mode Shares by Iteration")
    plt.tight_layout()
    mode_share_plot_path = summary_dir / "mode_shares.png"
    plt.savefig(mode_share_plot_path)
    plt.close()

    pmt_long = pmt_totals_df.melt(
        id_vars="iteration", var_name="mode", value_name="pmt"
    )
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    sns.lineplot(data=summaries_df, x="iteration", y="vmt", marker="o", ax=axes[0])
    axes[0].set_title("Vehicle Miles Traveled")
    sns.lineplot(
        data=pmt_long,
        x="iteration",
        y="pmt",
        hue="mode",
        marker="o",
        ax=axes[1],
    )
    axes[1].set_title("Person Miles Traveled by Mode")
    fig.tight_layout()
    totals_plot_path = summary_dir / "iteration_totals.png"
    fig.savefig(totals_plot_path)
    plt.close(fig)

    consist.log_dataframe(
        mode_shares_df,
        key="mode_shares",
        path=mode_shares_path,
    )
    consist.log_dataframe(
        summaries_df,
        key="iteration_summaries",
        path=summaries_path,
    )
    consist.log_dataframe(
        pmt_totals_df,
        key="pmt_totals",
        path=pmt_totals_path,
    )
    consist.log_artifact(
        mode_share_plot_path,
        key="mode_shares_plot",
        direction="output",
    )
    consist.log_artifact(
        totals_plot_path,
        key="iteration_totals_plot",
        direction="output",
    )


def run_scenario(scenario_config, scenario_run_id):
    mode_shares = {}
    summaries = {}
    pmt_totals = {}

    zone_params = scenario_config.zone_params
    mode_params = scenario_config.mode_params
    dest_params = scenario_config.dest_params
    assignment_params = scenario_config.assignment_params
    seed = scenario_config.seed

    with tracker.scenario(
        scenario_run_id,
        config={
            **asdict(scenario_config),
            "scenario_name": SCENARIO_NAME,
        },
        facet_from=["n_iterations"],
        tags=["examples", "simulation", "travel_demand"],
    ) as scenario:
        cache_validation = "lazy"  # Skip output checks on cache hits for speed; use 'eager' to validate files.
        cache_hydration = (
            "inputs-missing"  # Copy cached inputs into the new run_dir on cache misses.
        )

        scenario.run(
            name="init",
            fn=initialize_scenario,
            config=asdict(mode_params)
            | asdict(zone_params)
            | {"skim_perturbation": SKIM_PERTURBATION},
            facet_from=["skim_perturbation"],
            inject_context=True,
            fn_args={
                "zone_params": zone_params,
                "mode_params": mode_params,
                "skim_perturbation": SKIM_PERTURBATION,
            },
        )

        for i in tqdm(range(scenario_config.n_iterations)):
            # Inline trace pattern (no wrapper function required).
            with scenario.trace(
                name="logsums",
                run_id=f"{scenario.run_id}_logsums_{i}",
                config=asdict(mode_params),
                inputs={"skims": "skims", "zones": "zones"},
                facet_from=[
                    "beta_cost",
                    "beta_time",
                    "asc_walk",
                    "asc_transit",
                    "fuel_cost_per_mile",
                    "transit_fare",
                ],
                validate_cached_outputs=cache_validation,
                cache_hydration=cache_hydration,
                iteration=i,
            ) as t:
                if not t.is_cached:
                    logsums = compute_od_logsums(
                        t.load(scenario.coupler.require("skims")),
                        t.load(scenario.coupler.require("zones")),
                        mode_params,
                    )
                    consist.log_dataframe(
                        logsums.to_dataframe().reset_index(),
                        key="logsums",
                    )

            scenario.run(
                name="trip_distribution",
                fn=distribute_trips_step,
                run_id=f"{scenario.run_id}_trip_distribution_{i}",
                config={
                    **asdict(dest_params),
                    "update_share": DESTINATION_UPDATE_SHARE,
                },
                inputs={
                    "skims": "skims",
                    "zones": "zones",
                    "persons": "persons",
                    "logsums": "logsums",
                },
                optional_input_keys=["trips"],
                facet_from=["beta_size", "beta_access", "update_share"],
                inject_context=True,
                fn_args={
                    "dest_params": dest_params,
                    "seed": seed + i,
                    "update_share": DESTINATION_UPDATE_SHARE,
                },
                iteration=i,
            )

            # `load_inputs=True` hydrates artifacts into function args by name.
            scenario.run(
                name="calculate_utilities",
                fn=compute_mode_utilities,
                run_id=f"{scenario.run_id}_utilities_{i}",
                config=asdict(mode_params),
                inputs={"trips": "trips", "skims": "skims", "zones": "zones"},
                facet_from=[
                    "beta_cost",
                    "beta_time",
                    "asc_walk",
                    "asc_transit",
                ],
                validate_cached_outputs=cache_validation,
                cache_hydration=cache_hydration,
                load_inputs=True,
                fn_args={"mode_params": mode_params},
                outputs=["utilities"],
                iteration=i,
            )

            mode_choice_result = scenario.run(
                name="mode_choice",
                fn=apply_mode_choice,
                run_id=f"{scenario.run_id}_mode_choice_{i}",
                config={"seed": seed},
                inputs={"utilities_df": "utilities"},
                fn_args={"seed": seed + i},
                outputs=["trips_with_modes"],
                iteration=i,
            )

            scenario.run(
                name="assignment",
                fn=compute_od_volumes,
                run_id=f"{scenario.run_id}_assignment_{i}",
                inputs={"trips": "trips_with_modes"},
                outputs=["volumes"],
                iteration=i,
            )

            scenario.run(
                name="traffic_simulation",
                fn=apply_congestion,
                run_id=f"{scenario.run_id}_traffic_simulation_{i}",
                config=asdict(assignment_params),
                inputs={"volumes": "volumes", "skims": "skims"},
                facet_from=["bpr_alpha", "bpr_beta", "base_capacity"],
                fn_args={"assignment_params": assignment_params},
                outputs=["skims"],
                iteration=i,
            )

            trips_with_modes = tracker.load(
                mode_choice_result.outputs["trips_with_modes"]
            )
            shares = compute_mode_shares(trips_with_modes)
            mode_shares[i] = pd.Series(shares)
            distance_by_mode = trips_with_modes.groupby("mode")["distance_miles"].sum()
            pmt_totals[i] = distance_by_mode.sort_index()
            summaries[i] = summarize_iteration(i, trips_with_modes, shares, 0, False)

        summary_result = scenario.run(
            name="summaries",
            fn=summarize_results_step,
            inputs={"trips_with_modes": "trips_with_modes"},
            inject_context=True,
            validate_cached_outputs=cache_validation,
            cache_hydration=cache_hydration,
            fn_args={
                "mode_shares": mode_shares,
                "summaries": summaries,
                "pmt_totals": pmt_totals,
            },
        )

        mode_shares_df = tracker.load(summary_result.outputs["mode_shares"])
        summaries_df = tracker.load(summary_result.outputs["iteration_summaries"])
        pmt_totals_df = tracker.load(summary_result.outputs["pmt_totals"])
        mode_share_plot_path = Path(
            tracker.resolve_uri(summary_result.outputs["mode_shares_plot"].uri)
        )
        totals_plot_path = Path(
            tracker.resolve_uri(summary_result.outputs["iteration_totals_plot"].uri)
        )

    return {
        "scenario_run_id": scenario_run_id,
        "mode_shares_df": mode_shares_df,
        "summaries_df": summaries_df,
        "pmt_totals_df": pmt_totals_df,
        "mode_share_plot_path": mode_share_plot_path,
        "totals_plot_path": totals_plot_path,
    }

## Baseline Run (10 iterations)

We run the baseline scenario and capture a single plot to confirm the model converges.


In [None]:
base_run_id = f"{SCENARIO_NAME}_{SESSION_ID}"
base_results = run_scenario(base_config, base_run_id)

analysis = base_results

mode_shares_df = analysis["mode_shares_df"]
summaries_df = analysis["summaries_df"]
pmt_totals_df = analysis["pmt_totals_df"]
mode_share_plot_path = analysis["mode_share_plot_path"]
totals_plot_path = analysis["totals_plot_path"]
SCENARIO_RUN_ID = analysis["scenario_run_id"]

### Convergence check

The plot below shows mode shares stabilizing over iterations.


In [None]:
from IPython.display import Image

Image(filename=str(mode_share_plot_path))

## Artifact Lineage

The lineage tree shows how the final `skims` artifact traces back through the computational graph.
Reading the tree:

- Each **artifact** (like `skims`, `volumes`) shows its key and unique ID
- Below each artifact is the **run** that produced it, with the step name, run ID, and iteration number
- Indented below each run are its **input artifacts**, which recursively show their own producers

This forms a DAG (directed acyclic graph) where you can trace any output back to the original inputs.
With `max_depth=4`, we see four levels of this chain—in a 10-iteration model, the full lineage would
be much deeper.


In [None]:
from rich import print as rprint
from rich.tree import Tree


def _add_lineage(branch, node):
    artifact = node["artifact"]
    art_label = f"{artifact.key} ({artifact.id})"
    art_branch = branch.add(art_label)
    run_node = node.get("producing_run")
    if not run_node:
        return
    run = run_node["run"]
    run_label = f"{run.model_name} run={run.id} iter={run.iteration}"
    run_branch = art_branch.add(run_label)
    for child in run_node.get("inputs", []):
        _add_lineage(run_branch, child)


final_skims = tracker.get_artifact("skims")
lineage = (
    tracker.get_artifact_lineage(final_skims.id, max_depth=4) if final_skims else None
)

if lineage:
    tree = Tree("lineage")
    _add_lineage(tree, lineage)
    rprint(tree)

## Incremental Computation Demo

We ran the baseline scenario for 10 iterations above. Now we extend to 15 iterations. Since the
extended run shares the same parameters for iterations 0-9, Consist should recognize those steps
have already been computed and skip them—only running iterations 10 through 14.


In [None]:
import os

os.environ["CONSIST_CACHE_DEBUG"] = (
    "1"  # Log cache hits/misses during the incremental run.
)
extended_config = replace(base_config, n_iterations=15)
extended_run_id = f"{SCENARIO_NAME}_{SESSION_ID}_extended"
extended_results = run_scenario(extended_config, extended_run_id)

In [None]:
# Quick cache-hit check for the extended run.
cached_runs = [
    run
    for run in tracker.find_runs(parent_id=extended_results["scenario_run_id"])
    if run.meta.get("cache_hit")
]
[(run.model_name, run.iteration, run.id) for run in cached_runs]

As expected, all steps from iterations 0-9 were retrieved from cache. The extended run only
computed the 5 new iterations, demonstrating how Consist enables incremental refinement of
iterative models.


## Cached runs and on-disk outputs

Because iterations 0-9 were cache hits, Consist does not re-run those steps or write their output
files under `examples/runs/travel_demand_demo/outputs/travel_demand_demo_demo_extended/`. The run
records and artifacts still exist in the database, and the original files still live in the earlier
run's output directory. If you want copies in a new location, you can materialize those cached
artifacts on demand.

See: [Caching and Hydration](../docs/caching-and-hydration.md) for more detail.


In [None]:
from consist.core.materialize import materialize_artifacts

extended_outputs_dir = RUN_DIR / "outputs" / extended_run_id
extended_parquet = sorted(
    p.relative_to(extended_outputs_dir)
    for p in extended_outputs_dir.rglob("*.parquet")
)
extended_parquet[:10]

cached_run = next(
    run
    for run in tracker.find_runs(parent_id=extended_results["scenario_run_id"], model="mode_choice")
    if run.iteration == 3 and run.meta.get("cache_hit")
)
cached_outputs = tracker.find_artifacts(creator=cached_run, direction="output")
cached_parquet = [
    art
    for art in cached_outputs
    if Path(tracker.resolve_uri(art.uri)).suffix == ".parquet"
]
cached_parquet_paths = [Path(tracker.resolve_uri(art.uri)) for art in cached_parquet]
[str(path).startswith(str(extended_outputs_dir)) for path in cached_parquet_paths[:3]]

materialize_dir = RUN_DIR / "outputs" / f"{extended_run_id}_materialized"
items = [
    (art, materialize_dir / Path(tracker.resolve_uri(art.uri)).name)
    for art in cached_parquet[:1]
]
materialize_artifacts(tracker, items)


## Querying Provenance

Now let's use Consist's query capabilities to explore runs and artifacts for a specific iteration.


In [None]:
# Runs for a specific iteration (e.g., iteration 5).
iteration_runs = [
    run
    for run in tracker.find_runs(parent_id=SCENARIO_RUN_ID, status="completed")
    if run.iteration == 5
]
iteration_run_df = pd.DataFrame(
    [
        {
            "model": run.model_name,
            "run_id": run.id,
            "iteration": run.iteration,
        }
        for run in iteration_runs
    ]
)
iteration_run_df

We can find all the artifacts created during that iteration.


In [None]:
iter5_artifacts = {
    run.id: [artifact.key for artifact in tracker.find_artifacts(creator=run)]
    for run in iteration_runs
}
iter5_artifacts

In [None]:
# Which steps would re-run if parking costs changed?
# Parking costs live in the zones artifact, so any run that consumes 'zones'
# would be invalidated.
zone_consumers = []
for run in tracker.find_runs(parent_id=SCENARIO_RUN_ID, status="completed"):
    artifacts = tracker.get_artifacts_for_run(run.id)
    if "zones" in artifacts.inputs:
        zone_consumers.append(run)
zone_consumer_df = pd.DataFrame(
    [
        {"model": run.model_name, "run_id": run.id, "iteration": run.iteration}
        for run in zone_consumers
    ]
)
zone_consumer_df

## Summary

This tutorial demonstrated three key provenance patterns:

1. **Deep chains**: Each iteration's outputs become the next iteration's inputs, creating
   traceable lineage dozens of steps deep.
2. **Incremental computation**: Running 15 iterations after already running 10 reused all
   prior work—only the new iterations executed.
3. **Impact analysis**: We can query which steps consume a given artifact to understand
   what would need to re-run if that artifact changed.

These patterns apply to any iterative workflow: MCMC sampling, neural network training checkpoints,
agent-based simulations, and economic equilibrium models.
