diff --git a/changelog.d/3445.fixed.md b/changelog.d/3445.fixed.md new file mode 100644 index 000000000..634f0d804 --- /dev/null +++ b/changelog.d/3445.fixed.md @@ -0,0 +1 @@ +Reject unknown columns in `PUT //user-policy` with HTTP 400 and restrict writable fields to a whitelist, closing a SQL injection path where JSON keys were interpolated into the UPDATE statement. diff --git a/changelog.d/3446.fixed.md b/changelog.d/3446.fixed.md new file mode 100644 index 000000000..bcddc35cf --- /dev/null +++ b/changelog.d/3446.fixed.md @@ -0,0 +1 @@ +Move `@validate_country` below `@bp.route` in economy and AI-prompt routes so unknown country IDs are rejected with HTTP 400 instead of reaching the view function. diff --git a/changelog.d/3447.fixed.md b/changelog.d/3447.fixed.md new file mode 100644 index 000000000..c1cb5308f --- /dev/null +++ b/changelog.d/3447.fixed.md @@ -0,0 +1 @@ +Scope `update_household` by `country_id` so an update for one country cannot overwrite a household that shares the same numeric id under another country. Missing rows now raise `LookupError` instead of silently succeeding. diff --git a/changelog.d/3448.fixed.md b/changelog.d/3448.fixed.md new file mode 100644 index 000000000..3144993c8 --- /dev/null +++ b/changelog.d/3448.fixed.md @@ -0,0 +1 @@ +Narrow the exception handlers in `simulation_routes` and `report_output_routes` so only `ValueError`, `pydantic.ValidationError`, and `jsonschema.ValidationError` are mapped to HTTP 400. Unexpected exceptions now propagate as 500 with a logged traceback instead of being silently relabelled as bad requests. diff --git a/changelog.d/3449.fixed.md b/changelog.d/3449.fixed.md new file mode 100644 index 000000000..5746d7e04 --- /dev/null +++ b/changelog.d/3449.fixed.md @@ -0,0 +1 @@ +`SimulationService.update_simulation` no longer rewrites `api_version` when the PATCH payload contains no user-supplied fields. The "no fields to update" guard now fires correctly so an empty PATCH returns 400 instead of silently bumping `api_version`. diff --git a/changelog.d/3450.fixed.md b/changelog.d/3450.fixed.md new file mode 100644 index 000000000..368293bdb --- /dev/null +++ b/changelog.d/3450.fixed.md @@ -0,0 +1 @@ +Use `hashlib.sha256` for API cache keys instead of Python's builtin `hash()`. `hash()` is salted per process (PYTHONHASHSEED), so workers could produce different keys for identical inputs and miss the cache. diff --git a/changelog.d/3451.fixed.md b/changelog.d/3451.fixed.md new file mode 100644 index 000000000..73ef006dc --- /dev/null +++ b/changelog.d/3451.fixed.md @@ -0,0 +1 @@ +`get_simulations` now always applies a LIMIT, clamps `max_results` to `[1, 1000]`, and binds it as a parameter, eliminating the f-string SQL fragment and the unbounded-scan risk. diff --git a/changelog.d/3452.fixed.md b/changelog.d/3452.fixed.md new file mode 100644 index 000000000..d403c87ba --- /dev/null +++ b/changelog.d/3452.fixed.md @@ -0,0 +1 @@ +Replace the bare `except:` around wealth-decile impact computation in `endpoints/economy/compare.py` with `except Exception:` plus `logger.exception(...)` so SystemExit/KeyboardInterrupt can still propagate and failures are visible in logs. diff --git a/changelog.d/3453.fixed.md b/changelog.d/3453.fixed.md new file mode 100644 index 000000000..c7a8e21ee --- /dev/null +++ b/changelog.d/3453.fixed.md @@ -0,0 +1 @@ +Constituency-level country filters in `uk_constituency_breakdown` now use `code.startswith("E"/"S"/"W"/"N")` instead of `"E"/"S"/"W"/"N" not in code`, matching the local-authority pattern already used elsewhere in the file and preventing a constituency's country letter from leaking into other buckets. diff --git a/policyengine_api/endpoints/economy/compare.py b/policyengine_api/endpoints/economy/compare.py index 3d3cf9c27..8437392cf 100644 --- a/policyengine_api/endpoints/economy/compare.py +++ b/policyengine_api/endpoints/economy/compare.py @@ -1,3 +1,4 @@ +import logging from microdf import MicroDataFrame, MicroSeries import numpy as np import sys @@ -7,6 +8,8 @@ from pydantic import BaseModel from typing import Any +logger = logging.getLogger(__name__) + def budgetary_impact(baseline: dict, reform: dict) -> dict: tax_revenue_impact = reform["total_tax"] - baseline["total_tax"] @@ -596,15 +599,19 @@ def uk_constituency_breakdown( if name != selected_constituency and code != selected_constituency: continue - # Filter to specific country if requested + # Filter to specific country if requested. Constituency codes + # are prefixed with a single letter that identifies the + # country (E=England, S=Scotland, W=Wales, N=Northern + # Ireland); use startswith so we don't accidentally match a + # letter anywhere else in the code. if selected_country is not None: - if selected_country == "ENGLAND" and "E" not in code: + if selected_country == "ENGLAND" and not code.startswith("E"): continue - elif selected_country == "SCOTLAND" and "S" not in code: + elif selected_country == "SCOTLAND" and not code.startswith("S"): continue - elif selected_country == "WALES" and "W" not in code: + elif selected_country == "WALES" and not code.startswith("W"): continue - elif selected_country == "NORTHERN_IRELAND" and "N" not in code: + elif selected_country == "NORTHERN_IRELAND" and not code.startswith("N"): continue weight: np.ndarray = weights[i] @@ -624,13 +631,13 @@ def uk_constituency_breakdown( } regions = ["uk"] - if "E" in code: + if code.startswith("E"): regions.append("england") - elif "S" in code: + elif code.startswith("S"): regions.append("scotland") - elif "W" in code: + elif code.startswith("W"): regions.append("wales") - elif "N" in code: + elif code.startswith("N"): regions.append("northern_ireland") if percent_household_income_change > 0.05: @@ -813,7 +820,10 @@ def compare_economic_outputs( intra_wealth_decile_impact_data = intra_wealth_decile_impact( baseline, reform ) - except: + except Exception: + logger.exception( + "Wealth decile impact computation failed; returning empty breakdowns." + ) wealth_decile_impact_data = {} intra_wealth_decile_impact_data = {} diff --git a/policyengine_api/endpoints/policy.py b/policyengine_api/endpoints/policy.py index 290fe895e..f5d33e938 100644 --- a/policyengine_api/endpoints/policy.py +++ b/policyengine_api/endpoints/policy.py @@ -315,18 +315,78 @@ def get_user_policy(country_id: str, user_id: str) -> dict: ) +# Whitelist of columns that callers are allowed to modify via +# update_user_policy. Identity columns (id, country_id, user_id, +# reform_id, baseline_id) are intentionally excluded because they +# define the record; allowing clients to rewrite them would both +# break referential assumptions and let the column name be used +# as a SQL injection vector (keys are interpolated into the +# UPDATE statement below). +UPDATE_USER_POLICY_ALLOWED_FIELDS = frozenset( + { + "reform_label", + "baseline_label", + "year", + "geography", + "dataset", + "number_of_provisions", + "api_version", + "added_date", + "updated_date", + "budgetary_impact", + "type", + } +) + + @validate_country def update_user_policy(country_id: str) -> dict: """ Update any parts of a user_policy, given a user_policy ID """ - # Construct the relevant UPDATE request - setter_array = [] - args = [] payload = request.json + if not isinstance(payload, dict) or "id" not in payload: + return Response( + json.dumps({"message": "Request body must include an 'id' field."}), + status=400, + mimetype="application/json", + ) + user_policy_id = payload.pop("id") + # Reject any unknown/unsafe keys. The keys end up interpolated + # into a SQL UPDATE statement, so we must validate them against + # a static whitelist instead of trusting the JSON payload. + unknown_keys = [ + key for key in payload if key not in UPDATE_USER_POLICY_ALLOWED_FIELDS + ] + if unknown_keys: + return Response( + json.dumps( + { + "message": ( + "Request body contains unsupported fields: " + f"{sorted(unknown_keys)}" + ) + } + ), + status=400, + mimetype="application/json", + ) + + if not payload: + return Response( + json.dumps( + {"message": "Request body must include at least one field to update."} + ), + status=400, + mimetype="application/json", + ) + + # Construct the relevant UPDATE request from whitelisted keys. + setter_array = [] + args = [] for key in payload: setter_array.append(f"{key} = ?") args.append(payload[key]) diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index 132e5b2d6..a0d9bd70d 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -21,15 +21,30 @@ """ +_MAX_SIMULATION_RESULTS = 1000 +_DEFAULT_SIMULATION_RESULTS = 100 + + def get_simulations( - max_results: int = 100, + max_results: int | None = 100, ): - # Get the last N simulations ordered by start time - - desc_limit = f"DESC LIMIT {max_results}" if max_results is not None else "" + # Get the last N simulations ordered by start time. + # + # LIMIT is always applied (unbounded scans against reform_impact + # are expensive) and max_results is clamped to [1, + # _MAX_SIMULATION_RESULTS] before being bound as a parameter, so + # the value can never be interpolated into the SQL string. + if max_results is None: + max_results = _DEFAULT_SIMULATION_RESULTS + try: + max_results = int(max_results) + except (TypeError, ValueError): + max_results = _DEFAULT_SIMULATION_RESULTS + max_results = max(1, min(max_results, _MAX_SIMULATION_RESULTS)) result = local_database.query( - f"SELECT * FROM reform_impact ORDER BY start_time {desc_limit}", + "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", + (max_results,), ).fetchall() # Format into [{}] diff --git a/policyengine_api/routes/ai_prompt_routes.py b/policyengine_api/routes/ai_prompt_routes.py index a15bd16dd..c497613a0 100644 --- a/policyengine_api/routes/ai_prompt_routes.py +++ b/policyengine_api/routes/ai_prompt_routes.py @@ -12,11 +12,11 @@ ai_prompt_service = AIPromptService() -@validate_country @ai_prompt_bp.route( "//ai-prompts/", methods=["POST"], ) +@validate_country def generate_ai_prompt(country_id, prompt_name: str) -> Response: """ Get an AI prompt with a given name, filled with the given data. diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 4279a1b1b..8d2b8c6c4 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -13,11 +13,11 @@ economy_service = EconomyService() -@validate_country @economy_bp.route( "//economy//over/", methods=["GET"], ) +@validate_country def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int): policy_id = int(policy_id or get_current_law_policy_id(country_id)) diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index b3be7672c..1100faf97 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -1,7 +1,10 @@ -from flask import Blueprint, Response, request -from werkzeug.exceptions import NotFound, BadRequest +from flask import Blueprint, Response, current_app, request +from werkzeug.exceptions import HTTPException, NotFound, BadRequest import json +import jsonschema +import pydantic + from policyengine_api.services.report_output_service import ReportOutputService from policyengine_api.constants import CURRENT_YEAR from policyengine_api.utils.payload_validators import validate_country @@ -93,9 +96,20 @@ def create_report_output(country_id: str) -> Response: mimetype="application/json", ) - except Exception as e: - print(f"Error creating report output: {str(e)}") - raise BadRequest(f"Failed to create report output: {str(e)}") + except HTTPException: + # Let explicit client-error responses (BadRequest/NotFound/etc.) pass + # through without being logged as "Unexpected error". + raise + except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: + current_app.logger.warning( + "Bad request creating report output for country %s: %s", country_id, e + ) + raise BadRequest(f"Failed to create report output: {e}") + except Exception: + current_app.logger.exception( + "Unexpected error creating report output for country %s", country_id + ) + raise @report_output_bp.route("//report/", methods=["GET"]) @@ -204,8 +218,22 @@ def update_report_output(country_id: str) -> Response: mimetype="application/json", ) - except NotFound: + except HTTPException: + # Let explicit client-error responses (BadRequest/NotFound/etc.) pass + # through without being logged as "Unexpected error". + raise + except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: + current_app.logger.warning( + "Bad request updating report #%s for country %s: %s", + report_id, + country_id, + e, + ) + raise BadRequest(f"Failed to update report output: {e}") + except Exception: + current_app.logger.exception( + "Unexpected error updating report #%s for country %s", + report_id, + country_id, + ) raise - except Exception as e: - print(f"Error updating report output: {str(e)}") - raise BadRequest(f"Failed to update report output: {str(e)}") diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index 5a16b807e..f2bacd6cb 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -1,7 +1,10 @@ -from flask import Blueprint, Response, request -from werkzeug.exceptions import NotFound, BadRequest +from flask import Blueprint, Response, current_app, request +from werkzeug.exceptions import HTTPException, NotFound, BadRequest import json +import jsonschema +import pydantic + from policyengine_api.services.simulation_service import SimulationService from policyengine_api.utils.payload_validators import validate_country @@ -93,9 +96,20 @@ def create_simulation(country_id: str) -> Response: mimetype="application/json", ) - except Exception as e: - print(f"Error creating simulation: {str(e)}") - raise BadRequest(f"Failed to create simulation: {str(e)}") + except HTTPException: + # Let explicit client-error responses (BadRequest/NotFound/etc.) pass + # through without being logged as "Unexpected error". + raise + except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: + current_app.logger.warning( + "Bad request creating simulation for country %s: %s", country_id, e + ) + raise BadRequest(f"Failed to create simulation: {e}") + except Exception: + current_app.logger.exception( + "Unexpected error creating simulation for country %s", country_id + ) + raise @simulation_bp.route("//simulation/", methods=["GET"]) @@ -208,8 +222,22 @@ def update_simulation(country_id: str) -> Response: mimetype="application/json", ) - except NotFound: + except HTTPException: + # Let explicit client-error responses (BadRequest/NotFound/etc.) pass + # through without being logged as "Unexpected error". + raise + except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: + current_app.logger.warning( + "Bad request updating simulation #%s for country %s: %s", + simulation_id, + country_id, + e, + ) + raise BadRequest(f"Failed to update simulation: {e}") + except Exception: + current_app.logger.exception( + "Unexpected error updating simulation #%s for country %s", + simulation_id, + country_id, + ) raise - except Exception as e: - print(f"Error updating simulation: {str(e)}") - raise BadRequest(f"Failed to update simulation: {str(e)}") diff --git a/policyengine_api/services/household_service.py b/policyengine_api/services/household_service.py index bedf64400..2d3601737 100644 --- a/policyengine_api/services/household_service.py +++ b/policyengine_api/services/household_service.py @@ -107,19 +107,34 @@ def update_household( household_hash: str = hash_object(household_json) api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) + # WHERE must include country_id so an update scoped to + # one country cannot silently overwrite a household that + # happens to share the same numeric id under another + # country. database.query( - f"UPDATE household SET household_json = ?, household_hash = ?, label = ?, api_version = ? WHERE id = ?", + "UPDATE household " + "SET household_json = ?, household_hash = ?, label = ?, api_version = ? " + "WHERE id = ? AND country_id = ?", ( json.dumps(household_json), household_hash, label, api_version, household_id, + country_id, ), ) - # Fetch the updated JSON back from the table - updated_household: dict = self.get_household(country_id, household_id) + # Fetch the updated JSON back from the table. If the + # household did not exist for this country, get_household + # returns None. + updated_household: dict | None = self.get_household( + country_id, household_id + ) + if updated_household is None: + raise LookupError( + f"Household #{household_id} not found for country {country_id}." + ) return updated_household except Exception as e: print(f"Error updating household #{household_id}. Details: {str(e)}") diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index dfb208db2..e5582ee17 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -492,13 +492,18 @@ def update_simulation( update_fields.append("error_message = ?") update_values.append(error_message) - update_fields.append("api_version = ?") - update_values.append(api_version) - + # Only refresh api_version when the caller is actually + # changing one of the user-supplied fields above. The + # previous code appended api_version unconditionally, so + # the "no fields to update" guard below never fired and a + # PATCH with an empty body still touched the row. if not update_fields: print("No fields to update") return False + update_fields.append("api_version = ?") + update_values.append(api_version) + def tx_callback(tx): simulation = self._get_simulation_row( simulation_id, diff --git a/policyengine_api/utils/cache_utils.py b/policyengine_api/utils/cache_utils.py index 612dbf88c..12f3c6d04 100644 --- a/policyengine_api/utils/cache_utils.py +++ b/policyengine_api/utils/cache_utils.py @@ -1,5 +1,6 @@ """Tools for caching API responses.""" +import hashlib import json import logging import flask @@ -10,6 +11,11 @@ def make_cache_key(*args, **kwargs): """make a hash to uniquely identify a cache entry. keep it fast, adding overhead to try to add some minor chance of a cache hit is not worth it. + + Use a cryptographic digest (SHA-256) rather than the builtin + `hash()`, whose output depends on PYTHONHASHSEED and is therefore + different across workers/restarts; that made same-input requests + miss the cache in production. """ data = "" if flask.request.content_type == "application/json": @@ -22,7 +28,8 @@ def make_cache_key(*args, **kwargs): if data != "": data = json.dumps(data, separators=("", "")) - cache_key = str(hash(flask.request.full_path + data)) + full_path = flask.request.full_path + cache_key = hashlib.sha256((full_path + data).encode("utf-8")).hexdigest() logging.basicConfig(level=logging.DEBUG) logging.getLogger().debug( "PATH: %s, CACHE_KEY: %s", flask.request.full_path, cache_key diff --git a/tests/to_refactor/python/test_household_routes.py b/tests/to_refactor/python/test_household_routes.py index 3fa5af319..3456429dc 100644 --- a/tests/to_refactor/python/test_household_routes.py +++ b/tests/to_refactor/python/test_household_routes.py @@ -132,14 +132,18 @@ def test_update_household_success( assert data["status"] == "ok" assert data["result"]["household_id"] == 1 # assert data["result"]["household_json"] == updated_data["data"] + # WHERE now includes country_id (issue #3447). mock_database.query.assert_any_call( - "UPDATE household SET household_json = ?, household_hash = ?, label = ?, api_version = ? WHERE id = ?", + "UPDATE household " + "SET household_json = ?, household_hash = ?, label = ?, api_version = ? " + "WHERE id = ? AND country_id = ?", ( json.dumps(updated_household), "some-hash", valid_request_body["label"], COUNTRY_PACKAGE_VERSIONS.get("us"), 1, + "us", ), ) diff --git a/tests/unit/endpoints/economy/test_compare.py b/tests/unit/endpoints/economy/test_compare.py index 759cc7f26..86d2187dd 100644 --- a/tests/unit/endpoints/economy/test_compare.py +++ b/tests/unit/endpoints/economy/test_compare.py @@ -677,6 +677,58 @@ def test__given_no_region__returns_all_constituencies( assert result is not None assert len(result.by_constituency) == 3 + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch("policyengine_api.endpoints.economy.compare.h5py.File") + @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") + def test__country_filter_uses_prefix_not_substring( + self, mock_read_csv, mock_h5py_file, mock_download + ): + """Regression for issue #3453. + + Previously the filter used `"E" not in code`, so a Welsh + code containing any 'E' (e.g. "W12345E") or a Scottish code + containing 'W' would leak into the wrong country bucket. + Use startswith on the leading country-letter instead. + """ + mock_download.side_effect = [ + "/path/to/weights.h5", + "/path/to/names.csv", + ] + + mock_weights = np.ones((3, 10)) + mock_h5py_context = MagicMock() + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__exit__ = MagicMock(return_value=False) + mock_h5py_file.return_value = mock_h5py_context + + # Welsh code "W12345E7" happens to contain 'E' in its tail. + mock_const_df = pd.DataFrame( + { + "code": ["W12345E7", "E12345678", "S12345678"], + "name": ["Cardiff Trap", "Aldershot", "Edinburgh East"], + "x": [10.0, 5.0, 3.0], + "y": [20.0, 15.0, 12.0], + } + ) + mock_read_csv.return_value = mock_const_df + + baseline = {"household_net_income": np.array([1000.0] * 10)} + reform = {"household_net_income": np.array([1050.0] * 10)} + + result = uk_constituency_breakdown(baseline, reform, "uk", "country/england") + + assert result is not None + # Only Aldershot (code starting with 'E') should pass the + # England filter; the Welsh trap code must be excluded. + assert "Aldershot" in result.by_constituency + assert "Cardiff Trap" not in result.by_constituency + assert "Edinburgh East" not in result.by_constituency + + # The Welsh trap code must also not be double-counted in the + # England regional bucket. + england_total = sum(result.outcomes_by_region["england"].values()) + assert england_total == 1 + def _make_economy( incomes, diff --git a/tests/unit/endpoints/test_get_simulations.py b/tests/unit/endpoints/test_get_simulations.py new file mode 100644 index 000000000..56cb4bf6e --- /dev/null +++ b/tests/unit/endpoints/test_get_simulations.py @@ -0,0 +1,75 @@ +"""Regression tests for issue #3451. + +get_simulations built its LIMIT via an f-string +(`f"DESC LIMIT {max_results}"`), which is a SQL injection vector +(max_results flows in from a caller) and had no cap, so a tall +integer could drop unbounded rows on a production MySQL. The fix: +always LIMIT, clamp to [1, 1000], and bind as a parameter. +""" + +from policyengine_api.endpoints.simulation import get_simulations + + +def _seed_reform_impacts(test_db, n: int) -> None: + for i in range(n): + test_db.query( + """INSERT INTO reform_impact + (baseline_policy_id, reform_policy_id, country_id, region, dataset, + time_period, options_json, options_hash, api_version, + reform_impact_json, status, start_time, execution_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + i + 1, + i + 2, + "us", + "us", + "cps", + "2025", + "{}", + f"hash-{i}", + "1.0.0", + "{}", + "complete", + f"2026-01-01 00:00:{i:02d}", + f"exec-{i}", + ), + ) + + +def test_get_simulations_default_limit_caps_at_100(test_db): + _seed_reform_impacts(test_db, 150) + result = get_simulations() + assert len(result["result"]) == 100 + + +def test_get_simulations_clamps_huge_max_results(test_db): + _seed_reform_impacts(test_db, 50) + # A caller passing an absurdly large value must not crash and + # must not cause a full scan; the value is clamped at 1000. + result = get_simulations(max_results=10**9) + assert len(result["result"]) == 50 # only 50 seeded + + +def test_get_simulations_clamps_negative_max_results(test_db): + _seed_reform_impacts(test_db, 5) + # max_results of 0 or negative must still return something sane. + result = get_simulations(max_results=0) + assert 1 <= len(result["result"]) <= 5 + + +def test_get_simulations_defaults_when_none(test_db): + _seed_reform_impacts(test_db, 10) + result = get_simulations(max_results=None) + assert len(result["result"]) == 10 # fewer than the default 100 + + +def test_get_simulations_rejects_non_integer_gracefully(test_db): + _seed_reform_impacts(test_db, 5) + # A string like "100; DROP TABLE reform_impact" must not reach + # the SQL statement; it falls back to the default. + result = get_simulations(max_results="100; DROP TABLE reform_impact") + assert len(result["result"]) == 5 + + # And the table must still exist. + rows = test_db.query("SELECT COUNT(*) AS c FROM reform_impact").fetchone() + assert rows["c"] == 5 diff --git a/tests/unit/endpoints/test_update_user_policy.py b/tests/unit/endpoints/test_update_user_policy.py new file mode 100644 index 000000000..967c32ee3 --- /dev/null +++ b/tests/unit/endpoints/test_update_user_policy.py @@ -0,0 +1,125 @@ +"""Regression tests for issue #3445. + +update_user_policy (policy.py) previously interpolated untrusted +payload keys directly into an UPDATE statement, allowing arbitrary +SQL fragments (and identity-column tampering) via the JSON body. + +The fix rejects unknown keys with a 400 response and restricts +writable columns to a static whitelist. +""" + +import time + +from flask import Flask + +from policyengine_api.endpoints import update_user_policy + + +def _create_test_client() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + app.route("//user-policy", methods=["PUT"])(update_user_policy) + return app.test_client() + + +def _insert_user_policy(test_db) -> int: + now = int(time.time()) + test_db.query( + "INSERT INTO user_policies (country_id, reform_label, reform_id, " + "baseline_label, baseline_id, user_id, year, geography, dataset, " + "number_of_provisions, api_version, added_date, updated_date) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + "us", + "old label", + 2, + None, + 1, + "user1", + "2025", + "us", + "cps", + 3, + "1.0.0", + now, + now, + ), + ) + row = test_db.query( + "SELECT id FROM user_policies ORDER BY id DESC LIMIT 1" + ).fetchone() + return row["id"] + + +def test_update_user_policy_rejects_sql_injection_key(test_db): + """Unknown keys (including SQL injection attempts) must be rejected.""" + policy_id = _insert_user_policy(test_db) + + client = _create_test_client() + response = client.put( + "/us/user-policy", + json={ + "id": policy_id, + "username; DROP TABLE x --": "x", + }, + ) + + assert response.status_code == 400 + body = response.get_json() + assert "unsupported fields" in body["message"] + + # The row must be untouched. + row = test_db.query( + "SELECT reform_label FROM user_policies WHERE id = ?", + (policy_id,), + ).fetchone() + assert row["reform_label"] == "old label" + + +def test_update_user_policy_rejects_identity_column(test_db): + """Identity columns (user_id, country_id, ...) must not be writable.""" + policy_id = _insert_user_policy(test_db) + + client = _create_test_client() + response = client.put( + "/us/user-policy", + json={"id": policy_id, "user_id": "attacker"}, + ) + + assert response.status_code == 400 + row = test_db.query( + "SELECT user_id FROM user_policies WHERE id = ?", + (policy_id,), + ).fetchone() + assert row["user_id"] == "user1" + + +def test_update_user_policy_allows_whitelisted_field(test_db): + """Whitelisted fields (e.g. reform_label) can still be updated.""" + policy_id = _insert_user_policy(test_db) + + client = _create_test_client() + response = client.put( + "/us/user-policy", + json={"id": policy_id, "reform_label": "new label"}, + ) + + assert response.status_code == 200 + row = test_db.query( + "SELECT reform_label FROM user_policies WHERE id = ?", + (policy_id,), + ).fetchone() + assert row["reform_label"] == "new label" + + +def test_update_user_policy_requires_id(test_db): + client = _create_test_client() + response = client.put("/us/user-policy", json={"reform_label": "x"}) + assert response.status_code == 400 + + +def test_update_user_policy_requires_at_least_one_field(test_db): + policy_id = _insert_user_policy(test_db) + client = _create_test_client() + response = client.put("/us/user-policy", json={"id": policy_id}) + assert response.status_code == 400 diff --git a/tests/unit/routes/__init__.py b/tests/unit/routes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/routes/test_decorator_order.py b/tests/unit/routes/test_decorator_order.py new file mode 100644 index 000000000..20a967299 --- /dev/null +++ b/tests/unit/routes/test_decorator_order.py @@ -0,0 +1,52 @@ +"""Regression tests for issue #3446. + +economy_routes and ai_prompt_routes originally stacked +@validate_country above @bp.route. Because Flask only inspects the +function registered by bp.route, the wrapping logic ran in the wrong +order: validate_country bypassed the Response it returned, or Flask +saw a decorator that hadn't been registered as a route handler. +The fix puts @bp.route as the outermost decorator. + +An invalid country must now produce a 400 from validate_country +instead of a 200/500 from the view function. +""" + +from flask import Flask + +from policyengine_api.routes.ai_prompt_routes import ai_prompt_bp +from policyengine_api.routes.economy_routes import economy_bp + + +def _client_with(*blueprints) -> object: + app = Flask(__name__) + app.config["TESTING"] = True + for bp in blueprints: + app.register_blueprint(bp) + return app.test_client() + + +def test_economy_route_rejects_bogus_country(): + client = _client_with(economy_bp) + response = client.get("/bogus/economy/1/over/2?region=us&time_period=2025") + assert response.status_code == 400 + + +def test_ai_prompt_route_rejects_bogus_country(): + client = _client_with(ai_prompt_bp) + # Use a payload that passes validate_sim_analysis_payload so the only + # remaining reason to 400 is the unknown country_id. With the pre-#3446 + # decorator order the view runs first and reaches the service, so this + # request would not be rejected on country grounds. + valid_payload = { + "currency": "USD", + "selected_version": "v1.0", + "time_period": "2024", + "impact": {"value": 100}, + "policy_label": "Test Policy", + "policy": {"type": "tax", "rate": 0.1}, + "region": "NA", + "relevant_parameters": ["param1", "param2"], + "relevant_parameter_baseline_values": [1.0, 2.0], + } + response = client.post("/bogus/ai-prompts/some_prompt", json=valid_payload) + assert response.status_code == 400 diff --git a/tests/unit/routes/test_route_exception_handling.py b/tests/unit/routes/test_route_exception_handling.py new file mode 100644 index 000000000..b6fb38a28 --- /dev/null +++ b/tests/unit/routes/test_route_exception_handling.py @@ -0,0 +1,108 @@ +"""Regression tests for issue #3448. + +simulation_routes and report_output_routes caught every Exception and +converted it to BadRequest (400). That masked real 500s (DB errors, +coding bugs) and hid tracebacks. The fix: only ValueError / +pydantic.ValidationError / jsonschema.ValidationError become 400; +everything else propagates as 500 with logger.exception(). +""" + +from unittest.mock import patch + +from flask import Flask + +from policyengine_api.routes.report_output_routes import report_output_bp +from policyengine_api.routes.simulation_routes import simulation_bp + + +def _client_with(*blueprints): + app = Flask(__name__) + app.config["TESTING"] = True + # Required so werkzeug propagates exceptions to response as 500 + # rather than reraising in the test client. + app.config["PROPAGATE_EXCEPTIONS"] = False + for bp in blueprints: + app.register_blueprint(bp) + return app.test_client() + + +def test_simulation_create_runtime_error_becomes_500(): + client = _client_with(simulation_bp) + with patch( + "policyengine_api.routes.simulation_routes.simulation_service.find_existing_simulation", + side_effect=RuntimeError("db went away"), + ): + response = client.post( + "/us/simulation", + json={ + "population_id": "abc", + "population_type": "household", + "policy_id": 1, + }, + ) + assert response.status_code == 500 + + +def test_simulation_create_value_error_still_400(): + client = _client_with(simulation_bp) + with patch( + "policyengine_api.routes.simulation_routes.simulation_service.find_existing_simulation", + side_effect=ValueError("bad input"), + ): + response = client.post( + "/us/simulation", + json={ + "population_id": "abc", + "population_type": "household", + "policy_id": 1, + }, + ) + assert response.status_code == 400 + + +def test_report_create_runtime_error_becomes_500(): + client = _client_with(report_output_bp) + with patch( + "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output", + side_effect=RuntimeError("db went away"), + ): + response = client.post( + "/us/report", + json={"simulation_1_id": 1, "year": "2025"}, + ) + assert response.status_code == 500 + + +def test_report_create_value_error_still_400(): + client = _client_with(report_output_bp) + with patch( + "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output", + side_effect=ValueError("bad input"), + ): + response = client.post( + "/us/report", + json={"simulation_1_id": 1, "year": "2025"}, + ) + assert response.status_code == 400 + + +def test_simulation_patch_empty_body_returns_400(test_db): + """Regression for issue #3449. + + PATCH /{country}/simulation with a body that only contains the + id field must return 400 (no fields to update) instead of + silently rewriting api_version. + """ + from policyengine_api.services.simulation_service import SimulationService + + simulation_service = SimulationService() + created = simulation_service.create_simulation( + country_id="us", + population_id="household_patch_empty", + population_type="household", + policy_id=50, + ) + + client = _client_with(simulation_bp) + response = client.patch("/us/simulation", json={"id": created["id"]}) + assert response.status_code == 400 diff --git a/tests/unit/services/test_household_service.py b/tests/unit/services/test_household_service.py index b59c78e50..3d629e164 100644 --- a/tests/unit/services/test_household_service.py +++ b/tests/unit/services/test_household_service.py @@ -162,12 +162,40 @@ def test_update_household_given_nonexistent_record(self, test_db): existing_data = valid_db_row["household_json"] existing_label = valid_db_row["label"] - result = service.update_household( - existing_country_id, - NO_SUCH_RECORD_ID, - existing_data, - existing_label, - ) + # THEN update_household raises LookupError because the id + # does not exist for this country (issue #3447). + with pytest.raises(LookupError): + service.update_household( + existing_country_id, + NO_SUCH_RECORD_ID, + existing_data, + existing_label, + ) + + def test_update_household_rejects_cross_country_id( + self, test_db, mock_hash_object, existing_household_record + ): + """Regression for issue #3447. - # THEN no record will be modified - assert result is None + An existing US household must not be overwritten by a request + that targets the same numeric id under a different country. + """ + + existing_record_id = valid_db_row["id"] + existing_data = valid_db_row["household_json"] + + with pytest.raises(LookupError): + service.update_household( + "uk", # wrong country + existing_record_id, + existing_data, + "Attacker label", + ) + + # The original US row must be untouched. + row = test_db.query( + "SELECT label, country_id FROM household WHERE id = ?", + (existing_record_id,), + ).fetchone() + assert row["country_id"] == "us" + assert row["label"] == valid_db_row["label"] diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index 254c8867c..34116287f 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -485,3 +485,38 @@ def fail_dual_write(tx, simulation_id, *, country_id=None): assert run is not None assert run["status"] == "pending" assert run["output"] is None + + def test_update_simulation_with_no_user_fields_returns_false(self, test_db): + """Regression for issue #3449. + + update_fields used to always append api_version, so a PATCH + with no status/output/error_message still passed the + "no fields to update" guard and rewrote the row. The guard + must fire before api_version is appended so an empty PATCH + returns False (and the route converts that to a 400). + """ + created_simulation = service.create_simulation( + country_id="us", + population_id="household_empty_patch", + population_type="household", + policy_id=16, + ) + + pre_row = test_db.query( + "SELECT * FROM simulations WHERE id = ?", + (created_simulation["id"],), + ).fetchone() + + success = service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + ) + + assert success is False + + post_row = test_db.query( + "SELECT * FROM simulations WHERE id = ?", + (created_simulation["id"],), + ).fetchone() + assert post_row["api_version"] == pre_row["api_version"] + assert post_row["status"] == pre_row["status"] diff --git a/tests/unit/utils/__init__.py b/tests/unit/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/utils/test_cache_utils.py b/tests/unit/utils/test_cache_utils.py new file mode 100644 index 000000000..ddf0fac1c --- /dev/null +++ b/tests/unit/utils/test_cache_utils.py @@ -0,0 +1,101 @@ +"""Regression tests for issue #3450. + +make_cache_key previously used `str(hash(...))`, whose value depends +on PYTHONHASHSEED. Two workers (or even two processes) produced +different keys for identical inputs, defeating the cache. +Switch to SHA-256 so the digest is deterministic across processes. +""" + +import hashlib +import subprocess +import sys +import textwrap + +from flask import Flask + +from policyengine_api.utils.cache_utils import make_cache_key + + +def test_make_cache_key_deterministic_within_process(): + app = Flask(__name__) + + with app.test_request_context( + "/us/economy/1/over/2?foo=bar", + method="POST", + json={"alpha": 1, "beta": [2, 3]}, + ): + first = make_cache_key() + with app.test_request_context( + "/us/economy/1/over/2?foo=bar", + method="POST", + json={"alpha": 1, "beta": [2, 3]}, + ): + second = make_cache_key() + + assert first == second + + +def test_make_cache_key_is_sha256_hex(): + app = Flask(__name__) + with app.test_request_context( + "/us/economy/1/over/2", + method="POST", + json={"hello": "world"}, + ): + key = make_cache_key() + + # 64-character lowercase hex string with no non-hex characters. + assert len(key) == 64 + assert all(c in "0123456789abcdef" for c in key) + + # Exact digest value computed directly must match. + # full_path is "/us/economy/1/over/2?" (Flask appends '?' when no query string) + # and json.dumps(..., separators=("", "")) produces no separators. + import json + + expected = hashlib.sha256( + ( + "/us/economy/1/over/2?" + + json.dumps({"hello": "world"}, separators=("", "")) + ).encode("utf-8") + ).hexdigest() + assert key == expected + + +def test_make_cache_key_stable_across_processes(): + """Two independent Python processes must produce the same cache + key for the same inputs, even though they use different + PYTHONHASHSEED values.""" + script = textwrap.dedent( + """ + from flask import Flask + from policyengine_api.utils.cache_utils import make_cache_key + + app = Flask(__name__) + with app.test_request_context( + "/us/economy/1/over/2?foo=bar", + method="POST", + json={"alpha": 1}, + ): + print(make_cache_key()) + """ + ) + + def run_with_seed(seed: str) -> str: + env_cmd = [sys.executable, "-c", script] + result = subprocess.run( + env_cmd, + capture_output=True, + text=True, + env={ + "PATH": "/usr/bin:/bin:/usr/local/bin", + "PYTHONHASHSEED": seed, + }, + check=True, + ) + return result.stdout.strip() + + # Using PYTHONHASHSEED=0 (deterministic) vs random seed must match. + key_seed_0 = run_with_seed("0") + key_seed_1 = run_with_seed("1") + assert key_seed_0 == key_seed_1