In [None]:
import json
import math
import os
from pathlib import Path
from typing import List

import gymnasium as gym
import h5py
import mani_skill2.envs
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from imitation.algorithms.adversarial.gail import GAIL
from imitation.data.types import Trajectory
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.logger import configure
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
from mani_skill2.utils.wrappers import RecordEpisode
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import (DummyVecEnv, SubprocVecEnv,
                                              VecFrameStack)
from stable_baselines3.ppo import MlpPolicy
from torch.nn import (Flatten, Linear, TransformerEncoder,
                      TransformerEncoderLayer)
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from data.dataset import StackDatasetOriginalSequential
from utils.data_utils import flatten_obs, make_path
from utils.train_utils import init_deque, update_deque

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

ckpt_path = make_path('GAIL', 'checkpoints')
log_path = make_path('GAIL', 'logs')
tensorboard_path = make_path('GAIL', 'logs', 'tensorboard')
data_path = os.path.join('..', 'datasets', 'trajectory_state_original.h5')

Path(ckpt_path).mkdir(exist_ok=True, parents=True)
Path(log_path).mkdir(exist_ok=True, parents=True)
Path(tensorboard_path).mkdir(exist_ok=True, parents=True)

In [None]:
# prepare trajectory for imitation package
def prep_trajectory(file_path: str) -> List[Trajectory]:
    traj_list = []
    with h5py.File(file_path, 'r') as file:
        for traj_key in file.keys():
            traj_data = file[traj_key]
            obs = flatten_obs(traj_data['obs'])
            acts = np.array(traj_data['actions'])
            traj = Trajectory(obs, acts, infos=None, terminal=True)
            traj_list.append(traj)
    return traj_list

In [None]:
# venv = make_vec_env(
#     "StackCube-v0",
#     rng=np.random.default_rng(seed=42),
#     parallel=False,
#     n_envs=1,
#     log_dir=log_path,
#     max_episode_steps=250,
#     post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
#     env_make_kwargs=dict(obs_mode='state',
#                          control_mode='pd_joint_delta_pos',
#                          reward_mode='normalized_dense',
#                          render_mode='cameras')
# )
# stack_env = VecFrameStack(venv, n_stack=8)

In [None]:

SEED = 42
trajectories = prep_trajectory(data_path)

venv = make_vec_env(
    "StackCube-v0",
    rng=np.random.default_rng(seed=SEED),
    parallel=False,
    n_envs=4,
    log_dir=log_path,
    max_episode_steps=250,
    post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
    env_make_kwargs=dict(obs_mode='state',
                         control_mode='pd_joint_delta_pos',
                         reward_mode='normalized_dense',
                         render_mode='cameras')
)


learner = PPO(
    env=venv,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=1e-6,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=SEED,
    tensorboard_log=tensorboard_path
)


reward_net = BasicRewardNet(
    observation_space=venv.observation_space,
    action_space=venv.action_space,
    normalize_input_layer=RunningNorm,
)


gail_trainer = GAIL(
    demonstrations=trajectories,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
    allow_variable_horizon=True,
    log_dir=log_path,
    init_tensorboard=True,
    init_tensorboard_graph=True,
    custom_logger=configure(log_path, ('tensorboard', 'stdout', 'csv'))
)


# venv.seed(SEED)
# learner_rewards_before_training, _ = evaluate_policy(model=learner,
#                                                      env=venv,
#                                                      n_eval_episodes=100,
#                                                      return_episode_rewards=True)


# train the learner and evaluate again
ckpts = 10
steps_per_ckpt = 100_000
for ckpt in range(1, ckpts+1):
    gail_trainer.train(steps_per_ckpt)
    learner.save(os.path.join(ckpt_path, f'PPO_{ckpt*steps_per_ckpt}'))
    torch.save(reward_net.state_dict(), os.path.join(ckpt_path, f'Reward_{ckpt*steps_per_ckpt}'))

# venv.seed(SEED)
# learner_rewards_after_training, _ = evaluate_policy(model=learner,
#                                                     env=venv,
#                                                     n_eval_episodes=100,
#                                                     return_episode_rewards=True)

In [7]:
learner.learn(total_timesteps=100_000, progress_bar=True)

Output()

In [None]:
env = gym.make('StackCube-v0',
                render_mode="cameras",
                enable_shadow=True,
                obs_mode="state",
                control_mode="pd_joint_delta_pos", 
                max_episode_steps=400)

env = RecordEpisode(
    env,
    log_path,
    info_on_video=True,
    save_trajectory=False
)


obs, _ = env.reset()

action, info = learner.predict(obs, deterministic=True)

terminated = False
truncated = False

with torch.no_grad():
    while not terminated and not truncated:
        action, info = learner.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)

env.flush_video(suffix=f'GAIL')
env.close()