diff --git a/src/microplex_us/targets/arch.py b/src/microplex_us/targets/arch.py index f590b6c..6f2edbe 100644 --- a/src/microplex_us/targets/arch.py +++ b/src/microplex_us/targets/arch.py @@ -204,6 +204,13 @@ } ) +ARCH_COMPONENT_SUM_TARGETS = { + "salt_amount": ( + "state_local_income_or_sales_tax_amount", + "real_estate_taxes_amount", + ), +} + ARCH_NATIONAL_ROLLUP_STATE_FIPS = frozenset( state_fips for state_fips in US_STATE_ABBR_BY_FIPS if state_fips != "72" ) @@ -2462,10 +2469,146 @@ def _compose_arch_model_year_records( def _with_state_to_national_rollup_records( records: list[ArchTargetRecord], ) -> list[ArchTargetRecord]: - rollups = _state_to_national_rollup_records(records) + expanded_records = _with_component_sum_records(records) + rollups = _state_to_national_rollup_records(expanded_records) if not rollups: + return expanded_records + return [*expanded_records, *rollups] + + +def _with_component_sum_records( + records: list[ArchTargetRecord], +) -> list[ArchTargetRecord]: + component_records = _component_sum_records(records) + if not component_records: return records - return [*records, *rollups] + return [*records, *component_records] + + +def _component_sum_records( + records: list[ArchTargetRecord], +) -> list[ArchTargetRecord]: + existing_keys = { + _component_sum_record_key(record, output_variable=record.variable) + for record in records + if record.target_type == "AMOUNT" + } + grouped: dict[ + tuple[Any, ...], + dict[str, ArchTargetRecord], + ] = {} + for record in records: + if record.target_type != "AMOUNT": + continue + for output_variable, component_variables in ARCH_COMPONENT_SUM_TARGETS.items(): + if record.variable not in component_variables: + continue + key = _component_sum_record_key(record, output_variable=output_variable) + if key in existing_keys: + continue + components = grouped.setdefault(key, {}) + if record.variable in components: + components.clear() + break + components[record.variable] = record + + composite_records: list[ArchTargetRecord] = [] + for key, components_by_variable in grouped.items(): + output_variable = str(key[0]) + component_variables = ARCH_COMPONENT_SUM_TARGETS[output_variable] + if set(components_by_variable) != set(component_variables): + continue + composite_records.append( + _component_records_to_sum_record( + key, + [ + components_by_variable[component_variable] + for component_variable in component_variables + ], + ) + ) + return composite_records + + +def _component_sum_record_key( + record: ArchTargetRecord, + *, + output_variable: str, +) -> tuple[Any, ...]: + return ( + output_variable, + record.target_type, + record.period, + _arch_record_geo_level(record), + record.geography_id, + tuple(sorted(record.constraints)), + _normalize_arch_source(record.source), + record.source_period, + record.aging_factors, + record.unit, + ) + + +def _component_records_to_sum_record( + key: tuple[Any, ...], + records: list[ArchTargetRecord], +) -> ArchTargetRecord: + first = records[0] + digest = sha1(repr(key).encode("utf-8")).hexdigest() + component_labels = ", ".join(record.variable for record in records) + source_tables = tuple( + dict.fromkeys(record.source_table for record in records if record.source_table) + ) + source_urls = tuple( + dict.fromkeys(record.source_url for record in records if record.source_url) + ) + source_row_keys = tuple( + dict.fromkeys( + source_row_key + for record in records + for source_row_key in ( + record.source_row_keys + or (str(record.source_target_id or record.target_id),) + ) + ) + ) + source_cell_keys = tuple( + dict.fromkeys( + source_cell_key + for record in records + for source_cell_key in record.source_cell_keys + ) + ) + notes = ( + "Microplex component sum matching PolicyEngine salt sources: " + f"{component_labels}." + ) + return replace( + first, + target_id=-int(digest[:12], 16), + stratum_id=-int(digest[12:20], 16), + variable=str(key[0]), + value=sum(record.value for record in records), + source_table=( + source_tables[0] + if len(source_tables) == 1 + else "Microplex component sum from Arch source tables" + ), + source_url=source_urls[0] if len(source_urls) == 1 else None, + notes=f"{first.notes} {notes}" if first.notes else notes, + source_record_id=f"microplex_component_sum:{digest[:16]}", + source_cell_keys=source_cell_keys, + source_row_keys=source_row_keys, + aggregate_fact_key=None, + semantic_fact_key=None, + source_target_id=None, + source_stratum_id=None, + concept=None, + source_concept=None, + concept_relation="sum_of_components", + concept_authority="policyengine_us", + concept_evidence_notes=notes, + ) def _state_to_national_rollup_records( diff --git a/tests/targets/test_arch_facts.py b/tests/targets/test_arch_facts.py index 8154db8..0958b7c 100644 --- a/tests/targets/test_arch_facts.py +++ b/tests/targets/test_arch_facts.py @@ -1260,6 +1260,16 @@ def test_arch_consumer_fact_jsonl_provider_maps_table_2_1_itemized_details( period=period, unit="usd", ), + _consumer_fact( + "soi-real-estate-taxes", + concept="irs_soi.real_estate_taxes", + domain=domain, + source_name="irs_soi", + source_table=source_table, + value=108_606_373_000, + period=period, + unit="usd", + ), _consumer_fact( "soi-mortgage-financial", concept="irs_soi.home_mortgage_interest_paid_to_financial_institutions", @@ -1305,11 +1315,17 @@ def test_arch_consumer_fact_jsonl_provider_maps_table_2_1_itemized_details( "\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n" ) - target_set = ArchConsumerFactJSONLTargetProvider(consumer_jsonl).load_target_set( - TargetQuery(period=2023) - ) + provider = ArchConsumerFactJSONLTargetProvider(consumer_jsonl) + target_set = provider.load_target_set(TargetQuery(period=2023)) targets_by_arch_variable = { - target.metadata["arch_variable"]: target for target in target_set.targets + target.metadata["arch_variable"]: target + for target in target_set.targets + if target.metadata.get("arch_variable") is not None + } + targets_by_measure = { + str(target.measure): target + for target in target_set.targets + if target.measure is not None } charitable = targets_by_arch_variable["charitable_amount"] @@ -1336,6 +1352,13 @@ def test_arch_consumer_fact_jsonl_provider_maps_table_2_1_itemized_details( "state_local_income_or_sales_tax_amount" ] assert state_local_income_sales.measure == "state_and_local_sales_or_income_tax" + assert targets_by_arch_variable["real_estate_taxes_amount"].measure == ( + "real_estate_taxes" + ) + salt = targets_by_measure["salt"] + assert salt.measure == "salt" + assert salt.value == 327_149_456_000 + assert salt.metadata["variable"] == "salt" assert targets_by_arch_variable["mortgage_interest_paid_amount"].measure == ( "deductible_mortgage_interest" ) @@ -1350,7 +1373,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_table_2_1_itemized_details( ) coverage = summarize_arch_target_profile_coverage( - ArchConsumerFactJSONLTargetProvider(consumer_jsonl), + provider, period=2023, profile_name="custom", target_cells=( @@ -1380,6 +1403,11 @@ def test_arch_consumer_fact_jsonl_provider_maps_table_2_1_itemized_details( "state_and_local_sales_or_income_tax", geo_level="national", ), + PolicyEngineUSTargetCell( + "salt", + geo_level="national", + domain_variable="salt,tax_unit_itemizes", + ), ), ) assert coverage.covered_cell_count == coverage.target_cell_count