In [1]:
#auto reload modules
%load_ext autoreload
%autoreload 2

In [2]:
import os.path as osp
import random
from typing import Dict,List

import gym.spaces as spaces
import hydra
import numpy as np
import torch
from hydra.utils import instantiate as hydra_instantiate
from omegaconf import DictConfig
from rl_utils.envs import create_vectorized_envs
from rl_utils.logging import Logger
from tensordict.tensordict import TensorDict
from torchrl.envs.utils import step_mdp

from imitation_learning.common.evaluator import Evaluator

In [3]:
from cassie import CassieEnv 

Device is  cuda


In [4]:
env = CassieEnv({})
env.render_mode = "rgb_array"

In [5]:
import yaml 
cfg = yaml.load(open("bc-irl-cassie.yaml", 'r'), Loader=yaml.SafeLoader)
cfg = DictConfig(cfg)

In [6]:
from rl_utils.envs.pointmass.pointmass_env import PointMassParams

In [7]:

def set_seed(seed: int) -> None:
    """
    Sets the seed for numpy, python random, and pytorch.
    """
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

In [8]:
from gymnasium.envs.mujoco.mujoco_env import MujocoEnv
class vectorized_env():
    def __init__(self, envs : List[MujocoEnv]):
        self.envs = envs
        self.num_envs = len(self.envs)
        self.observation_space = self.envs[0].observation_space
        self.action_space = self.envs[0].action_space
    def reset(self):
        return torch.tensor([env.reset()[0] for env in self.envs],dtype=torch.float32)
    
    def step(self, action): 
        steps = [env.step(action[i]) for i,env in enumerate(self.envs)]
        return (torch.tensor([step[0] for step in steps],dtype=torch.float32),torch.tensor([step[1] for step in steps],dtype=torch.float32),torch.tensor([step[2] for step in steps],dtype=torch.bool),[step[3] for step in steps])

In [9]:
set_seed(cfg.seed)

device = torch.device(cfg.device)

# Setup the environments
set_env_settings = {
    k: hydra_instantiate(v) if isinstance(v, DictConfig) else v
    for k, v in cfg.env.env_settings.items()
}
envs = vectorized_env([CassieEnv({}) for _ in range(cfg.num_envs)])

steps_per_update = cfg.num_steps * cfg.num_envs
num_updates = int(cfg.num_env_steps) // steps_per_update

# Set dynamic variables in the config.
cfg.obs_shape = envs.observation_space.shape
cfg.action_dim = envs.action_space.shape[0]
cfg.action_is_discrete = isinstance(cfg.action_dim, spaces.Discrete)
cfg.total_num_updates = num_updates

logger: Logger = hydra_instantiate(cfg.logger, full_cfg=cfg)
policy = hydra_instantiate(cfg.policy)
policy = policy.to(device)
updater = hydra_instantiate(cfg.policy_updater, policy=policy, device=device)



DEBUG:matplotlib:matplotlib data path: /home/alhussein.jamil/.pyenv/versions/3.7.16/envs/bcirl2/lib/python3.7/site-packages/matplotlib/mpl-data
DEBUG:matplotlib:CONFIGDIR=/home/alhussein.jamil/.config/matplotlib
DEBUG:matplotlib:interactive is False
DEBUG:matplotlib:platform is linux


Assigning full prefix 530-3-BSxwVo


DEBUG:matplotlib:CACHEDIR=/home/alhussein.jamil/.cache/matplotlib
DEBUG:matplotlib.font_manager:Using fontManager instance from /home/alhussein.jamil/.cache/matplotlib/fontlist-v330.json


50000 5 16


In [10]:

start_update = 0
if cfg.load_checkpoint is not None:
    # Load a checkpoint for the policy/reward. Also potentially resume
    # training.
    ckpt = torch.load(cfg.load_checkpoint)
    updater.load_state_dict(ckpt["updater"], should_load_opt=cfg.resume_training)
    if cfg.load_policy:
        policy.load_state_dict(ckpt["policy"])
    if cfg.resume_training:
        start_update = ckpt["update_i"] + 1

eval_info = {"run_name": logger.run_name}


In [11]:
import warnings 
warnings.filterwarnings("ignore")

In [12]:

obs = envs.reset()
td = TensorDict({"observation": obs}, batch_size=[cfg.num_envs])
# Storage for the rollouts
storage_td = TensorDict({}, batch_size=[cfg.num_envs, cfg.num_steps], device=device)

for update_i in range(start_update, num_updates):
    is_last_update = update_i == num_updates - 1
    for step_idx in range(cfg.num_steps):
        # Collect experience.
        with torch.no_grad():
            policy.act(td)
        next_obs, reward, done, infos = envs.step(td["action"])

        td["next_observation"] = next_obs
        for env_i, info in enumerate(infos):
            if "final_obs" in info:
                td["next_observation"][env_i] = info["final_obs"]
        td["reward"] = reward.reshape(-1, 1)

        td["done"] = done
        storage_td[:, step_idx] = td
        td["observation"] = next_obs
        # Log to CLI/wandb.
        logger.collect_env_step_info(infos)

    # Call method specific update function
    updater.update(policy, storage_td, logger, envs=envs)


    if cfg.log_interval != -1 and (
        update_i % cfg.log_interval == 0 or is_last_update
    ):
        logger.interval_log(update_i, steps_per_update * (update_i + 1))

    if cfg.save_interval != -1 and (
        (update_i + 1) % cfg.save_interval == 0 or is_last_update
    ):
        save_name = osp.join(logger.save_path, f"ckpt.{update_i}.pth")
        torch.save(
            {
                "policy": policy.state_dict(),
                "updater": updater.state_dict(),
                "update_i": update_i,
            },
            save_name,
        )
        print(f"Saved to {save_name}")
        eval_info["last_ckpt"] = save_name

logger.close()
print(eval_info)

DEBUG:matplotlib.pyplot:Loaded backend module://matplotlib_inline.backend_inline version unknown.
DEBUG:matplotlib.pyplot:Loaded backend module://matplotlib_inline.backend_inline version unknown.
DEBUG:matplotlib.font_manager:findfont: Matching sans\-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=10.0.
DEBUG:matplotlib.font_manager:findfont: score(FontEntry(fname='/home/alhussein.jamil/.pyenv/versions/3.7.16/envs/bcirl2/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf/cmb10.ttf', name='cmb10', style='normal', variant='normal', weight=400, stretch='normal', size='scalable')) = 10.05
DEBUG:matplotlib.font_manager:findfont: score(FontEntry(fname='/home/alhussein.jamil/.pyenv/versions/3.7.16/envs/bcirl2/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSerif.ttf', name='DejaVu Serif', style='normal', variant='normal', weight=400, stretch='normal', size='scalable')) = 10.05
DEBUG:matplotlib.font_manager:findfont: score(FontEntry(fname='/home/alhus


Updates 0, Steps 80, FPS 14
Over the last 0 episodes:
    - value_loss: 0.07510656788945198
    - action_loss: 0.14939691945910455
    - dist_entropy: 14.189386367797852
    - irl_loss: 63.07669448852539

Updates 10, Steps 880, FPS 27
Over the last 0 episodes:
    - value_loss: 0.025855140760540962
    - action_loss: -0.08418859876692295
    - dist_entropy: 14.188786506652832
    - irl_loss: 63.06956481933594
    - inferred_episode_reward: 0.04512657821178436

Updates 20, Steps 1680, FPS 28
Over the last 0 episodes:
    - value_loss: 0.045288025960326196
    - action_loss: -0.0027310811681672932
    - dist_entropy: 14.188984870910645
    - irl_loss: 63.09466552734375
    - inferred_episode_reward: -0.00014535840600728988

Updates 30, Steps 2480, FPS 28
Over the last 0 episodes:
    - value_loss: 0.06999862119555474
    - action_loss: -0.04938741829246283
    - dist_entropy: 14.189784049987793
    - irl_loss: 63.15843811035156
    - inferred_episode_reward: 0.14909648299217224

Updates