In [1]:
import jax.numpy as jnp
import numpy as np
import time
import chex
import jax
from functools import partial
import gymnasium as gym
from exciting_environments import core_env
from diffrax import diffeqsolve, ODETerm, Dopri5

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [176]:
class Pendulum(core_env.CoreEnvironment):

    def __init__(self, batch_size: int = 8, l: float = 1, m: float = 1,  max_torque: float = 20, reward_func=None, g: float = 9.81, tau: float = 1e-4, constraints: list = [10]):

        self.params = {"g": g, "l": l, "m": m}
        self.state_constraints = [np.pi, constraints[0]]  # ["theta", "omega"]
        self.state_initials = [1, 0]
        self.max_action = [max_torque]

        super().__init__(batch_size=batch_size, tau=tau)

    @partial(jax.jit, static_argnums=0)
    def _ode_exp_euler_step(self, states_norm, torque_norm):

        torque = torque_norm*self.action_normalizer
        states = self.state_normalizer * states_norm
        theta = states[:, 0].reshape(-1, 1)
        omega = states[:, 1].reshape(-1, 1)

        domega = (torque+self.params["l"]*self.params["m"]*self.params["g"]
                  * jnp.sin(theta)) / (self.params["m"] * (self.params["l"])**2)

        omega_k1 = omega + self.tau * domega  # explicit Euler
        dtheta = omega_k1
        theta_k1 = theta + self.tau * dtheta  # explicit Euler
        theta_k1 = ((theta_k1+jnp.pi) % (2*jnp.pi))-jnp.pi

        states_k1 = jnp.hstack((
            theta_k1,
            omega_k1,
        ))
        states_k1_norm = states_k1/self.state_normalizer

        return states_k1_norm

    @partial(jax.jit, static_argnums=0)
    def default_reward_func(self, obs, action):
        return ((obs[:, 0])**2 + 0.1*(obs[:, 1])**2 + 0.1*(action[:, 0])**2).reshape(-1, 1)

    @property
    def obs_description(self):
        return self.states_description

    @property
    def states_description(self):
        return np.array(["theta", "omega"])

    @property
    def action_description(self):
        return np.array(["torque"])


In [177]:
env=Pendulum(batch_size=4,tau=1e-4)
env.reset()

(Array([[1, 0],
        [1, 0],
        [1, 0],
        [1, 0]], dtype=int32),
 {})

In [178]:
obs,_,_,_,_=env.step(jnp.array([0.5,0.5,0.5,1]).reshape(-1,1))

In [179]:
obs

Array([[-1.000000e+00,  1.000000e-04],
       [-1.000000e+00,  1.000000e-04],
       [-1.000000e+00,  1.000000e-04],
       [-9.999998e-01,  2.000000e-04]], dtype=float32)

In [180]:
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController

In [187]:
class Pendulum2(core_env.CoreEnvironment):

    def __init__(self, batch_size: int = 8, l: float = 1, m: float = 1,  max_torque: float = 20, reward_func=None, g: float = 9.81, tau: float = 1e-4, constraints: list = [10]):

        self.params = {"g": g, "l": l, "m": m}
        self.state_constraints = [np.pi, constraints[0]]  # ["theta", "omega"]
        self.state_initials = [1, 0]
        self.max_action = [max_torque]

        super().__init__(batch_size=batch_size, tau=tau)


    @partial(jax.jit, static_argnums=0)
    def _ode_exp_euler_step(self, states_norm, torque_norm):

        torque = torque_norm*self.action_normalizer
        states = self.state_normalizer * states_norm
        theta = states[:, 0].reshape(-1, 1)
        omega = states[:, 1].reshape(-1, 1)

        #voraussichtlich funktionen ausladen

        def d_omega(t, y, args):
            action,theta,l, m, g= args
            d_omega= (action+l*m*g* jnp.sin(theta)) / (m * (l)**2)
            return d_omega

        def d_theta(t, y, args):
            omega,_= args
            d_theta= omega
            return d_theta

        solver = Dopri5()
        t0 = 0
        t1 = self.tau
        dt0 = self.tau/10
        saveat = SaveAt(ts=[self.tau])

        term = ODETerm(d_omega)
        omega_k = omega.reshape(1,-1)
        args = (torque[0], theta[0],self.params["l"][0], self.params["m"][0],self.params["g"][0])
        #args = (torque, theta,self.params["l"], self.params["m"],self.params["g"])
        sol_omega = diffeqsolve(term, solver, t0, t1, dt0, omega_k, args=args, saveat=saveat)
        omega_k1=sol_omega.ys[0].reshape(-1,1)

        term2 = ODETerm(d_theta)
        theta_k = theta.reshape(1,-1)
        args2 = (omega_k1[0],1)
        sol_theta = diffeqsolve(term2, solver, t0, t1, dt0, theta_k, args=args2, saveat=saveat)
        theta_k1=sol_theta.ys[0].reshape(-1,1)


        theta_k1 = ((theta_k1+jnp.pi) % (2*jnp.pi))-jnp.pi

        states_k1 = jnp.hstack((
            theta_k1,
            omega_k1,
        ))
        states_k1_norm = states_k1/self.state_normalizer

        return states_k1_norm

    @partial(jax.jit, static_argnums=0)
    def default_reward_func(self, obs, action):
        return ((obs[:, 0])**2 + 0.1*(obs[:, 1])**2 + 0.1*(action[:, 0])**2).reshape(-1, 1)

    @property
    def obs_description(self):
        return self.states_description

    @property
    def states_description(self):
        return np.array(["theta", "omega"])

    @property
    def action_description(self):
        return np.array(["torque"])

In [188]:
env2=Pendulum2(batch_size=2,tau=1e-4)
env2.reset()

(Array([[1, 0],
        [1, 0]], dtype=int32),
 {})

In [189]:
obs,_,_,_,_=env2.step(jnp.array([0.5,0.5]).reshape(1,-1))

In [169]:
obs

Array([[-1.e+00,  3.e-04],
       [-1.e+00,  3.e-04]], dtype=float32)

In [87]:

def d_omega(t, y, args):
    action,theta,l, m, g= args
    d_omega= (action+l*m*g* jnp.sin(theta)) / (m * (l)**2)
    return d_omega

term = ODETerm(d_omega)
solver = Dopri5()
t0 = 0
t1 = 1e-4
dt0 = 1e-4
omega0 = np.array([0,0])*10
args = (0.5*20, 1*jnp.pi, 1, 1,9.81)
saveat = SaveAt(ts=[1e-4])
sol = diffeqsolve(term, solver, t0, t1, dt0, omega0, args=args, saveat=saveat)

In [88]:
print(sol.ts)  # DeviceArray([0.   , 1.   , 2.   , 3.    ])
print(sol.ys/10)

[1.e-04]
[[9.999999e-05 9.999999e-05]]


In [94]:
sol.ys[0].reshape(-1,1)

Array([[0.001],
       [0.001]], dtype=float32, weak_type=True)

In [21]:

def vector_field(t, y, args):
    d_y=-y
    return d_y

vector_field1 = lambda t, y, args: -y
term = ODETerm(vector_field1)
solver = Dopri5()
saveat = SaveAt(ts=[0., 2., 3.])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

sol = diffeqsolve(term, solver, t0=0, t1=3,dt0=0.1, y0=(0,1), saveat=saveat,
                  stepsize_controller=stepsize_controller)


TypeError: bad operand type for unary -: 'tuple'

In [9]:
bla=(1,2,3)
type(bla)

tuple

In [20]:
print(sol.ts)  # DeviceArray([0.   , 1.   , 2.   , 3.    ])
print(sol.ys)

[0. 2. 3.]
[[1.         0.        ]
 [0.13533796 0.        ]
 [0.04979085 0.        ]]
