From cb6f9dc80e005f83127fb43c12f589024a1c1fd1 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 25 Apr 2026 17:46:48 -0400 Subject: [PATCH] Move Supabase target loader behind core shim --- src/microplex/supabase_targets.py | 333 ++---------------------------- tests/test_supabase_targets.py | 299 ++------------------------- 2 files changed, 41 insertions(+), 591 deletions(-) diff --git a/src/microplex/supabase_targets.py b/src/microplex/supabase_targets.py index 782677d..2b41fbd 100644 --- a/src/microplex/supabase_targets.py +++ b/src/microplex/supabase_targets.py @@ -1,331 +1,36 @@ -""" -Load calibration targets from Supabase. - -Provides SupabaseTargetLoader for loading PE calibration targets from the -microplex Supabase schema and mapping them to CPS columns for calibration. +"""Compatibility shim for US Supabase calibration targets. -Deprecated: - This is US-specific compatibility code and will move to microplex-us. +US-specific Supabase target loading now lives in `microplex-us`. """ from __future__ import annotations -import os import warnings -from typing import Any - -import requests warnings.warn( - "microplex.supabase_targets is US-specific compatibility code and will move " - "to microplex-us.", + "microplex.supabase_targets is deprecated; use " + "microplex_us.supabase_targets instead.", DeprecationWarning, stacklevel=2, ) +_moved_loader_import_error: ImportError | None = None -class SupabaseTargetLoader: - """Load calibration targets from Supabase.""" - - # Mapping from Supabase variable names to CPS column names - CPS_COLUMN_MAP = { - # IRS Income targets - "employment_income": "employment_income", - "self_employment_income": "self_employment_income", - "dividend_income": "dividend_income", - "interest_income": "interest_income", - "rental_income": "rental_income", - "social_security": "social_security", - "unemployment_compensation": "unemployment_compensation", - "taxable_pension_income": "taxable_pension_income", - "tax_exempt_pension_income": "tax_exempt_pension_income", - "long_term_capital_gains": "long_term_capital_gains", - "short_term_capital_gains": "short_term_capital_gains", - "partnership_s_corp_income": "partnership_s_corp_income", - "farm_income": "farm_income", - "alimony_income": "alimony_income", - # Benefit spending targets - "snap_spending": "snap", - "ssi_spending": "ssi", - "eitc_spending": "eitc", - "social_security_spending": "social_security", - "unemployment_spending": "unemployment_compensation", - # Benefit enrollment/count targets - "medicaid_enrollment": "medicaid", - "aca_enrollment": "aca", - "snap_households": "snap", - # Healthcare targets - "health_insurance_premiums": "health_insurance_premiums", - "other_medical_expenses": "medical_expenses", - } - - # State FIPS to abbreviation - STATE_FIPS = { - "01": "al", "02": "ak", "04": "az", "05": "ar", "06": "ca", - "08": "co", "09": "ct", "10": "de", "11": "dc", "12": "fl", - "13": "ga", "15": "hi", "16": "id", "17": "il", "18": "in", - "19": "ia", "20": "ks", "21": "ky", "22": "la", "23": "me", - "24": "md", "25": "ma", "26": "mi", "27": "mn", "28": "ms", - "29": "mo", "30": "mt", "31": "ne", "32": "nv", "33": "nh", - "34": "nj", "35": "nm", "36": "ny", "37": "nc", "38": "nd", - "39": "oh", "40": "ok", "41": "or", "42": "pa", "44": "ri", - "45": "sc", "46": "sd", "47": "tn", "48": "tx", "49": "ut", - "50": "vt", "51": "va", "53": "wa", "54": "wv", "55": "wi", - "56": "wy" - } - - def __init__( - self, - url: str | None = None, - key: str | None = None, - schema: str = "microplex", - ): - """Initialize the loader. - - Args: - url: Supabase URL. Defaults to SUPABASE_URL env var. - key: Supabase key. Defaults to COSILICO_SUPABASE_SERVICE_KEY env var. - schema: Schema to use. Defaults to 'microplex'. - """ - self.url = url or os.environ.get( - "SUPABASE_URL", - "https://nsupqhfchdtqclomlrgs.supabase.co" - ) - self.key = key or os.environ.get("COSILICO_SUPABASE_SERVICE_KEY") - if not self.key: - raise ValueError( - "Supabase service key must be provided via the key argument or " - "COSILICO_SUPABASE_SERVICE_KEY." - ) - self.base_url = f"{self.url}/rest/v1" - self.headers = { - "apikey": self.key, - "Authorization": f"Bearer {self.key}", - "Content-Type": "application/json", - "Accept-Profile": schema, - "Content-Profile": schema, - } - self._cache = {} - - def _get( - self, - endpoint: str, - params: dict[str, Any] | None = None, - paginate: bool = True, - ) -> list[dict[str, Any]]: - """Make a GET request to Supabase with optional pagination. - - Args: - endpoint: API endpoint. - params: Query parameters. - paginate: If True, fetch all results using pagination. - - Returns: - List of result dicts. - """ - url = f"{self.base_url}/{endpoint}" - params = params or {} - - if not paginate: - resp = requests.get(url, headers=self.headers, params=params, timeout=30) - resp.raise_for_status() - return resp.json() - - # Paginate to get all results - all_results = [] - offset = 0 - limit = 1000 # Supabase default max - - while True: - page_params = {**params, "limit": limit, "offset": offset} - resp = requests.get(url, headers=self.headers, params=page_params, timeout=30) - resp.raise_for_status() - results = resp.json() - - if not results: - break - - all_results.extend(results) - offset += limit - - # If we got fewer than limit, we're done - if len(results) < limit: - break - - return all_results - - def load_all(self, period: int | None = None) -> list[dict[str, Any]]: - """Load all targets with source and stratum info. - - Args: - period: Optional year to filter by. - - Returns: - List of target dicts with nested source and stratum info. - """ - # Use PostgREST's embedded resources to join - params = { - "select": "id,variable,value,target_type,period,notes,source:sources(id,name,institution),stratum:strata(id,name,jurisdiction)", - } - if period: - params["period"] = f"eq.{period}" - - return self._get("targets", params) - - def load_by_institution( - self, - institution: str, - period: int | None = None, - ) -> list[dict[str, Any]]: - """Load targets from a specific institution. - - Args: - institution: Institution name (e.g., 'IRS', 'Census', 'USDA'). - period: Optional year to filter by. - - Returns: - List of target dicts. - """ - # First get source IDs for this institution - sources = self._get("sources", {"institution": f"eq.{institution}"}) - source_ids = [s["id"] for s in sources] - - if not source_ids: - return [] - - # Filter targets by source IDs - params = { - "select": "id,variable,value,target_type,period,notes,source:sources(id,name,institution),stratum:strata(id,name,jurisdiction)", - "source_id": f"in.({','.join(source_ids)})", - } - if period: - params["period"] = f"eq.{period}" - - return self._get("targets", params) - - def load_by_period(self, period: int) -> list[dict[str, Any]]: - """Load targets for a specific year. - - Args: - period: Year to filter by. - - Returns: - List of target dicts. - """ - return self.load_all(period=period) - - def get_cps_column_map(self) -> dict[str, str]: - """Get the mapping from Supabase variable names to CPS columns. - - Returns: - Dict mapping variable -> CPS column name. - """ - return self.CPS_COLUMN_MAP.copy() - - def _parse_jurisdiction(self, jurisdiction: str) -> str | None: - """Parse jurisdiction to get state code if applicable. - - Args: - jurisdiction: Jurisdiction string (e.g., 'us', 'us-ca', 'us-06'). - - Returns: - State abbreviation if state-level, None for national. - """ - if jurisdiction == "us" or jurisdiction == "us-national": - return None - - # Handle us-XX format (state abbrev) - if jurisdiction.startswith("us-") and len(jurisdiction) == 5: - state = jurisdiction[3:].lower() - if len(state) == 2: - return state - - # Handle us-FIPS format - if jurisdiction.startswith("us-") and len(jurisdiction) == 5: - fips = jurisdiction[3:] - return self.STATE_FIPS.get(fips) - - return None - - def build_calibration_constraints( - self, - period: int = 2024, - include_states: bool = False, - target_types: list[str] | None = None, - ) -> dict[str, float]: - """Build calibration constraint dict from Supabase targets. - - Args: - period: Year to get targets for. - include_states: Whether to include state-level targets. - target_types: List of target types to include ('amount', 'count'). - Defaults to all. - - Returns: - Dict mapping CPS column name to target value. - """ - targets = self.load_all(period=period) - constraints = {} - - for target in targets: - variable = target["variable"] - value = target["value"] - target_type = target.get("target_type", "amount") - stratum = target.get("stratum", {}) - jurisdiction = stratum.get("jurisdiction", "us") - - # Filter by target type - if target_types and target_type not in target_types: - continue - - # Map variable to CPS column - cps_col = self.CPS_COLUMN_MAP.get(variable) - if not cps_col: - continue - - # Handle national vs state targets - state = self._parse_jurisdiction(jurisdiction) - - if state and include_states: - # State-level target: append state code - key = f"{cps_col}_{state}" - constraints[key] = value - elif not state: - # National target - # Avoid duplicates (prefer first encountered) - if cps_col not in constraints: - constraints[cps_col] = value - - return constraints - - def get_summary(self) -> dict[str, Any]: - """Get summary of available targets in Supabase. - - Returns: - Dict with counts by institution, variable, etc. - """ - targets = self.load_all() - - by_institution = {} - by_variable = {} - by_type = {} +try: + from microplex_us.supabase_targets import SupabaseTargetLoader +except ImportError as _import_error: + _moved_loader_import_error = _import_error - for t in targets: - # By institution - inst = t.get("source", {}).get("institution", "Unknown") - by_institution[inst] = by_institution.get(inst, 0) + 1 + class SupabaseTargetLoader: # type: ignore[no-redef] + """Placeholder that explains how to access the moved US loader.""" - # By variable - var = t["variable"] - by_variable[var] = by_variable.get(var, 0) + 1 + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + raise ImportError( + "SupabaseTargetLoader moved to microplex-us. Install " + "`microplex-us` and import " + "`microplex_us.supabase_targets.SupabaseTargetLoader`." + ) from _moved_loader_import_error - # By type - tt = t.get("target_type", "amount") - by_type[tt] = by_type.get(tt, 0) + 1 - return { - "total": len(targets), - "by_institution": by_institution, - "by_variable": by_variable, - "by_type": by_type, - } +__all__ = ["SupabaseTargetLoader"] diff --git a/tests/test_supabase_targets.py b/tests/test_supabase_targets.py index 5d080d5..a0df65c 100644 --- a/tests/test_supabase_targets.py +++ b/tests/test_supabase_targets.py @@ -1,292 +1,37 @@ -""" -TDD tests for loading calibration targets from Supabase. +"""Compatibility tests for the US Supabase target loader shim.""" -These tests verify that: -1. Targets can be loaded from Supabase with proper filtering -2. Target variables are mapped correctly to CPS columns -3. Calibration constraints can be built from targets -""" +from __future__ import annotations import importlib.util -import sys from pathlib import Path import pytest -import responses -# Direct import to avoid torch dependency in __init__.py -src_path = Path(__file__).parent.parent / "src" / "microplex" -sys.path.insert(0, str(src_path.parent)) -# Import directly to avoid package __init__.py -spec = importlib.util.spec_from_file_location("supabase_targets", src_path / "supabase_targets.py") -module = importlib.util.module_from_spec(spec) -spec.loader.exec_module(module) -SupabaseTargetLoader = module.SupabaseTargetLoader +def _load_supabase_targets_module(): + src_path = Path(__file__).parent.parent / "src" / "microplex" + spec = importlib.util.spec_from_file_location( + "supabase_targets", + src_path / "supabase_targets.py", + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + with pytest.warns(DeprecationWarning, match="microplex_us.supabase_targets"): + spec.loader.exec_module(module) + return module -SUPABASE_URL = "https://test.supabase.co" -SUPABASE_KEY = "test-key" +def test_supabase_target_loader_is_compatibility_shim() -> None: + module = _load_supabase_targets_module() + assert module.__all__ == ["SupabaseTargetLoader"] -class TestSupabaseTargetLoader: - """Tests for loading targets from Supabase.""" - @pytest.fixture - def loader(self): - return SupabaseTargetLoader(SUPABASE_URL, SUPABASE_KEY) +def test_missing_microplex_us_loader_raises_actionable_import_error() -> None: + module = _load_supabase_targets_module() - def test_missing_service_key_raises(self, monkeypatch): - """Should never fall back to an embedded service-role key.""" - monkeypatch.delenv("COSILICO_SUPABASE_SERVICE_KEY", raising=False) + if module.SupabaseTargetLoader.__module__.startswith("microplex_us"): + pytest.skip("microplex-us is installed; shim resolved to the moved loader") - with pytest.raises(ValueError, match="COSILICO_SUPABASE_SERVICE_KEY"): - SupabaseTargetLoader(SUPABASE_URL) - - @responses.activate - def test_load_all_targets(self, loader): - """Should load all targets with source and stratum info.""" - # Mock the targets query with joined data - responses.add( - responses.GET, - f"{SUPABASE_URL}/rest/v1/targets", - json=[ - { - "id": "t1", - "variable": "employment_income", - "value": 9022400000000, - "target_type": "amount", - "period": 2024, - "source": {"name": "IRS SOI", "institution": "IRS"}, - "stratum": {"name": "National", "jurisdiction": "us"}, - }, - { - "id": "t2", - "variable": "snap_spending", - "value": 103100000000, - "target_type": "amount", - "period": 2024, - "source": {"name": "USDA SNAP", "institution": "USDA"}, - "stratum": {"name": "National", "jurisdiction": "us"}, - }, - ], - status=200, - ) - - targets = loader.load_all() - - assert len(targets) == 2 - assert targets[0]["variable"] == "employment_income" - assert targets[0]["value"] == 9022400000000 - assert targets[1]["variable"] == "snap_spending" - - @responses.activate - def test_load_by_institution(self, loader): - """Should filter targets by source institution.""" - # Mock the sources query first - responses.add( - responses.GET, - f"{SUPABASE_URL}/rest/v1/sources", - json=[{"id": "src-1", "institution": "IRS", "name": "IRS SOI"}], - status=200, - ) - # Then mock the targets query - responses.add( - responses.GET, - f"{SUPABASE_URL}/rest/v1/targets", - json=[ - { - "id": "t1", - "variable": "employment_income", - "value": 9022400000000, - "target_type": "amount", - "period": 2024, - "source": {"name": "IRS SOI", "institution": "IRS"}, - "stratum": {"name": "National", "jurisdiction": "us"}, - }, - ], - status=200, - ) - - targets = loader.load_by_institution("IRS") - - assert len(targets) == 1 - assert targets[0]["source"]["institution"] == "IRS" - - @responses.activate - def test_load_by_period(self, loader): - """Should filter targets by period/year.""" - responses.add( - responses.GET, - f"{SUPABASE_URL}/rest/v1/targets", - json=[ - { - "id": "t1", - "variable": "employment_income", - "value": 9022400000000, - "target_type": "amount", - "period": 2024, - "source": {"name": "IRS SOI", "institution": "IRS"}, - "stratum": {"name": "National", "jurisdiction": "us"}, - }, - ], - status=200, - ) - - targets = loader.load_by_period(2024) - - assert len(targets) == 1 - assert targets[0]["period"] == 2024 - - -class TestTargetToCPSMapping: - """Tests for mapping Supabase targets to CPS columns.""" - - @pytest.fixture - def loader(self): - return SupabaseTargetLoader(SUPABASE_URL, SUPABASE_KEY) - - def test_income_variable_mapping(self, loader): - """Should map PE income variables to CPS columns.""" - mapping = loader.get_cps_column_map() - - # IRS income targets should map to CPS columns - assert mapping["employment_income"] == "employment_income" - assert mapping["self_employment_income"] == "self_employment_income" - assert mapping["dividend_income"] == "dividend_income" - assert mapping["interest_income"] == "interest_income" - assert mapping["social_security"] == "social_security" - assert mapping["unemployment_compensation"] == "unemployment_compensation" - - def test_benefit_variable_mapping(self, loader): - """Should map benefit targets to CPS columns.""" - mapping = loader.get_cps_column_map() - - # Benefit spending targets - assert mapping["snap_spending"] == "snap" - assert mapping["ssi_spending"] == "ssi" - assert mapping["eitc_spending"] == "eitc" - - -class TestBuildCalibrationConstraints: - """Tests for building calibration constraints from targets.""" - - @pytest.fixture - def loader(self): - return SupabaseTargetLoader(SUPABASE_URL, SUPABASE_KEY) - - @responses.activate - def test_build_continuous_targets(self, loader): - """Should build continuous calibration targets dict.""" - responses.add( - responses.GET, - f"{SUPABASE_URL}/rest/v1/targets", - json=[ - { - "id": "t1", - "variable": "employment_income", - "value": 9022400000000, - "target_type": "amount", - "period": 2024, - "source": {"name": "IRS SOI", "institution": "IRS"}, - "stratum": {"name": "National", "jurisdiction": "us"}, - }, - { - "id": "t2", - "variable": "snap_spending", - "value": 103100000000, - "target_type": "amount", - "period": 2024, - "source": {"name": "USDA SNAP", "institution": "USDA"}, - "stratum": {"name": "National", "jurisdiction": "us"}, - }, - ], - status=200, - ) - - constraints = loader.build_calibration_constraints() - - # Should return dict with CPS column names as keys - assert "employment_income" in constraints - assert constraints["employment_income"] == 9022400000000 - assert "snap" in constraints - assert constraints["snap"] == 103100000000 - - @responses.activate - def test_build_state_targets(self, loader): - """Should build state-level calibration targets.""" - responses.add( - responses.GET, - f"{SUPABASE_URL}/rest/v1/targets", - json=[ - { - "id": "t1", - "variable": "medicaid_enrollment", - "value": 14000000, - "target_type": "count", - "period": 2024, - "source": {"name": "CMS Medicaid", "institution": "HHS"}, - "stratum": {"name": "California", "jurisdiction": "us-ca"}, - }, - ], - status=200, - ) - - constraints = loader.build_calibration_constraints(include_states=True) - - # State targets should include state code - assert "medicaid_ca" in constraints or "medicaid_enrollment_ca" in constraints - - -class TestIntegrationWithCalibrator: - """Integration tests with the Calibrator.""" - - @pytest.fixture - def loader(self): - return SupabaseTargetLoader(SUPABASE_URL, SUPABASE_KEY) - - @pytest.mark.skip(reason="Integration test requires real Supabase connection") - def test_calibration_with_supabase_targets(self, loader): - """End-to-end test: load targets from Supabase and run calibration.""" - import numpy as np - import pandas as pd - try: - # Direct import to avoid torch dependency - import importlib.util - cal_spec = importlib.util.spec_from_file_location( - "calibration", - Path(__file__).parent.parent / "src" / "microplex" / "calibration.py" - ) - cal_module = importlib.util.module_from_spec(cal_spec) - cal_spec.loader.exec_module(cal_module) - calibrator_cls = cal_module.Calibrator - except Exception as e: - pytest.skip(f"Cannot import Calibrator: {e}") - - # Create mock CPS data - np.random.seed(42) - n = 1000 - df = pd.DataFrame({ - "weight": np.ones(n) * 100, - "employment_income": np.random.lognormal(10, 1, n), - "snap": np.random.choice([0, 500], n, p=[0.9, 0.1]), - }) - - # Load targets from Supabase (uses live connection) - targets = loader.build_calibration_constraints() - - if not targets: - pytest.skip("No targets loaded from Supabase") - - # Filter to available columns - available = {k: v for k, v in targets.items() if k in df.columns} - - if not available: - pytest.skip("No matching targets for test data") - - # Run calibration - calibrator = calibrator_cls(method="ipf", max_iter=100) - calibrator.fit(df, marginal_targets={}, continuous_targets=available, weight_col="weight") - - assert calibrator.weights_ is not None - assert len(calibrator.weights_) == n + with pytest.raises(ImportError, match="microplex-us"): + module.SupabaseTargetLoader()