Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/managed-release-runtime.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added managed release-bundle runtime enforcement for bundled US and UK microsimulations, including manifest-backed dataset pinning and runtime bundle metadata.
10 changes: 3 additions & 7 deletions src/policyengine/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
)
from .release_manifest import get_data_release_manifest as get_data_release_manifest
from .release_manifest import get_release_manifest as get_release_manifest
from .release_manifest import (
resolve_managed_dataset_reference as resolve_managed_dataset_reference,
)
from .scoping_strategy import RegionScopingStrategy as RegionScopingStrategy
from .scoping_strategy import RowFilterStrategy as RowFilterStrategy
from .scoping_strategy import ScopingStrategy as ScopingStrategy
Expand All @@ -36,13 +39,6 @@
from .tax_benefit_model_version import (
TaxBenefitModelVersion as TaxBenefitModelVersion,
)
from .trace_tro import (
build_trace_tro_from_release_bundle as build_trace_tro_from_release_bundle,
)
from .trace_tro import (
compute_trace_composition_fingerprint as compute_trace_composition_fingerprint,
)
from .trace_tro import serialize_trace_tro as serialize_trace_tro
from .variable import Variable as Variable

# Rebuild models to resolve forward references
Expand Down
158 changes: 93 additions & 65 deletions src/policyengine/core/release_manifest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import os
from functools import lru_cache
from importlib import import_module, metadata
from importlib import import_module
from importlib.resources import files
from pathlib import Path

import requests
from pydantic import BaseModel, Field

HF_REQUEST_TIMEOUT_SECONDS = 30


class DataReleaseManifestUnavailable(ValueError):
pass
LOCAL_DATA_REPO_HINTS = {
"us": ("policyengine_us", "policyengine-us-data", "policyengine_us_data"),
"uk": ("policyengine_uk", "policyengine-uk-data", "policyengine_uk_data"),
}


class PackageVersion(BaseModel):
Expand Down Expand Up @@ -126,28 +126,6 @@ def build_hf_uri(repo_id: str, path_in_repo: str, revision: str) -> str:
return f"hf://{repo_id}/{path_in_repo}@{revision}"


def get_runtime_model_build_metadata(package_name: str) -> dict[str, str | None]:
installed_version = metadata.version(package_name)
module_name = package_name.replace("-", "_")

try:
build_metadata_module = import_module(f"{module_name}.build_metadata")
except Exception:
return {
"name": package_name,
"version": installed_version,
"git_sha": None,
"data_build_fingerprint": None,
}

build_metadata = build_metadata_module.get_data_build_metadata()
build_metadata.setdefault("name", package_name)
build_metadata.setdefault("version", installed_version)
build_metadata.setdefault("git_sha", None)
build_metadata.setdefault("data_build_fingerprint", None)
return build_metadata


@lru_cache
def get_release_manifest(country_id: str) -> CountryReleaseManifest:
manifest_path = files("policyengine").joinpath(
Expand Down Expand Up @@ -183,15 +161,10 @@ def get_data_release_manifest(country_id: str) -> DataReleaseManifest:
timeout=HF_REQUEST_TIMEOUT_SECONDS,
)
if response.status_code in (401, 403):
raise DataReleaseManifestUnavailable(
raise ValueError(
"Could not fetch the data release manifest from Hugging Face. "
"If this country uses a private data repo, set HUGGING_FACE_TOKEN."
)
if response.status_code == 404:
raise DataReleaseManifestUnavailable(
"Could not find the data release manifest on Hugging Face for "
f"{data_package.repo_id}@{data_package.version}."
)
response.raise_for_status()
return DataReleaseManifest.model_validate_json(response.text)

Expand All @@ -208,7 +181,17 @@ def certify_data_release_compatibility(
runtime_data_build_fingerprint: str | None = None,
) -> DataCertification:
country_manifest = get_release_manifest(country_id)
data_release_manifest = get_data_release_manifest(country_id)
try:
data_release_manifest = get_data_release_manifest(country_id)
except Exception as exc:
bundled_certification = country_manifest.certification
if (
bundled_certification is not None
and bundled_certification.certified_for_model_version
== runtime_model_version
):
return bundled_certification
raise exc
built_with_model = (
data_release_manifest.build.built_with_model_package
if data_release_manifest.build is not None
Expand Down Expand Up @@ -295,37 +278,6 @@ def certify_data_release_compatibility(
)


def resolve_runtime_data_certification(
country_id: str,
runtime_model_version: str,
runtime_data_build_fingerprint: str | None = None,
bundled_certification: DataCertification | None = None,
) -> DataCertification:
try:
return certify_data_release_compatibility(
country_id=country_id,
runtime_model_version=runtime_model_version,
runtime_data_build_fingerprint=runtime_data_build_fingerprint,
)
except DataReleaseManifestUnavailable:
if (
bundled_certification is not None
and bundled_certification.certified_for_model_version
== runtime_model_version
):
bundled_fingerprint = bundled_certification.data_build_fingerprint
if (
bundled_certification.compatibility_basis
== "matching_data_build_fingerprint"
and bundled_fingerprint is not None
and runtime_data_build_fingerprint is not None
and bundled_fingerprint != runtime_data_build_fingerprint
):
raise
return bundled_certification
raise


def resolve_dataset_reference(country_id: str, dataset: str) -> str:
if "://" in dataset:
return dataset
Expand All @@ -350,6 +302,82 @@ def resolve_dataset_reference(country_id: str, dataset: str) -> str:
return artifact.uri


def resolve_managed_dataset_reference(
country_id: str,
dataset: str | None = None,
*,
allow_unmanaged: bool = False,
) -> str:
"""Resolve a dataset reference under policyengine.py bundle enforcement.

Managed mode pins dataset selection to the bundled `policyengine.py`
release manifest. Callers can:

- omit `dataset` to use the certified default dataset for the bundle
- pass a logical dataset name present in the bundled/data-release manifests

Direct URLs or raw Hugging Face references are treated as unmanaged unless
`allow_unmanaged=True` is set explicitly.
"""

manifest = get_release_manifest(country_id)
if dataset is None:
return manifest.default_dataset_uri

if "://" in dataset:
if dataset == manifest.default_dataset_uri:
return dataset
if allow_unmanaged:
return dataset
raise ValueError(
"Explicit dataset URIs bypass the policyengine.py release bundle. "
"Pass a manifest dataset name or omit `dataset` to use the certified "
"default dataset. Set `allow_unmanaged=True` only if you intend to "
"bypass bundle enforcement."
)

return resolve_dataset_reference(country_id, dataset)


def resolve_local_managed_dataset_source(country_id: str, dataset_uri: str) -> str:
"""Resolve a local mirror of a managed dataset when available.

This preserves the bundled dataset URI for provenance while allowing local
development environments with sibling data-repo checkouts to load the
exact certified artifact from disk rather than re-downloading it.
"""

if not dataset_uri.startswith("hf://"):
return dataset_uri

local_hint = LOCAL_DATA_REPO_HINTS.get(country_id)
if local_hint is None:
return dataset_uri

path_without_revision = dataset_uri[5:].rsplit("@", 1)[0]
parts = path_without_revision.split("/", 2)
if len(parts) != 3:
return dataset_uri
_, _, path_in_repo = parts

model_module_name, data_repo_name, data_package_name = local_hint
try:
model_module = import_module(model_module_name)
except ImportError:
return dataset_uri

repo_root = Path(model_module.__file__).resolve().parents[1]
local_path = (
repo_root.with_name(data_repo_name)
/ data_package_name
/ "storage"
/ path_in_repo
)
if local_path.exists():
return str(local_path)
return dataset_uri


def dataset_logical_name(dataset: str) -> str:
return Path(dataset.rsplit("@", 1)[0]).stem

Expand Down
24 changes: 1 addition & 23 deletions src/policyengine/core/tax_benefit_model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,8 @@

from pydantic import BaseModel, Field

from .release_manifest import (
CountryReleaseManifest,
DataCertification,
PackageVersion,
get_data_release_manifest,
)
from .release_manifest import CountryReleaseManifest, DataCertification, PackageVersion
from .tax_benefit_model import TaxBenefitModel
from .trace_tro import build_trace_tro_from_release_bundle

if TYPE_CHECKING:
from .parameter import Parameter
Expand Down Expand Up @@ -207,22 +201,6 @@ def release_bundle(self) -> dict[str, str | None]:
),
}

@property
def trace_tro(self) -> dict:
if self.release_manifest is None:
raise ValueError(
"TRACE TRO export requires a bundled country release manifest."
)

data_release_manifest = get_data_release_manifest(
self.release_manifest.country_id
)
return build_trace_tro_from_release_bundle(
self.release_manifest,
data_release_manifest,
certification=self.data_certification,
)

def __repr__(self) -> str:
# Give the id and version, and the number of variables, parameters, parameter nodes, parameter values
return f"<TaxBenefitModelVersion id={self.id} variables={len(self.variables)} parameters={len(self.parameters)} parameter_nodes={len(self.parameter_nodes)} parameter_values={len(self.parameter_values)}>"
2 changes: 2 additions & 0 deletions src/policyengine/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
UK_INEQUALITY_INCOME_VARIABLE,
US_INEQUALITY_INCOME_VARIABLE,
Inequality,
USInequalityPreset,
calculate_uk_inequality,
calculate_us_inequality,
)
Expand Down Expand Up @@ -76,6 +77,7 @@
"GENDER_GROUPS",
"RACE_GROUPS",
"Inequality",
"USInequalityPreset",
"UK_INEQUALITY_INCOME_VARIABLE",
"US_INEQUALITY_INCOME_VARIABLE",
"calculate_uk_inequality",
Expand Down
Loading
Loading