diff --git a/docs/inference_workflows.rst b/docs/inference_workflows.rst index c23862ee..2ddda6ea 100644 --- a/docs/inference_workflows.rst +++ b/docs/inference_workflows.rst @@ -304,6 +304,22 @@ For large particle counts, configure optional IFU accumulation controls: These settings are used by the particlewise IFU builders in ``rubix.core.ifu``. + +Synthetic Science Recipe +------------------------ + +Run an end-to-end synthetic workflow (optimize -> VI -> posterior predictive -> +residual metrics) and persist science-ready outputs: + +.. code-block:: bash + + python scripts/run_synthetic_science_recipe.py \ + --output-dir outputs/science_recipe \ + --nx 8 --ny 8 --nw 64 \ + --optimize-steps 200 \ + --vi-steps 200 \ + --num-posterior-draws 16 + Benchmarking Full-IFU Optimization ---------------------------------- diff --git a/docs/rubix.inference.rst b/docs/rubix.inference.rst index 2eb63867..a9e51d17 100644 --- a/docs/rubix.inference.rst +++ b/docs/rubix.inference.rst @@ -116,6 +116,14 @@ rubix.inference.vi_benchmark module :undoc-members: :show-inheritance: +rubix.inference.workflows module +-------------------------------- + +.. automodule:: rubix.inference.workflows + :members: + :undoc-members: + :show-inheritance: + Module contents --------------- diff --git a/rubix/inference/__init__.py b/rubix/inference/__init__.py index b1c26c64..6e0eacb1 100644 --- a/rubix/inference/__init__.py +++ b/rubix/inference/__init__.py @@ -63,6 +63,7 @@ benchmark_variational_inference, vi_benchmark_result_to_dict, ) +from .workflows import run_synthetic_science_recipe, save_science_recipe_outputs __all__ = [ "IdentityTransform", @@ -122,4 +123,6 @@ "summarize_predictive_cube_samples", "save_checkpoint", "value_and_grad", + "run_synthetic_science_recipe", + "save_science_recipe_outputs", ] diff --git a/rubix/inference/workflows.py b/rubix/inference/workflows.py new file mode 100644 index 00000000..98b2e154 --- /dev/null +++ b/rubix/inference/workflows.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import jax.numpy as jnp +import numpy as np + +from rubix.core.data import Galaxy, GasData, RubixData, StarsData + +from .optimize import optimize_ifu_cube +from .posterior_predictive import ( + compute_residual_products, + sample_posterior_predictive_cubes, + summarize_masked_metrics, + summarize_predictive_cube_samples, +) +from .variational import optimize_variational_ifu_cube + + +class SyntheticScalePipeline: + """Simple synthetic IFU pipeline used by the science recipe workflow.""" + + def __init__(self, template: jnp.ndarray): + self.template = template + + def run_sharded(self, rubixdata: RubixData) -> jnp.ndarray: + scale = rubixdata.stars.age[0] + return scale * self.template + + +def _make_static_data() -> RubixData: + return RubixData( + galaxy=Galaxy(), + stars=StarsData( + coords=jnp.zeros((1, 3)), + velocity=jnp.zeros((1, 3)), + mass=jnp.ones(1), + age=jnp.array([0.0]), + metallicity=jnp.array([0.01]), + ), + gas=GasData( + coords=jnp.zeros((1, 3)), + velocity=jnp.zeros((1, 3)), + mass=jnp.ones(1), + ), + ) + + +def _to_jsonable(value: Any) -> Any: + """Convert nested JAX/NumPy containers into JSON-serializable values.""" + if isinstance(value, dict): + return {k: _to_jsonable(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_to_jsonable(v) for v in value] + if isinstance(value, np.ndarray): + if value.ndim == 0: + return value.item() + return value.tolist() + if isinstance(value, jnp.ndarray): + arr = np.asarray(value) + if arr.ndim == 0: + return arr.item() + return arr.tolist() + if isinstance(value, np.generic): + return value.item() + return value + + +def run_synthetic_science_recipe( + cube_shape: tuple[int, int, int] = (4, 4, 16), + target_scale: float = 1.7, + optimize_steps: int = 120, + vi_steps: int = 120, + num_vi_samples: int = 4, + num_posterior_draws: int = 8, + seed: int = 0, +) -> dict[str, Any]: + """Run a compact end-to-end synthetic science workflow. + + This workflow performs deterministic optimization, variational inference, + posterior predictive sampling, and residual/metric summarization. + + Args: + cube_shape (tuple[int, int, int], optional): Synthetic IFU cube shape + ``(nx, ny, nw)``. Defaults to ``(4, 4, 16)``. + target_scale (float, optional): Multiplicative scale for the synthetic + target cube. Defaults to 1.7. + optimize_steps (int, optional): Maximum deterministic optimization + steps. Defaults to 120. + vi_steps (int, optional): Maximum variational optimization steps. + Defaults to 120. + num_vi_samples (int, optional): Monte Carlo samples per VI step. + Defaults to 4. + num_posterior_draws (int, optional): Number of posterior predictive + cube draws. Defaults to 8. + seed (int, optional): Random seed for VI and predictive sampling. + Defaults to 0. + + Returns: + dict[str, Any]: Workflow outputs and diagnostic summaries. + """ + template = jnp.ones(cube_shape, dtype=jnp.float32) + target = target_scale * template + + pipeline = SyntheticScalePipeline(template) + static_data = _make_static_data() + params_init = {"stars": {"age": jnp.array([0.2])}} + + opt_result = optimize_ifu_cube( + pipeline=pipeline, + params_init=params_init, + static_data=static_data, + target=target, + learning_rate=0.1, + max_steps=optimize_steps, + tol=1e-8, + ) + + vi_result = optimize_variational_ifu_cube( + pipeline=pipeline, + params_init=params_init, + static_data=static_data, + target=target, + sigma=jnp.ones_like(target), + learning_rate=5e-2, + max_steps=vi_steps, + tol=1e-8, + num_samples=num_vi_samples, + beta_kl=1e-4, + seed=seed, + ) + + predictive_samples = sample_posterior_predictive_cubes( + pipeline=pipeline, + posterior_mean_params=vi_result.posterior_mean_params, + posterior_log_std_params=vi_result.posterior_log_std_params, + static_data=static_data, + num_samples=num_posterior_draws, + seed=seed + 1, + ) + predictive_summary = summarize_predictive_cube_samples(predictive_samples) + residual_products = compute_residual_products( + prediction=predictive_summary["mean"], + target=target, + ) + metrics = summarize_masked_metrics( + prediction=predictive_summary["mean"], + target=target, + ) + + return { + "config": { + "cube_shape": cube_shape, + "target_scale": target_scale, + "optimize_steps": optimize_steps, + "vi_steps": vi_steps, + "num_vi_samples": num_vi_samples, + "num_posterior_draws": num_posterior_draws, + "seed": seed, + }, + "optimization": { + "final_loss": opt_result.final_loss, + "best_loss": opt_result.best_loss, + "steps_run": opt_result.steps_run, + "converged": opt_result.converged, + }, + "variational": { + "final_objective": vi_result.final_objective, + "best_objective": vi_result.best_objective, + "steps_run": vi_result.steps_run, + "converged": vi_result.converged, + }, + "predictive_summary": predictive_summary, + "residual_products": residual_products, + "metrics": metrics, + } + + +def save_science_recipe_outputs( + outputs: dict[str, Any], + output_dir: str, +) -> None: + """Persist workflow outputs to JSON and NPZ files. + + Args: + outputs (dict[str, Any]): Outputs from + :func:`run_synthetic_science_recipe`. + output_dir (str): Destination directory. + """ + out_dir = Path(output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + summary = { + "config": outputs["config"], + "optimization": outputs["optimization"], + "variational": outputs["variational"], + "metrics": outputs["metrics"], + } + json_summary = _to_jsonable(summary) + (out_dir / "summary.json").write_text( + json.dumps(json_summary, indent=2), encoding="utf-8" + ) + + predictive_np = {k: np.asarray(v) for k, v in outputs["predictive_summary"].items()} + residual_np = {k: np.asarray(v) for k, v in outputs["residual_products"].items()} + + np.savez(out_dir / "predictive_summary.npz", **predictive_np) + np.savez(out_dir / "residual_products.npz", **residual_np) diff --git a/scripts/run_synthetic_science_recipe.py b/scripts/run_synthetic_science_recipe.py new file mode 100644 index 00000000..6d807925 --- /dev/null +++ b/scripts/run_synthetic_science_recipe.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +import argparse + +from rubix.inference.workflows import ( + run_synthetic_science_recipe, + save_science_recipe_outputs, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run synthetic end-to-end science workflow and save outputs." + ) + parser.add_argument("--output-dir", type=str, default="outputs/science_recipe") + parser.add_argument("--nx", type=int, default=4) + parser.add_argument("--ny", type=int, default=4) + parser.add_argument("--nw", type=int, default=16) + parser.add_argument("--target-scale", type=float, default=1.7) + parser.add_argument("--optimize-steps", type=int, default=120) + parser.add_argument("--vi-steps", type=int, default=120) + parser.add_argument("--num-vi-samples", type=int, default=4) + parser.add_argument("--num-posterior-draws", type=int, default=8) + parser.add_argument("--seed", type=int, default=0) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + outputs = run_synthetic_science_recipe( + cube_shape=(args.nx, args.ny, args.nw), + target_scale=args.target_scale, + optimize_steps=args.optimize_steps, + vi_steps=args.vi_steps, + num_vi_samples=args.num_vi_samples, + num_posterior_draws=args.num_posterior_draws, + seed=args.seed, + ) + save_science_recipe_outputs(outputs, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/tests/test_inference_workflows.py b/tests/test_inference_workflows.py new file mode 100644 index 00000000..ff409d21 --- /dev/null +++ b/tests/test_inference_workflows.py @@ -0,0 +1,65 @@ +import json + +import numpy as np + +from rubix.inference.workflows import ( + run_synthetic_science_recipe, + save_science_recipe_outputs, +) + + +def test_run_synthetic_science_recipe_returns_expected_structure(): + outputs = run_synthetic_science_recipe( + cube_shape=(2, 2, 4), + target_scale=1.5, + optimize_steps=20, + vi_steps=20, + num_vi_samples=2, + num_posterior_draws=4, + seed=0, + ) + + assert "config" in outputs + assert "optimization" in outputs + assert "variational" in outputs + assert "predictive_summary" in outputs + assert "residual_products" in outputs + assert "metrics" in outputs + + assert outputs["predictive_summary"]["mean"].shape == (2, 2, 4) + assert outputs["residual_products"]["residual"].shape == (2, 2, 4) + assert outputs["metrics"]["mse"] >= 0.0 + assert outputs["metrics"]["mae"] >= 0.0 + + +def test_save_science_recipe_outputs_writes_json_and_npz(tmp_path): + outputs = run_synthetic_science_recipe( + cube_shape=(2, 2, 4), + target_scale=1.2, + optimize_steps=10, + vi_steps=10, + num_vi_samples=2, + num_posterior_draws=3, + seed=1, + ) + + save_science_recipe_outputs(outputs, str(tmp_path)) + + summary_path = tmp_path / "summary.json" + predictive_path = tmp_path / "predictive_summary.npz" + residual_path = tmp_path / "residual_products.npz" + + assert summary_path.exists() + assert predictive_path.exists() + assert residual_path.exists() + + summary = json.loads(summary_path.read_text(encoding="utf-8")) + assert summary["config"]["cube_shape"] == [2, 2, 4] + assert "final_loss" in summary["optimization"] + assert "final_objective" in summary["variational"] + + predictive = np.load(predictive_path) + residual = np.load(residual_path) + + assert predictive["mean"].shape == (2, 2, 4) + assert residual["residual"].shape == (2, 2, 4)