Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions src/microplex_us/pipelines/ecps_replacement_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def build_sound_ecps_replacement_comparison(
optimizer_max_iter: int = 200,
optimizer_tol: float = 1e-8,
score_consistency_tol: float = 1e-6,
target_diagnostics_top_k: int = 50,
policyengine_us_data_repo: str | Path | None = None,
policyengine_us_data_python: str | Path | None = None,
skip_tax_expenditure_targets: bool = False,
Expand Down Expand Up @@ -206,6 +207,15 @@ def build_sound_ecps_replacement_comparison(
candidate_weights=np.asarray(candidate_refit["optimized_weights"]),
baseline_weights=np.asarray(baseline_refit["optimized_weights"]),
)
target_diagnostics = _target_loss_diagnostics(
target_names=target_names,
candidate_inputs=candidate_inputs,
baseline_inputs=baseline_inputs,
candidate_weights=np.asarray(candidate_refit["optimized_weights"]),
baseline_weights=np.asarray(baseline_refit["optimized_weights"]),
holdout_mask=holdout_mask,
top_k=target_diagnostics_top_k,
)

score_summary.update(
{
Expand Down Expand Up @@ -233,6 +243,7 @@ def build_sound_ecps_replacement_comparison(
"holdout_target_fraction": float(holdout_target_fraction),
"holdout_targets": int(holdout_mask.sum()),
"protected_family_losses": protected_family_losses,
"target_diagnostics": target_diagnostics["summary"],
}
)
payload = {
Expand Down Expand Up @@ -286,6 +297,7 @@ def build_sound_ecps_replacement_comparison(
},
"summary": score_summary,
"score": pe_native_scores,
"target_diagnostics": target_diagnostics,
"candidate_refit": _strip_weights(candidate_refit),
"baseline_refit": _strip_weights(baseline_refit),
"target_split": {
Expand All @@ -306,13 +318,26 @@ def build_sound_ecps_replacement_comparison(

def write_sound_ecps_replacement_comparison(
output_path: str | Path,
target_diagnostics_path: str | Path | None = None,
**kwargs: Any,
) -> Path:
"""Write a sound eCPS replacement comparison payload."""

payload = build_sound_ecps_replacement_comparison(**kwargs)
destination = Path(output_path).expanduser().resolve()
destination.parent.mkdir(parents=True, exist_ok=True)
diagnostics_destination = (
Path(target_diagnostics_path).expanduser().resolve()
if target_diagnostics_path is not None
else destination.parent / "target_loss_diagnostics.json"
)
diagnostics_destination.parent.mkdir(parents=True, exist_ok=True)
diagnostics_destination.write_text(
json.dumps(payload["target_diagnostics"], indent=2, sort_keys=True)
)
payload.setdefault("artifacts", {})["target_loss_diagnostics"] = (
_dataset_descriptor(diagnostics_destination)
)
destination.write_text(json.dumps(payload, indent=2, sort_keys=True))
return destination

Expand Down Expand Up @@ -726,6 +751,128 @@ def _protected_family_losses(
return rows


def _target_loss_diagnostics(
*,
target_names: list[str],
candidate_inputs: dict[str, Any],
baseline_inputs: dict[str, Any],
candidate_weights: np.ndarray,
baseline_weights: np.ndarray,
holdout_mask: np.ndarray,
top_k: int,
) -> dict[str, Any]:
candidate_terms = _loss_terms(candidate_inputs, candidate_weights)
baseline_terms = _loss_terms(baseline_inputs, baseline_weights)
if candidate_terms.shape != baseline_terms.shape:
raise ValueError("candidate and baseline target loss term shapes differ")
if len(target_names) != candidate_terms.shape[0]:
raise ValueError("target name count does not match loss terms")
if holdout_mask.shape[0] != candidate_terms.shape[0]:
raise ValueError("holdout mask length does not match loss terms")

rows: list[dict[str, Any]] = []
candidate_wins = 0
baseline_wins = 0
ties = 0
for index, target_name in enumerate(target_names):
candidate_loss = float(candidate_terms[index])
baseline_loss = float(baseline_terms[index])
loss_delta = candidate_loss - baseline_loss
if np.isclose(candidate_loss, baseline_loss):
winner = "tie"
ties += 1
elif candidate_loss < baseline_loss:
winner = "candidate"
candidate_wins += 1
else:
winner = "baseline"
baseline_wins += 1
rows.append(
{
"target_index": int(index),
"target_name": str(target_name),
"family": classify_pe_native_target_family(target_name),
"split": "holdout" if bool(holdout_mask[index]) else "train",
"candidate_loss_term": candidate_loss,
"baseline_loss_term": baseline_loss,
"loss_delta": float(loss_delta),
"candidate_abs_scaled_error": float(np.sqrt(candidate_loss)),
"baseline_abs_scaled_error": float(np.sqrt(baseline_loss)),
"winner": winner,
}
)

top_k = max(0, int(top_k))
regressions = sorted(
rows,
key=lambda row: float(row["loss_delta"]),
reverse=True,
)[:top_k]
improvements = sorted(rows, key=lambda row: float(row["loss_delta"]))[:top_k]
summary = {
"n_targets": int(len(rows)),
"candidate_loss": float(candidate_terms.sum()),
"baseline_loss": float(baseline_terms.sum()),
"loss_delta": float(candidate_terms.sum() - baseline_terms.sum()),
"candidate_wins": int(candidate_wins),
"baseline_wins": int(baseline_wins),
"ties": int(ties),
"train_targets": int((~holdout_mask).sum()),
"holdout_targets": int(holdout_mask.sum()),
"top_k": int(top_k),
}
return {
"schema_version": 1,
"metric": "sound_ecps_target_loss_diagnostics",
"summary": summary,
"family_breakdown": _target_family_breakdown(rows, len(rows)),
"top_regressions": regressions,
"top_improvements": improvements,
"targets": rows,
}


def _target_family_breakdown(
target_rows: list[dict[str, Any]],
total_targets: int,
) -> list[dict[str, Any]]:
families: dict[str, list[dict[str, Any]]] = {}
for row in target_rows:
families.setdefault(str(row["family"]), []).append(row)
denominator = float(total_targets) if total_targets else 1.0
breakdown = []
for family, rows in sorted(families.items()):
candidate_loss = sum(float(row["candidate_loss_term"]) for row in rows)
baseline_loss = sum(float(row["baseline_loss_term"]) for row in rows)
breakdown.append(
{
"family": family,
"n_targets": int(len(rows)),
"train_targets": int(
sum(1 for row in rows if row["split"] == "train")
),
"holdout_targets": int(
sum(1 for row in rows if row["split"] == "holdout")
),
"candidate_loss_contribution": float(
candidate_loss / denominator
),
"baseline_loss_contribution": float(baseline_loss / denominator),
"loss_delta": float(
(candidate_loss - baseline_loss) / denominator
),
"candidate_wins": int(
sum(1 for row in rows if row["winner"] == "candidate")
),
"baseline_wins": int(
sum(1 for row in rows if row["winner"] == "baseline")
),
"ties": int(sum(1 for row in rows if row["winner"] == "tie")),
}
)
return sorted(breakdown, key=lambda row: abs(float(row["loss_delta"])), reverse=True)


def _loss_terms(loss_inputs: dict[str, Any], weights: np.ndarray) -> np.ndarray:
matrix = np.asarray(loss_inputs["scaled_matrix"], dtype=np.float64)
target = np.asarray(loss_inputs["scaled_target"], dtype=np.float64)
Expand Down Expand Up @@ -793,6 +940,10 @@ def main(argv: list[str] | None = None) -> int:
"--output-path",
help="Defaults to <output-dir>/sound_ecps_replacement_comparison.json.",
)
parser.add_argument(
"--target-diagnostics-path",
help="Defaults to <output-dir>/target_loss_diagnostics.json.",
)
parser.add_argument("--period", type=int, default=2024)
parser.add_argument("--matched-household-count", type=int)
parser.add_argument("--random-seed", type=int, default=20260529)
Expand All @@ -810,6 +961,7 @@ def main(argv: list[str] | None = None) -> int:
parser.add_argument("--optimizer-max-iter", type=int, default=200)
parser.add_argument("--optimizer-tol", type=float, default=1e-8)
parser.add_argument("--score-consistency-tol", type=float, default=1e-6)
parser.add_argument("--target-diagnostics-top-k", type=int, default=50)
parser.add_argument("--policyengine-us-data-repo")
parser.add_argument("--policyengine-us-data-python")
parser.add_argument("--skip-tax-expenditure-targets", action="store_true")
Expand All @@ -824,6 +976,7 @@ def main(argv: list[str] | None = None) -> int:
)
written = write_sound_ecps_replacement_comparison(
output_path,
target_diagnostics_path=args.target_diagnostics_path,
candidate_dataset_path=args.candidate_dataset,
baseline_dataset_path=args.baseline_dataset,
output_dir=output_dir,
Expand All @@ -836,6 +989,7 @@ def main(argv: list[str] | None = None) -> int:
optimizer_max_iter=args.optimizer_max_iter,
optimizer_tol=args.optimizer_tol,
score_consistency_tol=args.score_consistency_tol,
target_diagnostics_top_k=args.target_diagnostics_top_k,
policyengine_us_data_repo=args.policyengine_us_data_repo,
policyengine_us_data_python=args.policyengine_us_data_python,
skip_tax_expenditure_targets=args.skip_tax_expenditure_targets,
Expand Down
50 changes: 50 additions & 0 deletions tests/pipelines/test_ecps_replacement_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,23 @@ def test_sound_ecps_replacement_comparison_satisfies_gate_contract(
"household_net_income",
}
assert summary["protected_family_losses"]["wages"]["n_targets"] == 1
target_diagnostics = payload["target_diagnostics"]
assert target_diagnostics["summary"]["n_targets"] == len(_TARGET_NAMES)
assert (
target_diagnostics["summary"]["candidate_wins"]
+ target_diagnostics["summary"]["baseline_wins"]
+ target_diagnostics["summary"]["ties"]
== len(_TARGET_NAMES)
)
assert target_diagnostics["summary"]["train_targets"] > 0
assert target_diagnostics["summary"]["holdout_targets"] > 0
assert target_diagnostics["top_regressions"]
assert target_diagnostics["top_improvements"]
assert len(target_diagnostics["targets"]) == len(_TARGET_NAMES)
assert {
row["split"] for row in target_diagnostics["targets"]
} == {"train", "holdout"}
assert target_diagnostics["family_breakdown"]
structure = payload["entity_structure"]["candidate_matched"]
assert structure["household_count"] == 2
assert structure["person_count"] == 3
Expand Down Expand Up @@ -313,6 +330,39 @@ def test_sound_ecps_replacement_comparison_satisfies_gate_contract(
assert gate_report["gates"]["ecps_comparison"]["status"] == "pass"


def test_sound_ecps_replacement_comparison_writes_target_diagnostics_sidecar(
monkeypatch,
tmp_path,
):
candidate = _write_minimal_policyengine_dataset(tmp_path / "candidate.h5")
baseline = _write_minimal_policyengine_dataset(tmp_path / "baseline.h5")
output_dir = tmp_path / "comparison"
output_path = output_dir / "comparison.json"
monkeypatch.setattr(ecps, "_extract_pe_native_loss_inputs", _fake_loss_inputs)
monkeypatch.setattr(ecps, "compute_us_pe_native_scores", _fake_pe_native_scores)

written = ecps.write_sound_ecps_replacement_comparison(
output_path,
candidate_dataset_path=candidate,
baseline_dataset_path=baseline,
output_dir=output_dir,
optimizer_max_iter=50,
target_diagnostics_top_k=3,
)

payload = json.loads(written.read_text())
diagnostics_path = output_dir / "target_loss_diagnostics.json"
diagnostics_payload = json.loads(diagnostics_path.read_text())
descriptor = payload["artifacts"]["target_loss_diagnostics"]

assert descriptor["path"] == str(diagnostics_path.resolve())
assert descriptor["size_bytes"] == diagnostics_path.stat().st_size
assert payload["target_diagnostics"] == diagnostics_payload
assert diagnostics_payload["summary"]["top_k"] == 3
assert len(diagnostics_payload["top_regressions"]) == 3
assert len(diagnostics_payload["top_improvements"]) == 3


def test_sound_ecps_replacement_comparison_flags_score_mismatch(
monkeypatch,
tmp_path,
Expand Down
Loading