Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/3445.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reject unknown columns in `PUT /<country_id>/user-policy` with HTTP 400 and restrict writable fields to a whitelist, closing a SQL injection path where JSON keys were interpolated into the UPDATE statement.
1 change: 1 addition & 0 deletions changelog.d/3446.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Move `@validate_country` below `@bp.route` in economy and AI-prompt routes so unknown country IDs are rejected with HTTP 400 instead of reaching the view function.
1 change: 1 addition & 0 deletions changelog.d/3447.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Scope `update_household` by `country_id` so an update for one country cannot overwrite a household that shares the same numeric id under another country. Missing rows now raise `LookupError` instead of silently succeeding.
1 change: 1 addition & 0 deletions changelog.d/3448.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Narrow the exception handlers in `simulation_routes` and `report_output_routes` so only `ValueError`, `pydantic.ValidationError`, and `jsonschema.ValidationError` are mapped to HTTP 400. Unexpected exceptions now propagate as 500 with a logged traceback instead of being silently relabelled as bad requests.
1 change: 1 addition & 0 deletions changelog.d/3449.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`SimulationService.update_simulation` no longer rewrites `api_version` when the PATCH payload contains no user-supplied fields. The "no fields to update" guard now fires correctly so an empty PATCH returns 400 instead of silently bumping `api_version`.
1 change: 1 addition & 0 deletions changelog.d/3450.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `hashlib.sha256` for API cache keys instead of Python's builtin `hash()`. `hash()` is salted per process (PYTHONHASHSEED), so workers could produce different keys for identical inputs and miss the cache.
1 change: 1 addition & 0 deletions changelog.d/3451.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`get_simulations` now always applies a LIMIT, clamps `max_results` to `[1, 1000]`, and binds it as a parameter, eliminating the f-string SQL fragment and the unbounded-scan risk.
1 change: 1 addition & 0 deletions changelog.d/3452.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Replace the bare `except:` around wealth-decile impact computation in `endpoints/economy/compare.py` with `except Exception:` plus `logger.exception(...)` so SystemExit/KeyboardInterrupt can still propagate and failures are visible in logs.
1 change: 1 addition & 0 deletions changelog.d/3453.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Constituency-level country filters in `uk_constituency_breakdown` now use `code.startswith("E"/"S"/"W"/"N")` instead of `"E"/"S"/"W"/"N" not in code`, matching the local-authority pattern already used elsewhere in the file and preventing a constituency's country letter from leaking into other buckets.
30 changes: 20 additions & 10 deletions policyengine_api/endpoints/economy/compare.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from microdf import MicroDataFrame, MicroSeries
import numpy as np
import sys
Expand All @@ -7,6 +8,8 @@
from pydantic import BaseModel
from typing import Any

logger = logging.getLogger(__name__)


def budgetary_impact(baseline: dict, reform: dict) -> dict:
tax_revenue_impact = reform["total_tax"] - baseline["total_tax"]
Expand Down Expand Up @@ -596,15 +599,19 @@ def uk_constituency_breakdown(
if name != selected_constituency and code != selected_constituency:
continue

# Filter to specific country if requested
# Filter to specific country if requested. Constituency codes
# are prefixed with a single letter that identifies the
# country (E=England, S=Scotland, W=Wales, N=Northern
# Ireland); use startswith so we don't accidentally match a
# letter anywhere else in the code.
if selected_country is not None:
if selected_country == "ENGLAND" and "E" not in code:
if selected_country == "ENGLAND" and not code.startswith("E"):
continue
elif selected_country == "SCOTLAND" and "S" not in code:
elif selected_country == "SCOTLAND" and not code.startswith("S"):
continue
elif selected_country == "WALES" and "W" not in code:
elif selected_country == "WALES" and not code.startswith("W"):
continue
elif selected_country == "NORTHERN_IRELAND" and "N" not in code:
elif selected_country == "NORTHERN_IRELAND" and not code.startswith("N"):
continue

weight: np.ndarray = weights[i]
Expand All @@ -624,13 +631,13 @@ def uk_constituency_breakdown(
}

regions = ["uk"]
if "E" in code:
if code.startswith("E"):
regions.append("england")
elif "S" in code:
elif code.startswith("S"):
regions.append("scotland")
elif "W" in code:
elif code.startswith("W"):
regions.append("wales")
elif "N" in code:
elif code.startswith("N"):
regions.append("northern_ireland")

if percent_household_income_change > 0.05:
Expand Down Expand Up @@ -813,7 +820,10 @@ def compare_economic_outputs(
intra_wealth_decile_impact_data = intra_wealth_decile_impact(
baseline, reform
)
except:
except Exception:
logger.exception(
"Wealth decile impact computation failed; returning empty breakdowns."
)
wealth_decile_impact_data = {}
intra_wealth_decile_impact_data = {}

Expand Down
66 changes: 63 additions & 3 deletions policyengine_api/endpoints/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,18 +315,78 @@ def get_user_policy(country_id: str, user_id: str) -> dict:
)


# Whitelist of columns that callers are allowed to modify via
# update_user_policy. Identity columns (id, country_id, user_id,
# reform_id, baseline_id) are intentionally excluded because they
# define the record; allowing clients to rewrite them would both
# break referential assumptions and let the column name be used
# as a SQL injection vector (keys are interpolated into the
# UPDATE statement below).
UPDATE_USER_POLICY_ALLOWED_FIELDS = frozenset(
{
"reform_label",
"baseline_label",
"year",
"geography",
"dataset",
"number_of_provisions",
"api_version",
"added_date",
"updated_date",
"budgetary_impact",
"type",
}
)


@validate_country
def update_user_policy(country_id: str) -> dict:
"""
Update any parts of a user_policy, given a user_policy ID
"""

# Construct the relevant UPDATE request
setter_array = []
args = []
payload = request.json
if not isinstance(payload, dict) or "id" not in payload:
return Response(
json.dumps({"message": "Request body must include an 'id' field."}),
status=400,
mimetype="application/json",
)

user_policy_id = payload.pop("id")

# Reject any unknown/unsafe keys. The keys end up interpolated
# into a SQL UPDATE statement, so we must validate them against
# a static whitelist instead of trusting the JSON payload.
unknown_keys = [
key for key in payload if key not in UPDATE_USER_POLICY_ALLOWED_FIELDS
]
if unknown_keys:
return Response(
json.dumps(
{
"message": (
"Request body contains unsupported fields: "
f"{sorted(unknown_keys)}"
)
}
),
status=400,
mimetype="application/json",
)

if not payload:
return Response(
json.dumps(
{"message": "Request body must include at least one field to update."}
),
status=400,
mimetype="application/json",
)

# Construct the relevant UPDATE request from whitelisted keys.
setter_array = []
args = []
for key in payload:
setter_array.append(f"{key} = ?")
args.append(payload[key])
Expand Down
25 changes: 20 additions & 5 deletions policyengine_api/endpoints/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,30 @@
"""


_MAX_SIMULATION_RESULTS = 1000
_DEFAULT_SIMULATION_RESULTS = 100


def get_simulations(
max_results: int = 100,
max_results: int | None = 100,
):
# Get the last N simulations ordered by start time

desc_limit = f"DESC LIMIT {max_results}" if max_results is not None else ""
# Get the last N simulations ordered by start time.
#
# LIMIT is always applied (unbounded scans against reform_impact
# are expensive) and max_results is clamped to [1,
# _MAX_SIMULATION_RESULTS] before being bound as a parameter, so
# the value can never be interpolated into the SQL string.
if max_results is None:
max_results = _DEFAULT_SIMULATION_RESULTS
try:
max_results = int(max_results)
except (TypeError, ValueError):
max_results = _DEFAULT_SIMULATION_RESULTS
max_results = max(1, min(max_results, _MAX_SIMULATION_RESULTS))

result = local_database.query(
f"SELECT * FROM reform_impact ORDER BY start_time {desc_limit}",
"SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?",
(max_results,),
).fetchall()

# Format into [{}]
Expand Down
2 changes: 1 addition & 1 deletion policyengine_api/routes/ai_prompt_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
ai_prompt_service = AIPromptService()


@validate_country
@ai_prompt_bp.route(
"/<country_id>/ai-prompts/<string:prompt_name>",
methods=["POST"],
)
@validate_country
def generate_ai_prompt(country_id, prompt_name: str) -> Response:
"""
Get an AI prompt with a given name, filled with the given data.
Expand Down
2 changes: 1 addition & 1 deletion policyengine_api/routes/economy_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
economy_service = EconomyService()


@validate_country
@economy_bp.route(
"/<country_id>/economy/<int:policy_id>/over/<int:baseline_policy_id>",
methods=["GET"],
)
@validate_country
def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int):

policy_id = int(policy_id or get_current_law_policy_id(country_id))
Expand Down
46 changes: 37 additions & 9 deletions policyengine_api/routes/report_output_routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from flask import Blueprint, Response, request
from werkzeug.exceptions import NotFound, BadRequest
from flask import Blueprint, Response, current_app, request
from werkzeug.exceptions import HTTPException, NotFound, BadRequest
import json

import jsonschema
import pydantic

from policyengine_api.services.report_output_service import ReportOutputService
from policyengine_api.constants import CURRENT_YEAR
from policyengine_api.utils.payload_validators import validate_country
Expand Down Expand Up @@ -93,9 +96,20 @@ def create_report_output(country_id: str) -> Response:
mimetype="application/json",
)

except Exception as e:
print(f"Error creating report output: {str(e)}")
raise BadRequest(f"Failed to create report output: {str(e)}")
except HTTPException:
# Let explicit client-error responses (BadRequest/NotFound/etc.) pass
# through without being logged as "Unexpected error".
raise
except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e:
current_app.logger.warning(
"Bad request creating report output for country %s: %s", country_id, e
)
raise BadRequest(f"Failed to create report output: {e}")
except Exception:
current_app.logger.exception(
"Unexpected error creating report output for country %s", country_id
)
raise


@report_output_bp.route("/<country_id>/report/<int:report_id>", methods=["GET"])
Expand Down Expand Up @@ -204,8 +218,22 @@ def update_report_output(country_id: str) -> Response:
mimetype="application/json",
)

except NotFound:
except HTTPException:
# Let explicit client-error responses (BadRequest/NotFound/etc.) pass
# through without being logged as "Unexpected error".
raise
except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e:
current_app.logger.warning(
"Bad request updating report #%s for country %s: %s",
report_id,
country_id,
e,
)
raise BadRequest(f"Failed to update report output: {e}")
except Exception:
current_app.logger.exception(
"Unexpected error updating report #%s for country %s",
report_id,
country_id,
)
raise
except Exception as e:
print(f"Error updating report output: {str(e)}")
raise BadRequest(f"Failed to update report output: {str(e)}")
46 changes: 37 additions & 9 deletions policyengine_api/routes/simulation_routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from flask import Blueprint, Response, request
from werkzeug.exceptions import NotFound, BadRequest
from flask import Blueprint, Response, current_app, request
from werkzeug.exceptions import HTTPException, NotFound, BadRequest
import json

import jsonschema
import pydantic

from policyengine_api.services.simulation_service import SimulationService
from policyengine_api.utils.payload_validators import validate_country

Expand Down Expand Up @@ -93,9 +96,20 @@ def create_simulation(country_id: str) -> Response:
mimetype="application/json",
)

except Exception as e:
print(f"Error creating simulation: {str(e)}")
raise BadRequest(f"Failed to create simulation: {str(e)}")
except HTTPException:
# Let explicit client-error responses (BadRequest/NotFound/etc.) pass
# through without being logged as "Unexpected error".
raise
except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e:
current_app.logger.warning(
"Bad request creating simulation for country %s: %s", country_id, e
)
raise BadRequest(f"Failed to create simulation: {e}")
except Exception:
current_app.logger.exception(
"Unexpected error creating simulation for country %s", country_id
)
raise


@simulation_bp.route("/<country_id>/simulation/<int:simulation_id>", methods=["GET"])
Expand Down Expand Up @@ -208,8 +222,22 @@ def update_simulation(country_id: str) -> Response:
mimetype="application/json",
)

except NotFound:
except HTTPException:
# Let explicit client-error responses (BadRequest/NotFound/etc.) pass
# through without being logged as "Unexpected error".
raise
except (ValueError, pydantic.ValidationError, jsonschema.ValidationError) as e:
current_app.logger.warning(
"Bad request updating simulation #%s for country %s: %s",
simulation_id,
country_id,
e,
)
raise BadRequest(f"Failed to update simulation: {e}")
except Exception:
current_app.logger.exception(
"Unexpected error updating simulation #%s for country %s",
simulation_id,
country_id,
)
raise
except Exception as e:
print(f"Error updating simulation: {str(e)}")
raise BadRequest(f"Failed to update simulation: {str(e)}")
21 changes: 18 additions & 3 deletions policyengine_api/services/household_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,34 @@ def update_household(
household_hash: str = hash_object(household_json)
api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id)

# WHERE must include country_id so an update scoped to
# one country cannot silently overwrite a household that
# happens to share the same numeric id under another
# country.
database.query(
f"UPDATE household SET household_json = ?, household_hash = ?, label = ?, api_version = ? WHERE id = ?",
"UPDATE household "
"SET household_json = ?, household_hash = ?, label = ?, api_version = ? "
"WHERE id = ? AND country_id = ?",
(
json.dumps(household_json),
household_hash,
label,
api_version,
household_id,
country_id,
),
)

# Fetch the updated JSON back from the table
updated_household: dict = self.get_household(country_id, household_id)
# Fetch the updated JSON back from the table. If the
# household did not exist for this country, get_household
# returns None.
updated_household: dict | None = self.get_household(
country_id, household_id
)
if updated_household is None:
raise LookupError(
f"Household #{household_id} not found for country {country_id}."
)
return updated_household
except Exception as e:
print(f"Error updating household #{household_id}. Details: {str(e)}")
Expand Down
Loading
Loading