In [4]:
import jax.numpy as jnp
import jax
import numpy as np
import time
import gymnasium as gym
import sys
sys.path.append("..")
import exciting_environments as excenvs
import diffrax
from exciting_environments import GymWrapper
import jax_dataclasses as jdc
from dataclasses import fields
from exciting_environments.utils import MinMaxNormalization
import os
from pathlib import Path
import pickle
jax.config.update("jax_enable_x64", True)

In [8]:
with open("data/sim_properties.pkl", "rb") as f:
    loaded_data = pickle.load(f)
loaded_params = loaded_data["params"]
loaded_action_normalizations = loaded_data["action_normalizations"]
loaded_physical_normalizations = loaded_data["physical_normalizations"]
loaded_tau = loaded_data["tau"]
env = excenvs.make(
    "CartPole-v0",
    tau=loaded_tau,
    solver=diffrax.Euler(),
    static_params=loaded_params,
    physical_normalizations=loaded_physical_normalizations,
    action_normalizations=loaded_action_normalizations,
)

stored_observations = jnp.load("data/observations.npy")
actions_data = jnp.load("data/actions.npy")
state = env.generate_state_from_observation(stored_observations[0], env.env_properties)
generated_observations = []
generated_observations.append(stored_observations[0])
for i in range(10000):
    action = actions_data[i]
    obs, state = env.step(state, action, env.env_properties)
    generated_observations.append(obs)
generated_observations = jnp.array(generated_observations)

In [9]:
jnp.allclose(generated_observations, stored_observations, 1e-16)

Array(True, dtype=bool)

In [10]:
generated_observations

Array([[ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00],
       [ 0.00000000e+00, -1.52142068e-04, -1.00000000e+00,
        -2.28213103e-04],
       [-5.07140228e-08, -2.65808706e-04,  9.99999942e-01,
        -3.98713058e-04],
       ...,
       [ 3.39400971e-03,  6.48156987e-03,  9.99897298e-01,
         1.05608749e-02],
       [ 3.39617023e-03,  6.30010138e-03,  9.99899987e-01,
         1.02887315e-02],
       [ 3.39827026e-03,  6.53424452e-03,  9.99902607e-01,
         1.06400039e-02]], dtype=float64)

In [11]:
stored_observations

Array([[ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00,
         0.00000000e+00],
       [ 0.00000000e+00, -1.52142068e-04, -1.00000000e+00,
        -2.28213103e-04],
       [-5.07140228e-08, -2.65808706e-04,  9.99999942e-01,
        -3.98713058e-04],
       ...,
       [ 3.39400971e-03,  6.48156987e-03,  9.99897298e-01,
         1.05608749e-02],
       [ 3.39617023e-03,  6.30010138e-03,  9.99899987e-01,
         1.02887315e-02],
       [ 3.39827026e-03,  6.53424452e-03,  9.99902607e-01,
         1.06400039e-02]], dtype=float64)

In [12]:
import json
def safe_json_dump(obj, fp):
    default = lambda o: f"<<non-serializable: {type(o).__qualname__}>>"
    return json.dump(obj, fp, default=default)

In [14]:
loaded_params = loaded_data["params"]
loaded_action_normalizations = loaded_data["action_normalizations"]
loaded_physical_normalizations = loaded_data["physical_normalizations"]
loaded_tau = loaded_data["tau"]

In [21]:
loaded_action_normalizations


{'force': MinMaxNormalization(min=-20, max=20)}

In [20]:
json_data = {
    "params": loaded_params,
    "action_normalizations": loaded_action_normalizations,
    "physical_normalizations": loaded_physical_normalizations,
    "tau": loaded_tau  # tau ist vermutlich float oder ähnlich
}

with open("data/props_2.json", "w") as f:
    json.dump(json_data, f, indent=4)

TypeError: Object of type MinMaxNormalization is not JSON serializable

In [25]:
import json
from dataclasses import asdict

# Alle inneren JDC-Objekte konvertieren
action_norm_serialized = {
    k: asdict(v) for k, v in loaded_action_normalizations.items()
}
physical_norm_serialized = {
        k: asdict(v) for k, v in loaded_physical_normalizations.items()
    }

# Gleiches ggf. für andere ähnlichen Strukturen
data = {
    "params": loaded_params,
    "action_normalizations": action_norm_serialized,
    "physical_normalizations": physical_norm_serialized,
    "tau": loaded_tau,
}

with open("data/sim_properties.json", "w") as f:
    json.dump(data, f, indent=4)

In [None]:
def save_to_json(params, action_normalizations, physical_normalizations, tau, filename):
    action_norm_serialized = {
    k: asdict(v) for k, v in action_normalizations.items()
    }
    physical_norm_serialized = {
            k: asdict(v) for k, v in physical_normalizations.items()
        }
    data = {
        "params": params,
        "action_normalizations": action_norm_serialized,
        "physical_normalizations": physical_norm_serialized,
        "tau": tau
    }
    with open(filename, 'w') as f:
        json.dump(data, f, indent=4)

def load_from_json(filename):
    with open(filename, 'r') as f:
        data = json.load(f)
    
    params= data["params"]
    action_norm_serialized = data["action_normalizations"]
    physical_norm_serialized = data["physical_normalizations"]
    tau = data["tau"]
    action_normalizations = {
        key: MinMaxNormalization(**value)
        for key, value in action_norm_serialized.items()
    }
    physical_normaliztions = {
        key: MinMaxNormalization(**value)
        for key, value in physical_norm_serialized.items()
    }
    return params, action_normalizations, physical_normaliztions, tau