In [1]:
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure
import jax_dataclasses as jdc
import diffrax
import chex
from functools import partial
from abc import ABC
from abc import abstractmethod
from exciting_environments import spaces
from dataclasses import fields

class CoreEnvironment(ABC):
    """
    Description:
        Core Structure of provided Environments.

    """

    def __init__(
            self,
            batch_size: int,
            physical_constraints,
            action_constraints,
            static_params,
            tau: float = 1e-4,
            solver=diffrax.Euler(),
            reward_func=None
    ):
        """
        Args:
            batch_size(int): Number of training examples utilized in one iteration.
            physical_constraints(jdc.pytree_dataclass): Constraints of physical states of the environment.
            action_constraints(jdc.pytree_dataclass): Constraints of actions.
            static_params(jdc.pytree_dataclass): Parameters of environment which do not change during simulation.
            tau(float): Duration of one control step in seconds. Default: 1e-4.
            solver(diffrax.solver): Solver used to compute states for next step.
            reward_func(function): Reward function for training. Needs observation vector, action and action_constraints as Parameters. 
                                    Default: None (default_reward_func from class) 
        """
        self.batch_size = batch_size
        self.tau = tau
        self._solver = solver
        self.env_properties = self.EnvProperties(
            physical_constraints=physical_constraints, action_constraints=action_constraints, static_params=static_params)
        self.in_axes_env_properties = self.create_in_axes_dataclass(self.env_properties)
        if reward_func:
            if self._test_rew_func(reward_func):
                self.reward_func = reward_func
        else:
            self.reward_func = self.default_reward_func

    @property
    def default_reward_function(self):
        """Returns the default reward function for the given environment."""
        return self.default_reward_func

    @abstractmethod
    @jdc.pytree_dataclass
    class PhysicalStates:
        pass

    @abstractmethod
    @jdc.pytree_dataclass
    class Optional:
        pass

    @abstractmethod
    @jdc.pytree_dataclass
    class StaticParams:
        pass

    @abstractmethod
    @jdc.pytree_dataclass
    class Actions:
        pass

    @jdc.pytree_dataclass
    class States:
        """Dataclass used for simulation which contains environment specific dataclasses."""
        physical_states: jdc.pytree_dataclass
        PRNGKey: jax.Array
        optional: jdc.pytree_dataclass

    @jdc.pytree_dataclass
    class EnvProperties:
        """Dataclass used for simulation which contains environment specific dataclasses."""
        physical_constraints: jdc.pytree_dataclass
        action_constraints: jdc.pytree_dataclass
        static_params: jdc.pytree_dataclass

    def create_in_axes_dataclass(self,dataclass):
        with jdc.copy_and_mutate(dataclass,validate=False) as dataclass_in_axes:
            for field in fields(dataclass_in_axes):
                name=field.name
                value=getattr(dataclass_in_axes, name)
                if jdc.is_dataclass(value):
                    setattr(dataclass_in_axes,name,self.create_in_axes_dataclass(value))
                elif jnp.isscalar(value):
                    setattr(dataclass_in_axes,name,None)
                else:
                    assert len(
                        value) == self.batch_size, f"{name} is expected to be a scalar a pytree_dataclass or a jnp.Array with len(jnp.Array)=batch_size={self.batch_size}"
                    setattr(dataclass_in_axes,name,0)
        return dataclass_in_axes

    @partial(jax.jit, static_argnums=0)
    def step(self, states, action, env_properties):
        """Computes one simulation step for one batch.

        Args:
            states: The states from which to calculate states for the next step.
            action: The action to apply to the environment.
            env_properties: Contains action/state constraints and static parameter.

        Returns:
            Multiple Outputs:

            observation: The gathered observation.
            reward: Amount of reward received for the last step.
            terminated: Flag, indicating if Agent has reached the terminal state.
            truncated: Flag, e.g. indicating if state has gone out of bounds.
            states: New states for the next step.
        """

        # ode step
        states = self._ode_solver_step(
            states, action, env_properties.static_params
        )

        # observation
        obs = self.generate_observation(
            states, env_properties.physical_constraints
        )

        # reward
        reward = self.reward_func(
            obs, action, env_properties.action_constraints
        )

        # bound check
        truncated = self.generate_truncated(
            states, env_properties.physical_constraints
        )

        terminated = self.generate_terminated(
            states, reward
        )

        return obs, reward, terminated, truncated, states

    @partial(jax.jit, static_argnums=0)
    def vmap_step(self, action, states):
        """JAX jit compiled and vmapped step for batch_size of environment.

        Args:
            states: The states from which to calculate states for the next step.
            action: The action to apply to the environment.
            env_properties: Contains action/state constraints and static parameters.


        Returns:
            Multiple Outputs:

            observation: The gathered observations (shape=(batch_size,obs_dim)).
            reward: Amount of reward received for the last step (shape=(batch_size,1)).
            terminated: Flag, indicating if Agent has reached the terminal state (shape=(batch_size,1)).
            truncated: Flag, indicating if state has gone out of bounds (shape=(batch_size,states_dim)).
            states: New states for the next step.

        """
        # vmap single operations
        obs, reward, terminated, truncated, states = jax.vmap(self.step, in_axes=(0, 0, self.in_axes_env_properties))(
            states, action, self.env_properties
        )

        return obs, reward, terminated, truncated, states

    # @partial(jax.jit, static_argnums=0)
    # def vmap_simulate_ahead(self, actions, init_states, init_obs):

    #     def body_fun(carry, action):
    #         obs, states = carry

    #         obs, _, _, _, states = self.vmap_step(
    #             action.reshape(-1, 1),
    #             states
    #         )
    #         return (obs, states), obs

    #     (_, _), observations = jax.lax.scan(
    #         body_fun, (init_obs, init_states), actions.T)
    #     observations = jnp.concatenate(
    #         [init_obs[None, :], observations], axis=0)

    #     return observations

    @property
    @abstractmethod
    def obs_description(self):
        """Returns a list of state names of all states in the observation."""
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def default_reward_func(self, obs, action):
        """Returns the default RewardFunction of the environment."""
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def generate_observation(self, states):
        """Returns observation."""
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def generate_truncated(self, states):
        """Returns truncated information."""
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def generate_terminated(self, states, reward):
        """Returns terminated information."""
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def _ode_solver_step(self, states_norm, action_norm, state_normalizer,  action_normalizer, params):
        """Computes states by simulating one step.

        Args:
            states: The states from which to calculate states for the next step.
            action: The action to apply to the environment.
            static_params: Parameter of the environment, that do not change over time.

        Returns:
            states: The computed states after the one step simulation.
        """

        return

    @abstractmethod
    def reset(self, rng: chex.PRNGKey = None, initial_states: jdc.pytree_dataclass = None):
        """Resets environment to default or passed initial values."""
        return

In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure
import jax_dataclasses as jdc
import chex
from functools import partial
import diffrax
from exciting_environments import core_env


class Pendulum(CoreEnvironment):
    """
    State Variables:
        ``['theta', 'omega']``

    Action Variable:
        ``['torque']``

    Initial State:
        Unless chosen otherwise, theta equals pi and omega is set to zero.

    Example:
        >>> import jax
        >>> import exciting_environments as excenvs
        >>> from exciting_environments import GymWrapper
        >>> 
        >>> # Create the environment
        >>> pend=excenv.Pendulum(batch_size=4,action_constraints={"torque":10})
        >>> 
        >>> # Use GymWrapper for Simulation (optional)
        >>> gym_pend=GymWrapper(env=pend)
        >>> 
        >>> # Reset the environment with default initial values
        >>> gym_pend.reset()
        >>> 
        >>> # Perform step
        >>> obs,reward,terminated,truncated,info= gym_pend.step(action=jnp.ones(4).reshape(-1,1))
        >>> 

    """

    def __init__(
        self,
        batch_size: int = 8,
        physical_constraints: dict = None,
        action_constraints: dict = None,
        static_params: dict = None,
        solver=diffrax.Euler(),
        reward_func=None,
        tau: float = 1e-4,
    ):
        """
        Args:
            batch_size(int): Number of training examples utilized in one iteration. Default: 8
            physical_constraints(dict): Constraints of physical states of the environment.
                theta(float): Rotation angle. Default: jnp.pi
                omega(float): Angular velocity. Default: 10
            action_constraints(dict): Constraints of actions.
                torque(float): Maximum torque that can be applied to the system as action. Default: 20 
            static_params(dict): Parameters of environment which do not change during simulation.
                l(float): Length of the pendulum. Default: 1
                m(float): Mass of the pendulum tip. Default: 1
                g(float): Gravitational acceleration. Default: 9.81
            solver(diffrax.solver): Solver used to compute states for next step.
            reward_func(function): Reward function for training. Needs observation vector, action and action_constraints as Parameters. 
                                    Default: None (default_reward_func from class)
            tau(float): Duration of one control step in seconds. Default: 1e-4.

        Note: Attributes of physical_constraints, action_constraints and static_params can also be passed as jnp.Array with the length of the batch_size to set different values per batch.  
        """

        if not physical_constraints:
            physical_constraints = {"theta": jnp.pi, "omega": 10}
        if not action_constraints:
            action_constraints = {"torque": 20}

        if not static_params:
            static_params = {"g": 9.81, "l": 2, "m": 1}

        physical_constraints = self.PhysicalStates(**physical_constraints)
        action_constraints = self.Actions(**action_constraints)
        static_params = self.StaticParams(**static_params)

        super().__init__(batch_size, physical_constraints, action_constraints, static_params, tau=tau,
                         solver=solver, reward_func=reward_func)

    @jdc.pytree_dataclass
    class PhysicalStates:
        """Dataclass containing the physical states of the environment."""
        theta: jax.Array
        omega: jax.Array

    @jdc.pytree_dataclass
    class Optional:
        """Dataclass containing additional information for simulation."""
        something: jax.Array

    @jdc.pytree_dataclass
    class StaticParams:
        """Dataclass containing the static parameters of the environment."""
        g: jax.Array
        l: jax.Array
        m: jax.Array

    @jdc.pytree_dataclass
    class Actions:
        """Dataclass containing the actions, that can be applied to the environment."""
        torque: jax.Array

    @partial(jax.jit, static_argnums=0)
    def _ode_solver_step(self, states, action, static_params):
        """Computes states by simulating one step.

        Args:
            states: The states from which to calculate states for the next step.
            action: The action to apply to the environment.
            static_params: Parameter of the environment, that do not change over time.

        Returns:
            states: The computed states after the one step simulation.
        """

        env_states = states.physical_states
        args = (action, static_params)

        def vector_field(t, y, args):
            theta, omega = y
            action, params = args
            d_omega = (action[0]+params.l*params.m*params.g
                       * jnp.sin(theta)) / (params.m * (params.l)**2)
            d_theta = omega
            d_y = d_theta, d_omega
            return d_y

        term = diffrax.ODETerm(vector_field)
        t0 = 0
        t1 = self.tau
        y0 = tuple([env_states.theta, env_states.omega])
        env_state = self._solver.init(term, t0, t1, y0, args)
        y, _, _, env_state, _ = self._solver.step(
            term, t0, t1, y0, args, env_state, made_jump=False)

        theta_k1 = y[0]
        omega_k1 = y[1]
        theta_k1 = ((theta_k1+jnp.pi) % (2*jnp.pi))-jnp.pi

        phys = self.PhysicalStates(
            theta=theta_k1, omega=omega_k1)
        opt = None  # Optional(something=...)
        return self.States(physical_states=phys, PRNGKey=None, optional=None)

    @partial(jax.jit, static_argnums=0)
    def init_states(self):
        """Returns default initial states for all batches."""
        phys = self.PhysicalStates(theta=jnp.full(
            self.batch_size, jnp.pi), omega=jnp.zeros(self.batch_size))
        opt = None  # self.Optional(something=jnp.zeros(self.batch_size))
        return self.States(physical_states=phys, PRNGKey=None, optional=opt)

    @partial(jax.jit, static_argnums=0)
    def default_reward_func(self, obs, action, action_constraints):
        """Returns reward for one batch."""
        reward = ((obs[0])**2 + 0.1*(obs[1])**2
                  + 0.1 * (action[0]/action_constraints.torque)**2)
        return jnp.array([reward])

    @partial(jax.jit, static_argnums=0)
    def generate_observation(self, states, physical_constraints):
        """Returns observation for one batch."""
        obs = jnp.hstack((
            states.physical_states.theta / physical_constraints.theta,
            states.physical_states.omega / physical_constraints.omega,
        ))
        return obs

    @property
    def obs_description(self):
        return np.array(["theta", "omega"])

    @partial(jax.jit, static_argnums=0)
    def generate_truncated(self, states, physical_constraints):
        """Returns truncated information for one batch."""
        _states = jnp.hstack((
            states.physical_states.theta / physical_constraints.theta,
            states.physical_states.theta / physical_constraints.omega,
        ))
        return jnp.abs(_states) > 1

    @partial(jax.jit, static_argnums=0)
    def generate_terminated(self, states, reward):
        """Returns terminated information for one batch."""
        return reward == 0

    def reset(self, rng: chex.PRNGKey = None, initial_states: jdc.pytree_dataclass = None):
        """Resets environment to default or passed initial states."""
        if initial_states is not None:
            assert tree_structure(self.init_states()) == tree_structure(initial_states), (
                f"initial_states should have the same dataclass structure as self.init_states()"
            )
            states = initial_states
        else:
            states = self.init_states()

        obs = jax.vmap(self.generate_observation, in_axes=(0, self.in_axes_env_properties.physical_constraints))(
            states, self.env_properties.physical_constraints
        )

        return obs, states


In [3]:
pend=Pendulum()

In [5]:
jnp.tile(jnp.array([65]), (4, 1))

Array([[65],
       [65],
       [65],
       [65]], dtype=int32)

In [116]:
from exciting_environments import GymWrapper

In [117]:
pend=Pendulum(batch_size=5,action_constraints={"torque":jnp.array([10,20,30,40,50])})
gym_pend=GymWrapper(env=pend)

In [118]:
gym_pend.reset()

(Array([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]], dtype=float32),
 {})

In [119]:
gym_pend.step(action=jnp.ones(5).reshape(-1,1))

(Array([[-1.0000000e+00,  2.4999996e-05],
        [-1.0000000e+00,  4.9999999e-05],
        [-1.0000000e+00,  7.4999996e-05],
        [-1.0000000e+00,  9.9999997e-05],
        [-1.0000000e+00,  1.2500001e-04]], dtype=float32),
 Array([[1.1],
        [1.1],
        [1.1],
        [1.1],
        [1.1]], dtype=float32),
 Array([[False],
        [False],
        [False],
        [False],
        [False]], dtype=bool),
 Array([[False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False]], dtype=bool))