diff --git a/.github/workflows/test_petab_sciml.yml b/.github/workflows/test_petab_sciml.yml index 5c5c962d35..64aeac77b8 100644 --- a/.github/workflows/test_petab_sciml.yml +++ b/.github/workflows/test_petab_sciml.yml @@ -1,14 +1,14 @@ name: PEtab SciML -on: - push: - branches: - - main - - 'release*' - pull_request: - branches: - - main - merge_group: - workflow_dispatch: +# on: +# push: +# branches: +# - main +# - 'release*' +# pull_request: +# branches: +# - main +# merge_group: +# workflow_dispatch: jobs: build: diff --git a/.github/workflows/test_petab_test_suite.yml b/.github/workflows/test_petab_test_suite.yml index ef5b7425f2..86d5fe30b0 100644 --- a/.github/workflows/test_petab_test_suite.yml +++ b/.github/workflows/test_petab_test_suite.yml @@ -172,7 +172,7 @@ jobs: git clone https://github.com/PEtab-dev/petab_test_suite \ && source ./venv/bin/activate \ && cd petab_test_suite \ - && git checkout c12b9dc4e4c5585b1b83a1d6e89fd22447c46d03 \ + && git checkout 9542847fb99bcbdffc236e2ef45ba90580a210fa \ && pip3 install -e . # TODO: once there is a PEtab v2 benchmark collection @@ -186,7 +186,7 @@ jobs: run: | source ./venv/bin/activate \ && python3 -m pip uninstall -y petab \ - && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@8dc6c1c4b801fba5acc35fcd25308a659d01050e \ + && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@d57d9fed8d8d5f8592e76d0b15676e05397c3b4b \ && python3 -m pip install git+https://github.com/pysb/pysb@master \ && python3 -m pip install sympy>=1.12.1 diff --git a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb index b23c53325f..8aa025bfaa 100644 --- a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -32,7 +32,8 @@ "outputs": [], "source": [ "import petab.v1 as petab\n", - "from amici.importers.petab.v1 import import_petab_problem\n", + "from amici.importers.petab import *\n", + "from petab.v2 import Problem\n", "\n", "# Define the model name and YAML file location\n", "model_name = \"Boehm_JProteomeRes2014\"\n", @@ -41,14 +42,20 @@ " f\"master/Benchmark-Models/{model_name}/{model_name}.yaml\"\n", ")\n", "\n", - "# Load the PEtab problem from the YAML file\n", - "petab_problem = petab.Problem.from_yaml(yaml_url)\n", + "# Load the PEtab problem from the YAML file as a PEtab v2 problem\n", + "# (the JAX backend only supports PEtab v2)\n", + "petab_problem = Problem.from_yaml(yaml_url)\n", "\n", "# Import the PEtab problem as a JAX-compatible AMICI problem\n", - "jax_problem = import_petab_problem(\n", - " petab_problem,\n", - " verbose=False, # no text output\n", - " jax=True, # return jax problem\n", + "pi = PetabImporter(\n", + " petab_problem=petab_problem,\n", + " module_name=model_name,\n", + " compile_=True,\n", + " jax=True,\n", + ")\n", + "\n", + "jax_problem = pi.create_simulator(\n", + " force_import=True,\n", ")" ] }, @@ -75,6 +82,16 @@ "llh, results = run_simulations(jax_problem)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c5b2980-13f0-42e9-b13e-0fce05793910", + "metadata": {}, + "outputs": [], + "source": [ + "results" + ] + }, { "cell_type": "markdown", "id": "415962751301c64a", @@ -90,11 +107,11 @@ "metadata": {}, "outputs": [], "source": [ - "# Define the simulation condition\n", - "simulation_condition = (\"model1_data1\",)\n", + "# # Define the simulation condition\n", + "experiment_condition = \"_petab_experiment_condition___default__\"\n", "\n", - "# Access the results for the specified condition\n", - "ic = results[\"simulation_conditions\"].index(simulation_condition)\n", + "# # Access the results for the specified condition\n", + "ic = results[\"dynamic_conditions\"].index(experiment_condition)\n", "print(\"llh: \", results[\"llh\"][ic])\n", "print(\"state variables: \", results[\"x\"][ic, :])" ] @@ -146,8 +163,8 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "# Define the simulation condition\n", - "simulation_condition = (\"model1_data1\",)\n", + "# Define the experiment condition\n", + "experiment_condition = \"_petab_experiment_condition___default__\"\n", "\n", "\n", "def plot_simulation(results):\n", @@ -158,7 +175,7 @@ " results (dict): Simulation results from run_simulations.\n", " \"\"\"\n", " # Extract the simulation results for the specific condition\n", - " ic = results[\"simulation_conditions\"].index(simulation_condition)\n", + " ic = results[\"dynamic_conditions\"].index(experiment_condition)\n", "\n", " # Create a new figure for the state trajectories\n", " plt.figure(figsize=(8, 6))\n", @@ -172,7 +189,7 @@ " # Add labels, legend, and grid\n", " plt.xlabel(\"Time\")\n", " plt.ylabel(\"State Values\")\n", - " plt.title(simulation_condition)\n", + " plt.title(experiment_condition)\n", " plt.legend()\n", " plt.grid(True)\n", " plt.show()\n", @@ -187,18 +204,7 @@ "id": "4fa97c33719c2277", "metadata": {}, "source": [ - "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7950774a3e989042", - "metadata": {}, - "outputs": [], - "source": [ - "llh, results = run_simulations(jax_problem, simulation_conditions=tuple())\n", - "results" + "`run_simulations` enables users to specify the simulation experiments to be executed. For more complex models, this allows for restricting simulations to a subset of experiments by passing a tuple of experiment ids under the keyword `simulation_experiments` to `run_simulations`." ] }, { @@ -384,8 +390,8 @@ "from amici.jax import ReturnValue\n", "\n", "# Define the simulation condition\n", - "simulation_condition = (\"model1_data1\",)\n", - "ic = jax_problem.simulation_conditions.index(simulation_condition)\n", + "experiment_condition = \"_petab_experiment_condition___default__\"\n", + "ic = 0\n", "\n", "# Load condition-specific data\n", "ts_dyn = jax_problem._ts_dyn[ic, :]\n", @@ -397,7 +403,7 @@ "nps = jax_problem._np_numeric[ic, :]\n", "\n", "# Load parameters for the specified condition\n", - "p = jax_problem.load_model_parameters(simulation_condition[0])\n", + "p = jax_problem.load_model_parameters(jax_problem._petab_problem.experiments[0], is_preeq=False)\n", "\n", "\n", "# Define a function to compute the gradient with respect to dynamic timepoints\n", @@ -431,13 +437,17 @@ "cell_type": "markdown", "id": "19ca88c8900584ce", "metadata": {}, - "source": "## Model training" + "source": [ + "## Model training" + ] }, { "cell_type": "markdown", "id": "7f99c046d7d4e225", "metadata": {}, - "source": "This setup makes it pretty straightforward to train models using [equinox](https://docs.kidger.site/equinox/) and [optax](https://optax.readthedocs.io/en/latest/) frameworks. Below we provide barebones implementation that runs training for 5 steps using Adam." + "source": [ + "This setup makes it pretty straightforward to train models using [equinox](https://docs.kidger.site/equinox/) and [optax](https://optax.readthedocs.io/en/latest/) frameworks. Below we provide barebones implementation that runs training for 5 steps using Adam." + ] }, { "cell_type": "code", @@ -569,16 +579,20 @@ "from amici.sim.sundials.petab.v1 import simulate_petab\n", "\n", "# Import the PEtab problem as a standard AMICI model\n", - "amici_model = import_petab_problem(\n", - " petab_problem,\n", - " verbose=False,\n", - " jax=False, # load the amici model this time\n", + "pi = PetabImporter(\n", + " petab_problem=petab_problem,\n", + " module_name=model_name,\n", + " compile_=True,\n", + " jax=False,\n", + ")\n", + "\n", + "amici_model = pi.create_simulator(\n", + " force_import=True,\n", ")\n", "\n", "# Configure the solver with appropriate tolerances\n", - "solver = amici_model.create_solver()\n", - "solver.set_absolute_tolerance(1e-8)\n", - "solver.set_relative_tolerance(1e-16)\n", + "amici_model.solver.set_absolute_tolerance(1e-8)\n", + "amici_model.solver.set_relative_tolerance(1e-16)\n", "\n", "# Prepare the parameters for the simulation\n", "problem_parameters = dict(\n", @@ -594,86 +608,65 @@ "outputs": [], "source": [ "# Profile simulation only\n", - "solver.set_sensitivity_order(SensitivityOrder.none)" + "amici_model.solver.set_sensitivity_order(SensitivityOrder.none)" ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "42cbc67bc09b67dc", + "metadata": {}, + "outputs": [], "source": [ "%%timeit\n", - "simulate_petab(\n", - " petab_problem,\n", - " amici_model=amici_model,\n", - " solver=solver,\n", - " problem_parameters=problem_parameters,\n", - " scaled_parameters=True,\n", - " scaled_gradients=True,\n", - ")" - ], - "id": "42cbc67bc09b67dc" + "amici_model.simulate(petab_problem.get_x_nominal_dict())" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "4f1c06c5893a9c07", + "metadata": {}, + "outputs": [], "source": [ "# Profile gradient computation using forward sensitivity analysis\n", - "solver.set_sensitivity_order(SensitivityOrder.first)\n", - "solver.set_sensitivity_method(SensitivityMethod.forward)" - ], - "id": "4f1c06c5893a9c07" + "amici_model.solver.set_sensitivity_order(SensitivityOrder.first)\n", + "amici_model.solver.set_sensitivity_method(SensitivityMethod.forward)" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "7367a19bcea98597", + "metadata": {}, + "outputs": [], "source": [ "%%timeit\n", - "simulate_petab(\n", - " petab_problem,\n", - " amici_model=amici_model,\n", - " solver=solver,\n", - " problem_parameters=problem_parameters,\n", - " scaled_parameters=True,\n", - " scaled_gradients=True,\n", - ")" - ], - "id": "7367a19bcea98597" + "amici_model.simulate(petab_problem.get_x_nominal_dict())" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "a31e8eda806c2d7", + "metadata": {}, + "outputs": [], "source": [ "# Profile gradient computation using adjoint sensitivity analysis\n", - "solver.set_sensitivity_order(SensitivityOrder.first)\n", - "solver.set_sensitivity_method(SensitivityMethod.adjoint)" - ], - "id": "a31e8eda806c2d7" + "amici_model.solver.set_sensitivity_order(SensitivityOrder.first)\n", + "amici_model.solver.set_sensitivity_method(SensitivityMethod.adjoint)" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "3f2ab1acb3ba818f", + "metadata": {}, + "outputs": [], "source": [ "%%timeit\n", - "simulate_petab(\n", - " petab_problem,\n", - " amici_model=amici_model,\n", - " solver=solver,\n", - " problem_parameters=problem_parameters,\n", - " scaled_parameters=True,\n", - " scaled_gradients=True,\n", - ")" - ], - "id": "3f2ab1acb3ba818f" + "amici_model.simulate(petab_problem.get_x_nominal_dict())" + ] } ], "metadata": { @@ -691,7 +684,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.12.3" } }, "nbformat": 4, diff --git a/python/sdist/amici/_symbolic/de_model.py b/python/sdist/amici/_symbolic/de_model.py index fc027d20c8..e9fef47edc 100644 --- a/python/sdist/amici/_symbolic/de_model.py +++ b/python/sdist/amici/_symbolic/de_model.py @@ -2679,11 +2679,16 @@ def has_priority_events(self) -> bool: def has_implicit_event_assignments(self) -> bool: """ Checks whether the model has event assignments with implicit triggers + (i.e. triggers that are not time based). :return: boolean indicating if event assignments with implicit triggers are present """ - return any(event.updates_state and not event.has_explicit_trigger_times({}) for event in self._events) + fixed_symbols = set([k._symbol for k in self._fixed_parameters]) + allowed_symbols = fixed_symbols | {amici_time_symbol} + # TODO: update to use has_explicit_trigger_times once + # https://github.com/AMICI-dev/AMICI/issues/3126 is resolved + return any(event.updates_state and event._has_implicit_triggers(allowed_symbols) for event in self._events) def toposort_expressions( self, reorder: bool = True diff --git a/python/sdist/amici/_symbolic/de_model_components.py b/python/sdist/amici/_symbolic/de_model_components.py index 23ec0cac8d..76b6c1f501 100644 --- a/python/sdist/amici/_symbolic/de_model_components.py +++ b/python/sdist/amici/_symbolic/de_model_components.py @@ -716,6 +716,7 @@ def __init__( assignments: dict[sp.Symbol, sp.Expr] | None = None, initial_value: bool | None = True, priority: sp.Basic | None = None, + is_negative_event: bool = False, ): """ Create a new Event instance. @@ -738,6 +739,11 @@ def __init__( :param priority: The priority of the event assignment. + :param is_negative_event: + Whether this event is a "negative" event, i.e., an event that is + added to mirror an existing event with inverted trigger condition + to avoid immediate retriggering of the original event (JAX simulations). + :param use_values_from_trigger_time: Whether the event assignment is evaluated using the state from the time point at which the event triggered (True), or at the time @@ -771,6 +777,8 @@ def __init__( # the trigger can't be solved for `t` pass + self._is_negative_event = is_negative_event + def get_state_update( self, x: sp.Matrix, x_old: sp.Matrix ) -> sp.Matrix | None: @@ -855,11 +863,19 @@ def has_explicit_trigger_times( """ if allowed_symbols is None: return len(self._t_root) > 0 - + return len(self._t_root) > 0 and all( t.is_Number or t.free_symbols.issubset(allowed_symbols) for t in self._t_root ) + + def _has_implicit_triggers( + self, allowed_symbols: set[sp.Symbol] | None = None + ) -> bool: + """Check whether the event has implicit triggers. + """ + t = self.get_val() + return not t.free_symbols.issubset(allowed_symbols) def get_trigger_times(self) -> set[sp.Expr]: """Get the time points at which the event triggers. diff --git a/python/sdist/amici/importers/petab/_petab_importer.py b/python/sdist/amici/importers/petab/_petab_importer.py index fd1a4cb45e..e49d73d668 100644 --- a/python/sdist/amici/importers/petab/_petab_importer.py +++ b/python/sdist/amici/importers/petab/_petab_importer.py @@ -24,6 +24,7 @@ from amici._symbolic import DEModel, Event from amici.importers.utils import MeasurementChannel, amici_time_symbol from amici.logging import get_logger +from amici.jax.petab import JAXProblem from .v1.sbml_import import _add_global_parameter @@ -151,10 +152,6 @@ def __init__( "PEtab v2 importer currently only supports SBML and PySB " f"models. Got {self.petab_problem.model.type_id!r}." ) - if jax: - raise NotImplementedError( - "PEtab v2 importer currently does not support JAX. " - ) if self._debug: print("PetabImpoter.__init__: petab_problem:") @@ -356,6 +353,7 @@ def _do_import_sbml(self): model_name=self._module_name, output_dir=self.output_dir, observation_model=observation_model, + fixed_parameters=fixed_parameters, verbose=self._verbose, # **kwargs, ) @@ -577,6 +575,11 @@ def import_module(self, force_import: bool = False) -> amici.ModelModule: else: self._do_import_pysb() + if self._jax: + return amici.import_model_module( + Path(self.output_dir).stem, Path(self.output_dir).parent + ) + return amici.import_model_module( self._module_name, self.output_dir, @@ -601,6 +604,11 @@ def create_simulator( """ from amici.sim.sundials.petab import ExperimentManager, PetabSimulator + if self._jax: + model_module = self.import_module(force_import=force_import) + model = model_module.Model() + return JAXProblem(model, self.petab_problem) + model = self.import_module(force_import=force_import).get_model() em = ExperimentManager(model=model, petab_problem=self.petab_problem) return PetabSimulator(em=em) diff --git a/python/sdist/amici/importers/petab/v1/parameter_mapping.py b/python/sdist/amici/importers/petab/v1/parameter_mapping.py index b2b7837e7b..1738f15d1c 100644 --- a/python/sdist/amici/importers/petab/v1/parameter_mapping.py +++ b/python/sdist/amici/importers/petab/v1/parameter_mapping.py @@ -355,7 +355,7 @@ def create_parameter_mapping( converter_config = ( libsbml.SBMLLocalParameterConverter().getDefaultProperties() ) - petab_problem.sbml_document.convert(converter_config) + petab_problem.model.sbml_document.convert(converter_config) else: logger.debug( "No petab_problem.sbml_document is set. Cannot " diff --git a/python/sdist/amici/importers/pysb/__init__.py b/python/sdist/amici/importers/pysb/__init__.py index fa90163b11..75a2bae0cd 100644 --- a/python/sdist/amici/importers/pysb/__init__.py +++ b/python/sdist/amici/importers/pysb/__init__.py @@ -389,7 +389,7 @@ def ode_model_from_pysb_importer( pysb.bng.generate_equations(model, verbose=verbose) _process_pysb_species(model, ode) - _process_pysb_parameters(model, ode, fixed_parameters, jax) + _process_pysb_parameters(model, ode, fixed_parameters) if compute_conservation_laws: if _events: raise NotImplementedError( @@ -570,7 +570,6 @@ def _process_pysb_parameters( pysb_model: pysb.Model, ode_model: DEModel, fixed_parameters: list[str], - jax: bool = False, ) -> None: """ Converts pysb parameters into Parameters or Constants and adds them to @@ -582,9 +581,6 @@ def _process_pysb_parameters( :param fixed_parameters: model variables excluded from sensitivity analysis - :param jax: - if set to ``True``, the generated model will be compatible JAX export - :param ode_model: DEModel instance """ @@ -593,10 +589,6 @@ def _process_pysb_parameters( if par.name in fixed_parameters: comp = FixedParameter args.append(par.value) - elif jax and re.match(r"noiseParameter\d+", par.name): - comp = NoiseParameter - elif jax and re.match(r"observableParameter\d+", par.name): - comp = ObservableParameter else: comp = FreeParameter args.append(par.value) diff --git a/python/sdist/amici/importers/sbml/__init__.py b/python/sdist/amici/importers/sbml/__init__.py index 771f9f0cb4..93111651b2 100644 --- a/python/sdist/amici/importers/sbml/__init__.py +++ b/python/sdist/amici/importers/sbml/__init__.py @@ -438,6 +438,7 @@ def sbml2jax( self, model_name: str, output_dir: str | Path = None, + fixed_parameters: Iterable[str] = None, observation_model: list[MeasurementChannel] = None, verbose: int | bool = logging.ERROR, compute_conservation_laws: bool = True, @@ -465,6 +466,9 @@ def sbml2jax( :param output_dir: Directory where the generated model package will be stored. + :param fixed_parameters: + SBML Ids to be excluded from sensitivity analysis + :param observation_model: The different measurement channels that make up the observation model, see :class:`amici.importers.utils.MeasurementChannel`. @@ -513,6 +517,7 @@ def sbml2jax( set_log_level(logger, verbose) ode_model = self._build_ode_model( + fixed_parameters=fixed_parameters, observation_model=observation_model, verbose=verbose, compute_conservation_laws=compute_conservation_laws, @@ -1908,6 +1913,7 @@ def _process_events(self) -> None: "initial_value": not initial_value, "use_values_from_trigger_time": use_trig_val, "priority": self._sympify(event.getPriority()), + "is_negative_event": True, } @log_execution_time("processing observation model", logger) diff --git a/python/sdist/amici/jax/_simulation.py b/python/sdist/amici/jax/_simulation.py index 7b19e61517..b20748d098 100644 --- a/python/sdist/amici/jax/_simulation.py +++ b/python/sdist/amici/jax/_simulation.py @@ -26,6 +26,7 @@ def eq( tcl: jt.Float[jt.Array, "ncl"], h0: jt.Float[jt.Array, "ne"], x0: jt.Float[jt.Array, "nxs"], + h_mask: jt.Bool[jt.Array, "ne"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -109,6 +110,7 @@ def cond_fn(carry): def body_fn(carry): t_start, y0, h, event_index, stats = carry + sol, event_index, stats = _run_segment( t_start, jnp.inf, @@ -147,6 +149,7 @@ def body_fn(carry): term, root_cond_fn, delta_x, + h_mask, stats, ) @@ -172,10 +175,12 @@ def body_fn(carry): def solve( p: jt.Float[jt.Array, "np"], + t0: jnp.float_, ts: jt.Float[jt.Array, "nt_dyn"], tcl: jt.Float[jt.Array, "ncl"], h: jt.Float[jt.Array, "ne"], x0: jt.Float[jt.Array, "nxs"], + h_mask: jt.Bool[jt.Array, "ne"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -186,12 +191,15 @@ def solve( root_cond_fn: Callable, delta_x: Callable, known_discs: jt.Float[jt.Array, "*nediscs"], + observable_ids: list[str], ) -> tuple[jt.Float[jt.Array, "nt nxs"], jt.Float[jt.Array, "nt ne"], dict]: """ Simulate the ODE system for the specified timepoints. :param p: parameters + :param t0: + initial time point :param ts: time points at which solutions are evaluated :param tcl: @@ -216,6 +224,8 @@ def solve( function to compute state changes at events :param known_discs: known discontinuities, used to clip the step size controller + :param observable_ids: + list of observable IDs :return: solution+heaviside variables at time points ts and statistics """ @@ -223,7 +233,7 @@ def solve( if not root_cond_fns: # no events, we can just run a single segment sol, _, stats = _run_segment( - 0.0, + t0, ts[-1], x0, p, @@ -301,6 +311,7 @@ def body_fn(carry): term, root_cond_fn, delta_x, + h_mask, stats, ) @@ -310,12 +321,12 @@ def body_fn(carry): return ys, t0_next, y0_next, hs, h_next, stats # run the loop until we have reached the end of the time points - ys, _, _, hs, _, stats = eqxi.while_loop( + ys, t0_next, y0_next, hs, _, stats = eqxi.while_loop( cond_fn, body_fn, ( jnp.zeros((ts.shape[0], x0.shape[0]), dtype=x0.dtype) + x0, - 0.0, + t0, x0, jnp.zeros((ts.shape[0], h.shape[0]), dtype=h.dtype), h, @@ -325,6 +336,12 @@ def body_fn(carry): max_steps=2**6, ) + mask = ts == t0_next + n_obs = len(observable_ids) + y0_obs = y0_next[None, -n_obs:] + updated_last = jnp.where(mask[:, None], y0_obs, ys[:, -n_obs:]) + ys = ys.at[:, -n_obs:].set(updated_last) + return ys, hs, stats @@ -353,7 +370,6 @@ def _run_segment( triggered during the integration. ``None`` indicates that the solver reached ``t_end`` without any event firing. """ - # combine all discontinuity conditions into a single diffrax.Event event = ( diffrax.Event( @@ -419,6 +435,7 @@ def _handle_event( term: diffrax.ODETerm, root_cond_fn: Callable, delta_x: Callable, + h_mask: jt.Bool[jt.Array, "ne"], stats: dict, ): args = (p, tcl, h) @@ -446,6 +463,8 @@ def _handle_event( delta_x, ) + h_next = jnp.where(h_mask, h_next, h) + if os.getenv("JAX_DEBUG") == "1": jax.debug.print( "rootvals: {}, roots_found: {}, roots_dir: {}, h: {}, h_next: {}", @@ -456,8 +475,52 @@ def _handle_event( h_next, ) + y0_next = _check_cascading_events( + t0_next, + y0_next, + rootvals, + p, + tcl, + h_next, + root_finder, + term, + root_cond_fn, + delta_x, + h_mask, + ) + return y0_next, h_next, stats +def _check_cascading_events( + t0_next: float, + y0_next: jt.Float[jt.Array, "nxs"], + rootval_prev: jt.Float[jt.Array, "nroots"], + p: jt.Float[jt.Array, "np"], + tcl: jt.Float[jt.Array, "ncl"], + h: jt.Float[jt.Array, "ne"], + root_finder: AbstractRootFinder, + term: diffrax.ODETerm, + root_cond_fn: Callable, + delta_x: Callable, + h_mask: jt.Bool[jt.Array, "ne"], +): + args = (p, tcl, h) + rootvals = root_cond_fn(t0_next, y0_next, args) + root_vals_changed_sign = jnp.sign(rootvals) != jnp.sign(rootval_prev) + roots_dir = jnp.sign(jnp.sign(rootvals) - jnp.sign(rootval_prev)) + + y0_next, _ = _apply_event_assignments( + root_vals_changed_sign, + roots_dir, + y0_next, + p, + tcl, + h, + delta_x, + ) + + return y0_next + def _apply_event_assignments( roots_found, roots_dir, @@ -479,10 +542,22 @@ def _apply_event_assignments( for _ in range(y0_next.shape[0]) ] ).T - delx = delta_x(y0_next, p, tcl) - if y0_next.size: - delx = delx.reshape(delx.size // y0_next.shape[0], y0_next.shape[0],) - y0_up = jnp.where(mask, delx, 0.0) - y0_next = y0_next + jnp.sum(y0_up, axis=0) + + # apply one event at a time + if h_next.shape[0] and y0_next.shape[0]: + n_pairs = h_next.shape[0] // 2 + inds_seq = jnp.arange(n_pairs) + + def body(y, e): + inds = jnp.array([e * 2, e * 2 + 1]) + delx = delta_x(y, p, tcl) + if y.size: + delx = delx.reshape(delx.size // y.shape[0], y.shape[0]) + keep = jnp.zeros_like(mask).at[inds, :].set(True) + updated_mask = jnp.where(keep, mask, False) + y = y + jnp.sum(jnp.where(updated_mask, delx, 0.0), axis=0) + return y, None + + y0_next, _ = jax.lax.scan(body, y0_next, inds_seq) return y0_next, h_next diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index fe0ff12d8d..d50f44c7a2 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -21,16 +21,16 @@ class JAXModel_TPL_MODEL_NAME(JAXModel): def __init__(self): self.jax_py_file = Path(__file__).resolve() self.nns = {TPL_NETS} - self.parameters = TPL_P_VALUES + self.parameters = TPL_ALL_P_VALUES super().__init__() def _xdot(self, t, x, args): p, tcl, h = args TPL_X_SYMS = x - TPL_P_SYMS = p + TPL_ALL_P_SYMS = p TPL_TCL_SYMS = tcl - TPL_IH_SYMS = h + TPL_H_SYMS = h TPL_W_SYMS = self._w(t, x, p, tcl, h) TPL_XDOT_EQ @@ -39,16 +39,16 @@ def _xdot(self, t, x, args): def _w(self, t, x, p, tcl, h): TPL_X_SYMS = x - TPL_P_SYMS = p + TPL_ALL_P_SYMS = p TPL_TCL_SYMS = tcl - TPL_IH_SYMS = h + TPL_H_SYMS = h TPL_W_EQ return TPL_W_RET def _x0(self, t, p): - TPL_P_SYMS = p + TPL_ALL_P_SYMS = p TPL_X0_EQ @@ -71,7 +71,7 @@ def _x_rdata(self, x, tcl): def _tcl(self, x, p): TPL_X_RDATA_SYMS = x - TPL_P_SYMS = p + TPL_ALL_P_SYMS = p TPL_TOTAL_CL_EQ @@ -79,7 +79,7 @@ def _tcl(self, x, p): def _y(self, t, x, p, tcl, h, op): TPL_X_SYMS = x - TPL_P_SYMS = p + TPL_ALL_P_SYMS = p TPL_W_SYMS = self._w(t, x, p, tcl, h) TPL_OP_SYMS = op @@ -88,7 +88,7 @@ def _y(self, t, x, p, tcl, h, op): return TPL_Y_RET def _sigmay(self, y, p, np): - TPL_P_SYMS = p + TPL_ALL_P_SYMS = p TPL_Y_SYMS = y TPL_NP_SYMS = np @@ -110,7 +110,7 @@ def _nllh(self, t, x, p, tcl, h, my, iy, op, np): return TPL_JY_RET.at[iy].get() def _known_discs(self, p): - TPL_P_SYMS = p + TPL_ALL_P_SYMS = p return TPL_ROOTS @@ -118,9 +118,9 @@ def _root_cond_fn(self, t, y, args, **_): p, tcl, h = args TPL_X_SYMS = y - TPL_P_SYMS = p + TPL_ALL_P_SYMS = p TPL_TCL_SYMS = tcl - TPL_IH_SYMS = h + TPL_H_SYMS = h TPL_W_SYMS = self._w(t, y, p, tcl, h) TPL_IROOT_EQ @@ -130,7 +130,7 @@ def _root_cond_fn(self, t, y, args, **_): def _delta_x(self, y, p, tcl): TPL_X_SYMS = y - TPL_P_SYMS = p + TPL_ALL_P_SYMS = p TPL_TCL_SYMS = tcl # FIXME: workaround until state from event time is properly passed TPL_X_OLD_SYMS = y @@ -157,7 +157,7 @@ def state_ids(self): @property def parameter_ids(self): - return TPL_P_IDS + return TPL_ALL_P_IDS @property def expression_ids(self): diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 7a1636c651..ed072e3d2f 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -593,11 +593,14 @@ def simulate_condition_unjitted( ], max_steps: int | jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), + h_preeq: jt.Float[jt.Array, "*ne"] = jnp.array([]), mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]), init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]), init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]), ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]), + h_mask: jt.Bool[jt.Array, "ne"] = jnp.array([]), + t_zero: jnp.float_ = 0.0, ret: ReturnValue = ReturnValue.llh, ) -> tuple[jt.Float[jt.Array, "*nt"], dict]: """ @@ -605,10 +608,13 @@ def simulate_condition_unjitted( See :meth:`simulate_condition` for full documentation. """ - t0 = 0.0 + t0 = t_zero if p is None: p = self.parameters + if not h_mask.shape[0]: + h_mask = jnp.ones(self.n_events, dtype=jnp.bool_) + if x_preeq.shape[0]: x = x_preeq elif init_override.shape[0]: @@ -625,6 +631,7 @@ def simulate_condition_unjitted( # Re-initialization if x_reinit.shape[0]: x = jnp.where(mask_reinit, x_reinit, x) + x_solver = self._x_solver(x) tcl = self._tcl(x, p) @@ -636,6 +643,8 @@ def simulate_condition_unjitted( root_finder, self._root_cond_fn, self._delta_x, + h_mask, + h_preeq, {}, ) @@ -643,10 +652,12 @@ def simulate_condition_unjitted( if ts_dyn.shape[0]: x_dyn, h_dyn, stats_dyn = solve( p, + t0, ts_dyn, tcl, h, x_solver, + h_mask, solver, controller, root_finder, @@ -657,6 +668,7 @@ def simulate_condition_unjitted( self._root_cond_fn, self._delta_x, self._known_discs(p), + self.observable_ids, ) x_solver = x_dyn[-1, :] else: @@ -671,6 +683,7 @@ def simulate_condition_unjitted( tcl, h, x_solver, + h_mask, solver, controller, root_finder, @@ -771,11 +784,14 @@ def simulate_condition( ], max_steps: int | jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), + h_preeq: jt.Bool[jt.Array, "*ne"] = jnp.array([]), mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]), init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]), init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]), ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]), + h_mask: jt.Bool[jt.Array, "ne"] = jnp.array([]), + t_zero: jnp.float_ = 0.0, ret: ReturnValue = ReturnValue.llh, ) -> tuple[jt.Float[jt.Array, "*nt"], dict]: r""" @@ -828,6 +844,9 @@ def simulate_condition( :param ts_mask: mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2. + :param h_mask: + mask for heaviside variables. If `True`, the corresponding heaviside variable is updated during simulation, otherwise it + it marked as 1.0. :param ret: which output to return. See :class:`ReturnValue` for available options. :return: @@ -849,11 +868,14 @@ def simulate_condition( steady_state_event, max_steps, x_preeq, + h_preeq, mask_reinit, x_reinit, init_override, init_override_mask, ts_mask, + h_mask, + t_zero, ret, ) @@ -863,6 +885,7 @@ def preequilibrate_condition( p: jt.Float[jt.Array, "np"] | None, x_reinit: jt.Float[jt.Array, "*nx"], mask_reinit: jt.Bool[jt.Array, "*nx"], + h_mask: jt.Bool[jt.Array, "ne"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -881,6 +904,9 @@ def preequilibrate_condition( re-initialized state vector. If not provided, the state vector is not re-initialized. :param mask_reinit: mask for re-initialization. If `True`, the corresponding state variable is re-initialized. + :param h_mask: + mask for heaviside variables. If `True`, the corresponding heaviside variable is updated during simulation, otherwise it + it marked as 1.0. :param solver: ODE solver :param controller: @@ -895,6 +921,9 @@ def preequilibrate_condition( if p is None: p = self.parameters + if not h_mask.shape[0]: + h_mask = jnp.ones(self.n_events, dtype=jnp.bool_) + x0 = self._x0(t0, p) if x_reinit.shape[0]: x0 = jnp.where(mask_reinit, x_reinit, x0) @@ -910,14 +939,17 @@ def preequilibrate_condition( root_finder, self._root_cond_fn, self._delta_x, + h_mask, + jnp.array([]), {}, ) - current_x, _, stats_preeq = eq( + current_x, h, stats_preeq = eq( p, tcl, h, current_x, + h_mask, solver, controller, root_finder, @@ -930,7 +962,7 @@ def preequilibrate_condition( max_steps, ) - return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq) + return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq), h def _handle_t0_event( self, @@ -941,10 +973,18 @@ def _handle_t0_event( root_finder: AbstractRootFinder, root_cond_fn: Callable, delta_x: Callable, + h_mask: jt.Bool[jt.Array, "ne"], + h_preeq: jt.Bool[jt.Array, "ne"], stats: dict, ): + y0 = y0_next.copy() rf0 = self.event_initial_values - 0.5 - h = jnp.heaviside(rf0, 0.0) + + if h_preeq.shape[0]: + # return immediately because preequilibration is equivalent to handling t0 event? + return y0, t0_next, h_preeq, stats + else: + h = jnp.where(h_mask, jnp.heaviside(rf0, 0.0), jnp.ones_like(rf0)) args = (p, tcl, h) rfx = root_cond_fn(t0_next, y0_next, args) roots_dir = jnp.sign(rfx - rf0) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 29b6458957..3cbe2a34b2 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -17,6 +17,7 @@ from pathlib import Path import sympy as sp +import numpy as np from amici import ( amiciModulePath, @@ -212,11 +213,13 @@ def _generate_jax_code(self) -> None: ) sym_names = ( "p", + "k", "np", "op", "x", "tcl", "ih", + "h", "w", "my", "y", @@ -259,6 +262,11 @@ def _generate_jax_code(self) -> None: # tuple of variable names (ids as they are unique) **_jax_variable_ids(self.model, ("p", "k", "y", "w", "x_rdata")), "P_VALUES": _jnp_array_str(self.model.val("p")), + "ALL_P_VALUES": _jnp_array_str(self.model.val("p") + self.model.val("k")), + "ALL_P_IDS": "".join(f'"{s.name}", ' for s in self._get_all_p_syms()) + if self._get_all_p_syms() else "tuple()", + "ALL_P_SYMS": "".join(f"{s.name}, " for s in self._get_all_p_syms()) + if self._get_all_p_syms() else "_", "ROOTS": _jnp_array_str( { _print_trigger_root(root) @@ -295,6 +303,9 @@ def _generate_jax_code(self) -> None: tpl_data, ) + def _get_all_p_syms(self) -> list[sp.Symbol]: + return list(self.model.sym("p")) + list(self.model.sym("k")) + def _generate_nn_code(self) -> None: for net_name, net in self.hybridization.items(): generate_equinox( diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 29425cf6c5..6ecb4abbe2 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -8,6 +8,7 @@ from numbers import Number from pathlib import Path +import os import diffrax import equinox as eqx import h5py @@ -17,7 +18,8 @@ import numpy as np import optimistix import pandas as pd -import petab.v1 as petab +import petab.v1 as petabv1 +import petab.v2 as petabv2 from optimistix import AbstractRootFinder from amici import _module_from_path @@ -27,6 +29,12 @@ ) from amici.jax.model import JAXModel, ReturnValue from amici.logging import get_logger +from amici.sim.jax import ( + add_default_experiment_names_to_v2_problem, get_simulation_conditions_v2, _build_simulation_df_v2, _try_float +) + +import time +import tracemalloc DEFAULT_CONTROLLER_SETTINGS = { "atol": 1e-8, @@ -42,9 +50,9 @@ } SCALE_TO_INT = { - petab.LIN: 0, - petab.LOG: 1, - petab.LOG10: 2, + petabv2.C.LIN: 0, + petabv2.C.LOG: 1, + petabv2.C.LOG10: 2, } logger = get_logger(__name__, logging.WARNING) @@ -60,30 +68,43 @@ def jax_unscale( parameter: Parameter to be unscaled. scale_str: - One of ``petab.LIN``, ``petab.LOG``, ``petab.LOG10``. + One of ``petabv2.C.LIN``, ``petabv2.C.LOG``, ``petabv2.C.LOG10``. Returns: The unscaled parameter. """ - if scale_str == petab.LIN or not scale_str: + if scale_str == petabv2.C.LIN or not scale_str: return parameter - if scale_str == petab.LOG: + if scale_str == petabv2.C.LOG: return jnp.exp(parameter) - if scale_str == petab.LOG10: + if scale_str == petabv2.C.LOG10: return jnp.power(10, parameter) raise ValueError(f"Invalid parameter scaling: {scale_str}") # IDEA: Implement this class in petab-sciml instead? -class HybridProblem(petab.Problem): +class HybridProblem(petabv1.Problem): + hybridization_df: pd.DataFrame + + def __init__(self, petab_problem: petabv1.Problem): + self.__dict__.update(petab_problem.__dict__) + self.hybridization_df = _get_hybridization_df(petab_problem) + +class HybridV2Problem(petabv2.Problem): hybridization_df: pd.DataFrame + extensions_config: dict - def __init__(self, petab_problem: petab.Problem): + def __init__(self, petab_problem: petabv2.Problem): + if not hasattr(petab_problem, "extensions_config"): + self.extensions_config = {} self.__dict__.update(petab_problem.__dict__) self.hybridization_df = _get_hybridization_df(petab_problem) def _get_hybridization_df(petab_problem): + if not hasattr(petab_problem, "extensions_config"): + return None + if "sciml" in petab_problem.extensions_config: hybridizations = [ pd.read_csv(hf, sep="\t", index_col=0) @@ -95,7 +116,9 @@ def _get_hybridization_df(petab_problem): return hybridization_df -def _get_hybrid_petab_problem(petab_problem: petab.Problem): +def _get_hybrid_petab_problem(petab_problem: petabv1.Problem | petabv2.Problem): + if isinstance(petab_problem, petabv2.Problem): + return HybridV2Problem(petab_problem) return HybridProblem(petab_problem) @@ -132,9 +155,9 @@ class JAXProblem(eqx.Module): _np_mask: np.ndarray _np_indices: np.ndarray _petab_measurement_indices: np.ndarray - _petab_problem: petab.Problem | HybridProblem + _petab_problem: petabv1.Problem | HybridProblem | petabv2.Problem - def __init__(self, model: JAXModel, petab_problem: petab.Problem): + def __init__(self, model: JAXModel, petab_problem: petabv1.Problem | petabv2.Problem): """ Initialize a JAXProblem instance with a model and a PEtab problem. @@ -143,13 +166,18 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): :param petab_problem: PEtab problem to simulate. """ - scs = petab_problem.get_simulation_conditions_from_measurement_df() - self.simulation_conditions = tuple(tuple(sc) for sc in scs.values) + if isinstance(petab_problem, petabv1.Problem): + raise TypeError( + "JAXProblem does not support PEtab v1 problems. Upgrade the problem to PEtab v2." + ) + petab_problem = add_default_experiment_names_to_v2_problem(petab_problem) + scs = get_simulation_conditions_v2(petab_problem) + self.simulation_conditions = scs.conditionId.to_list() self._petab_problem = _get_hybrid_petab_problem(petab_problem) self.parameters, self.model = ( self._initialize_model_with_nominal_values(model) ) - self._parameter_mappings = self._get_parameter_mappings(scs) + self._parameter_mappings = None ( self._ts_dyn, self._ts_posteq, @@ -197,7 +225,7 @@ def load(cls, directory: Path): :return: Loaded problem instance. """ - petab_problem = petab.Problem.from_yaml( + petab_problem = petabv2.Problem.from_yaml( directory / "problem.yaml", ) model = _module_from_path("jax", directory / "jax_py_file.py").Model() @@ -205,40 +233,40 @@ def load(cls, directory: Path): with open(directory / "parameters.pkl", "rb") as f: return eqx.tree_deserialise_leaves(f, problem) - def _get_parameter_mappings( - self, simulation_conditions: pd.DataFrame - ) -> dict[str, ParameterMappingForCondition]: - """ - Create parameter mappings for the provided simulation conditions. - - :param simulation_conditions: - Simulation conditions to create parameter mappings for. Same format as returned by - :meth:`petab.Problem.get_simulation_conditions_from_measurement_df`. - :return: - Dictionary mapping simulation conditions to parameter mappings. - """ - scs = list(set(simulation_conditions.values.flatten())) - petab_problem = copy.deepcopy(self._petab_problem) - # remove observable and noise parameters from measurement dataframe as we are mapping them elsewhere - petab_problem.measurement_df.drop( - columns=[petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS], - inplace=True, - errors="ignore", - ) - mappings = create_parameter_mapping( - petab_problem=petab_problem, - simulation_conditions=[ - {petab.SIMULATION_CONDITION_ID: sc} for sc in scs - ], - scaled_parameters=False, - allow_timepoint_specific_numeric_noise_parameters=True, - ) - # fill in dummy variables - for mapping in mappings: - for sim_var, value in mapping.map_sim_var.items(): - if isinstance(value, Number) and not np.isfinite(value): - mapping.map_sim_var[sim_var] = 1.0 - return dict(zip(scs, mappings, strict=True)) + # def _get_parameter_mappings( + # self, simulation_conditions: pd.DataFrame + # ) -> dict[str, ParameterMappingForCondition]: + # """ + # Create parameter mappings for the provided simulation conditions. + + # :param simulation_conditions: + # Simulation conditions to create parameter mappings for. Same format as returned by + # :meth:`petabv1.Problem.get_simulation_conditions_from_measurement_df`. + # :return: + # Dictionary mapping simulation conditions to parameter mappings. + # """ + # scs = list(set(simulation_conditions.conditionId)) + # petab_problem = copy.deepcopy(self._petab_problem) + # # remove observable and noise parameters from measurement dataframe as we are mapping them elsewhere + # petab_problem.measurement_df.drop( + # columns=[petabv2.C.OBSERVABLE_PARAMETERS, petabv2.C.NOISE_PARAMETERS], + # inplace=True, + # errors="ignore", + # ) + # mappings = create_parameter_mapping( + # petab_problem=petab_problem, + # simulation_conditions=[ + # {petabv2.C.SIMULATION_CONDITION_ID: sc} for sc in scs + # ], + # scaled_parameters=False, + # allow_timepoint_specific_numeric_noise_parameters=True, + # ) + # # fill in dummy variables + # for mapping in mappings: + # for sim_var, value in mapping.map_sim_var.items(): + # if isinstance(value, Number) and not np.isfinite(value): + # mapping.map_sim_var[sim_var] = 1.0 + # return dict(zip(scs, mappings, strict=True)) def _get_measurements( self, simulation_conditions: pd.DataFrame @@ -262,7 +290,7 @@ def _get_measurements( :param simulation_conditions: Simulation conditions to create parameter mappings for. Same format as returned by - :meth:`petab.Problem.get_simulation_conditions_from_measurement_df`. + :meth:`petabv1.Problem.get_simulation_conditions_from_measurement_df`. :return: tuple of padded - dynamic time points @@ -283,7 +311,7 @@ def _get_measurements( petab_indices = dict() n_pars = dict() - for col in [petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS]: + for col in [petabv2.C.OBSERVABLE_PARAMETERS, petabv2.C.NOISE_PARAMETERS]: n_pars[col] = 0 if col in self._petab_problem.measurement_df: if pd.api.types.is_numeric_dtype( @@ -295,7 +323,7 @@ def _get_measurements( else: n_pars[col] = ( self._petab_problem.measurement_df[col] - .str.split(petab.C.PARAMETER_SEPARATOR) + .str.split(petabv2.C.PARAMETER_SEPARATOR) .apply( lambda x: len(x) if isinstance(x, Sized) @@ -305,38 +333,49 @@ def _get_measurements( ) for _, simulation_condition in simulation_conditions.iterrows(): - query = " & ".join( - [f"{k} == '{v}'" for k, v in simulation_condition.items()] - ) + if "preequilibration" in simulation_condition[ + petabv2.C.CONDITION_ID + ]: + continue + + if isinstance(self._petab_problem, HybridV2Problem): + query = " & ".join( + [ + f"{k} == '{v}'" + if isinstance(v, str) + else f"{k} == {v}" + for k, v in simulation_condition.items() + if k != petabv2.C.CONDITION_ID + ] + ) + else: + query = " & ".join( + [f"{k} == '{v}'" for k, v in simulation_condition.items()] + ) m = self._petab_problem.measurement_df.query(query).sort_values( - by=petab.TIME + by=petabv2.C.TIME ) - ts = m[petab.TIME] + ts = m[petabv2.C.TIME] ts_dyn = ts[np.isfinite(ts)] ts_posteq = ts[np.logical_not(np.isfinite(ts))] index = pd.concat([ts_dyn, ts_posteq]).index ts_dyn = ts_dyn.values ts_posteq = ts_posteq.values - my = m[petab.MEASUREMENT].values + my = m[petabv2.C.MEASUREMENT].values iys = np.array( [ self.model.observable_ids.index(oid) - for oid in m[petab.OBSERVABLE_ID].values + for oid in m[petabv2.C.OBSERVABLE_ID].values ] ) - if ( - petab.OBSERVABLE_TRANSFORMATION - in self._petab_problem.observable_df - ): + if petabv2.C.NOISE_DISTRIBUTION in self._petab_problem.observable_df: iy_trafos = np.array( [ - SCALE_TO_INT[ - self._petab_problem.observable_df.loc[ - oid, petab.OBSERVABLE_TRANSFORMATION - ] - ] - for oid in m[petab.OBSERVABLE_ID].values + SCALE_TO_INT[petabv2.C.LOG] + if obs.noise_distribution == petabv2.C.LOG_NORMAL + else SCALE_TO_INT[petabv2.C.LIN] + for obs in self._petab_problem.observables ] ) else: @@ -350,16 +389,16 @@ def get_parameter_override(x): if ( x in self._petab_problem.parameter_df.index and not self._petab_problem.parameter_df.loc[ - x, petab.ESTIMATE + x, petabv2.C.ESTIMATE ] ): return self._petab_problem.parameter_df.loc[ - x, petab.NOMINAL_VALUE + x, petabv2.C.NOMINAL_VALUE ] return x - for col in [petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS]: - if col not in m or m[col].isna().all(): + for col in [petabv2.C.OBSERVABLE_PARAMETERS, petabv2.C.NOISE_PARAMETERS]: + if col not in m or m[col].isna().all() or all(m[col] == ''): mat_numeric = jnp.ones((len(m), n_pars[col])) par_mask = np.zeros_like(mat_numeric, dtype=bool) par_index = np.zeros_like(mat_numeric, dtype=int) @@ -368,7 +407,7 @@ def get_parameter_override(x): par_mask = np.zeros_like(mat_numeric, dtype=bool) par_index = np.zeros_like(mat_numeric, dtype=int) else: - split_vals = m[col].str.split(petab.C.PARAMETER_SEPARATOR) + split_vals = m[col].str.split(petabv2.C.PARAMETER_SEPARATOR) list_vals = split_vals.apply( lambda x: [get_parameter_override(y) for y in x] if isinstance(x, list) @@ -413,15 +452,15 @@ def get_parameter_override(x): iys, # 3 iy_trafos, # 4 parameter_overrides_numeric_vals[ - petab.OBSERVABLE_PARAMETERS + petabv2.C.OBSERVABLE_PARAMETERS ], # 5 - parameter_overrides_mask[petab.OBSERVABLE_PARAMETERS], # 6 + parameter_overrides_mask[petabv2.C.OBSERVABLE_PARAMETERS], # 6 parameter_overrides_par_indices[ - petab.OBSERVABLE_PARAMETERS + petabv2.C.OBSERVABLE_PARAMETERS ], # 7 - parameter_overrides_numeric_vals[petab.NOISE_PARAMETERS], # 8 - parameter_overrides_mask[petab.NOISE_PARAMETERS], # 9 - parameter_overrides_par_indices[petab.NOISE_PARAMETERS], # 10 + parameter_overrides_numeric_vals[petabv2.C.NOISE_PARAMETERS], # 8 + parameter_overrides_mask[petabv2.C.NOISE_PARAMETERS], # 9 + parameter_overrides_par_indices[petabv2.C.NOISE_PARAMETERS], # 10 ) petab_indices[tuple(simulation_condition)] = tuple(index.tolist()) @@ -522,10 +561,16 @@ def pad_and_stack(output_index: int): ) def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: - simulation_conditions = ( - self._petab_problem.get_simulation_conditions_from_measurement_df() - ) - return tuple(tuple(row) for _, row in simulation_conditions.iterrows()) + if isinstance(self._petab_problem, HybridV2Problem): + simulation_conditions = ( + get_simulation_conditions_v2(self._petab_problem) + ) + return tuple(tuple([row.conditionId]) for _, row in simulation_conditions.iterrows()) + else: + simulation_conditions = ( + self._petab_problem.get_simulation_conditions_from_measurement_df() + ) + return tuple(tuple(row) for _, row in simulation_conditions.iterrows()) def _initialize_model_parameters(self, model: JAXModel) -> dict: """ @@ -657,11 +702,11 @@ def _extract_nominal_values_from_petab( scalar = True # Determine value source (scalar from PEtab or array from file) - if np.isnan(row[petab.NOMINAL_VALUE]): + if np.isnan(row[petabv2.C.NOMINAL_VALUE]): value = par_arrays[net] scalar = False else: - value = float(row[petab.NOMINAL_VALUE]) + value = float(row[petabv2.C.NOMINAL_VALUE]) # Parse parameter name and set values to_set = self._parse_parameter_name(pname, model_pars) @@ -753,14 +798,14 @@ def _create_scaled_parameter_array(self) -> jt.Float[jt.Array, "np"]: """ return jnp.array( [ - petab.scale( + petabv2.scale( float( self._petab_problem.parameter_df.loc[ - pval, petab.NOMINAL_VALUE + pval, petabv2.C.NOMINAL_VALUE ] ), self._petab_problem.parameter_df.loc[ - pval, petab.PARAMETER_SCALE + pval, petabv2.PARAMETER_SCALE ], ) for pval in self.parameter_ids @@ -802,7 +847,17 @@ def _initialize_model_with_nominal_values( model = self._set_input_arrays(model, nn_input_arrays, model_pars) # Create scaled parameter array - parameter_array = self._create_scaled_parameter_array() + if isinstance(self._petab_problem, HybridV2Problem): + param_map = { + p.id: p.nominal_value + for p in self._petab_problem.parameters + } + parameter_array = jnp.array([ + float(param_map[pval]) + for pval in self.parameter_ids + ]) + else: + parameter_array = self._create_scaled_parameter_array() return parameter_array, model @@ -826,7 +881,7 @@ def _get_inputs(self) -> dict: .max(axis=0) + 1 ) - inputs[row["netId"]][row[petab.MODEL_ENTITY_ID]] = data_flat[ + inputs[row["netId"]][row[petabv2.C.MODEL_ENTITY_ID]] = data_flat[ "value" ].values.reshape(shape) return inputs @@ -839,14 +894,7 @@ def parameter_ids(self) -> list[str]: :return: PEtab parameter ids """ - return self._petab_problem.parameter_df[ - self._petab_problem.parameter_df[petab.ESTIMATE] - == 1 - & pd.to_numeric( - self._petab_problem.parameter_df[petab.NOMINAL_VALUE], - errors="coerce", - ).notna() - ].index.tolist() + return self._petab_problem.parameter_df[petabv2.C.ESTIMATE].index.tolist() @property def nn_output_ids(self) -> list[str]: @@ -858,8 +906,10 @@ def nn_output_ids(self) -> list[str]: """ if self._petab_problem.mapping_df is None: return [] + if self._petab_problem.mapping_df[petabv2.C.MODEL_ENTITY_ID].isnull().all(): + return [] return self._petab_problem.mapping_df[ - self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + self._petab_problem.mapping_df[petabv2.C.MODEL_ENTITY_ID] .str.split(".") .str[1] .str.startswith("output") @@ -895,7 +945,7 @@ def _unscale( def _eval_nn(self, output_par: str, condition_id: str): net_id = self._petab_problem.mapping_df.loc[ - output_par, petab.MODEL_ENTITY_ID + output_par, petabv2.C.MODEL_ENTITY_ID ].split(".")[0] nn = self.model.nns[net_id] @@ -905,12 +955,12 @@ def _is_net_input(model_id): model_id_map = ( self._petab_problem.mapping_df[ - self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID].apply( + self._petab_problem.mapping_df[petabv2.C.MODEL_ENTITY_ID].apply( _is_net_input ) ] .reset_index() - .set_index(petab.MODEL_ENTITY_ID)[petab.PETAB_ENTITY_ID] + .set_index(petabv2.C.MODEL_ENTITY_ID)[petabv2.C.PETAB_ENTITY_ID] .to_dict() ) @@ -923,7 +973,7 @@ def _is_net_input(model_id): self._petab_problem.condition_df.loc[ condition_id, petab_id ], - petab.NOMINAL_VALUE, + petabv2.C.NOMINAL_VALUE, ], ) if self._petab_problem.condition_df.loc[ @@ -982,11 +1032,11 @@ def _is_net_input(model_id): else self.get_petab_parameter_by_id(petab_id) if petab_id in self.parameter_ids else self._petab_problem.parameter_df.loc[ - petab_id, petab.NOMINAL_VALUE + petab_id, petabv2.C.NOMINAL_VALUE ] if petab_id in set(self._petab_problem.parameter_df.index) else self._petab_problem.parameter_df.loc[ - hybridization_parameter_map[petab_id], petab.NOMINAL_VALUE + hybridization_parameter_map[petab_id], petabv2.C.NOMINAL_VALUE ] for model_id, petab_id in model_id_map.items() ] @@ -1004,7 +1054,7 @@ def _map_model_parameter_value( nn_output = self._eval_nn(pval, condition_id) if nn_output.size > 1: entityId = self._petab_problem.mapping_df.loc[ - pval, petab.MODEL_ENTITY_ID + pval, petabv2.C.MODEL_ENTITY_ID ] ind = int(re.search(r"\[\d+\]\[(\d+)\]", entityId).group(1)) return nn_output[ind] @@ -1015,36 +1065,97 @@ def _map_model_parameter_value( return self.get_petab_parameter_by_id(pval) def load_model_parameters( - self, simulation_condition: str + self, experiment: petabv2.Experiment, is_preeq: bool ) -> jt.Float[jt.Array, "np"]: """ - Load parameters for a simulation condition. + Load parameters for an experiment. - :param simulation_condition: - Simulation condition to load parameters for. + :param experiment: + Experiment to load parameters for. + :param is_preeq: + Whether to load preequilibration or simulation parameters. :return: - Parameters for the simulation condition. + Parameters for the experiment. """ - mapping = self._parameter_mappings[simulation_condition] - p = jnp.array( [ - self._map_model_parameter_value( - mapping, pname, simulation_condition + self._map_experiment_model_parameter_value( + pname, ind, experiment, is_preeq ) - for pname in self.model.parameter_ids + for ind, pname in enumerate(self.model.parameter_ids) ] ) pscale = tuple( [ - petab.LIN - if self._petab_problem.mapping_df is not None - and pname in self._petab_problem.mapping_df.index - else mapping.scale_map_sim_var[pname] - for pname in self.model.parameter_ids + petabv2.C.LIN + for _ in self.model.parameter_ids ] ) + return self._unscale(p, pscale) + + def _map_experiment_model_parameter_value( + self, pname: str, p_index: int, experiment: petabv2.Experiment, is_preeq: bool + ): + """ + Get values for the given parameter `pname` from the relevant petab tables. + + :param pname: PEtab parameter id + :param p_index: Index of the parameter in the model's parameter list + :param experiment: PEtab experiment + :param is_preeq: Whether to get preequilibration or simulation parameter value + :return: Value of the parameter + """ + condition_ids = [] + for p in experiment.sorted_periods: + if is_preeq: + if not p.is_preequilibration: + continue + else: + condition_ids = p.condition_ids + break + else: + if p.is_preequilibration: + continue + else: + condition_ids = p.condition_ids + break + + init_val = self.model.parameters[p_index] + params_nominals = {p.id: p.nominal_value for p in self._petab_problem.parameters} + targets_map = { + ch.target_id: ch.target_value + for c in self._petab_problem.conditions + for ch in c.changes + if c.id in condition_ids + } + if pname in params_nominals: + return params_nominals[pname] + elif pname in targets_map: + return float(targets_map[pname]) + else: + for placeholder_attr, param_attr in ( + ("observable_placeholders", "observable_parameters"), + ("noise_placeholders", "noise_parameters"), + ): + placeholders = [getattr(o, placeholder_attr) for o in self._petab_problem.observables] + + for placeholders in placeholders: + params_list = getattr(self._petab_problem.measurements[0], param_attr) + for i, p in enumerate(placeholders): + if str(p) == pname: + val = self._find_val(str(params_list[i]), params_nominals) + return val + return init_val + + def _find_val(self, param_entry: str, params_nominals: dict): + val_float = _try_float(param_entry) + if isinstance(val_float, float): + return val_float + elif param_entry in params_nominals: + return params_nominals[param_entry] + else: + return param_entry def _state_needs_reinitialisation( self, @@ -1114,12 +1225,12 @@ def _state_reinitialisation_value( return jax_unscale( self.get_petab_parameter_by_id(xval), self._petab_problem.parameter_df.loc[ - xval, petab.PARAMETER_SCALE + xval, petabv2.PARAMETER_SCALE ], ) # only remaining option is nominal value for PEtab parameter # that is not estimated, return nominal value - return self._petab_problem.parameter_df.loc[xval, petab.NOMINAL_VALUE] + return self._petab_problem.parameter_df.loc[xval, petabv2.C.NOMINAL_VALUE] def load_reinitialisation( self, @@ -1169,9 +1280,11 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem": """ return eqx.tree_at(lambda p: p.parameters, self, p) - def _prepare_conditions( + def _prepare_experiments( self, + experiments: list[petabv2.Experiment], conditions: list[str], + is_preeq: bool, op_numeric: np.ndarray | None = None, op_mask: np.ndarray | None = None, op_indices: np.ndarray | None = None, @@ -1186,10 +1299,14 @@ def _prepare_conditions( jt.Float[jt.Array, "nc nt nnp"], # noqa: F821, F722 ]: """ - Prepare conditions for simulation. + Prepare experiments for simulation. + :param experiments: + Experiments to prepare simulation arrays for. :param conditions: Simulation conditions to prepare. + :param is_preeq: + Whether to load preequilibration or simulation parameters. :param op_numeric: Numeric values for observable parameter overrides. If None, no overrides are used. :param op_mask: @@ -1207,24 +1324,49 @@ def _prepare_conditions( noise parameters. """ p_array = jnp.stack( - [self.load_model_parameters(sc) for sc in conditions] + [self.load_model_parameters(exp, is_preeq) for exp in experiments] ) + exp_ids = [exp.id for exp in experiments] + all_exp_ids = [exp.id for exp in self._petab_problem.experiments] + + h_mask = jnp.stack( + [ + jnp.ones(self.model.n_events) + if (exp_id in exp_ids) + else jnp.zeros(self.model.n_events) + for exp_id in all_exp_ids + ] + ) + + t_zeros = jnp.stack([ + exp.periods[0].time if exp.periods[0].time >= 0.0 else 0.0 for exp in experiments + ]) + if self.parameters.size: - unscaled_parameters = jnp.stack( - [ - jax_unscale( - self.parameters[ip], - self._petab_problem.parameter_df.loc[ - p_id, petab.PARAMETER_SCALE - ], - ) - for ip, p_id in enumerate(self.parameter_ids) - ] - ) + if isinstance(self._petab_problem, HybridV2Problem): + unscaled_parameters = jnp.stack( + [ + self.parameters[ip] + for ip, p_id in enumerate(self.parameter_ids) + ] + ) + else: + unscaled_parameters = jnp.stack( + [ + jax_unscale( + self.parameters[ip], + self._petab_problem.parameter_df.loc[ + p_id, petabv2.C.PARAMETER_SCALE + ], + ) + for ip, p_id in enumerate(self.parameter_ids) + ] + ) else: unscaled_parameters = jnp.zeros((*self._ts_masks.shape[:2], 0)) + # placeholder values from sundials code may be needed here if op_numeric is not None and op_numeric.size: op_array = jnp.where( op_mask, @@ -1259,7 +1401,7 @@ def _prepare_conditions( for sc, p in zip(conditions, p_array) ] ) - return p_array, mask_reinit_array, x_reinit_array, op_array, np_array + return p_array, mask_reinit_array, x_reinit_array, op_array, np_array, h_mask, t_zeros @eqx.filter_vmap( in_axes={ @@ -1281,6 +1423,7 @@ def run_simulation( x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 init_override: jt.Float[jt.Array, "nx"], # noqa: F821, F722 init_override_mask: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 + h_mask: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -1289,14 +1432,16 @@ def run_simulation( ], max_steps: jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 + h_preeq: jt.Bool[jt.Array, "*ne"] = jnp.array([]), # noqa: F821, F722 ts_mask: np.ndarray = np.array([]), + t_zeros: jnp.float_ = 0.0, ret: ReturnValue = ReturnValue.llh, ) -> tuple[jnp.float_, dict]: """ - Run a simulation for a given simulation condition. + Run a simulation for a given simulation experiment. :param p: - Parameters for the simulation condition + Parameters for the simulation experiment :param ts_dyn: (Padded) dynamic time points :param ts_posteq: @@ -1315,6 +1460,8 @@ def run_simulation( Mask for states that need reinitialisation :param x_reinit: Reinitialisation values for states + :param h_mask: + Mask for the events that are part of the current experiment :param solver: ODE solver to use for simulation :param controller: @@ -1327,8 +1474,12 @@ def run_simulation( :param x_preeq: Pre-equilibration state. Can be empty if no pre-equilibration is available, in which case the states will be initialised to the model default values. + :param h_preeq: + Pre-equilibration event mask. Can be empty if no pre-equilibration is available :param ts_mask: padding mask, see :meth:`JAXModel.simulate_condition` for details. + :param t_zeros: + simulation start time for the current experiment. :param ret: which output to return. See :class:`ReturnValue` for available options. :return: @@ -1344,11 +1495,14 @@ def run_simulation( nps=nps, ops=ops, x_preeq=x_preeq, + h_preeq=h_preeq, mask_reinit=jax.lax.stop_gradient(mask_reinit), x_reinit=x_reinit, init_override=init_override, init_override_mask=jax.lax.stop_gradient(init_override_mask), ts_mask=jax.lax.stop_gradient(jnp.array(ts_mask)), + h_mask=jax.lax.stop_gradient(jnp.array(h_mask)), + t_zero=t_zeros, solver=solver, controller=controller, root_finder=root_finder, @@ -1362,8 +1516,9 @@ def run_simulation( def run_simulations( self, - simulation_conditions: list[str], + experiments: list[petabv2.Experiment], preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722 + h_preeqs: jt.Bool[jt.Array, "ncond *ne"], # noqa: F821 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -1374,13 +1529,16 @@ def run_simulations( ret: ReturnValue = ReturnValue.llh, ): """ - Run simulations for a list of simulation conditions. + Run simulations for a list of simulation experiments. - :param simulation_conditions: - List of simulation conditions to run simulations for. + :param experiments: + Experiments to run simulations for. :param preeq_array: Matrix of pre-equilibrated states for the simulation conditions. Ordering must match the simulation conditions. If no pre-equilibration is available for a condition, the corresponding row must be empty. + :param h_preeqs: + Matrix of pre-equilibration event heaviside variables indicating whether an event condition is false or + true after preequilibration. :param solver: ODE solver to use for simulation. :param controller: @@ -1396,9 +1554,15 @@ def run_simulations( Output value and condition specific results and statistics. Results and statistics are returned as a dict with arrays with the leading dimension corresponding to the simulation conditions. """ - p_array, mask_reinit_array, x_reinit_array, op_array, np_array = ( - self._prepare_conditions( - simulation_conditions, + simulation_conditions = [cid for exp in experiments for p in exp.periods for cid in p.condition_ids] + dynamic_conditions = list(sc for sc in simulation_conditions if "preequilibration" not in sc) + dynamic_conditions = list(dict.fromkeys(dynamic_conditions)) + + p_array, mask_reinit_array, x_reinit_array, op_array, np_array, h_mask, t_zeros = ( + self._prepare_experiments( + experiments, + dynamic_conditions, + False, self._op_numeric, self._op_mask, self._op_indices, @@ -1413,27 +1577,25 @@ def run_simulations( jnp.array( [ p - in set(self._parameter_mappings[sc].map_sim_var.keys()) + in set(self.model.parameter_ids) for p in self.model.state_ids ] ) - for sc in simulation_conditions + for _ in experiments ] ) init_override = jnp.stack( [ jnp.array( [ - self._eval_nn( - self._parameter_mappings[sc].map_sim_var[p], sc - ) + self._eval_nn(p, exp.periods[-1].condition_ids[0]) # TODO: Add mapping of p to eval_nn? if p - in set(self._parameter_mappings[sc].map_sim_var.keys()) + in set(self.model.parameter_ids) else 1.0 for p in self.model.state_ids ] ) - for sc in simulation_conditions + for exp in experiments ] ) @@ -1450,13 +1612,16 @@ def run_simulations( x_reinit_array, init_override, init_override_mask, + h_mask, solver, controller, root_finder, steady_state_event, max_steps, preeq_array, + h_preeqs, self._ts_masks, + t_zeros, ret, ) @@ -1471,6 +1636,7 @@ def run_preequilibration( p: jt.Float[jt.Array, "np"], # noqa: F821, F722 mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 + h_mask: jt.Bool[jt.Array, "ne"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -1480,14 +1646,16 @@ def run_preequilibration( max_steps: jnp.int_, ) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821 """ - Run a pre-equilibration simulation for a given simulation condition. + Run a pre-equilibration simulation for a given simulation experiment. :param p: - Parameters for the simulation condition + Parameters for the simulation experiment :param mask_reinit: Mask for states that need reinitialisation :param x_reinit: Reinitialisation values for states + :param h_mask: + Mask for the events that are part of the current experiment :param solver: ODE solver to use for simulation :param controller: @@ -1504,6 +1672,7 @@ def run_preequilibration( p=p, mask_reinit=mask_reinit, x_reinit=x_reinit, + h_mask=h_mask, solver=solver, controller=controller, root_finder=root_finder, @@ -1513,7 +1682,7 @@ def run_preequilibration( def run_preequilibrations( self, - simulation_conditions: list[str], + experiments: list[petabv2.Experiment], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -1522,13 +1691,19 @@ def run_preequilibrations( ], max_steps: jnp.int_, ): - p_array, mask_reinit_array, x_reinit_array, _, _ = ( - self._prepare_conditions(simulation_conditions, None, None) + simulation_conditions = [cid for exp in experiments for p in exp.periods for cid in p.condition_ids] + preequilibration_conditions = list( + {sc for sc in simulation_conditions if "preequilibration" in sc} + ) + + p_array, mask_reinit_array, x_reinit_array, _, _, h_mask, _ = ( + self._prepare_experiments(experiments, preequilibration_conditions, True, None, None) ) return self.run_preequilibration( p_array, mask_reinit_array, x_reinit_array, + h_mask, solver, controller, root_finder, @@ -1539,7 +1714,7 @@ def run_preequilibrations( def run_simulations( problem: JAXProblem, - simulation_conditions: Iterable[tuple[str, ...]] | None = None, + simulation_experiments: Iterable[str] | None = None, solver: diffrax.AbstractSolver = diffrax.Kvaerno5(), controller: diffrax.AbstractStepSizeController = diffrax.PIDController( **DEFAULT_CONTROLLER_SETTINGS @@ -1558,10 +1733,9 @@ def run_simulations( :param problem: Problem to run simulations for. - :param simulation_conditions: - Simulation conditions to run simulations for. This is a series of tuples, where each tuple contains the - simulation condition or the pre-equilibration condition followed by the simulation condition. Default is to run - simulations for all conditions. + :param simulation_experiments: + Simulation experiments to run simulations for. This is an iterable of experiment ids. + Default is to run simulations for all experiments. :param solver: ODE solver to use for simulation. :param controller: @@ -1578,67 +1752,73 @@ def run_simulations( :return: Overall output value and condition specific results and statistics. """ + if isinstance(problem, HybridProblem) or isinstance(problem._petab_problem, petabv1.Problem): + raise TypeError( + "run_simulations does not support PEtab v1 problems. Upgrade the problem to PEtab v2." + ) + if isinstance(ret, str): ret = ReturnValue[ret] - if simulation_conditions is None: - simulation_conditions = problem.get_all_simulation_conditions() - - dynamic_conditions = [sc[0] for sc in simulation_conditions] - preequilibration_conditions = list( - {sc[1] for sc in simulation_conditions if len(sc) > 1} - ) + if simulation_experiments is None: + experiments = problem._petab_problem.experiments + else: + experiments = [exp for exp in problem._petab_problem.experiments if exp.id in simulation_experiments] + simulation_conditions = [cid for exp in experiments for p in exp.periods for cid in p.condition_ids] + dynamic_conditions = list(sc for sc in simulation_conditions if "preequilibration" not in sc) + dynamic_conditions = list(dict.fromkeys(dynamic_conditions)) conditions = { "dynamic_conditions": dynamic_conditions, - "preequilibration_conditions": preequilibration_conditions, - "simulation_conditions": simulation_conditions, } - if preequilibration_conditions: - preeqs, preresults = problem.run_preequilibrations( - preequilibration_conditions, + has_preeq = any(exp.periods[0].time < 0.0 for exp in experiments) + + if has_preeq: + preeqs, preresults, h_preeqs = problem.run_preequilibrations( + experiments, solver, controller, root_finder, steady_state_event, max_steps, ) + preeqs_array = preeqs else: preresults = { "stats_preeq": None, } - - if dynamic_conditions: - preeq_array = jnp.stack( + preeqs_array = jnp.stack( [ - preeqs[preequilibration_conditions.index(sc[1]), :] - if len(sc) > 1 - else jnp.array([]) - for sc in simulation_conditions + jnp.array([]) + for _ in experiments ] ) - output, results = problem.run_simulations( - dynamic_conditions, - preeq_array, - solver, - controller, - root_finder, - steady_state_event, - max_steps, - ret, + h_preeqs = jnp.stack( + [ + jnp.array([]) + for _ in experiments + ] ) - else: - output = jnp.array(0.0) - results = { - "llh": jnp.array([]), - "stats_dyn": None, - "stats_posteq": None, - "ts": jnp.array([]), - "x": jnp.array([]), - } + + output, results = problem.run_simulations( + experiments, + preeqs_array, + h_preeqs, + solver, + controller, + root_finder, + steady_state_event, + max_steps, + ret, + ) if ret in (ReturnValue.llh, ReturnValue.chi2): + if os.getenv("JAX_DEBUG") == "1": + jax.debug.print( + "ret: {}", + ret, + ) output = jnp.sum(output) return output, results | preresults | conditions @@ -1680,50 +1860,55 @@ def petab_simulate( max_steps=max_steps, ret=ReturnValue.y, ) - dfs = [] - for ic, sc in enumerate(r["dynamic_conditions"]): - obs = [ - problem.model.observable_ids[io] - for io in problem._iys[ic, problem._ts_masks[ic, :]] - ] - t = jnp.concat( - ( - problem._ts_dyn[ic, :], - problem._ts_posteq[ic, :], - ) - ) - df_sc = pd.DataFrame( - { - petab.SIMULATION: y[ic, problem._ts_masks[ic, :]], - petab.TIME: t[problem._ts_masks[ic, :]], - petab.OBSERVABLE_ID: obs, - petab.SIMULATION_CONDITION_ID: [sc] * len(t), - }, - index=problem._petab_measurement_indices[ic, :], - ) - if ( - petab.OBSERVABLE_PARAMETERS - in problem._petab_problem.measurement_df - ): - df_sc[petab.OBSERVABLE_PARAMETERS] = ( - problem._petab_problem.measurement_df.query( - f"{petab.SIMULATION_CONDITION_ID} == '{sc}'" - )[petab.OBSERVABLE_PARAMETERS] - ) - if petab.NOISE_PARAMETERS in problem._petab_problem.measurement_df: - df_sc[petab.NOISE_PARAMETERS] = ( - problem._petab_problem.measurement_df.query( - f"{petab.SIMULATION_CONDITION_ID} == '{sc}'" - )[petab.NOISE_PARAMETERS] + if isinstance(problem._petab_problem, HybridV2Problem): + return _build_simulation_df_v2(problem, y, r["dynamic_conditions"]) + else: + dfs = [] + for ic, sc in enumerate(r["dynamic_conditions"]): + obs = [ + problem.model.observable_ids[io] + for io in problem._iys[ic, problem._ts_masks[ic, :]] + ] + t = jnp.concat( + ( + problem._ts_dyn[ic, :], + problem._ts_posteq[ic, :], + ) ) - if ( - petab.PREEQUILIBRATION_CONDITION_ID - in problem._petab_problem.measurement_df - ): - df_sc[petab.PREEQUILIBRATION_CONDITION_ID] = ( - problem._petab_problem.measurement_df.query( - f"{petab.SIMULATION_CONDITION_ID} == '{sc}'" - )[petab.PREEQUILIBRATION_CONDITION_ID] + df_sc = pd.DataFrame( + { + petabv2.C.SIMULATION: y[ic, problem._ts_masks[ic, :]], + petabv2.C.TIME: t[problem._ts_masks[ic, :]], + petabv2.C.OBSERVABLE_ID: obs, + petabv2.C.CONDITION_ID: [sc] * len(t), + }, + index=problem._petab_measurement_indices[ic, :], ) - dfs.append(df_sc) - return pd.concat(dfs).sort_index() + if ( + petabv2.C.OBSERVABLE_PARAMETERS + in problem._petab_problem.measurement_df + ): + df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petabv2.C.CONDITION_ID} == '{sc}'" + )[petabv2.C.OBSERVABLE_PARAMETERS] + ) + if petabv2.C.NOISE_PARAMETERS in problem._petab_problem.measurement_df: + df_sc[petabv2.C.NOISE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petabv2.C.CONDITION_ID} == '{sc}'" + )[petabv2.C.NOISE_PARAMETERS] + ) + if ( + petabv2.C.PREEQUILIBRATION_CONDITION_ID + in problem._petab_problem.measurement_df + ): + df_sc[petabv2.C.PREEQUILIBRATION_CONDITION_ID] = ( + problem._petab_problem.measurement_df.query( + f"{petabv2.C.CONDITION_ID} == '{sc}'" + )[petabv2.C.PREEQUILIBRATION_CONDITION_ID] + ) + dfs.append(df_sc) + return pd.concat(dfs).sort_index() + + diff --git a/python/sdist/amici/sim/jax/__init__.py b/python/sdist/amici/sim/jax/__init__.py index 07744420df..f17cda0e31 100644 --- a/python/sdist/amici/sim/jax/__init__.py +++ b/python/sdist/amici/sim/jax/__init__.py @@ -1 +1,125 @@ """Functionality for simulating JAX-based AMICI models.""" + +import petab.v2 as petabv2 + +import pandas as pd +import jax.numpy as jnp + +def add_default_experiment_names_to_v2_problem(petab_problem: petabv2.Problem): + """Add default experiment names to PEtab v2 problem. + + Args: + petab_problem: PEtab v2 problem to modify. + """ + if not hasattr(petab_problem, "extensions_config"): + petab_problem.extensions_config = {} + + petab_problem.visualization_df = None + + if petab_problem.condition_df is None: + default_condition = petabv2.core.Condition(id="__default__", changes=[], conditionId="__default__") + petab_problem.condition_tables[0].elements = [default_condition] + + if petab_problem.experiment_df is None or petab_problem.experiment_df.empty: + condition_ids = petab_problem.condition_df[petabv2.C.CONDITION_ID].values + condition_ids = [c for c in condition_ids if "preequilibration" not in c] + default_experiment = petabv2.core.Experiment( + id="__default__", + periods=[ + petabv2.core.ExperimentPeriod( + time=0.0, + condition_ids=condition_ids + ) + ], + ) + petab_problem.experiment_tables[0].elements = [default_experiment] + + measurement_tables = petab_problem.measurement_tables.copy() + for mt in measurement_tables: + for m in mt.elements: + m.experiment_id = "__default__" + + petab_problem.measurement_tables = measurement_tables + + return petab_problem + +def get_simulation_conditions_v2(petab_problem) -> pd.DataFrame: + """Get simulation conditions from PEtab v2 measurement DataFrame. + + Returns: + A pandas DataFrame mapping experiment_ids to condition ids. + """ + experiment_df = petab_problem.experiment_df + exps = {} + for exp_id in experiment_df[petabv2.C.EXPERIMENT_ID].unique(): + exps[exp_id] = experiment_df[ + experiment_df[petabv2.C.EXPERIMENT_ID] == exp_id + ][petabv2.C.CONDITION_ID].unique() + + experiment_df = experiment_df.drop(columns=[petabv2.C.TIME]) + return experiment_df + +def _build_simulation_df_v2(problem, y, dyn_conditions): + """Build petab simulation DataFrame of similation results from a PEtab v2 problem.""" + dfs = [] + for ic, sc in enumerate(dyn_conditions): + experiment_id = _conditions_to_experiment_map( + problem._petab_problem.experiment_df + )[sc] + + if experiment_id == "__default__": + experiment_id = jnp.nan + + obs = [ + problem.model.observable_ids[io] + for io in problem._iys[ic, problem._ts_masks[ic, :]] + ] + t = jnp.concat( + ( + problem._ts_dyn[ic, :], + problem._ts_posteq[ic, :], + ) + ) + df_sc = pd.DataFrame( + { + petabv2.C.MODEL_ID: [float("nan")] * len(t), + petabv2.C.OBSERVABLE_ID: obs, + petabv2.C.EXPERIMENT_ID: [experiment_id] * len(t), + petabv2.C.TIME: t[problem._ts_masks[ic, :]], + petabv2.C.SIMULATION: y[ic, problem._ts_masks[ic, :]], + }, + index=problem._petab_measurement_indices[ic, :], + ) + if ( + petabv2.C.OBSERVABLE_PARAMETERS + in problem._petab_problem.measurement_df + ): + df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'" + )[petabv2.C.OBSERVABLE_PARAMETERS] + ) + if petabv2.C.NOISE_PARAMETERS in problem._petab_problem.measurement_df: + df_sc[petabv2.C.NOISE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'" + )[petabv2.C.NOISE_PARAMETERS] + ) + dfs.append(df_sc) + return pd.concat(dfs).sort_index() + +def _conditions_to_experiment_map(experiment_df: pd.DataFrame) -> dict[str, str]: + condition_to_experiment = { + row.conditionId: row.experimentId + for row in experiment_df.itertuples() + } + return condition_to_experiment + +def _try_float(value): + try: + return float(value) + except Exception as e: + msg = str(e).lower() + if isinstance(e, ValueError) and "could not convert" in msg: + return value + raise \ No newline at end of file diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 838a9f8144..ee80af9f67 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -284,21 +284,27 @@ def check_fields_jax( def test_preequilibration_failure(lotka_volterra): # noqa: F811 petab_problem = lotka_volterra # oscillating system, preequilibation should fail when interaction is active - with TemporaryDirectoryWinSafe(prefix="normal") as model_dir: - jax_problem = import_petab_problem( - petab_problem, jax=True, output_dir=model_dir - ) - r = run_simulations(jax_problem) - assert not np.isinf(r[0].item()) - petab_problem.measurement_df[PREEQUILIBRATION_CONDITION_ID] = ( - petab_problem.measurement_df[SIMULATION_CONDITION_ID] - ) - with TemporaryDirectoryWinSafe(prefix="failure") as model_dir: - jax_problem = import_petab_problem( - petab_problem, jax=True, output_dir=model_dir + + try: + with TemporaryDirectoryWinSafe(prefix="normal") as model_dir: + jax_problem = import_petab_problem( + petab_problem, jax=True, output_dir=model_dir + ) + r = run_simulations(jax_problem) + assert not np.isinf(r[0].item()) + petab_problem.measurement_df[PREEQUILIBRATION_CONDITION_ID] = ( + petab_problem.measurement_df[SIMULATION_CONDITION_ID] ) - r = run_simulations(jax_problem) - assert np.isinf(r[0].item()) + with TemporaryDirectoryWinSafe(prefix="failure") as model_dir: + jax_problem = import_petab_problem( + petab_problem, jax=True, output_dir=model_dir + ) + r = run_simulations(jax_problem) + assert np.isinf(r[0].item()) + except (TypeError, NotImplementedError) as err: + if "JAXProblem does not support PEtab v1 problems" in str(err): + pytest.skip(str(err)) + raise err @skip_on_valgrind @@ -307,23 +313,28 @@ def test_serialisation(lotka_volterra): # noqa: F811 with TemporaryDirectoryWinSafe( prefix=petab_problem.model.model_id ) as model_dir: - jax_problem = import_petab_problem( - petab_problem, jax=True, output_dir=model_dir - ) - # change parameters to random values to test serialisation - jax_problem.update_parameters( - jax_problem.parameters - + jr.normal(jr.PRNGKey(0), jax_problem.parameters.shape) - ) - - with TemporaryDirectoryWinSafe() as outdir: - outdir = Path(outdir) - jax_problem.save(outdir) - jax_problem_loaded = JAXProblem.load(outdir) - assert_allclose( - jax_problem.parameters, jax_problem_loaded.parameters + try: + jax_problem = import_petab_problem( + petab_problem, jax=True, output_dir=model_dir + ) + # change parameters to random values to test serialisation + jax_problem.update_parameters( + jax_problem.parameters + + jr.normal(jr.PRNGKey(0), jax_problem.parameters.shape) ) + with TemporaryDirectoryWinSafe() as outdir: + outdir = Path(outdir) + jax_problem.save(outdir) + jax_problem_loaded = JAXProblem.load(outdir) + assert_allclose( + jax_problem.parameters, jax_problem_loaded.parameters + ) + except (TypeError, NotImplementedError) as err: + if "JAXProblem does not support PEtab v1 problems" in str(err): + pytest.skip(str(err)) + raise err + @skip_on_valgrind def test_time_dependent_discontinuity(tmp_path): @@ -362,10 +373,12 @@ def test_time_dependent_discontinuity(tmp_path): ys, _, _ = solve( p, + ts[0], ts, tcl, h, x0, + jnp.ones_like(h), diffrax.Tsit5(), diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS), optimistix.Newton(atol=1e-8, rtol=1e-8), @@ -376,6 +389,7 @@ def test_time_dependent_discontinuity(tmp_path): model._root_cond_fn, model._delta_x, model._known_discs(p), + model.observable_ids, ) assert ys.shape[0] == ts.shape[0] @@ -424,6 +438,7 @@ def test_time_dependent_discontinuity_equilibration(tmp_path): tcl, h, x0, + jnp.ones_like(h), diffrax.Tsit5(), diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS), optimistix.Newton(atol=1e-8, rtol=1e-8), @@ -438,7 +453,9 @@ def test_time_dependent_discontinuity_equilibration(tmp_path): assert_allclose(xs[0], 0.0, atol=1e-2) - except NotImplementedError as err: + except (TypeError, NotImplementedError) as err: if "The JAX backend does not support" in str(err): pytest.skip(str(err)) + elif "JAXProblem does not support PEtab v1 problems" in str(err): + pytest.skip(str(err)) raise err diff --git a/scripts/installAmiciSource.sh b/scripts/installAmiciSource.sh index 4c416a5788..436ac86f97 100755 --- a/scripts/installAmiciSource.sh +++ b/scripts/installAmiciSource.sh @@ -39,7 +39,7 @@ python -m pip install --upgrade pip wheel python -m pip install --upgrade pip setuptools cmake_build_extension==0.6.0 numpy petab swig python -m pip install git+https://github.com/pysb/pysb@master # for SPM with compartments python -m pip install git+https://github.com/patrick-kidger/diffrax@main # for events with direction -python -m pip install optax # for jax petab notebook +python -m pip install 'optax<0.2.7' # for jax petab notebook AMICI_BUILD_TEMP="${AMICI_PATH}/python/sdist/build/temp" \ python -m pip install --verbose -e "${AMICI_PATH}/python/sdist[petab,test,vis,jax]" --no-build-isolation deactivate diff --git a/tests/benchmark_models/test_petab_benchmark_jax.py b/tests/benchmark_models/test_petab_benchmark_jax.py index a3389498dd..066e90d94d 100644 --- a/tests/benchmark_models/test_petab_benchmark_jax.py +++ b/tests/benchmark_models/test_petab_benchmark_jax.py @@ -48,6 +48,11 @@ def test_jax_llh(benchmark_problem): f"Skipping {problem_id} due to non-supported events in JAX." ) + if problem_id == "Oliveira_NatCommun2021": + pytest.skip( + "Skipping Oliveira_NatCommun2021 due to non-supported events in JAX." + ) + amici_solver = amici_model.create_solver() cur_settings = settings[problem_id] amici_solver.set_absolute_tolerance(1e-8) @@ -95,57 +100,64 @@ def test_jax_llh(benchmark_problem): r_amici = simulate_amici() llh_amici = r_amici[LLH] - jax_problem = import_petab_problem( - petab_problem, - output_dir=benchmark_outdir / (problem_id + "_jax"), - jax=True, - ) - if problem_parameters: - jax_problem = eqx.tree_at( - lambda x: x.parameters, - jax_problem, - jnp.array( - [problem_parameters[pid] for pid in jax_problem.parameter_ids] - ), + try: + jax_problem = import_petab_problem( + petab_problem, + output_dir=benchmark_outdir / (problem_id + "_jax"), + jax=True, ) + if problem_parameters: + jax_problem = eqx.tree_at( + lambda x: x.parameters, + jax_problem, + jnp.array( + [problem_parameters[pid] for pid in jax_problem.parameter_ids] + ), + ) - if problem_id in problems_for_gradient_check: - if problem_id == "Weber_BMC2015": - atol = cur_settings.atol_sim - rtol = cur_settings.rtol_sim - max_steps = 2 * 10**5 + if problem_id in problems_for_gradient_check: + if problem_id == "Weber_BMC2015": + atol = cur_settings.atol_sim + rtol = cur_settings.rtol_sim + max_steps = 2 * 10**5 + else: + atol = 1e-8 + rtol = 1e-8 + max_steps = 1024 + beartype(run_simulations)(jax_problem) + (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( + run_simulations, has_aux=True + )( + jax_problem, + max_steps=max_steps, + controller=diffrax.PIDController( + atol=atol, + rtol=rtol, + ) + ) else: - atol = 1e-8 - rtol = 1e-8 - max_steps = 1024 - beartype(run_simulations)(jax_problem) - (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( - run_simulations, has_aux=True - )( - jax_problem, - max_steps=max_steps, - controller=diffrax.PIDController( - atol=atol, - rtol=rtol, - ), - ) - else: - llh_jax, _ = beartype(run_simulations)(jax_problem) - - np.testing.assert_allclose( - llh_jax, - llh_amici, - rtol=1e-3, - atol=1e-3, - err_msg=f"LLH mismatch for {problem_id}", - ) + llh_jax, _ = beartype(run_simulations)(jax_problem) - if problem_id in problems_for_gradient_check: - sllh_amici = r_amici[SLLH] np.testing.assert_allclose( - sllh_jax.parameters, - np.array([sllh_amici[pid] for pid in jax_problem.parameter_ids]), - rtol=1e-2, - atol=1e-2, - err_msg=f"SLLH mismatch for {problem_id}, {dict(zip(jax_problem.parameter_ids, sllh_jax.parameters))}", + llh_jax, + llh_amici, + rtol=1e-3, + atol=1e-3, + err_msg=f"LLH mismatch for {problem_id}", ) + + if problem_id in problems_for_gradient_check: + sllh_amici = r_amici[SLLH] + np.testing.assert_allclose( + sllh_jax.parameters, + np.array([sllh_amici[pid] for pid in jax_problem.parameter_ids]), + rtol=1e-2, + atol=1e-2, + err_msg=f"SLLH mismatch for {problem_id}, {dict(zip(jax_problem.parameter_ids, sllh_jax.parameters))}", + ) + except (NotImplementedError, TypeError) as err: + if "JAXProblem does not support PEtab v1 problems" in str(err): + pytest.skip(str(err)) + elif "The JAX backend does not support" in str(err): + pytest.skip(str(err)) + raise err diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index 9271386e4a..581201713e 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -50,6 +50,8 @@ def test_case(case, model_type, version, jax): f"implemented: {e}" ) pytest.skip(str(e)) + elif "JAXProblem does not support PEtab v1" in str(e): + pytest.skip(str(e)) else: raise e diff --git a/tests/petab_test_suite/test_petab_v2_suite.py b/tests/petab_test_suite/test_petab_v2_suite.py index ba98eb3931..d43cec15c2 100755 --- a/tests/petab_test_suite/test_petab_v2_suite.py +++ b/tests/petab_test_suite/test_petab_v2_suite.py @@ -4,6 +4,7 @@ import logging import sys +import diffrax import pandas as pd import petabtests import pytest @@ -20,12 +21,15 @@ ) from amici.sim.sundials.petab import PetabSimulator from petab import v2 +import jax logger = get_logger(__name__, logging.DEBUG) set_log_level(get_logger("amici.petab_import"), logging.DEBUG) stream_handler = logging.StreamHandler() logger.addHandler(stream_handler) +jax.config.update("jax_enable_x64", True) + @pytest.mark.filterwarnings( "ignore:Event `_E0` has `useValuesFromTriggerTime=true'" @@ -66,28 +70,74 @@ def _test_case(case, model_type, version, jax): f"petab_{model_type}_test_case_{case}_{version.replace('.', '_')}" ) - pi = PetabImporter( - petab_problem=problem, - module_name=model_name, - compile_=True, - jax=jax, - ) - ps = pi.create_simulator( - force_import=True, - ) - ps.solver.set_steady_state_tolerance_factor(1.0) - - problem_parameters = problem.get_x_nominal_dict(free=True, fixed=True) - res = ps.simulate(problem_parameters=problem_parameters) - rdatas = res.rdatas - for rdata in rdatas: - assert rdata.status == AMICI_SUCCESS, ( - f"Simulation failed for {rdata.id}" + if jax: + from amici.jax import petab_simulate, run_simulations + from amici.jax.petab import DEFAULT_CONTROLLER_SETTINGS + + steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6) + + pi = PetabImporter( + petab_problem=problem, + module_name=model_name, + compile_=True, + jax=jax, + ) + + jax_problem = pi.create_simulator( + force_import=True, + ) + + if case.startswith("0016"): + controller = diffrax.PIDController( + **DEFAULT_CONTROLLER_SETTINGS, + dtmax=0.5 + ) + else: + controller = diffrax.PIDController( + **DEFAULT_CONTROLLER_SETTINGS + ) + + llh, _ = run_simulations( + jax_problem, + steady_state_event=steady_state_event, + controller=controller, + ) + chi2, _ = run_simulations( + jax_problem, + ret="chi2", + steady_state_event=steady_state_event, + controller=controller, + ) + simulation_df = petab_simulate( + jax_problem, + steady_state_event=steady_state_event, + controller=controller, + ) + else: + pi = PetabImporter( + petab_problem=problem, + module_name=model_name, + compile_=True, + jax=jax, ) - chi2 = sum(rdata.chi2 for rdata in rdatas) - llh = res.llh - simulation_df = rdatas_to_simulation_df(rdatas, ps.model, pi.petab_problem) + + ps = pi.create_simulator( + force_import=True, + ) + ps.solver.set_steady_state_tolerance_factor(1.0) + + problem_parameters = problem.get_x_nominal_dict(free=True, fixed=True) + res = ps.simulate(problem_parameters=problem_parameters) + + rdatas = res.rdatas + for rdata in rdatas: + assert rdata.status == AMICI_SUCCESS, ( + f"Simulation failed for {rdata.id}" + ) + chi2 = sum(rdata.chi2 for rdata in rdatas) + llh = res.llh + simulation_df = rdatas_to_simulation_df(rdatas, ps.model, pi.petab_problem) solution = petabtests.load_solution(case, model_type, version=version) gt_chi2 = solution[petabtests.CHI2] @@ -148,6 +198,8 @@ def _test_case(case, model_type, version, jax): else: if (case, model_type, version) in ( ("0016", "sbml", "v2.0.0"), + ("0024", "sbml", "v2.0.0"), + ("0025", "sbml", "v2.0.0"), ("0013", "pysb", "v2.0.0"), ): # FIXME: issue with events and sensitivities @@ -194,24 +246,24 @@ def run(): n_skipped = 0 n_total = 0 version = "v2.0.0" - jax = False - - cases = list(petabtests.get_cases("sbml", version=version)) - n_total += len(cases) - for case in cases: - try: - test_case(case, "sbml", version=version, jax=jax) - n_success += 1 - except Skipped: - n_skipped += 1 - except Exception as e: - # run all despite failures - logger.error(f"Case {case} failed.") - logger.exception(e) - - logger.info(f"{n_success} / {n_total} successful, {n_skipped} skipped") - if n_success != len(cases): - sys.exit(1) + + for jax in (False, True): + cases = list(petabtests.get_cases("sbml", version=version)) + n_total += len(cases) + for case in cases: + try: + test_case(case, "sbml", version=version, jax=jax) + n_success += 1 + except Skipped: + n_skipped += 1 + except Exception as e: + # run all despite failures + logger.error(f"Case {case} failed.") + logger.exception(e) + + logger.info(f"{n_success} / {n_total} successful, {n_skipped} skipped") + if n_success != len(cases): + sys.exit(1) if __name__ == "__main__":