In [1]:
# Download checkpoint from Google Drive
import gdown
import zipfile

url = "https://drive.google.com/file/d/1XDWulLZBuHaMpabrOZp0JPBc09J3Oh93/view?usp=sharing"
output = "/sim2real/checkpoint.zip"

# Download the zip file
gdown.download(url, output, quiet=False, fuzzy=True)

# Extract the zip file
with zipfile.ZipFile(output, 'r') as zip_ref:
    zip_ref.extractall("/sim2real")

Downloading...
From: https://drive.google.com/uc?id=1XDWulLZBuHaMpabrOZp0JPBc09J3Oh93
To: /sim2real/checkpoint.zip
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 316k/316k [00:00<00:00, 3.81MB/s]


# Actor checkpoint

In [2]:
import os
# Importante para não alocar todo o espaço de memória
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [3]:
import cloudpickle

In [4]:
# Test loading from the pickle file
CHECKPOINT_PATH = "/sim2real"
pickle_path = f"{CHECKPOINT_PATH}/actor.pkl"
with open(pickle_path, "rb") as f:
    loaded_actor = cloudpickle.load(f)

In [5]:
loaded_actor

{'actor_fn': <PjitFunction>,
 'get_action': <function __main__.get_actor_from_checkpoint.<locals>.get_action(obs: numpy.ndarray) -> numpy.ndarray>,
 'obs_mean': array(0),
 'obs_std': array(1),
 'env_name': 'Go2JoystickFlatTerrain',
 'state_dim': 48,
 'action_dim': 12,
 'max_action': 1.0}

### Uso:

O dicionário `loaded_actor` contém os seguintes elementos:
- `actor_fn`: Função JIT compilada do ator
- `get_action`: Função wrapper para obter ações a partir de observações
- `obs_mean`: Média das observações para normalização (neste caso, 0)
- `obs_std`: Desvio padrão das observações para normalização (neste caso, 1)
- `env_name`: Nome do ambiente ('Go2JoystickFlatTerrain')
- `state_dim`: Dimensão do espaço de estados (48)
- `action_dim`: Dimensão do espaço de ações (12)
- `max_action`: Valor máximo da ação (1.0)
#
Para usar o ator:
```python
actor = loaded_actor["get_action"]
action = actor(obs=observation)  # observation deve ser um numpy array de shape (48,)
```

# Teste no simulador

In [6]:
import os
import sys
sys.path.append("..")
from algorithms.utils.wrapper_gym import get_env

# Importante para não alocar todo o espaço de memória
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [7]:
import numpy as np
from datetime import datetime

In [8]:
ENV_NAME = "Go2JoystickFlatTerrain"

In [9]:
render_trajectory = []
def render_callback(_, state):
    render_trajectory.append(state)

env_wrapped = get_env(ENV_NAME, "cuda", render_callback, command_type="fowardfixed")

In [10]:
def evaluate(
    actor,
    env,
    num_episodes,
    render=False
):
    episode_returns = []
    for _ in range(num_episodes):
        episode_return = 0
        observation, _ = env.reset()
        done = truncated = False
        while not done and not truncated:
            action = actor(obs=observation)
            observation, reward, done, truncated, info = env.step(action)
            if render:
                env.render()
            episode_return += reward
        episode_returns.append(episode_return)
    
    mean_return = np.mean(episode_returns)
    return mean_return

In [11]:
episode_rewards = evaluate(actor=loaded_actor["get_action"], env=env_wrapped, num_episodes=1, render=True)
print("Mean: ", episode_rewards.mean())
print("Std: ", episode_rewards.std())

Mean:  31.51585
Std:  0.0


In [12]:
import mujoco.egl
os.environ["MUJOCO_GL"] = "egl"
gl_context = mujoco.egl.GLContext(1024, 1024)
gl_context.make_current()

In [13]:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
env_wrapped.save_video(render_trajectory, save_path=f"{CHECKPOINT_PATH}/{ENV_NAME}-{timestamp}.mp4")

/usr/local/lib/python3.10/dist-packages/glfw/__init__.py:917: GLFWError: (65550) b'X11: The DISPLAY environment variable is missing'
/usr/local/lib/python3.10/dist-packages/glfw/__init__.py:917: GLFWError: (65537) b'The GLFW library is not initialized'
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:28<00:00, 17.60it/s]
