From 8e8e0643f0c723a5d4c754fbf5257e2bb298e106 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:44:36 -0400 Subject: [PATCH 01/12] Fix SQL injection via UPDATE keys in update_user_policy JSON payload keys were interpolated directly into the UPDATE statement, so any attacker-controlled key could inject SQL (or silently tamper with identity columns like user_id). Restrict writable columns to a static whitelist and reject unknown keys with HTTP 400. Fixes #3445 --- changelog.d/3445.fixed.md | 1 + policyengine_api/endpoints/policy.py | 66 ++++++++- .../unit/endpoints/test_update_user_policy.py | 125 ++++++++++++++++++ 3 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 changelog.d/3445.fixed.md create mode 100644 tests/unit/endpoints/test_update_user_policy.py diff --git a/changelog.d/3445.fixed.md b/changelog.d/3445.fixed.md new file mode 100644 index 000000000..634f0d804 --- /dev/null +++ b/changelog.d/3445.fixed.md @@ -0,0 +1 @@ +Reject unknown columns in `PUT //user-policy` with HTTP 400 and restrict writable fields to a whitelist, closing a SQL injection path where JSON keys were interpolated into the UPDATE statement. diff --git a/policyengine_api/endpoints/policy.py b/policyengine_api/endpoints/policy.py index 290fe895e..f5d33e938 100644 --- a/policyengine_api/endpoints/policy.py +++ b/policyengine_api/endpoints/policy.py @@ -315,18 +315,78 @@ def get_user_policy(country_id: str, user_id: str) -> dict: ) +# Whitelist of columns that callers are allowed to modify via +# update_user_policy. Identity columns (id, country_id, user_id, +# reform_id, baseline_id) are intentionally excluded because they +# define the record; allowing clients to rewrite them would both +# break referential assumptions and let the column name be used +# as a SQL injection vector (keys are interpolated into the +# UPDATE statement below). +UPDATE_USER_POLICY_ALLOWED_FIELDS = frozenset( + { + "reform_label", + "baseline_label", + "year", + "geography", + "dataset", + "number_of_provisions", + "api_version", + "added_date", + "updated_date", + "budgetary_impact", + "type", + } +) + + @validate_country def update_user_policy(country_id: str) -> dict: """ Update any parts of a user_policy, given a user_policy ID """ - # Construct the relevant UPDATE request - setter_array = [] - args = [] payload = request.json + if not isinstance(payload, dict) or "id" not in payload: + return Response( + json.dumps({"message": "Request body must include an 'id' field."}), + status=400, + mimetype="application/json", + ) + user_policy_id = payload.pop("id") + # Reject any unknown/unsafe keys. The keys end up interpolated + # into a SQL UPDATE statement, so we must validate them against + # a static whitelist instead of trusting the JSON payload. + unknown_keys = [ + key for key in payload if key not in UPDATE_USER_POLICY_ALLOWED_FIELDS + ] + if unknown_keys: + return Response( + json.dumps( + { + "message": ( + "Request body contains unsupported fields: " + f"{sorted(unknown_keys)}" + ) + } + ), + status=400, + mimetype="application/json", + ) + + if not payload: + return Response( + json.dumps( + {"message": "Request body must include at least one field to update."} + ), + status=400, + mimetype="application/json", + ) + + # Construct the relevant UPDATE request from whitelisted keys. + setter_array = [] + args = [] for key in payload: setter_array.append(f"{key} = ?") args.append(payload[key]) diff --git a/tests/unit/endpoints/test_update_user_policy.py b/tests/unit/endpoints/test_update_user_policy.py new file mode 100644 index 000000000..967c32ee3 --- /dev/null +++ b/tests/unit/endpoints/test_update_user_policy.py @@ -0,0 +1,125 @@ +"""Regression tests for issue #3445. + +update_user_policy (policy.py) previously interpolated untrusted +payload keys directly into an UPDATE statement, allowing arbitrary +SQL fragments (and identity-column tampering) via the JSON body. + +The fix rejects unknown keys with a 400 response and restricts +writable columns to a static whitelist. +""" + +import time + +from flask import Flask + +from policyengine_api.endpoints import update_user_policy + + +def _create_test_client() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + app.route("//user-policy", methods=["PUT"])(update_user_policy) + return app.test_client() + + +def _insert_user_policy(test_db) -> int: + now = int(time.time()) + test_db.query( + "INSERT INTO user_policies (country_id, reform_label, reform_id, " + "baseline_label, baseline_id, user_id, year, geography, dataset, " + "number_of_provisions, api_version, added_date, updated_date) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + "us", + "old label", + 2, + None, + 1, + "user1", + "2025", + "us", + "cps", + 3, + "1.0.0", + now, + now, + ), + ) + row = test_db.query( + "SELECT id FROM user_policies ORDER BY id DESC LIMIT 1" + ).fetchone() + return row["id"] + + +def test_update_user_policy_rejects_sql_injection_key(test_db): + """Unknown keys (including SQL injection attempts) must be rejected.""" + policy_id = _insert_user_policy(test_db) + + client = _create_test_client() + response = client.put( + "/us/user-policy", + json={ + "id": policy_id, + "username; DROP TABLE x --": "x", + }, + ) + + assert response.status_code == 400 + body = response.get_json() + assert "unsupported fields" in body["message"] + + # The row must be untouched. + row = test_db.query( + "SELECT reform_label FROM user_policies WHERE id = ?", + (policy_id,), + ).fetchone() + assert row["reform_label"] == "old label" + + +def test_update_user_policy_rejects_identity_column(test_db): + """Identity columns (user_id, country_id, ...) must not be writable.""" + policy_id = _insert_user_policy(test_db) + + client = _create_test_client() + response = client.put( + "/us/user-policy", + json={"id": policy_id, "user_id": "attacker"}, + ) + + assert response.status_code == 400 + row = test_db.query( + "SELECT user_id FROM user_policies WHERE id = ?", + (policy_id,), + ).fetchone() + assert row["user_id"] == "user1" + + +def test_update_user_policy_allows_whitelisted_field(test_db): + """Whitelisted fields (e.g. reform_label) can still be updated.""" + policy_id = _insert_user_policy(test_db) + + client = _create_test_client() + response = client.put( + "/us/user-policy", + json={"id": policy_id, "reform_label": "new label"}, + ) + + assert response.status_code == 200 + row = test_db.query( + "SELECT reform_label FROM user_policies WHERE id = ?", + (policy_id,), + ).fetchone() + assert row["reform_label"] == "new label" + + +def test_update_user_policy_requires_id(test_db): + client = _create_test_client() + response = client.put("/us/user-policy", json={"reform_label": "x"}) + assert response.status_code == 400 + + +def test_update_user_policy_requires_at_least_one_field(test_db): + policy_id = _insert_user_policy(test_db) + client = _create_test_client() + response = client.put("/us/user-policy", json={"id": policy_id}) + assert response.status_code == 400 From 84f083b2a35446666b914e6623e4de97f3e63322 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:46:14 -0400 Subject: [PATCH 02/12] Fix decorator order on economy and ai-prompt routes @validate_country was stacked above @bp.route, so Flask registered the unwrapped view and validate_country never ran. Requests with an unknown country_id reached the handler and produced 200/500 instead of a 400. Swap the order to match household_routes. Fixes #3446 --- changelog.d/3446.fixed.md | 1 + policyengine_api/routes/ai_prompt_routes.py | 2 +- policyengine_api/routes/economy_routes.py | 2 +- tests/unit/routes/__init__.py | 0 tests/unit/routes/test_decorator_order.py | 37 +++++++++++++++++++++ 5 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 changelog.d/3446.fixed.md create mode 100644 tests/unit/routes/__init__.py create mode 100644 tests/unit/routes/test_decorator_order.py diff --git a/changelog.d/3446.fixed.md b/changelog.d/3446.fixed.md new file mode 100644 index 000000000..bcddc35cf --- /dev/null +++ b/changelog.d/3446.fixed.md @@ -0,0 +1 @@ +Move `@validate_country` below `@bp.route` in economy and AI-prompt routes so unknown country IDs are rejected with HTTP 400 instead of reaching the view function. diff --git a/policyengine_api/routes/ai_prompt_routes.py b/policyengine_api/routes/ai_prompt_routes.py index a15bd16dd..c497613a0 100644 --- a/policyengine_api/routes/ai_prompt_routes.py +++ b/policyengine_api/routes/ai_prompt_routes.py @@ -12,11 +12,11 @@ ai_prompt_service = AIPromptService() -@validate_country @ai_prompt_bp.route( "//ai-prompts/", methods=["POST"], ) +@validate_country def generate_ai_prompt(country_id, prompt_name: str) -> Response: """ Get an AI prompt with a given name, filled with the given data. diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 4279a1b1b..8d2b8c6c4 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -13,11 +13,11 @@ economy_service = EconomyService() -@validate_country @economy_bp.route( "//economy//over/", methods=["GET"], ) +@validate_country def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int): policy_id = int(policy_id or get_current_law_policy_id(country_id)) diff --git a/tests/unit/routes/__init__.py b/tests/unit/routes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/routes/test_decorator_order.py b/tests/unit/routes/test_decorator_order.py new file mode 100644 index 000000000..5fe8c404f --- /dev/null +++ b/tests/unit/routes/test_decorator_order.py @@ -0,0 +1,37 @@ +"""Regression tests for issue #3446. + +economy_routes and ai_prompt_routes originally stacked +@validate_country above @bp.route. Because Flask only inspects the +function registered by bp.route, the wrapping logic ran in the wrong +order: validate_country bypassed the Response it returned, or Flask +saw a decorator that hadn't been registered as a route handler. +The fix puts @bp.route as the outermost decorator. + +An invalid country must now produce a 400 from validate_country +instead of a 200/500 from the view function. +""" + +from flask import Flask + +from policyengine_api.routes.ai_prompt_routes import ai_prompt_bp +from policyengine_api.routes.economy_routes import economy_bp + + +def _client_with(*blueprints) -> object: + app = Flask(__name__) + app.config["TESTING"] = True + for bp in blueprints: + app.register_blueprint(bp) + return app.test_client() + + +def test_economy_route_rejects_bogus_country(): + client = _client_with(economy_bp) + response = client.get("/bogus/economy/1/over/2?region=us&time_period=2025") + assert response.status_code == 400 + + +def test_ai_prompt_route_rejects_bogus_country(): + client = _client_with(ai_prompt_bp) + response = client.post("/bogus/ai-prompts/some_prompt", json={}) + assert response.status_code == 400 From dc4e57fdfaa497bbd65cf0f26b0b893a38831de1 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:47:23 -0400 Subject: [PATCH 03/12] Fix update_household to scope WHERE by country_id The UPDATE statement keyed only on household id, so two different countries with the same numeric id could clobber each other. Add country_id to the WHERE clause and surface missing rows with a LookupError rather than silently returning stale data. Fixes #3447 --- changelog.d/3447.fixed.md | 1 + .../services/household_service.py | 21 +++++++-- tests/unit/services/test_household_service.py | 44 +++++++++++++++---- 3 files changed, 55 insertions(+), 11 deletions(-) create mode 100644 changelog.d/3447.fixed.md diff --git a/changelog.d/3447.fixed.md b/changelog.d/3447.fixed.md new file mode 100644 index 000000000..c1cb5308f --- /dev/null +++ b/changelog.d/3447.fixed.md @@ -0,0 +1 @@ +Scope `update_household` by `country_id` so an update for one country cannot overwrite a household that shares the same numeric id under another country. Missing rows now raise `LookupError` instead of silently succeeding. diff --git a/policyengine_api/services/household_service.py b/policyengine_api/services/household_service.py index bedf64400..2d3601737 100644 --- a/policyengine_api/services/household_service.py +++ b/policyengine_api/services/household_service.py @@ -107,19 +107,34 @@ def update_household( household_hash: str = hash_object(household_json) api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) + # WHERE must include country_id so an update scoped to + # one country cannot silently overwrite a household that + # happens to share the same numeric id under another + # country. database.query( - f"UPDATE household SET household_json = ?, household_hash = ?, label = ?, api_version = ? WHERE id = ?", + "UPDATE household " + "SET household_json = ?, household_hash = ?, label = ?, api_version = ? " + "WHERE id = ? AND country_id = ?", ( json.dumps(household_json), household_hash, label, api_version, household_id, + country_id, ), ) - # Fetch the updated JSON back from the table - updated_household: dict = self.get_household(country_id, household_id) + # Fetch the updated JSON back from the table. If the + # household did not exist for this country, get_household + # returns None. + updated_household: dict | None = self.get_household( + country_id, household_id + ) + if updated_household is None: + raise LookupError( + f"Household #{household_id} not found for country {country_id}." + ) return updated_household except Exception as e: print(f"Error updating household #{household_id}. Details: {str(e)}") diff --git a/tests/unit/services/test_household_service.py b/tests/unit/services/test_household_service.py index b59c78e50..3d629e164 100644 --- a/tests/unit/services/test_household_service.py +++ b/tests/unit/services/test_household_service.py @@ -162,12 +162,40 @@ def test_update_household_given_nonexistent_record(self, test_db): existing_data = valid_db_row["household_json"] existing_label = valid_db_row["label"] - result = service.update_household( - existing_country_id, - NO_SUCH_RECORD_ID, - existing_data, - existing_label, - ) + # THEN update_household raises LookupError because the id + # does not exist for this country (issue #3447). + with pytest.raises(LookupError): + service.update_household( + existing_country_id, + NO_SUCH_RECORD_ID, + existing_data, + existing_label, + ) + + def test_update_household_rejects_cross_country_id( + self, test_db, mock_hash_object, existing_household_record + ): + """Regression for issue #3447. - # THEN no record will be modified - assert result is None + An existing US household must not be overwritten by a request + that targets the same numeric id under a different country. + """ + + existing_record_id = valid_db_row["id"] + existing_data = valid_db_row["household_json"] + + with pytest.raises(LookupError): + service.update_household( + "uk", # wrong country + existing_record_id, + existing_data, + "Attacker label", + ) + + # The original US row must be untouched. + row = test_db.query( + "SELECT label, country_id FROM household WHERE id = ?", + (existing_record_id,), + ).fetchone() + assert row["country_id"] == "us" + assert row["label"] == valid_db_row["label"] From b9490db72ed1cd6e45e36bbc746d03e9277c9581 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:49:19 -0400 Subject: [PATCH 04/12] Narrow exception handling in simulation and report routes The simulation and report_output create/update handlers wrapped every Exception as BadRequest, which silently downgraded DB failures and bugs to 400s with no traceback logged. Only ValueError / pydantic / jsonschema validation errors now become 400; everything else propagates as 500 via the Flask error handler, after being logged with logger.exception(). Fixes #3448 --- changelog.d/3448.fixed.md | 1 + .../routes/report_output_routes.py | 36 ++++++-- policyengine_api/routes/simulation_routes.py | 36 ++++++-- .../routes/test_route_exception_handling.py | 86 +++++++++++++++++++ 4 files changed, 145 insertions(+), 14 deletions(-) create mode 100644 changelog.d/3448.fixed.md create mode 100644 tests/unit/routes/test_route_exception_handling.py diff --git a/changelog.d/3448.fixed.md b/changelog.d/3448.fixed.md new file mode 100644 index 000000000..3144993c8 --- /dev/null +++ b/changelog.d/3448.fixed.md @@ -0,0 +1 @@ +Narrow the exception handlers in `simulation_routes` and `report_output_routes` so only `ValueError`, `pydantic.ValidationError`, and `jsonschema.ValidationError` are mapped to HTTP 400. Unexpected exceptions now propagate as 500 with a logged traceback instead of being silently relabelled as bad requests. diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index b3be7672c..2f1686e9a 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -1,7 +1,10 @@ -from flask import Blueprint, Response, request +from flask import Blueprint, Response, current_app, request from werkzeug.exceptions import NotFound, BadRequest import json +import jsonschema +import pydantic + from policyengine_api.services.report_output_service import ReportOutputService from policyengine_api.constants import CURRENT_YEAR from policyengine_api.utils.payload_validators import validate_country @@ -93,9 +96,16 @@ def create_report_output(country_id: str) -> Response: mimetype="application/json", ) - except Exception as e: - print(f"Error creating report output: {str(e)}") - raise BadRequest(f"Failed to create report output: {str(e)}") + except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: + current_app.logger.warning( + "Bad request creating report output for country %s: %s", country_id, e + ) + raise BadRequest(f"Failed to create report output: {e}") + except Exception: + current_app.logger.exception( + "Unexpected error creating report output for country %s", country_id + ) + raise @report_output_bp.route("//report/", methods=["GET"]) @@ -206,6 +216,18 @@ def update_report_output(country_id: str) -> Response: except NotFound: raise - except Exception as e: - print(f"Error updating report output: {str(e)}") - raise BadRequest(f"Failed to update report output: {str(e)}") + except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: + current_app.logger.warning( + "Bad request updating report #%s for country %s: %s", + report_id, + country_id, + e, + ) + raise BadRequest(f"Failed to update report output: {e}") + except Exception: + current_app.logger.exception( + "Unexpected error updating report #%s for country %s", + report_id, + country_id, + ) + raise diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index 5a16b807e..c7736df3e 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -1,7 +1,10 @@ -from flask import Blueprint, Response, request +from flask import Blueprint, Response, current_app, request from werkzeug.exceptions import NotFound, BadRequest import json +import jsonschema +import pydantic + from policyengine_api.services.simulation_service import SimulationService from policyengine_api.utils.payload_validators import validate_country @@ -93,9 +96,16 @@ def create_simulation(country_id: str) -> Response: mimetype="application/json", ) - except Exception as e: - print(f"Error creating simulation: {str(e)}") - raise BadRequest(f"Failed to create simulation: {str(e)}") + except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: + current_app.logger.warning( + "Bad request creating simulation for country %s: %s", country_id, e + ) + raise BadRequest(f"Failed to create simulation: {e}") + except Exception: + current_app.logger.exception( + "Unexpected error creating simulation for country %s", country_id + ) + raise @simulation_bp.route("//simulation/", methods=["GET"]) @@ -210,6 +220,18 @@ def update_simulation(country_id: str) -> Response: except NotFound: raise - except Exception as e: - print(f"Error updating simulation: {str(e)}") - raise BadRequest(f"Failed to update simulation: {str(e)}") + except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: + current_app.logger.warning( + "Bad request updating simulation #%s for country %s: %s", + simulation_id, + country_id, + e, + ) + raise BadRequest(f"Failed to update simulation: {e}") + except Exception: + current_app.logger.exception( + "Unexpected error updating simulation #%s for country %s", + simulation_id, + country_id, + ) + raise diff --git a/tests/unit/routes/test_route_exception_handling.py b/tests/unit/routes/test_route_exception_handling.py new file mode 100644 index 000000000..d0eaa224f --- /dev/null +++ b/tests/unit/routes/test_route_exception_handling.py @@ -0,0 +1,86 @@ +"""Regression tests for issue #3448. + +simulation_routes and report_output_routes caught every Exception and +converted it to BadRequest (400). That masked real 500s (DB errors, +coding bugs) and hid tracebacks. The fix: only ValueError / +pydantic.ValidationError / jsonschema.ValidationError become 400; +everything else propagates as 500 with logger.exception(). +""" + +from unittest.mock import patch + +from flask import Flask + +from policyengine_api.routes.report_output_routes import report_output_bp +from policyengine_api.routes.simulation_routes import simulation_bp + + +def _client_with(*blueprints): + app = Flask(__name__) + app.config["TESTING"] = True + # Required so werkzeug propagates exceptions to response as 500 + # rather than reraising in the test client. + app.config["PROPAGATE_EXCEPTIONS"] = False + for bp in blueprints: + app.register_blueprint(bp) + return app.test_client() + + +def test_simulation_create_runtime_error_becomes_500(): + client = _client_with(simulation_bp) + with patch( + "policyengine_api.routes.simulation_routes.simulation_service.find_existing_simulation", + side_effect=RuntimeError("db went away"), + ): + response = client.post( + "/us/simulation", + json={ + "population_id": "abc", + "population_type": "household", + "policy_id": 1, + }, + ) + assert response.status_code == 500 + + +def test_simulation_create_value_error_still_400(): + client = _client_with(simulation_bp) + with patch( + "policyengine_api.routes.simulation_routes.simulation_service.find_existing_simulation", + side_effect=ValueError("bad input"), + ): + response = client.post( + "/us/simulation", + json={ + "population_id": "abc", + "population_type": "household", + "policy_id": 1, + }, + ) + assert response.status_code == 400 + + +def test_report_create_runtime_error_becomes_500(): + client = _client_with(report_output_bp) + with patch( + "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output", + side_effect=RuntimeError("db went away"), + ): + response = client.post( + "/us/report", + json={"simulation_1_id": 1, "year": "2025"}, + ) + assert response.status_code == 500 + + +def test_report_create_value_error_still_400(): + client = _client_with(report_output_bp) + with patch( + "policyengine_api.routes.report_output_routes.report_output_service.find_existing_report_output", + side_effect=ValueError("bad input"), + ): + response = client.post( + "/us/report", + json={"simulation_1_id": 1, "year": "2025"}, + ) + assert response.status_code == 400 From b4ee91d53898a0703f4e6c7fa40294675633f52d Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:50:52 -0400 Subject: [PATCH 05/12] Fix update_simulation to treat empty PATCH as no-op update_fields unconditionally appended api_version, so the "no fields to update" guard never fired and an empty PATCH still rewrote the row. Append api_version only after we know at least one user field was supplied; the route converts the resulting False return value into a 400. Fixes #3449 --- changelog.d/3449.fixed.md | 1 + .../services/simulation_service.py | 11 ++++-- .../routes/test_route_exception_handling.py | 22 ++++++++++++ .../unit/services/test_simulation_service.py | 35 +++++++++++++++++++ 4 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 changelog.d/3449.fixed.md diff --git a/changelog.d/3449.fixed.md b/changelog.d/3449.fixed.md new file mode 100644 index 000000000..5746d7e04 --- /dev/null +++ b/changelog.d/3449.fixed.md @@ -0,0 +1 @@ +`SimulationService.update_simulation` no longer rewrites `api_version` when the PATCH payload contains no user-supplied fields. The "no fields to update" guard now fires correctly so an empty PATCH returns 400 instead of silently bumping `api_version`. diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index dfb208db2..e5582ee17 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -492,13 +492,18 @@ def update_simulation( update_fields.append("error_message = ?") update_values.append(error_message) - update_fields.append("api_version = ?") - update_values.append(api_version) - + # Only refresh api_version when the caller is actually + # changing one of the user-supplied fields above. The + # previous code appended api_version unconditionally, so + # the "no fields to update" guard below never fired and a + # PATCH with an empty body still touched the row. if not update_fields: print("No fields to update") return False + update_fields.append("api_version = ?") + update_values.append(api_version) + def tx_callback(tx): simulation = self._get_simulation_row( simulation_id, diff --git a/tests/unit/routes/test_route_exception_handling.py b/tests/unit/routes/test_route_exception_handling.py index d0eaa224f..b6fb38a28 100644 --- a/tests/unit/routes/test_route_exception_handling.py +++ b/tests/unit/routes/test_route_exception_handling.py @@ -84,3 +84,25 @@ def test_report_create_value_error_still_400(): json={"simulation_1_id": 1, "year": "2025"}, ) assert response.status_code == 400 + + +def test_simulation_patch_empty_body_returns_400(test_db): + """Regression for issue #3449. + + PATCH /{country}/simulation with a body that only contains the + id field must return 400 (no fields to update) instead of + silently rewriting api_version. + """ + from policyengine_api.services.simulation_service import SimulationService + + simulation_service = SimulationService() + created = simulation_service.create_simulation( + country_id="us", + population_id="household_patch_empty", + population_type="household", + policy_id=50, + ) + + client = _client_with(simulation_bp) + response = client.patch("/us/simulation", json={"id": created["id"]}) + assert response.status_code == 400 diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index 254c8867c..34116287f 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -485,3 +485,38 @@ def fail_dual_write(tx, simulation_id, *, country_id=None): assert run is not None assert run["status"] == "pending" assert run["output"] is None + + def test_update_simulation_with_no_user_fields_returns_false(self, test_db): + """Regression for issue #3449. + + update_fields used to always append api_version, so a PATCH + with no status/output/error_message still passed the + "no fields to update" guard and rewrote the row. The guard + must fire before api_version is appended so an empty PATCH + returns False (and the route converts that to a 400). + """ + created_simulation = service.create_simulation( + country_id="us", + population_id="household_empty_patch", + population_type="household", + policy_id=16, + ) + + pre_row = test_db.query( + "SELECT * FROM simulations WHERE id = ?", + (created_simulation["id"],), + ).fetchone() + + success = service.update_simulation( + country_id="us", + simulation_id=created_simulation["id"], + ) + + assert success is False + + post_row = test_db.query( + "SELECT * FROM simulations WHERE id = ?", + (created_simulation["id"],), + ).fetchone() + assert post_row["api_version"] == pre_row["api_version"] + assert post_row["status"] == pre_row["status"] From 8f3a82ea57179df9b8932e7146ea11898b769ee5 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:52:25 -0400 Subject: [PATCH 06/12] Hash cache keys with SHA-256 instead of hash() Python's builtin `hash()` is salted per process by default (PYTHONHASHSEED), so two gunicorn workers computed different cache keys for identical requests and the cache rarely hit. Switch to sha256(full_path + data) so the digest is stable across workers and restarts. Fixes #3450 --- changelog.d/3450.fixed.md | 1 + policyengine_api/utils/cache_utils.py | 9 ++- tests/unit/utils/__init__.py | 0 tests/unit/utils/test_cache_utils.py | 101 ++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 changelog.d/3450.fixed.md create mode 100644 tests/unit/utils/__init__.py create mode 100644 tests/unit/utils/test_cache_utils.py diff --git a/changelog.d/3450.fixed.md b/changelog.d/3450.fixed.md new file mode 100644 index 000000000..368293bdb --- /dev/null +++ b/changelog.d/3450.fixed.md @@ -0,0 +1 @@ +Use `hashlib.sha256` for API cache keys instead of Python's builtin `hash()`. `hash()` is salted per process (PYTHONHASHSEED), so workers could produce different keys for identical inputs and miss the cache. diff --git a/policyengine_api/utils/cache_utils.py b/policyengine_api/utils/cache_utils.py index 612dbf88c..12f3c6d04 100644 --- a/policyengine_api/utils/cache_utils.py +++ b/policyengine_api/utils/cache_utils.py @@ -1,5 +1,6 @@ """Tools for caching API responses.""" +import hashlib import json import logging import flask @@ -10,6 +11,11 @@ def make_cache_key(*args, **kwargs): """make a hash to uniquely identify a cache entry. keep it fast, adding overhead to try to add some minor chance of a cache hit is not worth it. + + Use a cryptographic digest (SHA-256) rather than the builtin + `hash()`, whose output depends on PYTHONHASHSEED and is therefore + different across workers/restarts; that made same-input requests + miss the cache in production. """ data = "" if flask.request.content_type == "application/json": @@ -22,7 +28,8 @@ def make_cache_key(*args, **kwargs): if data != "": data = json.dumps(data, separators=("", "")) - cache_key = str(hash(flask.request.full_path + data)) + full_path = flask.request.full_path + cache_key = hashlib.sha256((full_path + data).encode("utf-8")).hexdigest() logging.basicConfig(level=logging.DEBUG) logging.getLogger().debug( "PATH: %s, CACHE_KEY: %s", flask.request.full_path, cache_key diff --git a/tests/unit/utils/__init__.py b/tests/unit/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/utils/test_cache_utils.py b/tests/unit/utils/test_cache_utils.py new file mode 100644 index 000000000..ddf0fac1c --- /dev/null +++ b/tests/unit/utils/test_cache_utils.py @@ -0,0 +1,101 @@ +"""Regression tests for issue #3450. + +make_cache_key previously used `str(hash(...))`, whose value depends +on PYTHONHASHSEED. Two workers (or even two processes) produced +different keys for identical inputs, defeating the cache. +Switch to SHA-256 so the digest is deterministic across processes. +""" + +import hashlib +import subprocess +import sys +import textwrap + +from flask import Flask + +from policyengine_api.utils.cache_utils import make_cache_key + + +def test_make_cache_key_deterministic_within_process(): + app = Flask(__name__) + + with app.test_request_context( + "/us/economy/1/over/2?foo=bar", + method="POST", + json={"alpha": 1, "beta": [2, 3]}, + ): + first = make_cache_key() + with app.test_request_context( + "/us/economy/1/over/2?foo=bar", + method="POST", + json={"alpha": 1, "beta": [2, 3]}, + ): + second = make_cache_key() + + assert first == second + + +def test_make_cache_key_is_sha256_hex(): + app = Flask(__name__) + with app.test_request_context( + "/us/economy/1/over/2", + method="POST", + json={"hello": "world"}, + ): + key = make_cache_key() + + # 64-character lowercase hex string with no non-hex characters. + assert len(key) == 64 + assert all(c in "0123456789abcdef" for c in key) + + # Exact digest value computed directly must match. + # full_path is "/us/economy/1/over/2?" (Flask appends '?' when no query string) + # and json.dumps(..., separators=("", "")) produces no separators. + import json + + expected = hashlib.sha256( + ( + "/us/economy/1/over/2?" + + json.dumps({"hello": "world"}, separators=("", "")) + ).encode("utf-8") + ).hexdigest() + assert key == expected + + +def test_make_cache_key_stable_across_processes(): + """Two independent Python processes must produce the same cache + key for the same inputs, even though they use different + PYTHONHASHSEED values.""" + script = textwrap.dedent( + """ + from flask import Flask + from policyengine_api.utils.cache_utils import make_cache_key + + app = Flask(__name__) + with app.test_request_context( + "/us/economy/1/over/2?foo=bar", + method="POST", + json={"alpha": 1}, + ): + print(make_cache_key()) + """ + ) + + def run_with_seed(seed: str) -> str: + env_cmd = [sys.executable, "-c", script] + result = subprocess.run( + env_cmd, + capture_output=True, + text=True, + env={ + "PATH": "/usr/bin:/bin:/usr/local/bin", + "PYTHONHASHSEED": seed, + }, + check=True, + ) + return result.stdout.strip() + + # Using PYTHONHASHSEED=0 (deterministic) vs random seed must match. + key_seed_0 = run_with_seed("0") + key_seed_1 = run_with_seed("1") + assert key_seed_0 == key_seed_1 From ff69c1bf9dacb6f5b77fe01f8aef89a0cc7b876b Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:53:27 -0400 Subject: [PATCH 07/12] Harden get_simulations LIMIT handling The f-string LIMIT clause allowed the caller-controlled max_results value into the SQL statement and could be omitted entirely. Always apply LIMIT, clamp max_results to [1, 1000], and bind the value as a parameter. Fixes #3451 --- changelog.d/3451.fixed.md | 1 + policyengine_api/endpoints/simulation.py | 25 +++++-- tests/unit/endpoints/test_get_simulations.py | 75 ++++++++++++++++++++ 3 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 changelog.d/3451.fixed.md create mode 100644 tests/unit/endpoints/test_get_simulations.py diff --git a/changelog.d/3451.fixed.md b/changelog.d/3451.fixed.md new file mode 100644 index 000000000..73ef006dc --- /dev/null +++ b/changelog.d/3451.fixed.md @@ -0,0 +1 @@ +`get_simulations` now always applies a LIMIT, clamps `max_results` to `[1, 1000]`, and binds it as a parameter, eliminating the f-string SQL fragment and the unbounded-scan risk. diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index 132e5b2d6..a0d9bd70d 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -21,15 +21,30 @@ """ +_MAX_SIMULATION_RESULTS = 1000 +_DEFAULT_SIMULATION_RESULTS = 100 + + def get_simulations( - max_results: int = 100, + max_results: int | None = 100, ): - # Get the last N simulations ordered by start time - - desc_limit = f"DESC LIMIT {max_results}" if max_results is not None else "" + # Get the last N simulations ordered by start time. + # + # LIMIT is always applied (unbounded scans against reform_impact + # are expensive) and max_results is clamped to [1, + # _MAX_SIMULATION_RESULTS] before being bound as a parameter, so + # the value can never be interpolated into the SQL string. + if max_results is None: + max_results = _DEFAULT_SIMULATION_RESULTS + try: + max_results = int(max_results) + except (TypeError, ValueError): + max_results = _DEFAULT_SIMULATION_RESULTS + max_results = max(1, min(max_results, _MAX_SIMULATION_RESULTS)) result = local_database.query( - f"SELECT * FROM reform_impact ORDER BY start_time {desc_limit}", + "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", + (max_results,), ).fetchall() # Format into [{}] diff --git a/tests/unit/endpoints/test_get_simulations.py b/tests/unit/endpoints/test_get_simulations.py new file mode 100644 index 000000000..56cb4bf6e --- /dev/null +++ b/tests/unit/endpoints/test_get_simulations.py @@ -0,0 +1,75 @@ +"""Regression tests for issue #3451. + +get_simulations built its LIMIT via an f-string +(`f"DESC LIMIT {max_results}"`), which is a SQL injection vector +(max_results flows in from a caller) and had no cap, so a tall +integer could drop unbounded rows on a production MySQL. The fix: +always LIMIT, clamp to [1, 1000], and bind as a parameter. +""" + +from policyengine_api.endpoints.simulation import get_simulations + + +def _seed_reform_impacts(test_db, n: int) -> None: + for i in range(n): + test_db.query( + """INSERT INTO reform_impact + (baseline_policy_id, reform_policy_id, country_id, region, dataset, + time_period, options_json, options_hash, api_version, + reform_impact_json, status, start_time, execution_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + i + 1, + i + 2, + "us", + "us", + "cps", + "2025", + "{}", + f"hash-{i}", + "1.0.0", + "{}", + "complete", + f"2026-01-01 00:00:{i:02d}", + f"exec-{i}", + ), + ) + + +def test_get_simulations_default_limit_caps_at_100(test_db): + _seed_reform_impacts(test_db, 150) + result = get_simulations() + assert len(result["result"]) == 100 + + +def test_get_simulations_clamps_huge_max_results(test_db): + _seed_reform_impacts(test_db, 50) + # A caller passing an absurdly large value must not crash and + # must not cause a full scan; the value is clamped at 1000. + result = get_simulations(max_results=10**9) + assert len(result["result"]) == 50 # only 50 seeded + + +def test_get_simulations_clamps_negative_max_results(test_db): + _seed_reform_impacts(test_db, 5) + # max_results of 0 or negative must still return something sane. + result = get_simulations(max_results=0) + assert 1 <= len(result["result"]) <= 5 + + +def test_get_simulations_defaults_when_none(test_db): + _seed_reform_impacts(test_db, 10) + result = get_simulations(max_results=None) + assert len(result["result"]) == 10 # fewer than the default 100 + + +def test_get_simulations_rejects_non_integer_gracefully(test_db): + _seed_reform_impacts(test_db, 5) + # A string like "100; DROP TABLE reform_impact" must not reach + # the SQL statement; it falls back to the default. + result = get_simulations(max_results="100; DROP TABLE reform_impact") + assert len(result["result"]) == 5 + + # And the table must still exist. + rows = test_db.query("SELECT COUNT(*) AS c FROM reform_impact").fetchone() + assert rows["c"] == 5 From 515aa50d997b5f8216b4007948c9a09ce4f7e4de Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:55:14 -0400 Subject: [PATCH 08/12] Replace bare except in compare.py with except Exception A bare except suppressed BaseException-derived exits (SystemExit, KeyboardInterrupt) in addition to regular errors and dropped the traceback entirely. Use except Exception plus logger.exception so the failure is visible in logs while keeping the existing fallback to empty wealth-decile data. Fixes #3452 --- changelog.d/3452.fixed.md | 1 + policyengine_api/endpoints/economy/compare.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 changelog.d/3452.fixed.md diff --git a/changelog.d/3452.fixed.md b/changelog.d/3452.fixed.md new file mode 100644 index 000000000..d403c87ba --- /dev/null +++ b/changelog.d/3452.fixed.md @@ -0,0 +1 @@ +Replace the bare `except:` around wealth-decile impact computation in `endpoints/economy/compare.py` with `except Exception:` plus `logger.exception(...)` so SystemExit/KeyboardInterrupt can still propagate and failures are visible in logs. diff --git a/policyengine_api/endpoints/economy/compare.py b/policyengine_api/endpoints/economy/compare.py index 3d3cf9c27..7f5527e54 100644 --- a/policyengine_api/endpoints/economy/compare.py +++ b/policyengine_api/endpoints/economy/compare.py @@ -1,3 +1,4 @@ +import logging from microdf import MicroDataFrame, MicroSeries import numpy as np import sys @@ -7,6 +8,8 @@ from pydantic import BaseModel from typing import Any +logger = logging.getLogger(__name__) + def budgetary_impact(baseline: dict, reform: dict) -> dict: tax_revenue_impact = reform["total_tax"] - baseline["total_tax"] @@ -813,7 +816,10 @@ def compare_economic_outputs( intra_wealth_decile_impact_data = intra_wealth_decile_impact( baseline, reform ) - except: + except Exception: + logger.exception( + "Wealth decile impact computation failed; returning empty breakdowns." + ) wealth_decile_impact_data = {} intra_wealth_decile_impact_data = {} From 9316c2dd4e594ddfc0341233c7a5975ac6004dca Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:57:33 -0400 Subject: [PATCH 09/12] Match constituency country codes by prefix, not substring uk_constituency_breakdown used `"E" not in code` (and friends) to skip rows outside a selected UK nation, so a Welsh code containing any 'E' in its tail was treated as English (and double-counted in the England bucket). Switch to `code.startswith("E"/"S"/"W"/"N")`, matching the local-authority code already at line 721+. Fixes #3453 --- changelog.d/3453.fixed.md | 1 + policyengine_api/endpoints/economy/compare.py | 22 ++++---- tests/unit/endpoints/economy/test_compare.py | 52 +++++++++++++++++++ 3 files changed, 66 insertions(+), 9 deletions(-) create mode 100644 changelog.d/3453.fixed.md diff --git a/changelog.d/3453.fixed.md b/changelog.d/3453.fixed.md new file mode 100644 index 000000000..c7a8e21ee --- /dev/null +++ b/changelog.d/3453.fixed.md @@ -0,0 +1 @@ +Constituency-level country filters in `uk_constituency_breakdown` now use `code.startswith("E"/"S"/"W"/"N")` instead of `"E"/"S"/"W"/"N" not in code`, matching the local-authority pattern already used elsewhere in the file and preventing a constituency's country letter from leaking into other buckets. diff --git a/policyengine_api/endpoints/economy/compare.py b/policyengine_api/endpoints/economy/compare.py index 7f5527e54..8437392cf 100644 --- a/policyengine_api/endpoints/economy/compare.py +++ b/policyengine_api/endpoints/economy/compare.py @@ -599,15 +599,19 @@ def uk_constituency_breakdown( if name != selected_constituency and code != selected_constituency: continue - # Filter to specific country if requested + # Filter to specific country if requested. Constituency codes + # are prefixed with a single letter that identifies the + # country (E=England, S=Scotland, W=Wales, N=Northern + # Ireland); use startswith so we don't accidentally match a + # letter anywhere else in the code. if selected_country is not None: - if selected_country == "ENGLAND" and "E" not in code: + if selected_country == "ENGLAND" and not code.startswith("E"): continue - elif selected_country == "SCOTLAND" and "S" not in code: + elif selected_country == "SCOTLAND" and not code.startswith("S"): continue - elif selected_country == "WALES" and "W" not in code: + elif selected_country == "WALES" and not code.startswith("W"): continue - elif selected_country == "NORTHERN_IRELAND" and "N" not in code: + elif selected_country == "NORTHERN_IRELAND" and not code.startswith("N"): continue weight: np.ndarray = weights[i] @@ -627,13 +631,13 @@ def uk_constituency_breakdown( } regions = ["uk"] - if "E" in code: + if code.startswith("E"): regions.append("england") - elif "S" in code: + elif code.startswith("S"): regions.append("scotland") - elif "W" in code: + elif code.startswith("W"): regions.append("wales") - elif "N" in code: + elif code.startswith("N"): regions.append("northern_ireland") if percent_household_income_change > 0.05: diff --git a/tests/unit/endpoints/economy/test_compare.py b/tests/unit/endpoints/economy/test_compare.py index 759cc7f26..86d2187dd 100644 --- a/tests/unit/endpoints/economy/test_compare.py +++ b/tests/unit/endpoints/economy/test_compare.py @@ -677,6 +677,58 @@ def test__given_no_region__returns_all_constituencies( assert result is not None assert len(result.by_constituency) == 3 + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") + @patch("policyengine_api.endpoints.economy.compare.h5py.File") + @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") + def test__country_filter_uses_prefix_not_substring( + self, mock_read_csv, mock_h5py_file, mock_download + ): + """Regression for issue #3453. + + Previously the filter used `"E" not in code`, so a Welsh + code containing any 'E' (e.g. "W12345E") or a Scottish code + containing 'W' would leak into the wrong country bucket. + Use startswith on the leading country-letter instead. + """ + mock_download.side_effect = [ + "/path/to/weights.h5", + "/path/to/names.csv", + ] + + mock_weights = np.ones((3, 10)) + mock_h5py_context = MagicMock() + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) + mock_h5py_context.__exit__ = MagicMock(return_value=False) + mock_h5py_file.return_value = mock_h5py_context + + # Welsh code "W12345E7" happens to contain 'E' in its tail. + mock_const_df = pd.DataFrame( + { + "code": ["W12345E7", "E12345678", "S12345678"], + "name": ["Cardiff Trap", "Aldershot", "Edinburgh East"], + "x": [10.0, 5.0, 3.0], + "y": [20.0, 15.0, 12.0], + } + ) + mock_read_csv.return_value = mock_const_df + + baseline = {"household_net_income": np.array([1000.0] * 10)} + reform = {"household_net_income": np.array([1050.0] * 10)} + + result = uk_constituency_breakdown(baseline, reform, "uk", "country/england") + + assert result is not None + # Only Aldershot (code starting with 'E') should pass the + # England filter; the Welsh trap code must be excluded. + assert "Aldershot" in result.by_constituency + assert "Cardiff Trap" not in result.by_constituency + assert "Edinburgh East" not in result.by_constituency + + # The Welsh trap code must also not be double-counted in the + # England regional bucket. + england_total = sum(result.outcomes_by_region["england"].values()) + assert england_total == 1 + def _make_economy( incomes, From bd8e0317bdbb9efbe75fbc529c26ecb9409bc86a Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 09:07:00 -0400 Subject: [PATCH 10/12] Update legacy household route test to expect country-scoped UPDATE tests/to_refactor/python/test_household_routes.py asserted the old UPDATE SQL that lacked `AND country_id = ?`. Align the mock assertion with the fix from #3447 so the legacy suite matches the corrected production query. Fixes #3447 --- tests/to_refactor/python/test_household_routes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/to_refactor/python/test_household_routes.py b/tests/to_refactor/python/test_household_routes.py index 3fa5af319..3456429dc 100644 --- a/tests/to_refactor/python/test_household_routes.py +++ b/tests/to_refactor/python/test_household_routes.py @@ -132,14 +132,18 @@ def test_update_household_success( assert data["status"] == "ok" assert data["result"]["household_id"] == 1 # assert data["result"]["household_json"] == updated_data["data"] + # WHERE now includes country_id (issue #3447). mock_database.query.assert_any_call( - "UPDATE household SET household_json = ?, household_hash = ?, label = ?, api_version = ? WHERE id = ?", + "UPDATE household " + "SET household_json = ?, household_hash = ?, label = ?, api_version = ? " + "WHERE id = ? AND country_id = ?", ( json.dumps(updated_household), "some-hash", valid_request_body["label"], COUNTRY_PACKAGE_VERSIONS.get("us"), 1, + "us", ), ) From ffb8cf1bba36bc2528da5d249ae9462219ef3cb5 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 12:35:54 -0400 Subject: [PATCH 11/12] Fix test payload for #3446 decorator order regression check The original test sent json={}, which fails validate_sim_analysis_payload with 400 before reaching the country check. That meant the test passed with or without the decorator-order fix. Use a payload that satisfies validate_sim_analysis_payload so the only remaining reason to 400 is the unknown country_id. With the pre-#3446 decorator order the view runs first, calls ai_prompt_service.get_prompt which returns None for an unknown prompt name, and yields 404 - so the test now genuinely detects the regression. --- tests/unit/routes/test_decorator_order.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/unit/routes/test_decorator_order.py b/tests/unit/routes/test_decorator_order.py index 5fe8c404f..20a967299 100644 --- a/tests/unit/routes/test_decorator_order.py +++ b/tests/unit/routes/test_decorator_order.py @@ -33,5 +33,20 @@ def test_economy_route_rejects_bogus_country(): def test_ai_prompt_route_rejects_bogus_country(): client = _client_with(ai_prompt_bp) - response = client.post("/bogus/ai-prompts/some_prompt", json={}) + # Use a payload that passes validate_sim_analysis_payload so the only + # remaining reason to 400 is the unknown country_id. With the pre-#3446 + # decorator order the view runs first and reaches the service, so this + # request would not be rejected on country grounds. + valid_payload = { + "currency": "USD", + "selected_version": "v1.0", + "time_period": "2024", + "impact": {"value": 100}, + "policy_label": "Test Policy", + "policy": {"type": "tax", "rate": 0.1}, + "region": "NA", + "relevant_parameters": ["param1", "param2"], + "relevant_parameter_baseline_values": [1.0, 2.0], + } + response = client.post("/bogus/ai-prompts/some_prompt", json=valid_payload) assert response.status_code == 400 From 415d97ee23ba9bc1dc32abcd8a1ceaba2a6a3d90 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 12:36:00 -0400 Subject: [PATCH 12/12] Pass through HTTPException in simulation and report route handlers BadRequest and NotFound raised explicitly from within the create/update try blocks were being caught by the generic `except Exception:` arm and logged as "Unexpected error" before being re-raised. Add an `except HTTPException: raise` ahead of the generic handler so expected client-error responses pass through without polluting the logs. This also subsumes the existing `except NotFound: raise` in the PATCH handlers. --- policyengine_api/routes/report_output_routes.py | 10 ++++++++-- policyengine_api/routes/simulation_routes.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index 2f1686e9a..1100faf97 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -1,5 +1,5 @@ from flask import Blueprint, Response, current_app, request -from werkzeug.exceptions import NotFound, BadRequest +from werkzeug.exceptions import HTTPException, NotFound, BadRequest import json import jsonschema @@ -96,6 +96,10 @@ def create_report_output(country_id: str) -> Response: mimetype="application/json", ) + except HTTPException: + # Let explicit client-error responses (BadRequest/NotFound/etc.) pass + # through without being logged as "Unexpected error". + raise except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: current_app.logger.warning( "Bad request creating report output for country %s: %s", country_id, e @@ -214,7 +218,9 @@ def update_report_output(country_id: str) -> Response: mimetype="application/json", ) - except NotFound: + except HTTPException: + # Let explicit client-error responses (BadRequest/NotFound/etc.) pass + # through without being logged as "Unexpected error". raise except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: current_app.logger.warning( diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index c7736df3e..f2bacd6cb 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -1,5 +1,5 @@ from flask import Blueprint, Response, current_app, request -from werkzeug.exceptions import NotFound, BadRequest +from werkzeug.exceptions import HTTPException, NotFound, BadRequest import json import jsonschema @@ -96,6 +96,10 @@ def create_simulation(country_id: str) -> Response: mimetype="application/json", ) + except HTTPException: + # Let explicit client-error responses (BadRequest/NotFound/etc.) pass + # through without being logged as "Unexpected error". + raise except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: current_app.logger.warning( "Bad request creating simulation for country %s: %s", country_id, e @@ -218,7 +222,9 @@ def update_simulation(country_id: str) -> Response: mimetype="application/json", ) - except NotFound: + except HTTPException: + # Let explicit client-error responses (BadRequest/NotFound/etc.) pass + # through without being logged as "Unexpected error". raise except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e: current_app.logger.warning(