Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions src/microplex_us/pipelines/pe_us_recalibrate_from_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,27 @@

import argparse
import json
import os
import sys
from collections.abc import Sequence
from pathlib import Path
from typing import Sequence

from microplex_us.pipelines.us import (
USMicroplexBuildConfig,
recalibrate_policyengine_us_from_checkpoint,
)


def _prepare_output_root(output_root: Path) -> Path:
if not output_root.exists():
raise FileNotFoundError(f"--output-root does not exist: {output_root}")
if not output_root.is_dir():
raise NotADirectoryError(f"--output-root is not a directory: {output_root}")
if not os.access(output_root, os.W_OK | os.X_OK):
raise PermissionError(f"--output-root is not writable: {output_root}")
return output_root


def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(
description=(
Expand All @@ -51,7 +62,7 @@ def main(argv: Sequence[str] | None = None) -> int:
"--output-root",
type=Path,
required=True,
help="Output directory for the recalibrated bundle and summary.",
help="Existing output directory for the recalibrated bundle and summary.",
)
parser.add_argument(
"--targets-db",
Expand Down Expand Up @@ -96,6 +107,7 @@ def main(argv: Sequence[str] | None = None) -> int:
),
)
args = parser.parse_args(argv)
output_root = _prepare_output_root(args.output_root)

config_kwargs: dict[str, object] = {
"calibration_backend": args.calibration_backend,
Expand All @@ -116,20 +128,19 @@ def main(argv: Sequence[str] | None = None) -> int:
config = USMicroplexBuildConfig(**config_kwargs)
result = recalibrate_policyengine_us_from_checkpoint(config, args.checkpoint_path)

args.output_root.mkdir(parents=True, exist_ok=True)
result.calibrated_data.to_parquet(args.output_root / "calibrated_data.parquet")
result.calibrated_data.to_parquet(output_root / "calibrated_data.parquet")
result.policyengine_tables.households.to_parquet(
args.output_root / "households.parquet"
output_root / "households.parquet"
)
if result.policyengine_tables.persons is not None:
result.policyengine_tables.persons.to_parquet(
args.output_root / "persons.parquet"
output_root / "persons.parquet"
)
(args.output_root / "calibration_summary.json").write_text(
(output_root / "calibration_summary.json").write_text(
json.dumps(result.calibration_summary, indent=2, default=str)
)
print(
f"Recalibrated from {args.checkpoint_path} → {args.output_root} "
f"Recalibrated from {args.checkpoint_path} → {output_root} "
f"(stage={result.loaded_stage}, "
f"rows={len(result.calibrated_data)})"
)
Expand Down
141 changes: 135 additions & 6 deletions tests/pipelines/test_recalibrate_from_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@

1. The helper loads a post-imputation checkpoint and dispatches the
bundle to a fresh pipeline's calibrate method.
2. The helper rejects post-microsim checkpoints in v1 (resume from that
stage needs pickled constraints, which is a follow-up).
2. The helper also accepts post-microsim checkpoints, where materialized
target columns already exist on the bundle.
3. The helper raises a clear error if the checkpoint directory is
missing.
"""

from __future__ import annotations

import os
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -71,7 +71,9 @@ def test_checkpoint_dispatches_to_calibrate(
orchestrates the load and hand-off, so the parametrized test
covers both paths.
"""
from microplex_us.pipelines.us import recalibrate_policyengine_us_from_checkpoint
from microplex_us.pipelines.us import (
recalibrate_policyengine_us_from_checkpoint,
)

bundle = _make_bundle(n=40)
save_us_pipeline_checkpoint(
Expand Down Expand Up @@ -114,7 +116,9 @@ def _fake_calibrate(

def test_unsupported_stage_raises(self, tmp_path: Path) -> None:
"""A metadata.json with an unknown stage is rejected."""
from microplex_us.pipelines.us import recalibrate_policyengine_us_from_checkpoint
from microplex_us.pipelines.us import (
recalibrate_policyengine_us_from_checkpoint,
)

(tmp_path / "checkpoint").mkdir()
import json
Expand All @@ -127,8 +131,133 @@ def test_unsupported_stage_raises(self, tmp_path: Path) -> None:
recalibrate_policyengine_us_from_checkpoint(cfg, tmp_path / "checkpoint")

def test_missing_checkpoint_raises(self, tmp_path: Path) -> None:
from microplex_us.pipelines.us import recalibrate_policyengine_us_from_checkpoint
from microplex_us.pipelines.us import (
recalibrate_policyengine_us_from_checkpoint,
)

cfg = USMicroplexBuildConfig(policyengine_targets_db=tmp_path / "targets.db")
with pytest.raises(FileNotFoundError):
recalibrate_policyengine_us_from_checkpoint(cfg, tmp_path / "nope")


class TestRecalibrateFromCheckpointCli:
def test_prepare_output_root_accepts_existing_empty_directory(
self,
tmp_path: Path,
) -> None:
from microplex_us.pipelines.pe_us_recalibrate_from_checkpoint import (
_prepare_output_root,
)

output_root = tmp_path / "output"
output_root.mkdir()

assert _prepare_output_root(output_root) == output_root
assert output_root.is_dir()
assert list(output_root.iterdir()) == []

def test_prepare_output_root_rejects_missing_directory(
self,
tmp_path: Path,
) -> None:
from microplex_us.pipelines.pe_us_recalibrate_from_checkpoint import (
_prepare_output_root,
)

output_root = tmp_path / "output"

with pytest.raises(FileNotFoundError, match="--output-root does not exist"):
_prepare_output_root(output_root)
assert not output_root.exists()

def test_prepare_output_root_rejects_unwritable_directory(
self,
tmp_path: Path,
) -> None:
from microplex_us.pipelines.pe_us_recalibrate_from_checkpoint import (
_prepare_output_root,
)

output_root = tmp_path / "output"
output_root.mkdir()
original_mode = output_root.stat().st_mode
try:
output_root.chmod(0o500)
if os.access(output_root, os.W_OK | os.X_OK):
pytest.skip("current platform still reports chmod 0500 as writable")
with pytest.raises(PermissionError, match="--output-root is not writable"):
_prepare_output_root(output_root)
finally:
output_root.chmod(original_mode)

def test_main_rejects_output_file_before_recalibration(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
import microplex_us.pipelines.pe_us_recalibrate_from_checkpoint as cli

called = False

def _fail_if_called(*args: Any, **kwargs: Any) -> None:
nonlocal called
called = True
raise AssertionError("recalibration should not start")

monkeypatch.setattr(
cli,
"recalibrate_policyengine_us_from_checkpoint",
_fail_if_called,
)
output_root = tmp_path / "output"
output_root.write_text("not a directory")

with pytest.raises(NotADirectoryError, match="--output-root is not a directory"):
cli.main(
[
"--checkpoint-path",
str(tmp_path / "checkpoint"),
"--output-root",
str(output_root),
"--targets-db",
str(tmp_path / "targets.db"),
]
)

assert called is False

def test_main_rejects_missing_output_directory_before_recalibration(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
import microplex_us.pipelines.pe_us_recalibrate_from_checkpoint as cli

called = False

def _fail_if_called(*args: Any, **kwargs: Any) -> None:
nonlocal called
called = True
raise AssertionError("recalibration should not start")

monkeypatch.setattr(
cli,
"recalibrate_policyengine_us_from_checkpoint",
_fail_if_called,
)
output_root = tmp_path / "output"

with pytest.raises(FileNotFoundError, match="--output-root does not exist"):
cli.main(
[
"--checkpoint-path",
str(tmp_path / "checkpoint"),
"--output-root",
str(output_root),
"--targets-db",
str(tmp_path / "targets.db"),
]
)

assert called is False
assert not output_root.exists()
Loading