Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 145 additions & 2 deletions src/microplex_us/targets/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@
}
)

ARCH_COMPONENT_SUM_TARGETS = {
"salt_amount": (
"state_local_income_or_sales_tax_amount",
"real_estate_taxes_amount",
),
}

ARCH_NATIONAL_ROLLUP_STATE_FIPS = frozenset(
state_fips for state_fips in US_STATE_ABBR_BY_FIPS if state_fips != "72"
)
Expand Down Expand Up @@ -2462,10 +2469,146 @@ def _compose_arch_model_year_records(
def _with_state_to_national_rollup_records(
records: list[ArchTargetRecord],
) -> list[ArchTargetRecord]:
rollups = _state_to_national_rollup_records(records)
expanded_records = _with_component_sum_records(records)
rollups = _state_to_national_rollup_records(expanded_records)
if not rollups:
return expanded_records
return [*expanded_records, *rollups]


def _with_component_sum_records(
records: list[ArchTargetRecord],
) -> list[ArchTargetRecord]:
component_records = _component_sum_records(records)
if not component_records:
return records
return [*records, *rollups]
return [*records, *component_records]


def _component_sum_records(
records: list[ArchTargetRecord],
) -> list[ArchTargetRecord]:
existing_keys = {
_component_sum_record_key(record, output_variable=record.variable)
for record in records
if record.target_type == "AMOUNT"
}
grouped: dict[
tuple[Any, ...],
dict[str, ArchTargetRecord],
] = {}
for record in records:
if record.target_type != "AMOUNT":
continue
for output_variable, component_variables in ARCH_COMPONENT_SUM_TARGETS.items():
if record.variable not in component_variables:
continue
key = _component_sum_record_key(record, output_variable=output_variable)
if key in existing_keys:
continue
components = grouped.setdefault(key, {})
if record.variable in components:
components.clear()
break
components[record.variable] = record

composite_records: list[ArchTargetRecord] = []
for key, components_by_variable in grouped.items():
output_variable = str(key[0])
component_variables = ARCH_COMPONENT_SUM_TARGETS[output_variable]
if set(components_by_variable) != set(component_variables):
continue
composite_records.append(
_component_records_to_sum_record(
key,
[
components_by_variable[component_variable]
for component_variable in component_variables
],
)
)
return composite_records


def _component_sum_record_key(
record: ArchTargetRecord,
*,
output_variable: str,
) -> tuple[Any, ...]:
return (
output_variable,
record.target_type,
record.period,
_arch_record_geo_level(record),
record.geography_id,
tuple(sorted(record.constraints)),
_normalize_arch_source(record.source),
record.source_period,
record.aging_factors,
record.unit,
)


def _component_records_to_sum_record(
key: tuple[Any, ...],
records: list[ArchTargetRecord],
) -> ArchTargetRecord:
first = records[0]
digest = sha1(repr(key).encode("utf-8")).hexdigest()
component_labels = ", ".join(record.variable for record in records)
source_tables = tuple(
dict.fromkeys(record.source_table for record in records if record.source_table)
)
source_urls = tuple(
dict.fromkeys(record.source_url for record in records if record.source_url)
)
source_row_keys = tuple(
dict.fromkeys(
source_row_key
for record in records
for source_row_key in (
record.source_row_keys
or (str(record.source_target_id or record.target_id),)
)
)
)
source_cell_keys = tuple(
dict.fromkeys(
source_cell_key
for record in records
for source_cell_key in record.source_cell_keys
)
)
notes = (
"Microplex component sum matching PolicyEngine salt sources: "
f"{component_labels}."
)
return replace(
first,
target_id=-int(digest[:12], 16),
stratum_id=-int(digest[12:20], 16),
variable=str(key[0]),
value=sum(record.value for record in records),
source_table=(
source_tables[0]
if len(source_tables) == 1
else "Microplex component sum from Arch source tables"
),
source_url=source_urls[0] if len(source_urls) == 1 else None,
notes=f"{first.notes} {notes}" if first.notes else notes,
source_record_id=f"microplex_component_sum:{digest[:16]}",
source_cell_keys=source_cell_keys,
source_row_keys=source_row_keys,
aggregate_fact_key=None,
semantic_fact_key=None,
source_target_id=None,
source_stratum_id=None,
concept=None,
source_concept=None,
concept_relation="sum_of_components",
concept_authority="policyengine_us",
concept_evidence_notes=notes,
)


def _state_to_national_rollup_records(
Expand Down
38 changes: 33 additions & 5 deletions tests/targets/test_arch_facts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,16 @@ def test_arch_consumer_fact_jsonl_provider_maps_table_2_1_itemized_details(
period=period,
unit="usd",
),
_consumer_fact(
"soi-real-estate-taxes",
concept="irs_soi.real_estate_taxes",
domain=domain,
source_name="irs_soi",
source_table=source_table,
value=108_606_373_000,
period=period,
unit="usd",
),
_consumer_fact(
"soi-mortgage-financial",
concept="irs_soi.home_mortgage_interest_paid_to_financial_institutions",
Expand Down Expand Up @@ -1305,11 +1315,17 @@ def test_arch_consumer_fact_jsonl_provider_maps_table_2_1_itemized_details(
"\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n"
)

target_set = ArchConsumerFactJSONLTargetProvider(consumer_jsonl).load_target_set(
TargetQuery(period=2023)
)
provider = ArchConsumerFactJSONLTargetProvider(consumer_jsonl)
target_set = provider.load_target_set(TargetQuery(period=2023))
targets_by_arch_variable = {
target.metadata["arch_variable"]: target for target in target_set.targets
target.metadata["arch_variable"]: target
for target in target_set.targets
if target.metadata.get("arch_variable") is not None
}
targets_by_measure = {
str(target.measure): target
for target in target_set.targets
if target.measure is not None
}

charitable = targets_by_arch_variable["charitable_amount"]
Expand All @@ -1336,6 +1352,13 @@ def test_arch_consumer_fact_jsonl_provider_maps_table_2_1_itemized_details(
"state_local_income_or_sales_tax_amount"
]
assert state_local_income_sales.measure == "state_and_local_sales_or_income_tax"
assert targets_by_arch_variable["real_estate_taxes_amount"].measure == (
"real_estate_taxes"
)
salt = targets_by_measure["salt"]
assert salt.measure == "salt"
assert salt.value == 327_149_456_000
assert salt.metadata["variable"] == "salt"
assert targets_by_arch_variable["mortgage_interest_paid_amount"].measure == (
"deductible_mortgage_interest"
)
Expand All @@ -1350,7 +1373,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_table_2_1_itemized_details(
)

coverage = summarize_arch_target_profile_coverage(
ArchConsumerFactJSONLTargetProvider(consumer_jsonl),
provider,
period=2023,
profile_name="custom",
target_cells=(
Expand Down Expand Up @@ -1380,6 +1403,11 @@ def test_arch_consumer_fact_jsonl_provider_maps_table_2_1_itemized_details(
"state_and_local_sales_or_income_tax",
geo_level="national",
),
PolicyEngineUSTargetCell(
"salt",
geo_level="national",
domain_variable="salt,tax_unit_itemizes",
),
),
)
assert coverage.covered_cell_count == coverage.target_cell_count
Expand Down
Loading