In [None]:
import os
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import numpy as np
import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")
import chex
from functools import partial
from typing import Callable
import gymnasium as gym
from exciting_environments import PMSM
import matplotlib.pyplot as plt
from exciting_environments.utils import MinMaxNormalization
from exciting_environments import MujucoWrapper
jax.config.update("jax_enable_x64", True)

import mujoco
from mujoco import mjx
from etils import epath


In [None]:
#Option 1
HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'
#model_path = HUMANOID_ROOT_PATH / 'humanoid.mjb'
model_from_xml = mujoco.MjModel.from_xml_path((HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())

In [None]:
#Option 2

# <!-- ======================================================

# 		MODEL: Inverted Pendulum
# 		AUTHOR: Atabak Dehban
# 		Modifed from the following model

#     ======================================================
# 	Model 		:: Beam Balance

# 	Mujoco		:: Advanced physics simulation engine
# 		Source		: www.roboti.us
# 		Version		: 1.31
# 		Released 	: 23Apr16

# 	Author		:: Vikash Kumar
# 		Contacts 	: kumar@roboti.us
# 		Last edits 	: 30Apr'16, 30Nov'15, 10Oct'15
#     ======================================================

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# -->
XML = """
<mujoco model="inverted pendulum">
	<compiler 	angle="radian"/>
			
	<option	timestep="0.001"
			iterations="20">
	</option>
	
	<default>
	<geom  rgba="0.45 0.6 0.6 1"/>
		<site type="sphere" size="0.02"/>
	</default>
	<worldbody>
	
		<light directional="true" cutoff="4" exponent="20" diffuse="1 1 1" specular="0 0 0" pos=".9 .3 2.5" dir="-.9 -.3 -2.5 "/>
	
		<!-- ======= Ground ======= -->
		<geom name="ground" type="plane" pos="0 0 0" size="0.5 1 2" rgba=" .25 .26 .25 1"/>
		<site name="rFix" pos="0 -.2 .005"/>
		<site name="lFix" pos="0 .2 .005"/>
			
		<!-- ======= Beam ======= -->
		<body name="beam" pos="0 0 .5">
			<!--density of the rod is ten times smaller than other objects (e.g. the ball) for a better point mass approximation-->
            <geom name="rod" type="cylinder" pos="0 0 0.0" size=".01 .1" density="100"/>
			<geom pos="0 0 -.1" type="capsule" size=".01 .2" euler="1.57 0 0"/>
			<joint name="pivot" pos="0 0 -0.1" axis="0 1 0" limited="false" damping=".05"/>
			<site name="rBeam" pos="0 -.2 -.1"/>
			<site name="lBeam" pos="0 .2 -.1"/>
			<body name="ballbody" pos="0 0 0.1">
				<geom name="ballgeom" type="sphere" size=".05"/>
			</body>
		</body>
	</worldbody>
	
	<tendon>
		<spatial name="lThread" width="0.01">
			<site site="lBeam"/>
			<site site="lFix"/>
		</spatial>
		
		<spatial name="rThread"  width="0.01">
			<site site="rBeam"/>
			<site site="rFix"/>
		</spatial>
	</tendon>
	
	<actuator>
        <motor joint='pivot' name='motor' gear="1" forcelimited="true" forcerange="-10 10"/>
	</actuator>

	<sensor>
		<jointpos name="encoder" joint="pivot" noise="0.001"/>
	</sensor>
</mujoco>
"""
model_own_xml = mujoco.MjModel.from_xml_string(XML)

In [None]:
#Option 3
gym_env =  gym.make('InvertedDoublePendulum-v5')  # gym.make('InvertedPendulum-v5') 
gym_env_unwrapped=gym_env.unwrapped

model_from_gym = mujoco.MjModel.from_xml_path(gym_env_unwrapped.fullpath)

In [None]:
model=model_from_gym
wrap=MujucoWrapper(mujoco_model=model)


In [None]:
phys_norm_def=MujucoWrapper.generate_physical_normalization_dataclasses(MujucoWrapper,model)
act_norm_def=MujucoWrapper.generate_action_normalization_dataclasses(MujucoWrapper,model)

In [None]:
# exchange nans with valid values

# example for min=-1 and max=1 for all states and actions
from dataclasses import fields, is_dataclass, replace

def replace_nans(obj):
    if isinstance(obj, MinMaxNormalization):
        min=-1
        max=1
        if not jnp.isnan(obj.min):
            min=obj.min
        if not jnp.isnan(obj.max):
            max=obj.max

        return MinMaxNormalization(min=min,max=max)  # Ersetze NaN direkt
    
    elif is_dataclass(obj):  
        return replace(obj, **{
            field.name: replace_nans(getattr(obj, field.name))
            for field in fields(obj)
        })  
    
    return obj 

phys_norm=replace_nans(phys_norm_def)
act_norm=replace_nans(act_norm_def)

In [None]:
wrap=MujucoWrapper(mujoco_model=model,physical_normalizations=phys_norm,action_normalization=act_norm)

In [None]:
wrap.obs_description

In [None]:
obs,data_state=wrap.reset(wrap.env_properties)
for _ in range(3):
    obs,data_state=wrap.step(data_state,jnp.ones(wrap.mjx_model.nu),wrap.env_properties)
obs

In [None]:
obs,data_state=wrap.vmap_reset(None)
for _ in range(3):
    obs,data_state=wrap.vmap_step(data_state,jnp.ones((8,wrap.mjx_model.nu)))
obs[0]

### Comparison Gym and Own

In [None]:
obs_gym,_=gym_env.reset()
print(obs_gym)
obs_init=jnp.hstack([obs_gym[0:1],jnp.arctan2(obs_gym[1],obs_gym[3]),jnp.arctan2(obs_gym[2],obs_gym[4]),obs_gym[5:8]]) #for DoublePendulum
#obs_init=obs_gym

In [None]:
obs,data=wrap.reset(wrap.env_properties,initial_qpos_qvel=obs_init)
data.qpos

In [None]:
obs,_,_,_,_=gym_env.step(model.actuator_ctrlrange[0,1]*jnp.array([1]))
obs

In [None]:
for _ in range(gym_env_unwrapped.frame_skip):
    obs,data=wrap.step(data,1*jnp.ones(wrap.mjx_model.nu),wrap.env_properties)
print(data.qpos,data.qvel)