Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 94 additions & 2 deletions src/microplex_us/pipelines/stage_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import traceback
from collections.abc import Mapping
from dataclasses import replace
from dataclasses import fields, replace
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Literal
Expand Down Expand Up @@ -281,14 +281,31 @@ def finalize_from_artifact_manifest(
self.manifest_payload,
):
existing = self._stage_payload(outputs.stage_id)
outputs = _rehydrate_outputs_from_stage_payload(outputs, existing)
if (
_terminal_lifecycle(existing) == "failed"
and not outputs.complete
and not outputs.missing_required_outputs(self.artifact_root)
):
outputs = replace(
outputs,
complete=True,
lifecycle_status="complete",
deferred_reason=None,
failure=None,
)
now = _now()
existing_events = tuple(
dict(event)
for event in existing.get("events", ())
if isinstance(event, dict)
)
existing_lifecycle = _terminal_lifecycle(existing)
if existing_lifecycle is not None:
preserve_existing_lifecycle = existing_lifecycle in {
"complete",
"deferred",
} or (existing_lifecycle == "failed" and not outputs.complete)
if preserve_existing_lifecycle:
lifecycle_status = existing_lifecycle
complete = bool(existing.get("complete"))
started_at = _optional_str(existing.get("startedAt"))
Expand Down Expand Up @@ -582,6 +599,81 @@ def _final_lifecycle_status(
return "complete" if outputs.complete else "pending"


def _rehydrate_outputs_from_stage_payload(
outputs: USStageOutputManifest,
payload: Mapping[str, Any],
) -> USStageOutputManifest:
serialized_outputs = payload.get("outputs")
if not isinstance(serialized_outputs, Mapping):
return outputs

hydrated: dict[str, Any] = {}
for item in fields(outputs):
name = item.name
if name in {
"schema_version",
"contract_version",
"input_stage_manifest",
"diagnostics",
"auxiliary_artifacts",
"metadata",
"complete",
"lifecycle_status",
"started_at",
"updated_at",
"completed_at",
"failed_at",
"deferred_reason",
"failure",
"events",
"stage_id",
}:
continue
if name not in serialized_outputs:
continue
current = getattr(outputs, name)
if not _typed_output_is_missing(current):
continue
value = _deserialize_stage_output_field(serialized_outputs[name])
if not _typed_output_is_missing(value):
hydrated[name] = value
if not hydrated:
return outputs
return replace(outputs, **hydrated)


def _deserialize_stage_output_field(value: Any) -> Any:
if isinstance(value, Mapping):
if "path" in value and "key" in value:
return USArtifactRef(
key=str(value["key"]),
path=str(value["path"]),
format=value.get("format", "unknown"),
required=bool(value.get("required", False)),
category=value.get("category", "required_output"),
resume_role=value.get("resume_role"),
assume_exists=bool(value.get("assume_exists", False)),
exists=(
value.get("exists")
if isinstance(value.get("exists"), bool)
else None
),
)
return value


def _typed_output_is_missing(value: Any) -> bool:
if value is None:
return True
if isinstance(value, Mapping):
return not bool(value)
if isinstance(value, (tuple, list, set, frozenset)):
return not bool(value)
if isinstance(value, str):
return not value
return False


def _terminal_lifecycle(
payload: Mapping[str, Any],
) -> USStageLifecycleStatus | None:
Expand Down
93 changes: 93 additions & 0 deletions tests/pipelines/test_stage_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import pytest
from microplex.core import RelationshipCardinality

import microplex_us.pipelines.stage_runtime as stage_runtime_module
from microplex_us.pipelines.stage_contracts import US_STAGE_CONTRACT_VERSION
from microplex_us.pipelines.stage_run import (
USArtifactRef,
USDiagnosticOutput,
USRunProfileOutputs,
USSourceLoadingOutputs,
USStageInputOverride,
USValidationBenchmarkingOutputs,
)
from microplex_us.pipelines.stage_runtime import USStageRuntimeWriter

Expand Down Expand Up @@ -130,6 +132,97 @@ def test_runtime_writer_finalize_preserves_completed_stage_lifecycle(tmp_path):
assert after["events"] == before["events"]


def test_runtime_writer_finalize_rehydrates_failed_stage_outputs(
tmp_path,
monkeypatch,
):
previous_stage = (
tmp_path / "stage_artifacts" / "manifests" / "08_dataset_assembly.json"
)
previous_stage.parent.mkdir(parents=True, exist_ok=True)
previous_stage.write_text(
json.dumps(
{
"stageId": "08_dataset_assembly",
"outputs": {
"policyengine_dataset": {
"key": "policyengine_dataset",
"path": "policyengine_us.h5",
"exists": True,
}
},
}
)
)
(tmp_path / "policyengine_us.h5").write_text("{}")
evidence_path = (
tmp_path
/ "stage_artifacts"
/ "09_validation_benchmarking"
/ "evidence_manifest.json"
)
evidence_path.parent.mkdir(parents=True, exist_ok=True)
evidence_path.write_text("{}")

writer = USStageRuntimeWriter(tmp_path)
writer.record_output(
"09_validation_benchmarking",
"validation_evidence",
USArtifactRef(
key="validation_evidence",
path=evidence_path.relative_to(tmp_path),
format="json",
required=True,
assume_exists=True,
),
)
writer.record_output(
"09_validation_benchmarking",
"benchmark_summary",
{"policyengine_native_scores": {"loss_delta": -0.1}},
)
writer.fail_stage("09_validation_benchmarking", ValueError("finalize failed"))

def _empty_rebuilt_stage_outputs(*_args, **_kwargs):
return (
USValidationBenchmarkingOutputs(
diagnostics=_diagnostics("09_validation_benchmarking"),
complete=False,
),
)

monkeypatch.setattr(
stage_runtime_module,
"build_us_stage_output_manifests_from_artifact_manifest",
_empty_rebuilt_stage_outputs,
)

writer.finalize_from_artifact_manifest(
{
"artifacts": {"policyengine_dataset": "policyengine_us.h5"},
"config": {"calibration_backend": "microcalibrate"},
}
)

stage9 = json.loads(
(
tmp_path
/ "stage_artifacts"
/ "manifests"
/ "09_validation_benchmarking.json"
).read_text()
)
assert stage9["complete"] is True
assert stage9["lifecycleStatus"] == "complete"
assert stage9["failedAt"] is None
assert stage9["failure"] is None
assert stage9["missingRequiredOutputs"] == []
assert stage9["outputs"]["benchmark_summary"] == {
"policyengine_native_scores": {"loss_delta": -0.1}
}
assert stage9["outputs"]["validation_evidence"]["exists"] is True


def test_runtime_writer_serializes_enum_outputs(tmp_path):
writer = USStageRuntimeWriter(
tmp_path,
Expand Down
Loading