In [3]:
import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy

# from imitation.algorithms.adversarial.airl import AIRL
from IRL_lib_mod.airl import AIRL
from imitation.algorithms.adversarial.airl import AIRL as AIRL_old
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from envs.wrappers import SequentialObservationWrapper
from imitation.policies.serialize import load_policy
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.util.networks import RunningNorm
from utils.irl_utils import make_vec_env_robosuite
from utils.demostration_utils import load_dataset_to_trajectories
import os
import h5py
import json
from robosuite.controllers import load_controller_config
from utils.demostration_utils import load_dataset_and_annotations_simutanously
from utils.annotation_utils import read_all_json
from imitation.util import logger as imit_logger
import imitation.scripts.train_adversarial as train_adversarial
import torch
import sys
import argparse

from stable_baselines3.common.callbacks import BaseCallback

In [4]:
print_cnt = 0

class CustomLoggingPolicy(MlpPolicy):
    def forward(self, obs: torch.Tensor, deterministic: bool = False):
        global print_cnt
        print_cnt += 1

            # Get the action, value, and log probability from the parent class
        actions, values, log_probs = super().forward(obs, deterministic)
        if print_cnt % 2000 == 0:
            print(f"Actions: {actions[-1].detach().cpu().numpy()}")
                        # Convert actions to NumPy for easier processing
            actions_np = actions.detach().cpu().numpy()
            # Update total actions and count of positive last elements

            positive_last = np.sum(actions_np[:, -1] > 0)
            ratio = positive_last / actions_np.shape[0]
            print(f"Positive ratio: {ratio}")

        # Log the actions (you can adjust the logging as needed)
        
        
        # Return the outputs as usual
        return actions, values, log_probs

In [16]:
env_name = "square"
exp_name = "default_experiment"
dataset_type = "mh"
load_exp_name = "mh_sign_scale_loss_8m_1"
checkpoint = "320"
sequence_keys = []
obs_seq_len = 1
num_timesteps = 1e6
num_envs = 1
continue_training = False

# parser.dataset_type = "mh"
# parser.exp_name = "default_experiment"
# parser.continue_training = False
# parser.checkpoint = "320"
# parser.load_exp_name = "mh_sign_scale_loss_8m_1"
# parser.sequence_keys = []
# parser.obs_seq_len = 1
# parser.num_timesteps = 1e6
# parser.num_envs = 1

In [6]:
project_path = ""
dataset_path = "human-demo/square/low_dim_v141_square_mh.hdf5"
log_dir = os.path.join(project_path,f"logs/{exp_name}")
print(dataset_path)
f = h5py.File(dataset_path,'r')

config_path = os.path.join(project_path,"configs/osc_position.json")
with open(config_path, 'r') as cfg_file:
    configs = json.load(cfg_file)

controller_config = load_controller_config(default_controller="OSC_POSE")
env_meta = json.loads(f["data"].attrs["env_args"])
SEED = 42

human-demo/square/low_dim_v141_square_mh.hdf5


In [7]:
make_env_kwargs = dict(
    robots="Panda",             # load a Sawyer robot and a Panda robot
    gripper_types="default",                # use default grippers per robot arm
    controller_configs=env_meta["env_kwargs"]["controller_configs"],   # each arm is controlled using OSC
    has_renderer=False,                      # on-screen rendering
    render_camera="frontview",              # visualize the "frontview" camera
    has_offscreen_renderer=False,           # no off-screen rendering
    control_freq=20,                        # 20 hz control for applied actions
    horizon=500,                            # each episode terminates after 300 steps
    use_object_obs=True,                   # no observations needed
    use_camera_obs=False,
    reward_shaping=True,
    
)


if len(sequence_keys) > 0:
    sequential_wrapper_kwargs = dict(
        sequential_observation_keys = sequence_keys, 
        sequential_observation_length = obs_seq_len, 
        use_half_gripper_obs = True
    )

    seqential_wrapper_cls = SequentialObservationWrapper
    make_sequential_obs = True


else:
    sequential_wrapper_kwargs = None
    seqential_wrapper_cls = None
    make_sequential_obs = False



In [8]:
envs = make_vec_env_robosuite(
    "NutAssemblySquare",
    obs_keys = ["object-state","robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"],
    rng=np.random.default_rng(SEED),
    n_envs=1,
    parallel=True,
    post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],  # to compute rollouts
    env_make_kwargs=make_env_kwargs,
    sequential_wrapper = seqential_wrapper_cls,
    sequential_wrapper_kwargs = sequential_wrapper_kwargs
)



In [12]:
annotation_dict = read_all_json("nut_assembly_square_mh")

trajs = load_dataset_to_trajectories(["object","robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"],
                                        dataset_path = "human-demo/square/low_dim_v141_square_mh.hdf5", 
                                        make_sequential_obs=make_sequential_obs,
                                        sequential_obs_keys=sequence_keys,
                                        obs_seq_len=obs_seq_len,
                                        use_half_gripper_obs=True
                                        )

# for i in range(len(trajs)):
#     if trajs[i].obs.shape[1] != 31:
#         print(trajs[i].obs.shape)

trajs_for_shaping, annotation_list = load_dataset_and_annotations_simutanously(["object","robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"],
                                                                    annotation_dict=annotation_dict,
                                                                    dataset_path=dataset_path,
                                                                    make_sequential_obs=make_sequential_obs,
                                        sequential_obs_keys=sequence_keys,
                                        obs_seq_len=obs_seq_len,
                                        use_half_gripper_obs=True)
# type of reward shaping to use
# change this to enable or disable reward shaping
#shape_reward = ["progress_sign_loss", "value_sign_loss", "advantage_sign_loss"]
shape_reward = []

In [17]:
learner = PPO(
    env=envs,
    policy=CustomLoggingPolicy,  # Use your custom policy here
    #policy=MlpPolicy,
    batch_size=256,
    ent_coef=0.01,
    learning_rate=3e-4,
    gamma=0.95,
    clip_range=0.2,
    vf_coef=0.5,
    n_epochs=10,
    seed=SEED,
)
reward_net = BasicShapedRewardNet(
    observation_space=envs.observation_space,
    action_space=envs.action_space,
    normalize_input_layer=RunningNorm,
    reward_hid_sizes=(64, 64),
    potential_hid_sizes=(64, 64),
)
generator_model_path = f"{project_path}/checkpoints/{load_exp_name}/{checkpoint}/gen_policy/model"
if continue_training:
    reward_net = (torch.load(f"{project_path}/checkpoints/{load_exp_name}/{checkpoint}/reward_train.pt"))
    learner = PPO.load(generator_model_path)
# logger that write tensroborad to logs dir
logger = imit_logger.configure(folder=log_dir, format_strs=["tensorboard"])
airl_trainer = AIRL(
    demonstrations=trajs,
    demo_batch_size=128,
    gen_replay_buffer_capacity=20000,
    n_disc_updates_per_round=10,
    venv=envs,
    gen_algo=learner,
    reward_net=reward_net,
    shape_reward = shape_reward,
    annotation_list=annotation_list,
    demostrations_for_shaping=trajs_for_shaping,
    custom_logger = logger,
    save_path = f"checkpoints/{exp_name}",
)

creating save path


In [18]:
airl_trainer.train(2_00_000)

round:   0%|          | 0/97 [00:17<?, ?it/s]


KeyboardInterrupt: 