Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 313 additions & 8 deletions src/microplex_us/supabase_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,92 @@
from typing import Any

import requests
from microplex.core import EntityType
from microplex.targets import (
FilterOperator,
TargetAggregation,
TargetFilter,
TargetQuery,
TargetSet,
TargetSpec,
apply_target_query,
)

from microplex_us.target_registry import (
US_TARGET_AVAILABLE_KEY,
US_TARGET_CATEGORY_KEY,
US_TARGET_GROUP_KEY,
US_TARGET_IMPUTATION_KEY,
US_TARGET_LEVEL_KEY,
TargetCategory,
TargetLevel,
)

SUPABASE_TARGET_ID_KEY = "supabase_target_id"
SUPABASE_VARIABLE_KEY = "supabase_variable"
SUPABASE_TARGET_TYPE_KEY = "supabase_target_type"
SUPABASE_JURISDICTION_KEY = "supabase_jurisdiction"
SUPABASE_STRATUM_NAME_KEY = "supabase_stratum_name"
SUPABASE_SOURCE_INSTITUTION_KEY = "supabase_source_institution"
SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY = "supabase_supported_by_column_map"

_COUNT_ALL_VARIABLES = {
"family_count",
"household_count",
"person_count",
"spm_unit_count",
"tax_unit_count",
}

_COUNT_ENTITY_MAP = {
"family_count": EntityType.FAMILY,
"household_count": EntityType.HOUSEHOLD,
"person_count": EntityType.PERSON,
"spm_unit_count": EntityType.SPM_UNIT,
"tax_unit_count": EntityType.TAX_UNIT,
}

_INCOME_VARIABLES = {
"alimony_income",
"dividend_income",
"employment_income",
"farm_income",
"interest_income",
"long_term_capital_gains",
"partnership_s_corp_income",
"rental_income",
"self_employment_income",
"short_term_capital_gains",
"social_security",
"tax_exempt_pension_income",
"taxable_pension_income",
"unemployment_compensation",
}

_BENEFIT_VARIABLES = {
"eitc_spending",
"snap_households",
"snap_spending",
"social_security_spending",
"ssi_spending",
"unemployment_spending",
}

_HEALTH_VARIABLES = {
"aca_enrollment",
"health_insurance_premiums",
"medicaid_enrollment",
"other_medical_expenses",
}

_TAX_UNIT_VARIABLES = {
"eitc_spending",
}

_HOUSEHOLD_VARIABLES = {
"snap_households",
"snap_spending",
}


class SupabaseTargetLoader:
Expand Down Expand Up @@ -217,13 +303,11 @@ def _parse_jurisdiction(self, jurisdiction: str) -> str | None:
return None

if jurisdiction.startswith("us-") and len(jurisdiction) == 5:
state = jurisdiction[3:].lower()
if len(state) == 2:
return state

if jurisdiction.startswith("us-") and len(jurisdiction) == 5:
fips = jurisdiction[3:]
return self.STATE_FIPS.get(fips)
suffix = jurisdiction[3:].lower()
if suffix in self.STATE_FIPS:
return self.STATE_FIPS[suffix]
if suffix in _state_abbr_to_fips(self.STATE_FIPS):
return suffix

return None

Expand Down Expand Up @@ -286,4 +370,225 @@ def get_summary(self) -> dict[str, Any]:
}


__all__ = ["SupabaseTargetLoader"]
class SupabaseTargetProvider(SupabaseTargetLoader):
"""Load Supabase targets as canonical core target specs."""

def load_target_set(self, query: TargetQuery | None = None) -> TargetSet:
"""Load a canonical target set through the core provider protocol."""
query = query or TargetQuery()
provider_filters = query.provider_filters
period = _query_period(query.period)
institution = provider_filters.get("institution")
target_types = _as_string_set(provider_filters.get("target_types"))
include_unsupported = bool(provider_filters.get("include_unsupported", True))
include_states = bool(provider_filters.get("include_states", True))

if institution:
rows = self.load_by_institution(str(institution), period=period)
else:
rows = self.load_all(period=period)

specs: list[TargetSpec] = []
for row in rows:
target_type = _target_type(row)
if target_types and target_type not in target_types:
continue

spec = self.target_from_row(row)
if (
not include_states
and spec.metadata.get(US_TARGET_LEVEL_KEY) == TargetLevel.STATE.value
):
continue
if (
not include_unsupported
and not spec.metadata[SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY]
):
continue
specs.append(spec)

return apply_target_query(
TargetSet(specs),
TargetQuery(
period=period if period is not None else query.period,
entity=query.entity,
names=query.names,
metadata_filters=query.metadata_filters,
),
)

def target_from_row(self, row: dict[str, Any]) -> TargetSpec:
"""Translate one Supabase target row into the canonical target IR."""
variable = str(row["variable"])
jurisdiction = _target_jurisdiction(row)
state_fips, state_abbr = _jurisdiction_state(jurisdiction, self.STATE_FIPS)
target_type = _target_type(row)
aggregation = _aggregation_for_target_type(target_type)
measure = self.CPS_COLUMN_MAP.get(variable, variable)
supported = variable in self.CPS_COLUMN_MAP
source = row.get("source") if isinstance(row.get("source"), dict) else {}
source_name = source.get("name") or source.get("institution")
source_institution = source.get("institution")
stratum = row.get("stratum") if isinstance(row.get("stratum"), dict) else {}
category = _category_for_variable(variable)
level = TargetLevel.STATE if state_fips is not None else TargetLevel.NATIONAL

filters: list[TargetFilter] = []
if aggregation is TargetAggregation.COUNT and variable not in _COUNT_ALL_VARIABLES:
filters.append(
TargetFilter(
feature=measure,
operator=FilterOperator.GT,
value=0,
)
)

if state_fips is not None:
filters.append(
TargetFilter(
feature="state_fips",
operator=FilterOperator.EQ,
value=state_fips,
)
)

metadata: dict[str, Any] = {
SUPABASE_TARGET_ID_KEY: row.get("id"),
SUPABASE_VARIABLE_KEY: variable,
SUPABASE_TARGET_TYPE_KEY: target_type,
SUPABASE_JURISDICTION_KEY: jurisdiction,
SUPABASE_STRATUM_NAME_KEY: stratum.get("name"),
SUPABASE_SOURCE_INSTITUTION_KEY: source_institution,
SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY: supported,
US_TARGET_LEVEL_KEY: level.value,
US_TARGET_GROUP_KEY: _group_for_category(category),
US_TARGET_AVAILABLE_KEY: supported,
US_TARGET_IMPUTATION_KEY: not supported,
}
if category is not None:
metadata[US_TARGET_CATEGORY_KEY] = category.value
if state_fips is not None:
metadata["state_fips"] = state_fips
metadata["state_abbr"] = state_abbr

return TargetSpec(
name=_target_name(variable, jurisdiction),
entity=_entity_for_variable(variable),
value=float(row["value"]),
period=int(row["period"]),
measure=None if aggregation is TargetAggregation.COUNT else measure,
aggregation=aggregation,
filters=tuple(filters),
source=source_name,
units=_units_for_target_type(target_type),
description=row.get("notes"),
metadata=metadata,
)


def _target_type(row: dict[str, Any]) -> str:
return str(row.get("target_type") or "amount").lower()


def _aggregation_for_target_type(target_type: str) -> TargetAggregation:
if target_type == "count":
return TargetAggregation.COUNT
if target_type == "mean":
return TargetAggregation.MEAN
return TargetAggregation.SUM


def _target_jurisdiction(row: dict[str, Any]) -> str:
stratum = row.get("stratum") if isinstance(row.get("stratum"), dict) else {}
return str(stratum.get("jurisdiction") or "us")


def _target_name(variable: str, jurisdiction: str) -> str:
if jurisdiction in {"us", "us-national"}:
return variable
return f"{variable}_{jurisdiction.replace('-', '_')}"


def _query_period(period: int | str | None) -> int | None:
if isinstance(period, int):
return period
if isinstance(period, str) and period.isdigit():
return int(period)
return None


def _as_string_set(value: Any) -> set[str]:
if value is None:
return set()
if isinstance(value, str):
return {value}
return {str(item) for item in value}


def _state_abbr_to_fips(state_fips: dict[str, str]) -> dict[str, str]:
return {abbr: fips for fips, abbr in state_fips.items()}


def _jurisdiction_state(
jurisdiction: str,
state_fips: dict[str, str],
) -> tuple[str | None, str | None]:
if not jurisdiction.startswith("us-") or len(jurisdiction) != 5:
return None, None

suffix = jurisdiction[3:].lower()
if suffix in state_fips:
return suffix, state_fips[suffix]

abbr_to_fips = _state_abbr_to_fips(state_fips)
if suffix in abbr_to_fips:
return abbr_to_fips[suffix], suffix

return None, None


def _category_for_variable(variable: str) -> TargetCategory | None:
if variable in _INCOME_VARIABLES:
return TargetCategory.INCOME
if variable in _BENEFIT_VARIABLES:
return TargetCategory.BENEFITS
if variable in _HEALTH_VARIABLES:
return TargetCategory.HEALTH
if variable.endswith("_tax") or variable.endswith("_credit"):
return TargetCategory.TAX
if variable in _COUNT_ALL_VARIABLES:
return TargetCategory.DEMOGRAPHICS
return None


def _entity_for_variable(variable: str) -> EntityType:
if variable in _COUNT_ENTITY_MAP:
return _COUNT_ENTITY_MAP[variable]
if variable in _TAX_UNIT_VARIABLES:
return EntityType.TAX_UNIT
if variable in _HOUSEHOLD_VARIABLES:
return EntityType.HOUSEHOLD
return EntityType.PERSON


def _group_for_category(category: TargetCategory | None) -> str:
if category is None:
return "supabase_targets"
return f"supabase_{category.value}"


def _units_for_target_type(target_type: str) -> str | None:
return "USD" if target_type == "amount" else None


__all__ = [
"SUPABASE_JURISDICTION_KEY",
"SUPABASE_SOURCE_INSTITUTION_KEY",
"SUPABASE_STRATUM_NAME_KEY",
"SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY",
"SUPABASE_TARGET_ID_KEY",
"SUPABASE_TARGET_TYPE_KEY",
"SUPABASE_VARIABLE_KEY",
"SupabaseTargetLoader",
"SupabaseTargetProvider",
]
Loading
Loading