From 6a262b4e9eb0691d2c50d05c66b40215ddb0cd31 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 25 Apr 2026 19:22:45 -0400 Subject: [PATCH] Add Supabase target provider --- src/microplex_us/supabase_targets.py | 321 ++++++++++++++++++++++++++- tests/targets/test_supabase.py | 151 ++++++++++++- 2 files changed, 462 insertions(+), 10 deletions(-) diff --git a/src/microplex_us/supabase_targets.py b/src/microplex_us/supabase_targets.py index 3309e15..0889817 100644 --- a/src/microplex_us/supabase_targets.py +++ b/src/microplex_us/supabase_targets.py @@ -6,6 +6,92 @@ from typing import Any import requests +from microplex.core import EntityType +from microplex.targets import ( + FilterOperator, + TargetAggregation, + TargetFilter, + TargetQuery, + TargetSet, + TargetSpec, + apply_target_query, +) + +from microplex_us.target_registry import ( + US_TARGET_AVAILABLE_KEY, + US_TARGET_CATEGORY_KEY, + US_TARGET_GROUP_KEY, + US_TARGET_IMPUTATION_KEY, + US_TARGET_LEVEL_KEY, + TargetCategory, + TargetLevel, +) + +SUPABASE_TARGET_ID_KEY = "supabase_target_id" +SUPABASE_VARIABLE_KEY = "supabase_variable" +SUPABASE_TARGET_TYPE_KEY = "supabase_target_type" +SUPABASE_JURISDICTION_KEY = "supabase_jurisdiction" +SUPABASE_STRATUM_NAME_KEY = "supabase_stratum_name" +SUPABASE_SOURCE_INSTITUTION_KEY = "supabase_source_institution" +SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY = "supabase_supported_by_column_map" + +_COUNT_ALL_VARIABLES = { + "family_count", + "household_count", + "person_count", + "spm_unit_count", + "tax_unit_count", +} + +_COUNT_ENTITY_MAP = { + "family_count": EntityType.FAMILY, + "household_count": EntityType.HOUSEHOLD, + "person_count": EntityType.PERSON, + "spm_unit_count": EntityType.SPM_UNIT, + "tax_unit_count": EntityType.TAX_UNIT, +} + +_INCOME_VARIABLES = { + "alimony_income", + "dividend_income", + "employment_income", + "farm_income", + "interest_income", + "long_term_capital_gains", + "partnership_s_corp_income", + "rental_income", + "self_employment_income", + "short_term_capital_gains", + "social_security", + "tax_exempt_pension_income", + "taxable_pension_income", + "unemployment_compensation", +} + +_BENEFIT_VARIABLES = { + "eitc_spending", + "snap_households", + "snap_spending", + "social_security_spending", + "ssi_spending", + "unemployment_spending", +} + +_HEALTH_VARIABLES = { + "aca_enrollment", + "health_insurance_premiums", + "medicaid_enrollment", + "other_medical_expenses", +} + +_TAX_UNIT_VARIABLES = { + "eitc_spending", +} + +_HOUSEHOLD_VARIABLES = { + "snap_households", + "snap_spending", +} class SupabaseTargetLoader: @@ -217,13 +303,11 @@ def _parse_jurisdiction(self, jurisdiction: str) -> str | None: return None if jurisdiction.startswith("us-") and len(jurisdiction) == 5: - state = jurisdiction[3:].lower() - if len(state) == 2: - return state - - if jurisdiction.startswith("us-") and len(jurisdiction) == 5: - fips = jurisdiction[3:] - return self.STATE_FIPS.get(fips) + suffix = jurisdiction[3:].lower() + if suffix in self.STATE_FIPS: + return self.STATE_FIPS[suffix] + if suffix in _state_abbr_to_fips(self.STATE_FIPS): + return suffix return None @@ -286,4 +370,225 @@ def get_summary(self) -> dict[str, Any]: } -__all__ = ["SupabaseTargetLoader"] +class SupabaseTargetProvider(SupabaseTargetLoader): + """Load Supabase targets as canonical core target specs.""" + + def load_target_set(self, query: TargetQuery | None = None) -> TargetSet: + """Load a canonical target set through the core provider protocol.""" + query = query or TargetQuery() + provider_filters = query.provider_filters + period = _query_period(query.period) + institution = provider_filters.get("institution") + target_types = _as_string_set(provider_filters.get("target_types")) + include_unsupported = bool(provider_filters.get("include_unsupported", True)) + include_states = bool(provider_filters.get("include_states", True)) + + if institution: + rows = self.load_by_institution(str(institution), period=period) + else: + rows = self.load_all(period=period) + + specs: list[TargetSpec] = [] + for row in rows: + target_type = _target_type(row) + if target_types and target_type not in target_types: + continue + + spec = self.target_from_row(row) + if ( + not include_states + and spec.metadata.get(US_TARGET_LEVEL_KEY) == TargetLevel.STATE.value + ): + continue + if ( + not include_unsupported + and not spec.metadata[SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY] + ): + continue + specs.append(spec) + + return apply_target_query( + TargetSet(specs), + TargetQuery( + period=period if period is not None else query.period, + entity=query.entity, + names=query.names, + metadata_filters=query.metadata_filters, + ), + ) + + def target_from_row(self, row: dict[str, Any]) -> TargetSpec: + """Translate one Supabase target row into the canonical target IR.""" + variable = str(row["variable"]) + jurisdiction = _target_jurisdiction(row) + state_fips, state_abbr = _jurisdiction_state(jurisdiction, self.STATE_FIPS) + target_type = _target_type(row) + aggregation = _aggregation_for_target_type(target_type) + measure = self.CPS_COLUMN_MAP.get(variable, variable) + supported = variable in self.CPS_COLUMN_MAP + source = row.get("source") if isinstance(row.get("source"), dict) else {} + source_name = source.get("name") or source.get("institution") + source_institution = source.get("institution") + stratum = row.get("stratum") if isinstance(row.get("stratum"), dict) else {} + category = _category_for_variable(variable) + level = TargetLevel.STATE if state_fips is not None else TargetLevel.NATIONAL + + filters: list[TargetFilter] = [] + if aggregation is TargetAggregation.COUNT and variable not in _COUNT_ALL_VARIABLES: + filters.append( + TargetFilter( + feature=measure, + operator=FilterOperator.GT, + value=0, + ) + ) + + if state_fips is not None: + filters.append( + TargetFilter( + feature="state_fips", + operator=FilterOperator.EQ, + value=state_fips, + ) + ) + + metadata: dict[str, Any] = { + SUPABASE_TARGET_ID_KEY: row.get("id"), + SUPABASE_VARIABLE_KEY: variable, + SUPABASE_TARGET_TYPE_KEY: target_type, + SUPABASE_JURISDICTION_KEY: jurisdiction, + SUPABASE_STRATUM_NAME_KEY: stratum.get("name"), + SUPABASE_SOURCE_INSTITUTION_KEY: source_institution, + SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY: supported, + US_TARGET_LEVEL_KEY: level.value, + US_TARGET_GROUP_KEY: _group_for_category(category), + US_TARGET_AVAILABLE_KEY: supported, + US_TARGET_IMPUTATION_KEY: not supported, + } + if category is not None: + metadata[US_TARGET_CATEGORY_KEY] = category.value + if state_fips is not None: + metadata["state_fips"] = state_fips + metadata["state_abbr"] = state_abbr + + return TargetSpec( + name=_target_name(variable, jurisdiction), + entity=_entity_for_variable(variable), + value=float(row["value"]), + period=int(row["period"]), + measure=None if aggregation is TargetAggregation.COUNT else measure, + aggregation=aggregation, + filters=tuple(filters), + source=source_name, + units=_units_for_target_type(target_type), + description=row.get("notes"), + metadata=metadata, + ) + + +def _target_type(row: dict[str, Any]) -> str: + return str(row.get("target_type") or "amount").lower() + + +def _aggregation_for_target_type(target_type: str) -> TargetAggregation: + if target_type == "count": + return TargetAggregation.COUNT + if target_type == "mean": + return TargetAggregation.MEAN + return TargetAggregation.SUM + + +def _target_jurisdiction(row: dict[str, Any]) -> str: + stratum = row.get("stratum") if isinstance(row.get("stratum"), dict) else {} + return str(stratum.get("jurisdiction") or "us") + + +def _target_name(variable: str, jurisdiction: str) -> str: + if jurisdiction in {"us", "us-national"}: + return variable + return f"{variable}_{jurisdiction.replace('-', '_')}" + + +def _query_period(period: int | str | None) -> int | None: + if isinstance(period, int): + return period + if isinstance(period, str) and period.isdigit(): + return int(period) + return None + + +def _as_string_set(value: Any) -> set[str]: + if value is None: + return set() + if isinstance(value, str): + return {value} + return {str(item) for item in value} + + +def _state_abbr_to_fips(state_fips: dict[str, str]) -> dict[str, str]: + return {abbr: fips for fips, abbr in state_fips.items()} + + +def _jurisdiction_state( + jurisdiction: str, + state_fips: dict[str, str], +) -> tuple[str | None, str | None]: + if not jurisdiction.startswith("us-") or len(jurisdiction) != 5: + return None, None + + suffix = jurisdiction[3:].lower() + if suffix in state_fips: + return suffix, state_fips[suffix] + + abbr_to_fips = _state_abbr_to_fips(state_fips) + if suffix in abbr_to_fips: + return abbr_to_fips[suffix], suffix + + return None, None + + +def _category_for_variable(variable: str) -> TargetCategory | None: + if variable in _INCOME_VARIABLES: + return TargetCategory.INCOME + if variable in _BENEFIT_VARIABLES: + return TargetCategory.BENEFITS + if variable in _HEALTH_VARIABLES: + return TargetCategory.HEALTH + if variable.endswith("_tax") or variable.endswith("_credit"): + return TargetCategory.TAX + if variable in _COUNT_ALL_VARIABLES: + return TargetCategory.DEMOGRAPHICS + return None + + +def _entity_for_variable(variable: str) -> EntityType: + if variable in _COUNT_ENTITY_MAP: + return _COUNT_ENTITY_MAP[variable] + if variable in _TAX_UNIT_VARIABLES: + return EntityType.TAX_UNIT + if variable in _HOUSEHOLD_VARIABLES: + return EntityType.HOUSEHOLD + return EntityType.PERSON + + +def _group_for_category(category: TargetCategory | None) -> str: + if category is None: + return "supabase_targets" + return f"supabase_{category.value}" + + +def _units_for_target_type(target_type: str) -> str | None: + return "USD" if target_type == "amount" else None + + +__all__ = [ + "SUPABASE_JURISDICTION_KEY", + "SUPABASE_SOURCE_INSTITUTION_KEY", + "SUPABASE_STRATUM_NAME_KEY", + "SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY", + "SUPABASE_TARGET_ID_KEY", + "SUPABASE_TARGET_TYPE_KEY", + "SUPABASE_VARIABLE_KEY", + "SupabaseTargetLoader", + "SupabaseTargetProvider", +] diff --git a/tests/targets/test_supabase.py b/tests/targets/test_supabase.py index 238e3a1..ead8ca9 100644 --- a/tests/targets/test_supabase.py +++ b/tests/targets/test_supabase.py @@ -6,8 +6,22 @@ from typing import Any import pytest - -from microplex_us.supabase_targets import SupabaseTargetLoader +from microplex.core import EntityType +from microplex.targets import FilterOperator, TargetAggregation, TargetQuery + +from microplex_us.supabase_targets import ( + SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY, + SUPABASE_TARGET_TYPE_KEY, + SUPABASE_VARIABLE_KEY, + SupabaseTargetLoader, + SupabaseTargetProvider, +) +from microplex_us.target_registry import ( + US_TARGET_CATEGORY_KEY, + US_TARGET_LEVEL_KEY, + TargetCategory, + TargetLevel, +) SUPABASE_URL = "https://test.supabase.co" SUPABASE_KEY = "test-key" @@ -29,6 +43,11 @@ def loader() -> SupabaseTargetLoader: return SupabaseTargetLoader(SUPABASE_URL, SUPABASE_KEY) +@pytest.fixture +def provider() -> SupabaseTargetProvider: + return SupabaseTargetProvider(SUPABASE_URL, SUPABASE_KEY) + + @pytest.fixture def request_queue(monkeypatch: pytest.MonkeyPatch): calls = [] @@ -246,3 +265,131 @@ def test_get_summary(loader: SupabaseTargetLoader, request_queue) -> None: "by_variable": {"employment_income": 1, "person_count": 1}, "by_type": {"amount": 1, "count": 1}, } + + +def test_target_from_row_builds_national_sum_spec( + provider: SupabaseTargetProvider, +) -> None: + spec = provider.target_from_row( + { + "id": "target-1", + "variable": "employment_income", + "value": 9022400000000, + "target_type": "amount", + "period": 2024, + "source": {"name": "IRS SOI", "institution": "IRS"}, + "stratum": {"name": "National", "jurisdiction": "us"}, + } + ) + + assert spec.name == "employment_income" + assert spec.entity is EntityType.PERSON + assert spec.aggregation is TargetAggregation.SUM + assert spec.measure == "employment_income" + assert spec.filters == () + assert spec.value == 9022400000000 + assert spec.source == "IRS SOI" + assert spec.metadata[SUPABASE_VARIABLE_KEY] == "employment_income" + assert spec.metadata[SUPABASE_TARGET_TYPE_KEY] == "amount" + assert spec.metadata[SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY] is True + assert spec.metadata[US_TARGET_CATEGORY_KEY] == TargetCategory.INCOME.value + assert spec.metadata[US_TARGET_LEVEL_KEY] == TargetLevel.NATIONAL.value + + +def test_target_from_row_builds_state_count_spec( + provider: SupabaseTargetProvider, +) -> None: + spec = provider.target_from_row( + { + "id": "target-2", + "variable": "medicaid_enrollment", + "value": 14000000, + "target_type": "count", + "period": 2024, + "source": {"name": "CMS Medicaid", "institution": "HHS"}, + "stratum": {"name": "California", "jurisdiction": "us-ca"}, + } + ) + + assert spec.name == "medicaid_enrollment_us_ca" + assert spec.entity is EntityType.PERSON + assert spec.aggregation is TargetAggregation.COUNT + assert spec.measure is None + assert spec.filters[0].feature == "medicaid" + assert spec.filters[0].operator is FilterOperator.GT + assert spec.filters[0].value == 0 + assert spec.filters[1].feature == "state_fips" + assert spec.filters[1].operator is FilterOperator.EQ + assert spec.filters[1].value == "06" + assert spec.required_features == ("medicaid", "state_fips") + assert spec.metadata[US_TARGET_CATEGORY_KEY] == TargetCategory.HEALTH.value + assert spec.metadata[US_TARGET_LEVEL_KEY] == TargetLevel.STATE.value + + +def test_target_from_row_keeps_unsupported_variables_classifiable( + provider: SupabaseTargetProvider, +) -> None: + spec = provider.target_from_row( + { + "id": "target-3", + "variable": "unknown_cash_income", + "value": 100, + "target_type": "amount", + "period": 2024, + "source": {"name": "Unknown", "institution": "Other"}, + "stratum": {"name": "National", "jurisdiction": "us"}, + } + ) + + assert spec.measure == "unknown_cash_income" + assert spec.required_features == ("unknown_cash_income",) + assert spec.metadata[SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY] is False + + +def test_load_target_set_filters_rows_with_core_query( + provider: SupabaseTargetProvider, + request_queue, +) -> None: + calls = request_queue( + [ + { + "id": "target-1", + "variable": "employment_income", + "value": 9022400000000, + "target_type": "amount", + "period": 2024, + "source": {"name": "IRS SOI", "institution": "IRS"}, + "stratum": {"name": "National", "jurisdiction": "us"}, + }, + { + "id": "target-2", + "variable": "snap_spending", + "value": 103100000000, + "target_type": "amount", + "period": 2024, + "source": {"name": "USDA SNAP", "institution": "USDA"}, + "stratum": {"name": "National", "jurisdiction": "us"}, + }, + { + "id": "target-3", + "variable": "unknown_cash_income", + "value": 100, + "target_type": "amount", + "period": 2024, + "source": {"name": "Unknown", "institution": "Other"}, + "stratum": {"name": "National", "jurisdiction": "us"}, + }, + ] + ) + + target_set = provider.load_target_set( + TargetQuery( + period=2024, + entity=EntityType.PERSON, + metadata_filters={US_TARGET_CATEGORY_KEY: TargetCategory.INCOME.value}, + provider_filters={"include_unsupported": False}, + ) + ) + + assert [target.name for target in target_set.targets] == ["employment_income"] + assert calls[0]["params"]["period"] == "eq.2024"