In [1]:
from abc import ABC
from abc import abstractmethod
from functools import partial
from dataclasses import fields
from typing import Callable

import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, tree_leaves
import equinox as eqx
import diffrax
import numpy as np
import chex
jax.config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt

In [114]:
class CoreEnvironment(eqx.Module):
    tau: jax.Array #= eqx.field(static=True)
    _solver: diffrax.AbstractSolver #= eqx.field(static=True)
    env_properties: eqx.Module #= eqx.field(static=True)
    in_axes_env_properties: eqx.Module #= eqx.field(static=True)
    action_dim: int #= eqx.field(static=True) 
    physical_state_dim: int #= eqx.field(static=True)

    """
    Core Structure of provided Environments. Any new environments needs to inherit from this class
    and implement its abstract properties and methods.

    The simulations are all done with physical state space models. That means that the underlying description
    of the system is given through the differential equation describing the relationship between
    the change of the physical state x(t) w.r.t. the time as a function of the physical state and the
    input/action u(t) applied:

    dx(t)/dt = f(x(t), u(t)).

    The actual outputs of these simulations are discretized from this equation through the use of
    ODE solvers.

    NOTE: There is a difference between the state of the environment and the physical state x(t)
    of the underlying system. The former can also hold various helper variables such as PRNGKeys
    for stochastic environments, while the latter is reserved for the actual physical state of the
    ODE. The physical state is only a part of the full state.
    """

    def __init__(
        self,
        env_properties: eqx.Module,
        tau: float = 1e-4,
        solver=diffrax.Euler(),
    ):
        """Initialization of an environment.

        Args:
            batch_size (int): Number of parallel environment simulations.
            env_properties(eqx.Module): All parameters and properties of the environment.
            tau (float): Duration of one control step in seconds. Default: 1e-4.
            solver (diffrax.solver): ODE solver used to approximate the ODE solution.
        """
        self.tau = tau
        self._solver = solver
        self.env_properties = env_properties
        self.in_axes_env_properties = self.create_in_axes_dataclass(env_properties)
        self.action_dim = len(fields(self.Action))
        self.physical_state_dim = len(fields(self.PhysicalState))

    @abstractmethod
    class PhysicalState(eqx.Module):
        """The physical state x(t) of the underlying system and whose derivative
        w.r.t. time is described in the underlying ODE.

        The values stored in this dataclass are expected to be actual physical values
        that are unnormalized and given in SI units.
        """

        pass

    @abstractmethod
    class Additions(eqx.Module):
        """
        Stores additional environment state variables that may change over time.

        These variables do not directly belong to the physical state but are
        necessary for computations (e.g., internal buffers).
        """

        pass

    @abstractmethod
    class StaticParams(eqx.Module):
        """
        Holds static parameters of the environment that remain constant during simulation.

        Examples:
            - Length of a pendulum
            - Capacitance of a capacitor
            - Mass of an object
        """

        pass

    @abstractmethod
    class Action(eqx.Module):
        """
        Represents the input/action applied to the environment.

        The action influences the system dynamics through the function `f(x(t), u(t))`.
        """

        pass

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def _ode_solver_step(self, state, action):
        """
        Performs a single step of state evolution using the ODE solver.

        Args:
            state: The current state of the system.
            action: The action applied at the current step.
            static_params: Static parameters of the environment.

        Returns:
            state: The updated state after one simulation step.
        """
        return

    @partial(jax.jit, static_argnums=[0, 3, 4])
    @abstractmethod
    def _ode_solver_simulate_ahead(self, init_state, actions, obs_stepsize, action_stepsize):
        """
        Simulates a trajectory by applying a sequence of actions.

        Args:
            init_state: Initial state at the start of the trajectory.
            actions: Sequence of actions to be applied (shape=(n_action_steps, action_dim)).
            static_params: Static environment parameters.
            obs_stepsize: Sampling interval for observations.
            action_stepsize: Interval between consecutive action updates.

        Returns:
            states: Simulated trajectory states over time.
        """
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def init_state(self, rng: chex.PRNGKey = None, vmap_helper=None):
        """
        Generates an initial state for the environment.

        Args:
            env_properties: Environment properties.
            rng (optional): Random key for random initialization.
            vmap_helper (optional): Helper variable for vectorized computation.

        Returns:
            state: The initial state.
        """
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def generate_observation(self, state):
        """
        Generates an observation from the given state.

        Args:
            state: Current state of the environment.
            env_properties: Environment properties.

        Returns:
            observation: The computed observation.
        """
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def generate_state_from_observation(self, obs, key=None):
        """
        Generates state from a given observation.

        Args:
            obs: The given observation.
            env_properties: Environment properties.
            key (optional): Random key.

        Returns:
            state: Computed state.
        """
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def generate_reward(self, state, action):
        """
        Computes the reward for a given state-action pair.

        Args:
            state: The current environment state.
            action: The action applied at the current step.
            env_properties: Environment properties.

        Returns:
            reward: Computed reward value.
        """
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def generate_truncated(self, state):
        """
        Computes truncated flag for given state.

        Args:
            state: The current environment state.
            env_properties: Environment properties.

        Returns:
            truncated: Computed truncated flag.
        """
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def generate_terminated(self, state, reward):
        """
        Computes terminated flag for given state and reward.

        Args:
            state: The current environment state.
            reward: The reward for current state-action pair.
            env_properties: Environment properties.

        Returns:
            terminated: Computed terminated flag.
        """
        return

    class State(eqx.Module):
        """The state of the environment."""

        physical_state: eqx.Module
        PRNGKey: jax.Array
        additions: eqx.Module
        reference: eqx.Module

    class EnvProperties(eqx.Module):
        """The properties of the environment that stay constant during simulation."""

        physical_normalizations: eqx.Module
        action_normalizations: eqx.Module
        static_params: eqx.Module

    def create_in_axes_dataclass(self, dataclass):

        def filter_function(value):
            if value is None:
                return None
            elif isinstance(value, list):
                raise ValueError(
                    f"Leaf needs to be a jnp.array to have different setting per batch, but list is given."
                )
            elif jnp.isscalar(value):
                return None
            elif isinstance(value, jax.Array):
                if value.shape[0] == self.batch_size:
                    return 0
                else:
                    return None
            else:
                raise ValueError(f"Leaf needs to be a scalar, jnp.array, but {type(value)} is given.")

        dataclass_in_axes = jax.tree.map(filter_function, dataclass)
        return dataclass_in_axes

    def repeat_values(self, x, n_repeat):
        """Repeats the values of x n_repeat times."""
        if x == None:
            return None
        elif isinstance(x, tuple):
            return tuple(self.repeat_values(i, n_repeat) for i in x)
        elif isinstance(x, jax.numpy.ndarray):
            return jnp.full(n_repeat, x)
        elif isinstance(x, float) or isinstance(x, bool):
            return jnp.full(n_repeat, x)
        else:
            raise ValueError(f"State needs to consist of jnp.array, tuple, float or bool, but {type(x)} is given.")

    @eqx.filter_jit
    def normalize_state(self, state):
        """
        Normalizes the state using predefined normalization parameters.

        Args:
            state: Current environment state.
            env_properties: Environment properties containing normalization parameters.

        Returns:
            norm_state: Normalized state.
        """
        env_properties = self.env_properties
        physical_normalizations = env_properties.physical_normalizations

        new_physical_state = jax.tree.map(
            lambda value, norm: norm.normalize(value),
            state.physical_state,
            physical_normalizations,
        )
        new_reference = jax.tree.map(
            lambda value, norm: norm.normalize(value),
            state.reference,
            physical_normalizations,
        )
        new_state = eqx.tree_at(
            lambda s: (s.physical_state, s.reference),
            state,
            (new_physical_state, new_reference),
        )

        return new_state

    @eqx.filter_jit
    def denormalize_state(self, norm_state):
        """
        Denormalizes a given normalized state.

        Args:
            norm_state: The normalized state to be converted back.
            env_properties: Environment properties containing normalization parameters.

        Returns:
            state: The denormalized state.
        """
        env_properties = self.env_properties
        physical_normalizations = env_properties.physical_normalizations

        new_physical_state = jax.tree.map(
            lambda value, norm: norm.denormalize(value),
            norm_state.physical_state,
            physical_normalizations,
        )
        new_reference = jax.tree.map(
            lambda value, norm: norm.denormalize(value),
            norm_state.reference,
            physical_normalizations,
        )
        new_state = eqx.tree_at(
            lambda s: (s.physical_state, s.reference),
            norm_state,
            (new_physical_state, new_reference),
        )

        return new_state

    @eqx.filter_jit
    def denormalize_action(self, action_norm):
        """
        Denormalizes a given normalized action.

        Args:
            action_norm: The normalized action to be denormalized.
            env_properties: Environment properties containing normalization parameters.

        Returns:
            action: The denormalized action.
        """
        env_properties = self.env_properties
        normalizations = env_properties.action_normalizations
        norm_objects = [getattr(normalizations, name) for name in normalizations.__annotations__]

        denorm_values = jnp.array([norm.denormalize(val) for norm, val in zip(norm_objects, action_norm)])

        return denorm_values

    def reset(
        self,
        rng: chex.PRNGKey = None,
        initial_state: eqx.Module = None,
        vmap_helper=None,
    ):
        """
        Resets environment to default, random or passed initial state.

        Args:
            env_properties: Environment properties.
            rng (optional): Random key for random initialization.
            initial_state (optional): The initial_state to which the environment will be reset.
            vmap_helper (optional): Helper variable for vectorized computation.

        Returns:
            obs: Observation of initial state.
            state: The initial state.
        """
        if initial_state is not None:
            assert tree_structure(self.init_state()) == tree_structure(
                initial_state
            ), f"initial_state should have the same dataclass structure as init_state()"
            state = initial_state
        else:
            state = self.init_state(rng)
        obs = self.generate_observation(state)

        return obs, state

    @eqx.filter_jit
    def step(self, state, action_norm):
        """Computes one JAX-JIT compiled simulation step for one batch.

        Args:
            state: The current state of the simulation from which to calculate the next state.
            action: The action to apply to the environment.
            env_properties: Contains action/state constraints and static parameters.

        Returns:
            observation: The gathered observation.
            state: New state for the next step.
        """
        # assert action_norm.shape == (self.action_dim,), (
        #     f"The action needs to be of shape (action_dim,) which is "
        #     + f"{(self.action_dim,)}, but {action_norm.shape} is given"
        # )

        # physical_state_shape = jnp.array(tree_flatten(state.physical_state)[0]).T.shape

        # assert physical_state_shape == (self.physical_state_dim,), (
        #     "The physical state needs to be of shape (physical_state_dim,) which is "
        #     + f"{(self.physical_state_dim,)}, but {physical_state_shape} is given"
        # )

        # denormalize action
        action = self.denormalize_action(action_norm)

        state = self._ode_solver_step(state, action)
        obs = self.generate_observation(state)

        return obs, state

    @eqx.filter_jit
    def sim_ahead(self, init_state, actions, obs_stepsize, action_stepsize):
        """Computes multiple JAX-JIT compiled simulation steps for one batch.

        The length of the set of inputs together with the action_stepsize determine the
        overall length of the simulation -> overall_time = actions.shape[0] * action_stepsize
        The actions are interpolated with zero order hold inbetween their values.

        Warning:
            Depending on the underlying ODE solver (e.g., Tsit5 or other higher-order solvers),
            intermediate evaluations during integration may internally access actions at future time steps.
            Therefore `sim_ahead` is not guaranteed to be numerically equivalent to repeated
            calls of `step`.


        Args:
            init_state: The initial state of the simulation
            actions: A set of actions to be applied to the environment, the value changes every
            action_stepsize (shape=(n_action_steps, action_dim))
            env_properties: The constant properties of the simulation
            obs_stepsize: The sampling time for the observations
            action_stepsize: The time between changes in the input/action

        Returns:
            observations: The gathered observations.
            states: The computed states during the simulated trajectory.
            last_state: The last state of the simulations.
        """

        # assert actions.ndim == 2, "The actions need to have two dimensions: (n_action_steps, action_dim)"
        # assert (
        #     actions.shape[-1] == self.action_dim
        # ), f"The last dimension does not correspond to the action dim which is {self.action_dim}, but {actions.shape[-1]} is given"

        # init_physical_state_shape = jnp.array(tree_flatten(init_state.physical_state)[0]).T.shape
        # assert init_physical_state_shape == (self.physical_state_dim,), (
        #     "The initial physical state needs to be of shape (env.physical_state_dim,) which is "
        #     + f"{(self.physical_state_dim,)}, but {init_physical_state_shape} is given"
        # )

        # denormalize actions
        actions = jax.vmap(self.denormalize_action, in_axes=(0, None))(actions)

        single_state_struct = tree_structure(init_state)
        # compute states trajectory for given actions
        states = self._ode_solver_simulate_ahead(
            init_state,
            actions,
            obs_stepsize,
            action_stepsize,
        )

        # generate observations for all timesteps
        observations = jax.vmap(self.generate_observation, in_axes=(0, None))(states)

        # get last state so that the simulation can be continued from the end point
        states_flatten, _ = tree_flatten(states)
        last_state = tree_unflatten(single_state_struct, jnp.array(states_flatten)[:, -1])

        return observations, states, last_state

    @eqx.filter_jit
    def generate_rew_trunc_term_ahead(self, states, actions):
        """
        Computes rewards, truncated flags and terminated flags for data generated by `sim_ahead`.

        Args:
            states: A set of environment states over time, including the initial state.
            actions: A set of actions applied sequentially (shape=(n_action_steps, action_dim)).
            env_properties: The environment properties required for calculations.

        Returns:
            reward: Rewards computed for each step.
            truncated: Truncated flags at each step.
            terminated : Terminated flag at each step.
        """
        # assert actions.ndim == 2, "The actions need to have two dimensions: (n_action_steps, action_dim)"
        # assert (
        #     actions.shape[-1] == self.action_dim
        # ), f"The last dimension does not correspond to the action dim which is {self.action_dim}, but {actions.shape[-1]} is given"

        actions = jax.vmap(self.denormalize_action, in_axes=(0, None))(actions)

        states_flatten, struct = tree_flatten(states)

        states_without_init_state = tree_unflatten(struct, jnp.array(states_flatten)[:, 1:])

        reward = jax.vmap(self.generate_reward, in_axes=(0, 0, None))(
            states_without_init_state,
            jnp.expand_dims(
                jnp.repeat(
                    actions,
                    int((jnp.array(states_flatten).shape[1] - 1) / actions.shape[0]),
                ),
                1,
            ),
        )
        truncated = jax.vmap(self.generate_truncated, in_axes=(0, None))(states)
        terminated = jax.vmap(self.generate_terminated, in_axes=(0, 0, None))(
            states_without_init_state, reward
        )
        return reward, truncated, terminated



In [None]:
from exciting_environments.utils import MinMaxNormalization
def pendulum_soft_constraints(instance, state, action_norm):
    state_norm = instance.normalize_state(state)
    physical_state_norm = state_norm.physical_state
    phys_soft_const = jax.tree.map(lambda _: jnp.nan, physical_state_norm)
    phys_soft_const = eqx.tree_at(
        lambda s: s.omega, phys_soft_const,
        jax.nn.relu(jnp.abs(physical_state_norm.omega) - 1.0)
    )
    act_soft_constr = jax.nn.relu(jnp.abs(action_norm) - 1.0)
    return phys_soft_const, act_soft_constr

class Pendulum(CoreEnvironment):
    control_state: list = eqx.field(static=True)
    soft_constraints_logic: Callable = eqx.field(static=True)
    """
    State Variables:
        ``['theta', 'omega']``

    Action Variable:
        ``['torque']``

    Initial State:
        Unless chosen otherwise, theta=pi and omega=0

    Example:
        >>> import jax
        >>> import jax.numpy as jnp
        >>>
        >>> import exciting_environments as excenvs
        >>> from exciting_environments import GymWrapper
        >>>
        >>> # Create the environment
        >>> pend=excenvs.Pendulum(batch_size=4)
        >>>
        >>> # Use GymWrapper for Simulation (optional)
        >>> gym_pend=GymWrapper(env=pend)
        >>>
        >>> # Reset the environment with default initial values
        >>> gym_pend.reset()
        >>>
        >>> # Perform step
        >>> obs, reward, terminated,  truncated = gym_pend.step(action=jnp.ones(4).reshape(-1,1))
        >>>

    """

    def __init__(
        self,
        physical_normalizations: dict = None,
        action_normalizations: dict = None,
        soft_constraints: Callable = None,
        static_params: dict = None,
        control_state: list = None,
        solver=diffrax.Euler(),
        tau: float = 1e-4,
    ):
        """
        Args:
            batch_size (int): Number of parallel environment simulations. Default: 8
            physical_normalizations (dict): Min and max values of the physical state of the environment for normalization.
                theta (MinMaxNormalization): Rotation angle. Default: min=-jnp.pi, max=jnp.pi
                omega (MinMaxNormalization): Angular velocity. Default: min=-10, max=10
            action_normalizations (dict): Min and max values of the input/action for normalization.
                torque (MinMaxNormalization): Maximum torque that can be applied to the system as an action. Default: min=-20, max=20
            soft_constraints (Callable): Function that returns soft constraints values for state and/or action.
            static_params (dict): Parameters of environment which do not change during simulation.
                l (float): Length of the pendulum. Default: 1
                m (float): Mass of the pendulum tip. Default: 1
                g (float): Gravitational acceleration. Default: 9.81
            control_state (list): Components of the physical state that are considered in reference tracking.
            solver (diffrax.solver): Solver used to compute state for next step.
            tau (float): Duration of one control step in seconds. Default: 1e-4.

        Note: Attributes of MinMaxNormalization of physical_normalizations and action_normalizations as well as static_params can also be
            passed as jnp.Array with the length of the batch_size to set different values per batch.
        """

        if not physical_normalizations:
            physical_normalizations = {
                "theta": MinMaxNormalization(min=jnp.array(-jnp.pi), max=jnp.array(jnp.pi)),
                "omega": MinMaxNormalization(min=jnp.array(-10), max=jnp.array(10)),
            }

        if not action_normalizations:
            action_normalizations = {"torque": MinMaxNormalization(min=jnp.array(-20), max=jnp.array(20))}
        
        if not static_params:
            static_params = {"g": jnp.array(9.81), "l": jnp.array(2), "m": jnp.array(1)}

        if not control_state:
            control_state = []
        
        logic = soft_constraints if soft_constraints else pendulum_soft_constraints
        self.soft_constraints_logic = logic
        self.control_state = control_state

        physical_normalizations = self.PhysicalState(**physical_normalizations)
        action_normalizations = self.Action(**action_normalizations)
        static_params = self.StaticParams(**static_params)

        env_properties = self.EnvProperties(
            physical_normalizations=physical_normalizations,
            action_normalizations=action_normalizations,
            static_params=static_params,
        )
        super().__init__(env_properties=env_properties, tau=tau, solver=solver)

    class PhysicalState(eqx.Module):
        """Dataclass containing the physical state of the environment."""

        theta: jax.Array
        omega: jax.Array

    class Additions(eqx.Module):
        """Dataclass containing additional information for simulation."""

        solver_state: tuple
        active_solver_state: bool

    class StaticParams(eqx.Module):
        """Dataclass containing the static parameters of the environment."""

        g: jax.Array
        l: jax.Array
        m: jax.Array

    class Action(eqx.Module):
        """Dataclass containing the action, that can be applied to the environment."""

        torque: jax.Array

    def _ode(self, t, y, args, action):
        theta, omega = y
        params = args
        d_omega = (action(t)[0] + params.l * params.m * params.g * jnp.sin(theta)) / (params.m * (params.l) ** 2)
        d_theta = omega
        d_y = d_theta, d_omega
        return d_y

    @eqx.filter_jit
    def _ode_solver_step(self, state, action):
        """Computes the next state by simulating one step.

        Args:
            state: The state from which to calculate state for the next step.
            action: The action to apply to the environment.
            static_params: Parameter of the environment, that do not change over time.

        Returns:
            next_state: The computed next state after the one step simulation.
        """
        static_params = self.env_properties.static_params
        physical_state = state.physical_state
        args = static_params

        torque = lambda t: action

        vector_field = partial(self._ode, action=torque)

        term = diffrax.ODETerm(vector_field)
        t0 = 0
        t1 = self.tau
        y0 = tuple([physical_state.theta, physical_state.omega])

        def false_fn(_):
            return self.Additions(solver_state=self._solver.init(term, t0, t1, y0, args), active_solver_state=True)

        def true_fn(_):
            return state.additions

        additions = jax.lax.cond(state.additions.active_solver_state, false_fn, true_fn, operand=None)
        y, _, _, solver_state_k1, _ = self._solver.step(term, t0, t1, y0, args, additions.solver_state, made_jump=False)

        theta_k1 = y[0]
        omega_k1 = y[1]
        theta_k1 = ((theta_k1 + jnp.pi) % (2 * jnp.pi)) - jnp.pi

        new_physical_state = self.PhysicalState(theta=theta_k1, omega=omega_k1)
        new_additions = self.Additions(solver_state=solver_state_k1, active_solver_state=True)
        new_state = eqx.tree_at(lambda s: (s.physical_state, s.additions), state, (new_physical_state, new_additions))
        return new_state

    @eqx.filter_jit
    def _ode_solver_simulate_ahead(self, init_state, actions, obs_stepsize, action_stepsize):
        """Computes multiple simulation steps for one batch.

        Args:
            init_state: The initial state of the simulation.
            actions: A set of actions to be applied to the environment, the value changes every.
            action_stepsize (shape=(n_action_steps, action_dim)).
            static_params: The constant properties of the simulation.
            obs_stepsize: The sampling time for the observations.
            action_stepsize: The time between changes in the input/action.

        Returns:
            next_states: The computed states during the multiple step simulation.
        """
        static_params = self.env_properties.static_params
        init_physical_state = init_state.physical_state
        args = static_params

        def torque(t):
            return actions[jnp.array(t / action_stepsize, int)]

        vector_field = partial(self._ode, action=torque)

        term = diffrax.ODETerm(vector_field)
        t0 = 0
        t1 = action_stepsize * actions.shape[0]
        init_physical_state_array, _ = tree_flatten(init_physical_state)
        y0 = tuple(init_physical_state_array)
        saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, 1 + int(t1 / obs_stepsize)))  #
        sol = diffrax.diffeqsolve(
            term,
            self._solver,
            t0,
            t1,
            dt0=obs_stepsize,
            y0=y0,
            args=args,
            saveat=saveat,
        )

        theta_t = sol.ys[0]
        omega_t = sol.ys[1]
        obs_len = omega_t.shape[0]
        # keep theta between -pi and pi
        theta_t = ((theta_t + jnp.pi) % (2 * jnp.pi)) - jnp.pi

        physical_states = self.PhysicalState(theta=theta_t, omega=omega_t)
        ref = self.PhysicalState(
            theta=jnp.full(obs_len, init_state.reference.theta),
            omega=jnp.full(obs_len, init_state.reference.omega),
        )
        y0 = tuple([theta_t[-1], omega_t[-1]])
        solver_state = self._solver.init(term, t1, t1 + self.tau, y0, args)
        additions = self.Additions(
            solver_state=self.repeat_values(solver_state, obs_len), active_solver_state=jnp.full(obs_len, True)
        )
        PRNGKey = jnp.full(obs_len, init_state.PRNGKey)
        return self.State(
            physical_state=physical_states,
            PRNGKey=PRNGKey,
            additions=additions,
            reference=ref,
        )

    @eqx.filter_jit
    def init_state(self, rng: chex.PRNGKey = None, vmap_helper=None):
        env_properties = self.env_properties
        """Returns default or random initial state for one batch."""
        if rng is None:
            phys = self.PhysicalState(
                theta=jnp.array(1.0),
                omega=jnp.array(0.0),
            )
            subkey = jnp.nan
        else:
            state_norm = jax.random.uniform(rng, minval=-1, maxval=1, shape=(2,))
            phys = self.PhysicalState(
                theta=state_norm[0],
                omega=state_norm[1],
            )
            key, subkey = jax.random.split(rng)

        torque = lambda t: jnp.array([0])

        args = env_properties.static_params

        vector_field = partial(self._ode, action=torque)

        term = diffrax.ODETerm(vector_field)
        t0 = 0
        t1 = self.tau
        y0 = tuple([phys.theta, phys.omega])

        solver_state = self._solver.init(term, t0, t1, y0, args)
        #dummy_solver_state = jax.tree.map(lambda x: x * jnp.nan, solver_state)
        dummy_solver_state = jax.tree.map(lambda x: jnp.full_like(x, jnp.nan) if jnp.issubdtype(x.dtype, jnp.floating) else x, solver_state)

        additions = self.Additions(solver_state=dummy_solver_state, active_solver_state=False)
        ref = self.PhysicalState(theta=jnp.nan, omega=jnp.nan)
        norm_state = self.State(physical_state=phys, PRNGKey=subkey, additions=additions, reference=ref)
        return self.denormalize_state(norm_state)

    @eqx.filter_jit
    def generate_reward(self, state, action):
        """Returns reward for one batch."""
        reward = 0
        norm_state = self.normalize_state(state)
        for name in self.control_state:
            if name == "theta":
                theta = getattr(state.physical_state, name)
                theta_ref = getattr(state.reference, name)
                reward += -((jnp.sin(theta) - jnp.sin(theta_ref)) ** 2 + (jnp.cos(theta) - jnp.cos(theta_ref)) ** 2)
            else:
                reward += -((getattr(norm_state.physical_state, name) - getattr(norm_state.reference, name)) ** 2)
        return jnp.array([reward])

    @eqx.filter_jit
    def generate_observation(self, state):
        """Returns observation for one batch."""
        norm_state = self.normalize_state(state)
        norm_state_phys = norm_state.physical_state
        obs = jnp.hstack(
            (
                norm_state_phys.theta,
                norm_state_phys.omega,
            )
        )
        for name in self.control_state:
            obs = jnp.hstack(
                (
                    obs,
                    getattr(norm_state.reference, name),
                )
            )
        return obs

    @eqx.filter_jit
    def generate_state_from_observation(self, obs, key=None):
        env_properties = self.env_properties
        """Generates state from observation for one batch."""
        phys = self.PhysicalState(
            theta=obs[0],
            omega=obs[1],
        )
        if key is not None:
            subkey = key
        else:
            subkey = jnp.nan

        torque = lambda t: jnp.array([0])

        args = env_properties.static_params

        vector_field = partial(self._ode, action=torque)

        term = diffrax.ODETerm(vector_field)
        t0 = 0
        t1 = self.tau
        y0 = tuple([phys.theta, phys.omega])

        solver_state = self._solver.init(term, t0, t1, y0, args)

        dummy_solver_state = jax.tree.map(lambda x: x * jnp.nan, solver_state)

        additions = self.Additions(solver_state=dummy_solver_state, active_solver_state=False)  # None
        ref = self.PhysicalState(theta=jnp.nan, omega=jnp.nan)
        new_ref = ref
        for i, name in enumerate(self.control_state):
            new_ref = eqx.tree_at(lambda r: getattr(r, name), new_ref, obs[2 + i])
        norm_state = self.State(physical_state=phys, PRNGKey=subkey, additions=additions, reference=new_ref)
        return self.denormalize_state(norm_state)
    
    def soft_constraints(self, state, action_norm):
        return self.soft_constraints_logic(self, state, action_norm)

    @eqx.filter_jit
    def generate_truncated(self, state):
        """Returns truncated information for one batch."""
        obs = self.generate_observation(state)
        return jnp.abs(obs) > 1

    @eqx.filter_jit
    def generate_terminated(self, state, reward):
        """Returns terminated information for one batch."""
        return reward == 0

    @property
    def obs_description(self):
        return np.hstack(
            [
                np.array(["theta", "omega"]),
                np.array([name + "_ref" for name in self.control_state]),
            ]
        )

    @property
    def action_description(self):
        return np.array(["torque"])


In [131]:
pend_env = Pendulum(solver=diffrax.Heun())

In [128]:
obs, state = pend_env.reset(jax.random.PRNGKey(123))

In [None]:
key=jax.random.PRNGKey(1234)
obs, state = pend_env.reset(key)
generated_observations = []
generated_actions= []
generated_observations.append(obs)
for i in range(10000):
    key,subkey= jax.random.split(key)
    action = jax.random.uniform(subkey,(1,),minval=-1,maxval=1)
    obs, state = pend_env.step(state, action)
    generated_actions.append(action)
    generated_observations.append(obs)
generated_observations = jnp.array(generated_observations)
generated_actions = jnp.array(generated_actions)

In [None]:
plt.plot(jnp.sin(generated_observations[:,0]),-jnp.cos(generated_observations[:,0]))

In [12]:
keys = jax.random.split(jax.random.PRNGKey(0), 3)
envs = [Pendulum(solver=diffrax.Tsit5(),batch_size=1,static_params={"g": jnp.array(9.81), "l": jnp.array(float(i+1)),  "m": jnp.array(1.0)}) for i in range(4)]

trainables, statics = zip(*(eqx.partition(m, eqx.is_array) for m in envs))

batched_trainables = jax.tree.map(lambda *xs: jnp.stack(xs), *trainables)

batched_module = eqx.combine(batched_trainables, statics[0])

actions = jnp.ones((4, 1))
key=jax.random.PRNGKey(1234)
keys = jax.random.split(key,num=4)
obs, states = batched_module.reset(keys)

#obs, states = batched_module.step(states,jnp.array([0,2,5,1]))

ValueError: Mismatch custom node data: (<equinox._module._flatten._Missing object at 0x7f28f320b3d0>, Tsit5(), functools.partial(<function pendulum_soft_constraints at 0x7f271c6245e0>, Partial(
  func=_JitWrapper(
    fn='CoreEnvironment.normalize_state',
    filter_warning=False,
    donate_first=False,
    donate_rest=False
  ),
  args=(
    Pendulum(
      batch_size=1,
      tau=0.0001,
      _solver=Tsit5(),
      env_properties=CoreEnvironment.EnvProperties(
        physical_normalizations=Pendulum.PhysicalState(
          theta=MinMaxNormalization(min=weak_f64[], max=weak_f64[]),
          omega=MinMaxNormalization(min=weak_i64[], max=weak_i64[])
        ),
        action_normalizations=Pendulum.Action(
          torque=MinMaxNormalization(min=weak_i64[], max=weak_i64[])
        ),
        static_params=Pendulum.StaticParams(
          g=weak_f64[], l=weak_f64[], m=weak_f64[]
        )
      ),
      in_axes_env_properties=CoreEnvironment.EnvProperties(
        physical_normalizations=Pendulum.PhysicalState(
          theta=MinMaxNormalization(min=None, max=None),
          omega=MinMaxNormalization(min=None, max=None)
        ),
        action_normalizations=Pendulum.Action(
          torque=MinMaxNormalization(min=None, max=None)
        ),
        static_params=Pendulum.StaticParams(g=None, l=None, m=None)
      ),
      action_dim=1,
      physical_state_dim=2,
      control_state=[],
      soft_constraints=partial(
        <function pendulum_soft_constraints>, <recursive>
      )
    ),
  ),
  keywords={}
))) != (<equinox._module._flatten._Missing object at 0x7f28f320b3d0>, Tsit5(), functools.partial(<function pendulum_soft_constraints at 0x7f271c6245e0>, Partial(
  func=_JitWrapper(
    fn='CoreEnvironment.normalize_state',
    filter_warning=False,
    donate_first=False,
    donate_rest=False
  ),
  args=(
    Pendulum(
      batch_size=1,
      tau=0.0001,
      _solver=Tsit5(),
      env_properties=CoreEnvironment.EnvProperties(
        physical_normalizations=Pendulum.PhysicalState(
          theta=MinMaxNormalization(min=weak_f64[], max=weak_f64[]),
          omega=MinMaxNormalization(min=weak_i64[], max=weak_i64[])
        ),
        action_normalizations=Pendulum.Action(
          torque=MinMaxNormalization(min=weak_i64[], max=weak_i64[])
        ),
        static_params=Pendulum.StaticParams(
          g=weak_f64[], l=weak_f64[], m=weak_f64[]
        )
      ),
      in_axes_env_properties=CoreEnvironment.EnvProperties(
        physical_normalizations=Pendulum.PhysicalState(
          theta=MinMaxNormalization(min=None, max=None),
          omega=MinMaxNormalization(min=None, max=None)
        ),
        action_normalizations=Pendulum.Action(
          torque=MinMaxNormalization(min=None, max=None)
        ),
        static_params=Pendulum.StaticParams(g=None, l=None, m=None)
      ),
      action_dim=1,
      physical_state_dim=2,
      control_state=[],
      soft_constraints=partial(
        <function pendulum_soft_constraints>, <recursive>
      )
    ),
  ),
  keywords={}
))); value: Pendulum(
  batch_size=None,
  tau=None,
  _solver=Tsit5(),
  env_properties=CoreEnvironment.EnvProperties(
    physical_normalizations=Pendulum.PhysicalState(
      theta=MinMaxNormalization(min=weak_f64[], max=weak_f64[]),
      omega=MinMaxNormalization(min=weak_i64[], max=weak_i64[])
    ),
    action_normalizations=Pendulum.Action(
      torque=MinMaxNormalization(min=weak_i64[], max=weak_i64[])
    ),
    static_params=Pendulum.StaticParams(
      g=weak_f64[], l=weak_f64[], m=weak_f64[]
    )
  ),
  in_axes_env_properties=CoreEnvironment.EnvProperties(
    physical_normalizations=Pendulum.PhysicalState(
      theta=MinMaxNormalization(min=None, max=None),
      omega=MinMaxNormalization(min=None, max=None)
    ),
    action_normalizations=Pendulum.Action(
      torque=MinMaxNormalization(min=None, max=None)
    ),
    static_params=Pendulum.StaticParams(g=None, l=None, m=None)
  ),
  action_dim=None,
  physical_state_dim=None,
  control_state=[],
  soft_constraints=partial(
    <function pendulum_soft_constraints>,
    Partial(
      func=_JitWrapper(
        fn='CoreEnvironment.normalize_state',
        filter_warning=False,
        donate_first=False,
        donate_rest=False
      ),
      args=(
        Pendulum(
          batch_size=1,
          tau=0.0001,
          _solver=Tsit5(),
          env_properties=CoreEnvironment.EnvProperties(
            physical_normalizations=Pendulum.PhysicalState(
              theta=MinMaxNormalization(min=weak_f64[], max=weak_f64[]),
              omega=MinMaxNormalization(min=weak_i64[], max=weak_i64[])
            ),
            action_normalizations=Pendulum.Action(
              torque=MinMaxNormalization(min=weak_i64[], max=weak_i64[])
            ),
            static_params=Pendulum.StaticParams(
              g=weak_f64[], l=weak_f64[], m=weak_f64[]
            )
          ),
          in_axes_env_properties=CoreEnvironment.EnvProperties(
            physical_normalizations=Pendulum.PhysicalState(
              theta=MinMaxNormalization(min=None, max=None),
              omega=MinMaxNormalization(min=None, max=None)
            ),
            action_normalizations=Pendulum.Action(
              torque=MinMaxNormalization(min=None, max=None)
            ),
            static_params=Pendulum.StaticParams(g=None, l=None, m=None)
          ),
          action_dim=1,
          physical_state_dim=2,
          control_state=[],
          soft_constraints=<recursive>
        ),
      ),
      keywords={}
    )
  )
).

In [132]:

keys = jax.random.split(jax.random.PRNGKey(0), 2)
#envs = [Pendulum(solver=diffrax.Euler(),batch_size=1,static_params={"g": jnp.array(9.81), "l": jnp.array(float(1+i)),  "m": jnp.array(1.0)}) for i in range(4)]
envs = [Pendulum(solver=diffrax.Euler(),control_state=["theta"],tau=1e-2,static_params={"g": jnp.array(9.81), "l": jnp.array(float(1)),  "m": jnp.array(1.0)}),Pendulum(solver=diffrax.Euler(),control_state=["theta"],tau=1e-3,static_params={"g": jnp.array(9.81), "l": jnp.array(float(2)),  "m": jnp.array(1.0)})]
batched_envs = jax.tree.map(lambda *args: jnp.stack(args), *envs)
#batched_envs.reset(rng=jax.random.PRNGKey(321))

TypeError: stack requires ndarray or scalar arguments, got <class 'function'> at position 0.

In [119]:
keys = jax.random.split(jax.random.PRNGKey(0), 2)
obs, states = jax.vmap(lambda e,k: e.reset(k))(batched_envs,keys)

In [99]:
actions = jnp.ones((4,1))

next_obs, next_states = jax.vmap(
    lambda e, s, a: e.step(s, a)
)(batched_envs, states, actions)
next_obs

Array([[ 0.93977834, -0.36941787],
       [-0.83848137,  0.84792387],
       [ 0.30408614,  0.75913929],
       [ 0.50112481, -0.04066952]], dtype=float64)

In [None]:
key=jax.random.PRNGKey(123)
obs, state = pend_env.reset(key)

In [None]:

import equinox as eqx
import jax
import jax.numpy as jnp

envs = [Pendulum(batch_size=1,static_params={"g": 9.81, "l": float(i+1), "m": 1.0}) for i in range(4)]
actions = jnp.ones((4, 1))  # eine Action pro Environment
key=jax.random.PRNGKey(1234)
obs, state = pend_env.reset(key)
# vmap 端ber Module (0. Dimension) und Actions
# obs, states = jax.vmap(lambda env, a: env.step(state, a))(envs, actions)
keys = jax.random.split(key,num=4)
states = jax.vmap(lambda e, k: e.reset(k)[1])(envs, keys)
obs_batch, states_batch = jax.vmap(lambda e, s, a: e.step(s, a))(envs, states, actions)

In [None]:
class MyModule(eqx.Module):
    weight: jax.Array
    bias: jax.Array
    config: dict = eqx.field(static=True)  # static, nicht gemappt

# Erzeuge mehrere Module
mods = [MyModule(weight=jnp.array(i), bias=jnp.array(i*2),config={}) for i in range(4)]


# vmap 端ber die batch dimension
def forward(mod, x):
    return mod.weight * x + mod.bias

xs = jnp.ones(4)
out = jax.vmap(forward)(mods, xs)
print(out)

In [None]:
envs = [
    Pendulum(batch_size=1, static_params={"g": 9.81, "l": float(i+1), "m": 1.0})
    for i in range(4)
]

# 2. Partitioniere alles in Arrays vs. Rest
trainables_list, statics_list = zip(*(eqx.partition(env, eqx.is_array) for env in envs))

# 3. Jetzt nur die Arrays selbst stacken, keine Sub-Module!
def stack_arrays(*args):
    return jax.tree_map(lambda *x: jnp.stack(x), *args)

batched_trainables = stack_arrays(*trainables_list)

# 4. Einen "batched Module" erzeugen
batched_env = eqx.combine(batched_trainables, statics_list[0])  # statics[0] reicht

# 5. Actions vorbereiten
actions = jnp.ones((4, 1))
key = jax.random.PRNGKey(1234)

# 6. Reset (funktioniert f端r batched_env)
obs, state = batched_env.reset( key)

# 7. vmap 端ber Actions
batched_step = jax.vmap(lambda a: batched_env.step(state, a))
obs, states = batched_step(actions)

print("obs:", obs)
print("states:", states)

In [None]:
import equinox as eqx
import jax
import jax.numpy as jnp

class SimpleModule(eqx.Module):
    weight: jnp.ndarray
    bias: jnp.ndarray

    def __init__(self, key):
        k1, k2 = jax.random.split(key)
        self.weight = jax.random.normal(k1, (3, 3))
        self.bias = jax.random.normal(k2, (3,))

    def __call__(self, x):
        return jnp.dot(self.weight, x) + self.bias


keys = jax.random.split(jax.random.PRNGKey(0), 3)
modules = [SimpleModule(k) for k in keys]

trainables, statics = zip(*(eqx.partition(m, eqx.is_array) for m in modules))

batched_trainables = jax.tree.map(lambda *xs: jnp.stack(xs), *trainables)

batched_module = eqx.combine(batched_trainables, statics[0])

x_inputs = jnp.ones((3, 3)) 
batched_apply = jax.vmap(lambda m, x: jnp.sum(m(x)), in_axes=(None, 0))

results = batched_apply(batched_module, x_inputs)

print("Results:", results)   # shape (3,)

In [None]:
import equinox as eqx

# envs = [Pendulum(...), Pendulum(...), ...]
batched_env = eqx.combine(envs)

actions = jnp.ones((4, 1))  # eine Action pro Env

# state muss ggf. auch gebatched sein
obs, states = jax.vmap(lambda env, a: env.step(state, a))(
    batched_env, actions
)