diff --git a/src/microplex_us/pipelines/ecps_replacement_comparison.py b/src/microplex_us/pipelines/ecps_replacement_comparison.py index ee0b2ad..3853d1d 100644 --- a/src/microplex_us/pipelines/ecps_replacement_comparison.py +++ b/src/microplex_us/pipelines/ecps_replacement_comparison.py @@ -258,6 +258,32 @@ def build_sound_ecps_replacement_comparison( "holdout_targets": int(holdout_mask.sum()), "protected_family_losses": protected_family_losses, }, + "entity_structure": { + "candidate_source": _entity_structure_summary( + candidate_path, + period=period, + ), + "baseline_source": _entity_structure_summary( + baseline_path, + period=period, + ), + "candidate_matched": _entity_structure_summary( + matched_candidate_path, + period=period, + ), + "baseline_matched": _entity_structure_summary( + matched_baseline_path, + period=period, + ), + "candidate_refit": _entity_structure_summary( + candidate_refit_path, + period=period, + ), + "baseline_refit": _entity_structure_summary( + baseline_refit_path, + period=period, + ), + }, "summary": score_summary, "score": pe_native_scores, "candidate_refit": _strip_weights(candidate_refit), @@ -338,6 +364,146 @@ def _household_weights( return household_ids, weights +def _entity_structure_summary( + dataset_path: str | Path, + *, + period: int, +) -> dict[str, Any]: + path = Path(dataset_path).expanduser().resolve() + period_key = str(period) + with h5py.File(path, "r") as handle: + household_ids = _read_period_array(handle, "household_id", period_key) + person_ids = _read_period_array(handle, "person_id", period_key) + person_household_ids = _read_period_array( + handle, + "person_household_id", + period_key, + ) + if person_ids.shape[0] != person_household_ids.shape[0]: + raise ValueError( + f"{path} person_id and person_household_id lengths differ" + ) + + household_count = int(household_ids.shape[0]) + summary: dict[str, Any] = { + "dataset": str(path), + "period": int(period), + "household_count": household_count, + "person_count": int(person_ids.shape[0]), + } + for entity in ("tax_unit", "spm_unit", "family", "marital_unit"): + plural = _ENTITY_PLURALS[entity] + entity_summary = _entity_membership_summary( + handle, + entity=entity, + period_key=period_key, + person_household_ids=person_household_ids, + household_count=household_count, + dataset_path=path, + ) + summary[entity] = entity_summary + summary[f"{entity}_count"] = entity_summary["unit_count"] + summary[f"{plural}_per_household"] = entity_summary["units_per_household"] + return summary + + +_ENTITY_PLURALS = { + "tax_unit": "tax_units", + "spm_unit": "spm_units", + "family": "families", + "marital_unit": "marital_units", +} + + +def _read_period_array( + handle: h5py.File, + variable: str, + period_key: str, +) -> np.ndarray: + if variable not in handle or period_key not in handle[variable]: + raise ValueError(f"Dataset is missing {variable}/{period_key}") + return np.asarray(handle[variable][period_key], dtype=np.int64) + + +def _entity_membership_summary( + handle: h5py.File, + *, + entity: str, + period_key: str, + person_household_ids: np.ndarray, + household_count: int, + dataset_path: Path, +) -> dict[str, Any]: + entity_ids = _read_period_array(handle, f"{entity}_id", period_key) + person_entity_ids = _read_period_array( + handle, + f"person_{entity}_id", + period_key, + ) + if person_entity_ids.shape[0] != person_household_ids.shape[0]: + raise ValueError( + f"{dataset_path} person_{entity}_id and person_household_id " + "lengths differ" + ) + unique_entity_ids = np.unique(entity_ids) + duplicate_unit_id_count = int(entity_ids.shape[0] - unique_entity_ids.shape[0]) + unique_person_entity_ids, inverse = np.unique( + person_entity_ids, + return_inverse=True, + ) + member_counts = np.bincount(inverse) + singleton_count = int(np.count_nonzero(member_counts == 1)) + empty_unit_count = int( + np.setdiff1d(unique_entity_ids, unique_person_entity_ids).size + ) + missing_referenced_unit_count = int( + np.setdiff1d(unique_person_entity_ids, unique_entity_ids).size + ) + cross_household_count = _cross_household_entity_count( + inverse, + person_household_ids, + ) + unit_count = int(entity_ids.shape[0]) + return { + "unit_count": unit_count, + "person_membership_count": int(person_entity_ids.shape[0]), + "duplicate_unit_id_count": duplicate_unit_id_count, + "units_per_household": ( + float(unit_count / household_count) if household_count else None + ), + "singleton_unit_count": singleton_count, + "singleton_unit_share": ( + float(singleton_count / unit_count) if unit_count else None + ), + "empty_unit_count": empty_unit_count, + "missing_referenced_unit_count": missing_referenced_unit_count, + "cross_household_unit_count": cross_household_count, + } + + +def _cross_household_entity_count( + entity_inverse: np.ndarray, + person_household_ids: np.ndarray, +) -> int: + if entity_inverse.size == 0: + return 0 + order = np.argsort(entity_inverse, kind="stable") + sorted_entity = entity_inverse[order] + sorted_household = person_household_ids[order] + boundaries = np.concatenate( + ( + np.asarray([0]), + np.flatnonzero(np.diff(sorted_entity)) + 1, + np.asarray([sorted_entity.size]), + ) + ) + cross_household_count = 0 + for start, stop in zip(boundaries[:-1], boundaries[1:], strict=True): + if np.unique(sorted_household[start:stop]).size > 1: + cross_household_count += 1 + return cross_household_count + + def _extract_pe_native_loss_inputs( *, input_dataset_path: str | Path, diff --git a/tests/pipelines/test_ecps_replacement_comparison.py b/tests/pipelines/test_ecps_replacement_comparison.py index 34d5db5..ce3d724 100644 --- a/tests/pipelines/test_ecps_replacement_comparison.py +++ b/tests/pipelines/test_ecps_replacement_comparison.py @@ -264,6 +264,20 @@ def test_sound_ecps_replacement_comparison_satisfies_gate_contract( "household_net_income", } assert summary["protected_family_losses"]["wages"]["n_targets"] == 1 + structure = payload["entity_structure"]["candidate_matched"] + assert structure["household_count"] == 2 + assert structure["person_count"] == 3 + assert structure["tax_unit_count"] == 2 + assert structure["tax_unit"]["singleton_unit_count"] == 1 + assert structure["tax_unit"]["singleton_unit_share"] == pytest.approx(0.5) + assert structure["tax_unit"]["duplicate_unit_id_count"] == 0 + assert structure["tax_unit"]["missing_referenced_unit_count"] == 0 + assert structure["tax_unit"]["cross_household_unit_count"] == 0 + assert structure["spm_unit_count"] == 2 + assert structure["family_count"] == 2 + assert structure["marital_unit_count"] == 3 + assert structure["marital_unit"]["singleton_unit_share"] == pytest.approx(1.0) + assert payload["entity_structure"]["baseline_refit"]["household_count"] == 2 candidate_curve = payload["candidate_refit"]["loss_curve"] baseline_curve = payload["baseline_refit"]["loss_curve"] assert candidate_curve[0]["iteration"] == 0