In [1]:
import json
import math
import os
from pathlib import Path

import gymnasium as gym
import mani_skill2.envs
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from mani_skill2.utils.wrappers import RecordEpisode
from torch.nn import (Flatten, Linear, TransformerEncoder,
                      TransformerEncoderLayer)
from torch.nn.functional import relu
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from data.dataset import StackDatasetOriginalSequential
from utils.data_utils import flatten_obs, make_path
from utils.train_utils import init_deque, update_deque

from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy

import os
import h5py
from imitation.data.types import Trajectory




In [None]:
# prepare Trajectory for imitation package
def prep_Trajectory(file_path):
    traj_list = []
    with h5py.File(file_path,'r') as file:
        for traj_key in file.keys():
            traj_data = file[traj_key]
            obs = flatten_obs(traj_data['obs'])
            acts = np.array(traj_data['actions'])
            traj = Trajectory(obs, acts, infos=None,terminal=True) 
            traj_list.append(traj)  
            print(obs.shape) # (127, 55)
            print(acts.shape) #(126,8)
            #print(traj_list)
    return traj_list
dir_path = os.getcwd()

data_path = os.path.join(dir_path, '..', 'datasets')
file_path = os.path.join(data_path, 'trajectory_state_original.h5') 

traj_list = prep_Trajectory(file_path)

(127, 55)
(126, 8)
(138, 55)
(137, 8)
(134, 55)
(133, 8)
(137, 55)
(136, 8)
(143, 55)
(142, 8)
(147, 55)
(146, 8)
(120, 55)
(119, 8)
(130, 55)
(129, 8)
(153, 55)
(152, 8)
(145, 55)
(144, 8)
(130, 55)
(129, 8)
(157, 55)
(156, 8)
(183, 55)
(182, 8)
(156, 55)
(155, 8)
(135, 55)
(134, 8)
(141, 55)
(140, 8)
(173, 55)
(172, 8)
(135, 55)
(134, 8)
(147, 55)
(146, 8)
(140, 55)
(139, 8)
(133, 55)
(132, 8)
(159, 55)
(158, 8)
(137, 55)
(136, 8)
(130, 55)
(129, 8)
(155, 55)
(154, 8)
(150, 55)
(149, 8)
(159, 55)
(158, 8)
(170, 55)
(169, 8)
(150, 55)
(149, 8)
(149, 55)
(148, 8)
(135, 55)
(134, 8)
(144, 55)
(143, 8)
(134, 55)
(133, 8)
(145, 55)
(144, 8)
(140, 55)
(139, 8)
(156, 55)
(155, 8)
(137, 55)
(136, 8)
(142, 55)
(141, 8)
(161, 55)
(160, 8)
(152, 55)
(151, 8)
(178, 55)
(177, 8)
(170, 55)
(169, 8)
(122, 55)
(121, 8)
(147, 55)
(146, 8)
(160, 55)
(159, 8)
(129, 55)
(128, 8)
(132, 55)
(131, 8)
(136, 55)
(135, 8)
(167, 55)
(166, 8)
(155, 55)
(154, 8)
(166, 55)
(165, 8)
(135, 55)
(134, 8)
(199, 55)
(1

In [3]:
SEED = 42 

env = gym.make('StackCube-v0',
               obs_mode="state",
               control_mode="pd_joint_delta_pos",
               reward_mode="normalized_dense",
               render_mode="cameras",
               max_episode_steps=250)


[2023-11-16 16:49:49.652] [svulkan2] [error] GLFW error: X11: The DISPLAY environment variable is missing


In [8]:
from stable_baselines3.common.vec_env import DummyVecEnv

# Create a function that returns a new environment instance
def make_env():
    return gym.make('StackCube-v0',
               obs_mode="state",
               control_mode="pd_joint_delta_pos",
               reward_mode="normalized_dense",
               render_mode="cameras",
               max_episode_steps=250)
# Number of parallel environments
num_envs = 4

# Create a list of environment instances
envs = [make_env for _ in range(num_envs)]

# Wrap these environments with DummyVecEnv
env = DummyVecEnv(envs)


[2023-11-16 16:58:16.032] [svulkan2] [error] GLFW error: X11: The DISPLAY environment variable is missing


In [9]:

learner = PPO(
    env=env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=SEED,
)
reward_net = BasicRewardNet(
    observation_space=env.observation_space,
    action_space=env.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=traj_list,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=env,
    gen_algo=learner,
    reward_net=reward_net,
    allow_variable_horizon=True,
)
env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

# train the learner and evaluate again
gail_trainer.train(800_000)
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)



Running with `allow_variable_horizon` set to True. Some algorithms are biased towards shorter or longer episodes, which may significantly confound results. Additionally, even unbiased algorithms can exploit the information leak from the termination condition, producing spuriously high performance. See https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html for more information.


round:   0%|          | 0/390 [00:00<?, ?it/s]

--------------------------------------
| raw/                        |      |
|    gen/time/fps             | 188  |
|    gen/time/iterations      | 1    |
|    gen/time/time_elapsed    | 10   |
|    gen/time/total_timesteps | 2048 |
--------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.38     |
|    disc/disc_acc_expert             | 0.706    |
|    disc/disc_acc_gen                | 0.0547   |
|    disc/disc_entropy                | 0.692    |
|    disc/disc_loss                   | 0.713    |
|    disc/disc_proportion_expert_pred | 0.826    |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 1.02e+03 |
|    disc/n_generated                 | 1.02e+03 |
--------------------------------------------------
--------------------------------------------------
| raw/       

round:   0%|          | 1/390 [00:13<1:24:23, 13.02s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 179         |
|    gen/time/fps                    | 193         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 10          |
|    gen/time/total_timesteps        | 4096        |
|    gen/train/approx_kl             | 0.015432902 |
|    gen/train/clip_fraction         | 0.143       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -11.4       |
|    gen/train/explained_variance    | -0.0149     |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.501       |
|    gen/train/n_updates             | 5           |
|    gen/train/policy_gradient_loss  | -0.0134     |
|    gen/train/std                   | 1.01        |
|    gen/train/value_loss            | 9.7         |
----------------------------------------------

round:   1%|          | 2/390 [00:25<1:21:35, 12.62s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 173         |
|    gen/time/fps                    | 187         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 10          |
|    gen/time/total_timesteps        | 6144        |
|    gen/train/approx_kl             | 0.010771541 |
|    gen/train/clip_fraction         | 0.0788      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -11.4       |
|    gen/train/explained_variance    | 0.313       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.134       |
|    gen/train/n_updates             | 10          |
|    gen/train/policy_gradient_loss  | -0.00988    |
|    gen/train/std                   | 1.01        |
|    gen/train/value_loss            | 2.04        |
----------------------------------------------

round:   1%|          | 3/390 [00:38<1:21:40, 12.66s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 157         |
|    gen/time/fps                    | 194         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 10          |
|    gen/time/total_timesteps        | 8192        |
|    gen/train/approx_kl             | 0.008055401 |
|    gen/train/clip_fraction         | 0.0762      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -11.4       |
|    gen/train/explained_variance    | 0.0132      |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0422      |
|    gen/train/n_updates             | 15          |
|    gen/train/policy_gradient_loss  | -0.0124     |
|    gen/train/std                   | 1.01        |
|    gen/train/value_loss            | 2.14        |
----------------------------------------------

round:   1%|          | 4/390 [00:50<1:20:38, 12.53s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 150         |
|    gen/time/fps                    | 188         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 10          |
|    gen/time/total_timesteps        | 10240       |
|    gen/train/approx_kl             | 0.011489572 |
|    gen/train/clip_fraction         | 0.134       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -11.4       |
|    gen/train/explained_variance    | 0.668       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0207      |
|    gen/train/n_updates             | 20          |
|    gen/train/policy_gradient_loss  | -0.0155     |
|    gen/train/std                   | 1.01        |
|    gen/train/value_loss            | 0.126       |
----------------------------------------------

round:   1%|▏         | 5/390 [01:03<1:20:43, 12.58s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 139         |
|    gen/time/fps                    | 186         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 10          |
|    gen/time/total_timesteps        | 12288       |
|    gen/train/approx_kl             | 0.009092348 |
|    gen/train/clip_fraction         | 0.0613      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -11.4       |
|    gen/train/explained_variance    | 0.000626    |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.116       |
|    gen/train/n_updates             | 25          |
|    gen/train/policy_gradient_loss  | -0.00966    |
|    gen/train/std                   | 1.01        |
|    gen/train/value_loss            | 2.25        |
----------------------------------------------

round:   2%|▏         | 6/390 [01:16<1:21:19, 12.71s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_rew_wrapped_mean | 134        |
|    gen/time/fps                    | 185        |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 11         |
|    gen/time/total_timesteps        | 14336      |
|    gen/train/approx_kl             | 0.01719587 |
|    gen/train/clip_fraction         | 0.175      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -11.4      |
|    gen/train/explained_variance    | 0.722      |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | 0.0665     |
|    gen/train/n_updates             | 30         |
|    gen/train/policy_gradient_loss  | -0.0207    |
|    gen/train/std                   | 1          |
|    gen/train/value_loss            | 0.229      |
---------------------------------------------------
------------

round:   2%|▏         | 7/390 [01:28<1:21:17, 12.74s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 131         |
|    gen/time/fps                    | 183         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 11          |
|    gen/time/total_timesteps        | 16384       |
|    gen/train/approx_kl             | 0.011524811 |
|    gen/train/clip_fraction         | 0.123       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -11.4       |
|    gen/train/explained_variance    | 0.641       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.143       |
|    gen/train/n_updates             | 35          |
|    gen/train/policy_gradient_loss  | -0.0175     |
|    gen/train/std                   | 1           |
|    gen/train/value_loss            | 0.934       |
----------------------------------------------

round:   2%|▏         | 8/390 [01:41<1:21:29, 12.80s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_rew_wrapped_mean | 124        |
|    gen/time/fps                    | 190        |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 10         |
|    gen/time/total_timesteps        | 18432      |
|    gen/train/approx_kl             | 0.01042669 |
|    gen/train/clip_fraction         | 0.088      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -11.4      |
|    gen/train/explained_variance    | 0.662      |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | 0.446      |
|    gen/train/n_updates             | 40         |
|    gen/train/policy_gradient_loss  | -0.0139    |
|    gen/train/std                   | 1          |
|    gen/train/value_loss            | 1.27       |
---------------------------------------------------
------------

round:   2%|▏         | 9/390 [01:54<1:20:45, 12.72s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 118         |
|    gen/time/fps                    | 190         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 10          |
|    gen/time/total_timesteps        | 20480       |
|    gen/train/approx_kl             | 0.012530882 |
|    gen/train/clip_fraction         | 0.111       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -11.4       |
|    gen/train/explained_variance    | 0.678       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.187       |
|    gen/train/n_updates             | 45          |
|    gen/train/policy_gradient_loss  | -0.0147     |
|    gen/train/std                   | 1           |
|    gen/train/value_loss            | 0.938       |
----------------------------------------------

round:   3%|▎         | 10/390 [02:09<1:21:58, 12.94s/it]


KeyboardInterrupt: 