In [4]:
import os
import hydra
from omegaconf import OmegaConf, DictConfig
from utils.utils import print_config

if hydra.core.global_hydra.GlobalHydra().is_initialized():
    hydra.core.global_hydra.GlobalHydra().clear()

# Initialize Hydra configurations
with hydra.initialize(version_base="1.3", config_path="configs"):
    cfg = hydra.compose(config_name="config", overrides=["exp=dreamer_v3_dmc_walker_walk"])
    print_config(cfg)


In [6]:
import pathlib
from utils.utils import dotdict
from omegaconf import DictConfig, OmegaConf, open_dict

def resume_from_checkpoint(cfg: DictConfig) -> DictConfig:
    ckpt_path = pathlib.Path(cfg.checkpoint.resume_from)
    old_cfg = OmegaConf.load(ckpt_path.parent.parent / "config.yaml")
    old_cfg = dotdict(OmegaConf.to_container(old_cfg, resolve=True, throw_on_missing=True))
    if old_cfg.env.id != cfg.env.id:
        raise ValueError(
            "This experiment is run with a different environment from the one of the experiment you want to restart. "
            f"Got '{cfg.env.id}', but the environment of the experiment of the checkpoint was {old_cfg.env.id}. "
            "Set properly the environment for restarting the experiment."
        )
    if old_cfg.algo.name != cfg.algo.name:
        raise ValueError(
            "This experiment is run with a different algorithm from the one of the experiment you want to restart. "
            f"Got '{cfg.algo.name}', but the algorithm of the experiment of the checkpoint was {old_cfg.algo.name}. "
            "Set properly the algorithm name for restarting the experiment."
        )

    # Remove keys from the `old_cfg` that must not be overridden
    old_cfg.pop("root_dir", None)
    old_cfg.pop("run_name", None)
    old_cfg.checkpoint.pop("resume_from", None)
    # Substitute the config with the old one (except for the parameters removed before)
    # because the experiment must continue with the same parameters
    with open_dict(cfg):
        cfg.merge_with(old_cfg.as_dict())
    return cfg

In [7]:
if cfg.checkpoint.resume_from:
    cfg = resume_from_checkpoint(cfg)

In [8]:
cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True))

In [9]:
import warnings
from typing import Any, Dict
from utils.registry import algorithm_registry
from utils.imports import _IS_MLFLOW_AVAILABLE

In [10]:
def check_configs(cfg: Dict[str, Any]):
    """Check the validity of the configuration.

    Args:
        cfg (Dict[str, Any]): the loaded configuration to check.
    """
    if cfg.float32_matmul_precision not in {"medium", "high", "highest"}:
        raise ValueError(
            f"Invalid value '{cfg.float32_matmul_precision}' for the 'float32_matmul_precision' parameter. "
            "It must be one of 'medium', 'high' or 'highest'."
        )
    decoupled = False
    algo_name = cfg.algo.name
    for _, _algos in algorithm_registry.items():
        for _algo in _algos:
            if algo_name == _algo["name"]:
                decoupled = _algo["decoupled"]
                break
    if not (_IS_MLFLOW_AVAILABLE or cfg.model_manager.disabled):
        warnings.warn(
            "MLFlow is not installed. "
            "Please install it with 'pip install mlflow' if you want to use the MLFlow logger and log models. "
            "Setting `cfg.model_manager.disabled=True`",
            UserWarning,
        )
        cfg.model_manager.disabled = True

In [11]:
check_configs(cfg)

In [16]:
from envs.dmc import DMCWrapper
env = DMCWrapper(domain_name="walker", task_name="walk", from_pixels=True, from_vectors=True)
obs = env.reset()
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
print("Observation:", obs)
print("Reward:", reward)
print("Terminated:", terminated)
print("Truncated:", truncated)
print("Info:", info)


Observation: {'rgb': array([[[ 45,  45,  46, ...,  45,  45,  45],
        [ 44,  44,  44, ...,  44,  44,  44],
        [ 44,  44,  44, ...,  44,  44,  45],
        ...,
        [ 58,  45,  32, ...,  31,  31,  31],
        [ 55,  37,  30, ...,  31,  31,  31],
        [ 48,  32,  30, ...,  31,  31,  31]],

       [[ 67,  67,  68, ...,  67,  67,  67],
        [ 67,  67,  66, ...,  66,  67,  67],
        [ 66,  66,  66, ...,  66,  66,  67],
        ...,
        [ 87,  73,  62, ...,  63,  63,  63],
        [ 84,  66,  61, ...,  64,  63,  63],
        [ 78,  62,  60, ...,  64,  63,  63]],

       [[ 90,  90,  91, ...,  90,  90,  90],
        [ 89,  89,  89, ...,  89,  89,  89],
        [ 88,  88,  88, ...,  88,  88,  89],
        ...,
        [117, 103,  91, ...,  94,  94,  93],
        [113,  95,  90, ...,  95,  94,  94],
        [107,  91,  89, ...,  95,  94,  94]]], dtype=uint8), 'state': array([-0.68256835, -0.73082176, -0.634587  , -0.77285144,  0.98291309,
       -0.18407027,  0.799729

In [19]:
action_space = env.action_space
observation_space = env.observation_space

print('Action Space: ', action_space)
print('Observation Space: ', observation_space)


In [15]:
import gymnasium as gym
from functools import partial
from envs.wrappers import RestartOnException

vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv
envs = vectorized_env(
    [
        partial(
            RestartOnException,
            make_env(
                cfg,
                cfg.seed + rank * cfg.env.num_envs + i,
                rank * cfg.env.num_envs,
                log_dir if rank == 0 else None,
                "train",
                vector_env_idx=i,
            ),
        )
        for i in range(cfg.env.num_envs)
    ]
)


NameError: name 'partial' is not defined

In [30]:
from utils.env import make_env
from envs.wrappers import RestartOnException

# Environment setup
rank = 0
env = RestartOnException(
    make_env(
        cfg,
        cfg.seed + rank,
        rank,
        log_dir if rank == 0 else None,
        "train",
        vector_env_idx=0,
    )()
)

action_space = env.action_space
observation_space = env.observation_space


MuJoCo version: 3.1.4
MuJoCo Path: /home/jianheng/miniconda3/envs/sheeprl/lib/python3.10/site-packages/mujoco/__init__.py


ModuleNotFoundError: No module named 'mujoco.gl'