In [55]:
import jax.numpy as jnp
import numpy as np
import time
import chex
import jax
import gymnasium as gym
import random
import matplotlib.pyplot as plt
import diffrax
from collections import OrderedDict
from flax.core import FrozenDict
import jax_dataclasses as jdc

In [57]:
import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
import chex
from abc import ABC
from abc import abstractmethod
from exciting_environments import spaces
from exciting_environments.core_env import CoreEnvironment
import diffrax
from exciting_environments.registration import make
from collections import OrderedDict


class GymWrapper(ABC):

    def __init__(self, env):

        self.env = env
        self.states = jnp.array([jnp.full(self.env.batch_size,jnp.pi),jnp.full(self.env.batch_size,0)]).T

        # self.action_space = spaces.Box(
        #     low=-1.0, high=1.0, shape=(self.env.batch_size, len(list(self.env.env_max_actions.values()))), dtype=jnp.float32)

        # self.env_observation_space = spaces.Box(
        #     low=-1.0, high=1.0, shape=(self.env.batch_size, len(list(self.env.env_state_constraints.values()))), dtype=jnp.float32)

    @classmethod
    def fromName(cls, env_id: str, **env_kwargs):
        env = make(env_id, **env_kwargs)
        return cls(env)

    def step(self, action):
        """Perform one simulation step of the environment with an action of the action space.

        Args:
            action: Action to play on the environment.

        Returns:
            Multiple Outputs:

            observation(ndarray(float)): Observation/State Matrix (shape=(batch_size,states)).

            reward(ndarray(float)): Amount of reward received for the last step (shape=(batch_size,1)).

            terminated(bool): Flag, indicating if Agent has reached the terminal state.

            truncated(ndarray(bool)): Flag, indicating if state has gone out of bounds (shape=(batch_size,states)).

            {}: An empty dictionary for consistency with the OpenAi Gym interface.
        """

        obs, reward, terminated, truncated, self.states = self.gym_step(
            action, self.states)

        return obs, reward, terminated, truncated, {}

    @partial(jax.jit, static_argnums=0)
    def gym_step(self, action, states):

        # denormalize action
        action = action*jnp.array((self.env.env_max_action))

        # action shape from array to dict
        states=self.env.states_array_to_dataclass(states)

        obs, reward, terminated, truncated, states = self.env.step(
            action, states)
        
        states=self.env.states_dataclass_to_array(states)

        return obs, reward, terminated, truncated, states

    def reset(self, random_key: chex.PRNGKey = None, initial_values: jnp.ndarray = jnp.array([])):

        if random_key != None:
            states_mat = self.env_observation_space.sample(
                random_key)*jnp.array(list(self.env.env_state_constraints.values())).T
            self.states = {name: states_mat[:, idx] for name, idx in zip(
                self.env.env_states_name, range(states_mat.shape[1]))}

        else:
            self.states = self.env.reset(initial_values=initial_values)

        obs = self.env.generate_observation(
            self.states, self.env.env_state_constraints)
        return obs, {}

    def render(self, *_, **__):
        """
        Update the visualization of the environment.

        NotImplemented
        """
        raise NotImplementedError("To be implemented!")

    def close(self):
        """Called when the environment is deleted.

        NotImplemented
        """
        raise NotImplementedError("To be implemented!")

    # def sim_paras(self, env_state_constraints, max_action):
    #     """Creates or updates parameters,variables,spaces,etc. to fit batch_size.

    #     Creates/Updates:
    #         action_space: Space for applied actions.
    #         observation_space: Space for system states.
    #         env_state_normalizer: Environment State normalizer to normalize and denormalize states of the environment to implement physical equations with actual values.
    #         action_normalizer: Action normalizer to normalize and denormalize actions to implement physical equations with actual values.
    #     """
    #     action_space = spaces.Box(
    #         low=-1.0, high=1.0, shape=(self.batch_size, len(max_action)), dtype=jnp.float32)

    #     env_observation_space = spaces.Box(
    #         low=-1.0, high=1.0, shape=(self.batch_size, len(env_state_constraints)), dtype=jnp.float32)

    #     return env_observation_space, action_space


In [58]:
@jdc.pytree_dataclass
class Optional:
    something: jax.Array

In [60]:
@jdc.pytree_dataclass
class PhysicalStates:
    theta: jax.Array
    omega: jax.Array

In [61]:
from flax.core import FrozenDict
import jax_dataclasses as jdc
@jdc.pytree_dataclass
class States:
    physical_states: jax.Array
    PRNGKey: jax.Array
    optional: jax.Array

In [63]:
import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
import chex
from abc import ABC
from abc import abstractmethod
from exciting_environments import spaces
import diffrax
from collections import OrderedDict


class CoreEnvironment(ABC):
    """
    Description:
        Core Structure of provided Environments.

    State Variables:
        Each environment has got a list of state variables that are defined by the physical system represented.

        Example:
            ``['theta', 'omega']``

    Action Variable:
        Each environment has got an action which is applied to the physical system represented.

        Example:
            ``['torque']``

    Observation Space(State Space):
        Type: Box()
            The Observation Space is nothing but the State Space of the pyhsical system.
            This Space is a normalized, continious, multidimensional box in [-1,1].

    Action Space:
        Type: Box()
            The action space of the environments are the action spaces of the physical systems.
            This Space is a continious, multidimensional box. 


    Initial State:
        Initial state values depend on the physical system.

    """

    def __init__(self, batch_size: int, tau: float = 1e-4, solver=diffrax.Euler(), reward_func=None):
        """
        Args:
            batch_size(int): Number of training examples utilized in one iteration.
            tau(float): Duration of one control step in seconds. Default: 1e-4.
        """
        self.batch_size = batch_size
        self.tau = tau
        self._solver = solver

        if reward_func:
            if self._test_rew_func(reward_func):
                self.reward_func = reward_func
        else:
            self.reward_func = self.default_reward_func

    # @property
    # def batch_size(self):
    #     """Returns the batch size of the environment setup."""
    #     return self._batch_size

    @property
    def default_reward_function(self):
        """Returns the default reward function for the given environment."""
        return self.default_reward_func

    # @batch_size.setter
    # def batch_size(self, batch_size):
    #     # If batchsize change, update the corresponding dimension
    #     self._batch_size = batch_size

    def sim_paras(self, static_params_, env_state_constraints_, env_max_actions_):
        """Creates or updates static parameters to fit batch_size.

        Creates/Updates:
            params : Model Parameters.
        """
        static_params = static_params_.copy()
        for key, value in static_params.items():
            if jnp.isscalar(value):
                static_params[key]=(jnp.full((self.batch_size), value))
                # self.static_para_dims[key] = None
            # elif jnp.all(value == value[0]):
            #     self.static_params[key] = jnp.full(
            #         (self.batch_size, 1), value[0])
            else:
                assert len(
                    value) == self.batch_size, f"{key} is expected to be a scalar or a list with len(list)=batch_size"
                static_params[key]=(jnp.array(value))
                # self.static_para_dims[key] = 0

        env_state_constraints = env_state_constraints_.copy()
        env_state_constraints_ar=[]
        for key, value in env_state_constraints.items():
            if jnp.isscalar(value):
                env_state_constraints_ar.append(jnp.full((self.batch_size), value))
                # self.static_para_dims[key] = None
            # elif jnp.all(value == value[0]):
            #     self.static_params[key] = jnp.full(
            #         (self.batch_size, 1), value[0])
            else:
                assert len(
                    value) == self.batch_size, f"Constraint of {key} is expected to be a scalar or a list with len(list)=batch_size"
                env_state_constraints_ar.append(jnp.array(value))
                # self.static_para_dims[key] = 0

        env_max_actions = env_max_actions_.copy()
        env_max_actions_ar = []
        for key, value in env_max_actions.items():
            if jnp.isscalar(value):
                env_max_actions_ar.append(jnp.full((self.batch_size), value))
                # self.static_para_dims[key] = None
            # elif jnp.all(value == value[0]):
            #     self.static_params[key] = jnp.full(
            #         (self.batch_size, 1), value[0])
            else:
                assert len(
                    value) == self.batch_size, f"Constraint of {key} is expected to be a scalar or a list with len(list)=batch_size"
                env_max_actions_ar.append(jnp.array(value))
                # self.static_para_dims[key] = 0

        return static_params, jnp.array(env_state_constraints_ar).T, jnp.array(env_max_actions_ar).T

    # def solver(self):
    #     """Returns the current solver of the environment setup."""
    #     return self._solver

    # @solver.setter
    # def solver(self, solver):
    #     # TODO:check if solver exists in diffrax ?
    #     self._solver = solver

    def _test_rew_func(self, func):
        """Checks if passed reward function is compatible with given environment.

        Args:
            func(function): Reward function to test.

        Returns:
            compatible(bool): Environment compatibility.
        """
        try:
            out = func(
                jnp.zeros([self.batch_size, int(len(self.obs_description))]))
        except:
            raise Exception(
                "Reward function should be using obs matrix as only parameter")
        try:
            if out.shape != (self.batch_size, 1):
                raise Exception(
                    "Reward function should be returning vector in shape (batch_size,1)")
        except:
            raise Exception(
                "Reward function should be returning vector in shape (batch_size,1)")
        return True

    @partial(jax.jit, static_argnums=0)
    def step(self, action, states):
        """Addtional function in step execution to enable JAX jit.

        Args:
            states(ndarray(float)): State Matrix (shape=(batch_size,states)).
            action_norm(ndarray(float)): Action Matrix (shape=(batch_size,actions)).


        Returns:
            Multiple Outputs:

            observation(ndarray(float)): Observation/State Matrix (shape=(batch_size,states)).

            reward(ndarray(float)): Amount of reward received for the last step (shape=(batch_size,1)).

            terminated(bool): Flag, indicating if Agent has reached the terminal state.

            truncated(ndarray(bool)): Flag, indicating if state has gone out of bounds (shape=(batch_size,states)).

            {}: An empty dictionary for consistency with the OpenAi Gym interface.

        """

        # states=self.states_array_to_dataclass(states)
        # ode step
        states = jax.vmap(self._ode_exp_euler_step)(
            states, action, self.static_params)
        # states=self.states_dataclass_to_array(states)
        # observation
        # print(states)
        # print(self.env_state_constraints)
        #obs = jax.vmap(self.generate_observation)(
            #states, self.env_state_constraints)
        # reward
        #reward = jax.vmap(self.reward_func)(
            #obs, action, self.env_max_actions).reshape(-1, 1)

        # bound check
        #truncated = jax.vmap(self.generate_truncated)(
            #states, self.env_state_constraints)
        #terminated = jax.vmap(self.generate_terminated)(states, reward)

        #return obs, reward, terminated, truncated, states
        return {},{},{},{},states

    @property
    @abstractmethod
    def obs_description(self):
        """Returns a list of state names of all states in the observation (equal to state space)."""
        return self.states_description

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def default_reward_func(self, obs, action):
        """Returns the default RewardFunction of the environment."""
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def generate_observation(self, states):
        """Returns states."""
        return states

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def generate_truncated(self, states):
        """Returns states."""
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def generate_terminated(self, states, reward):
        """Returns states."""
        return

    @partial(jax.jit, static_argnums=0)
    @abstractmethod
    def _ode_exp_euler_step(self, states_norm, action_norm, state_normalizer,  action_normalizer, params):
        """Implementation of the system equations in the class with Explicit Euler.

        Args:
            states_norm(ndarray(float)): State Matrix (shape=(batch_size,states)).
            action_norm(ndarray(float)): Action Matrix (shape=(batch_size,actions)).


        Returns:
            states(ndarray(float)): State Matrix (shape=(batch_size,states)).

        """
        return

    @abstractmethod
    def reset(self, initial_values: jnp.ndarray = jnp.array([])):
        return


In [64]:
import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
import diffrax
from collections import OrderedDict


class Pendulum(CoreEnvironment):
    """
    State Variables:
        ``['theta', 'omega']``

    Action Variable:
        ``['torque']``

    Observation Space (State Space):
        Box(low=[-1, -1], high=[1, 1])    

    Action Space:
        Box(low=-1, high=1)

    Initial State:
        Unless chosen otherwise, theta equals 1(normalized to pi) and omega is set to zero.

    Example:
        >>> import jax
        >>> import exciting_environments as excenvs
        >>> 
        >>> # Create the environment
        >>> env= excenvs.make('Pendulum-v0',batch_size=2,l=2,m=4)
        >>> 
        >>> # Reset the environment with default initial values
        >>> env.reset()
        >>> 
        >>> # Sample a random action
        >>> action = env.action_space.sample(jax.random.PRNGKey(6))
        >>> 
        >>> # Perform step
        >>> obs,reward,terminated,truncated,info= env.step(action)
        >>> 

    """

    def __init__(self, batch_size: int = 8, l: float = 1, m: float = 1,  env_max_action: list = {"torque": 20}, solver=diffrax.Euler(), reward_func=None, g: float = 9.81, tau: float = 1e-4, env_state_constraints: dict = {"theta": np.pi, "omega": 10}):
        """
        Args:
            batch_size(int): Number of training examples utilized in one iteration. Default: 8
            l(float): Length of the pendulum. Default: 1
            m(float): Mass of the pendulum tip. Default: 1
            max_torque(float): Maximum torque that can be applied to the system as action. Default: 20 
            reward_func(function): Reward function for training. Needs Observation-Matrix and Action as Parameters. 
                                    Default: None (default_reward_func from class) 
            g(float): Gravitational acceleration. Default: 9.81
            tau(float): Duration of one control step in seconds. Default: 1e-4.
            constraints(list): Constraints for state ['omega'] (list with length 1). Default: [10]

        Note: l,m and max_torque can also be passed as lists with the length of the batch_size to set different parameters per batch. In addition to that constraints can also be passed as a list of lists with length 1 to set different constraints per batch.  
        """
        self.env_states_name = ["theta", "omega"]
        self.env_actions_name = ["torque"]

        self.env_states_initials = [jnp.pi,0]

        super().__init__(batch_size=batch_size, tau=tau,
                         solver=solver, reward_func=reward_func)

        self.static_params, self.env_state_constraints, self.env_max_action= self.sim_paras(
            {"l": l, "m": m, "g": g}, env_state_constraints, env_max_action)

    @partial(jax.jit, static_argnums=0)
    def _ode_exp_euler_step(self, states, action, static_params):

        env_states=states.physical_states
        args = (action, static_params)

        def vector_field(t, y, args):
            theta, omega = y
            action, params = args
            d_omega = (action[0]+params["l"]*params["m"]*params["g"]
                       * jnp.sin(theta)) / (params["m"] * (params["l"])**2)
            d_theta = omega
            d_y = d_theta, d_omega
            return d_y

        term = diffrax.ODETerm(vector_field)
        t0 = 0
        t1 = self.tau
        y0 = tuple([env_states.theta, env_states.omega])
        env_state = self._solver.init(term, t0, t1, y0, args)
        y, _, _, env_state, _ = self._solver.step(
            term, t0, t1, y0, args, env_state, made_jump=False)

        theta_k1 = y[0]
        omega_k1 = y[1]
        theta_k1 = ((theta_k1+jnp.pi) % (2*jnp.pi))-jnp.pi

        env_states_k1 = jnp.hstack((
            theta_k1,
            omega_k1,
        ))

        phys=PhysicalStates(theta=env_states_k1[0],omega=env_states_k1[1])
        opt = Optional(something=env_states_k1[0])
        return States(physical_states=phys,PRNGKey=env_states_k1[1],optional=opt)


    @partial(jax.jit, static_argnums=0)
    def states_array_to_dataclass(self,states):
        phys=PhysicalStates(theta=states[:,0],omega=states[:,1])
        opt = Optional(something=states[:,0])
        return States(physical_states=phys,PRNGKey=states[:,1],optional=opt)
    
    @partial(jax.jit, static_argnums=0)
    def states_dataclass_to_array(self,states):
        states_ = jnp.vstack((
            states.physical_states.theta,
            states.physical_states.omega,
        )).T
        return  states_

    @partial(jax.jit, static_argnums=0)
    def default_reward_func(self, obs, action, env_max_actions):
        return (obs[0])**2 + 0.1*(obs[1])**2 + 0.1*(action["torque"]/env_max_actions["torque"])**2

    @partial(jax.jit, static_argnums=0)
    def generate_observation(self, states, env_state_constraints):
        """Returns states."""
        return (jnp.array(list(states.values()))*(jnp.array(list(env_state_constraints.values())))**(-1)).T  #

    @partial(jax.jit, static_argnums=0)
    def generate_truncated(self, states, env_state_constraints):
        """Returns states."""
        return jnp.abs((jnp.array(list(states.values()))/jnp.array(list(env_state_constraints.values()))).T) > 1

    @partial(jax.jit, static_argnums=0)
    def generate_terminated(self, states, reward):
        """Returns states."""
        return reward == 0

    @property
    def obs_description(self):
        return self.env_states_name

    def reset(self, initial_values: jnp.ndarray = jnp.array([])):
        # TODO
        # if initial_values.any() != False:
        #     assert initial_values.shape[
        #         0] == self.batch_size, f"number of rows is expected to be batch_size, got: {initial_values.shape[0]}"
        #     assert initial_values.shape[1] == len(
        #         self.obs_description), f"number of columns is expected to be amount obs_entries: {len(self.obs_description)}, got: {initial_values.shape[0]}"
        #     states = initial_values
        # else:
        #     states = jnp.tile(
        #         jnp.array(self.env_state_initials), (self.batch_size, 1))

        # obs = self.generate_observation(states)

        return  # obs, states

In [65]:
gym_pend=GymWrapper(env=Pendulum(batch_size=5,l=[i+1 for i in range(5)],env_max_action = {"torque": 1}))

In [66]:
gym_pend.step(action=jnp.ones(5).reshape(-1,1))

({}, {}, {}, {}, {})

In [67]:
gym_pend.states

Array([[-3.1415927e+00,  9.9999917e-05],
       [-3.1415927e+00,  2.4999956e-05],
       [-3.1415927e+00,  1.1111083e-05],
       [-3.1415927e+00,  6.2499780e-06],
       [-3.1415927e+00,  3.9999827e-06]], dtype=float32)