Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/report-output-run-stage-2.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add internal report and simulation spec, alias, and run services for the report-output run migration.
89 changes: 77 additions & 12 deletions policyengine_api/data/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sqlite3
from policyengine_api.constants import REPO, VERSION, COUNTRY_PACKAGE_VERSIONS
from policyengine_api.constants import REPO, COUNTRY_PACKAGE_VERSIONS
from policyengine_api.utils import hash_object
from pathlib import Path
from dotenv import load_dotenv
Expand Down Expand Up @@ -41,6 +41,29 @@ def fetchall(self):
return remaining


class _TransactionProxy:
"""Execute queries against an existing connection inside a transaction."""

def __init__(self, connection, local: bool):
self._connection = connection
self._local = local

def query(self, *query):
if self._local:
cursor = self._connection.cursor()
return cursor.execute(*query)

query = list(query)
main_query = query[0].replace("?", "%s")
query[0] = main_query
params = query[1] if len(query) > 1 else None
if params is not None:
result = self._connection.exec_driver_sql(main_query, params)
else:
result = self._connection.exec_driver_sql(main_query)
return _ResultProxy(result)


class PolicyEngineDatabase:
"""
A wrapper around the database connection.
Expand All @@ -50,6 +73,13 @@ class PolicyEngineDatabase:

household_cache: dict = {}

@staticmethod
def _dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d

def __init__(
self,
local: bool = False,
Expand Down Expand Up @@ -91,7 +121,7 @@ def _close_pool(self):
try:
self.pool.dispose()
self.connector.close()
except:
except Exception:
pass

def _execute_remote(self, query_args):
Expand All @@ -110,17 +140,22 @@ def _execute_remote(self, query_args):
# connection context closing
return _ResultProxy(result)

def _execute_remote_transaction(self, callback):
with self.pool.connect() as conn:
transaction = conn.begin()
proxy = _TransactionProxy(conn, local=False)
try:
result = callback(proxy)
transaction.commit()
return result
except Exception:
transaction.rollback()
raise

def query(self, *query):
if self.local:
with sqlite3.connect(self.db_url) as conn:

def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d

conn.row_factory = dict_factory
conn.row_factory = self._dict_factory
cursor = conn.cursor()
return cursor.execute(*query)
else:
Expand All @@ -134,14 +169,44 @@ def dict_factory(cursor, row):
except (
sqlalchemy.exc.InterfaceError,
sqlalchemy.exc.OperationalError,
) as e:
):
try:
self._close_pool()
self._create_pool()
return self._execute_remote(query)
except Exception as e:
raise e

def transaction(self, callback):
if self.local:
connection = getattr(self, "_connection", None)
owns_connection = connection is None
if owns_connection:
connection = sqlite3.connect(self.db_url)
connection.row_factory = self._dict_factory
try:
connection.execute("BEGIN IMMEDIATE")
proxy = _TransactionProxy(connection, local=True)
result = callback(proxy)
connection.commit()
return result
except Exception:
connection.rollback()
raise
finally:
if owns_connection:
connection.close()

try:
return self._execute_remote_transaction(callback)
except (
sqlalchemy.exc.InterfaceError,
sqlalchemy.exc.OperationalError,
):
self._close_pool()
self._create_pool()
return self._execute_remote_transaction(callback)

def initialize(self):
"""
Create the database tables.
Expand Down Expand Up @@ -175,7 +240,7 @@ def initialize(self):
range(1, 1 + len(COUNTRY_PACKAGE_VERSIONS)),
):
self.query(
f"INSERT INTO policy (id, country_id, label, api_version, policy_json, policy_hash) VALUES (?, ?, ?, ?, ?, ?)",
"INSERT INTO policy (id, country_id, label, api_version, policy_json, policy_hash) VALUES (?, ?, ?, ?, ?, ?)",
(
policy_id,
country_id,
Expand Down
97 changes: 97 additions & 0 deletions policyengine_api/services/report_output_alias_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from sqlalchemy.engine.row import Row

from policyengine_api.data import database


class ReportOutputAliasService:
def _get_report_output_row(self, report_output_id: int) -> dict | None:
row: Row | None = database.query(
"""
SELECT id, country_id, simulation_1_id, simulation_2_id, year
FROM report_outputs
WHERE id = ?
""",
(report_output_id,),
).fetchone()
return dict(row) if row is not None else None

def get_alias(self, legacy_report_output_id: int) -> dict | None:
row: Row | None = database.query(
"""
SELECT * FROM legacy_report_output_aliases
WHERE legacy_report_output_id = ?
""",
(legacy_report_output_id,),
).fetchone()
return dict(row) if row is not None else None

def resolve_canonical_report_output_id(
self, requested_report_output_id: int
) -> int | None:
alias = self.get_alias(requested_report_output_id)
if alias is not None:
canonical_report_output_id = alias["canonical_report_output_id"]
if self._get_report_output_row(canonical_report_output_id) is None:
raise ValueError(
"Alias points to missing canonical report output "
f"#{canonical_report_output_id}"
)
return canonical_report_output_id

row: Row | None = database.query(
"SELECT id FROM report_outputs WHERE id = ?",
(requested_report_output_id,),
).fetchone()
return row["id"] if row is not None else None

def set_alias(
self,
legacy_report_output_id: int,
canonical_report_output_id: int,
) -> bool:
legacy_report_output = self._get_report_output_row(legacy_report_output_id)
if legacy_report_output is None:
raise ValueError(
f"Legacy report output #{legacy_report_output_id} not found"
)

canonical_report_output = self._get_report_output_row(
canonical_report_output_id
)
if canonical_report_output is None:
raise ValueError(
f"Canonical report output #{canonical_report_output_id} not found"
)
if legacy_report_output_id == canonical_report_output_id:
raise ValueError("Legacy and canonical report outputs must be different")

existing_alias = self.get_alias(legacy_report_output_id)
if existing_alias is not None:
if (
existing_alias["canonical_report_output_id"]
== canonical_report_output_id
):
return True

raise ValueError(
"Legacy report output alias already points to canonical report output "
f"#{existing_alias['canonical_report_output_id']}"
)

logical_key = ("country_id", "simulation_1_id", "simulation_2_id", "year")
if any(
legacy_report_output[field] != canonical_report_output[field]
for field in logical_key
):
raise ValueError(
"Legacy and canonical report outputs must describe the same report"
)
database.query(
"""
INSERT INTO legacy_report_output_aliases
(legacy_report_output_id, canonical_report_output_id)
VALUES (?, ?)
""",
(legacy_report_output_id, canonical_report_output_id),
)
return True
152 changes: 152 additions & 0 deletions policyengine_api/services/report_run_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import json
import uuid
from typing import Any

from sqlalchemy.engine.row import Row

from policyengine_api.data import database


REPORT_RUN_VERSION_FIELDS = (
"country_package_version",
"policyengine_version",
"data_version",
"runtime_app_name",
"report_cache_version",
"simulation_cache_version",
"requested_version_override",
"resolved_dataset",
"resolved_options_hash",
)


class ReportRunService:
def _serialize_json(
self, value: dict[str, Any] | list[Any] | str | None
) -> str | None:
if value is None or isinstance(value, str):
return value
return json.dumps(value)

def _parse_run_row(self, row: Row | dict | None) -> dict | None:
if row is None:
return None

run = dict(row)
if isinstance(run.get("report_spec_snapshot_json"), str):
run["report_spec_snapshot_json"] = json.loads(
run["report_spec_snapshot_json"]
)
return run

def create_report_output_run(
self,
report_output_id: int,
status: str = "pending",
trigger_type: str = "initial",
output: dict[str, Any] | list[Any] | str | None = None,
error_message: str | None = None,
source_run_id: str | None = None,
report_spec_snapshot: dict[str, Any] | str | None = None,
version_manifest: dict[str, str | None] | None = None,
run_id: str | None = None,
) -> dict:
run_id = run_id or str(uuid.uuid4())
version_manifest = version_manifest or {}
lock_clause = "" if database.local else " FOR UPDATE"

def create_run_transaction(tx) -> None:
parent_row: Row | None = tx.query(
f"SELECT id FROM report_outputs WHERE id = ?{lock_clause}",
(report_output_id,),
).fetchone()
if parent_row is None:
raise ValueError(f"Report output #{report_output_id} not found")

run_sequence_row: Row | None = tx.query(
"""
SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence
FROM report_output_runs
WHERE report_output_id = ?
""",
(report_output_id,),
).fetchone()
run_sequence = (
int(run_sequence_row["max_run_sequence"]) + 1
if run_sequence_row is not None
else 1
)

tx.query(
f"""
INSERT INTO report_output_runs (
id, report_output_id, run_sequence, status, output, error_message,
trigger_type, requested_at, started_at, finished_at, source_run_id,
report_spec_snapshot_json, {", ".join(REPORT_RUN_VERSION_FIELDS)}
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
run_id,
report_output_id,
run_sequence,
status,
self._serialize_json(output),
error_message,
trigger_type,
None,
None,
None,
source_run_id,
self._serialize_json(report_spec_snapshot),
*[
version_manifest.get(field)
for field in REPORT_RUN_VERSION_FIELDS
],
),
)

database.transaction(create_run_transaction)
return self.get_report_output_run(run_id)

def get_report_output_run(self, run_id: str) -> dict | None:
row: Row | None = database.query(
"SELECT * FROM report_output_runs WHERE id = ?",
(run_id,),
).fetchone()
return self._parse_run_row(row)

def list_report_output_runs(self, report_output_id: int) -> list[dict]:
rows = database.query(
"""
SELECT * FROM report_output_runs
WHERE report_output_id = ?
ORDER BY run_sequence ASC
""",
(report_output_id,),
).fetchall()
return [self._parse_run_row(row) for row in rows]

def get_newest_report_output_run(self, report_output_id: int) -> dict | None:
row: Row | None = database.query(
"""
SELECT * FROM report_output_runs
WHERE report_output_id = ?
ORDER BY run_sequence DESC
LIMIT 1
""",
(report_output_id,),
).fetchone()
return self._parse_run_row(row)

def select_display_run(self, report_output: dict) -> dict | None:
if report_output.get("active_run_id"):
active_run = self.get_report_output_run(report_output["active_run_id"])
if active_run is not None:
return active_run
if report_output.get("latest_successful_run_id"):
latest_successful_run = self.get_report_output_run(
report_output["latest_successful_run_id"]
)
if latest_successful_run is not None:
return latest_successful_run
return self.get_newest_report_output_run(report_output["id"])
Loading
Loading