In [None]:
from active_critic.utils.gym_utils import *
from gym.spaces.box import Box

In [45]:
class MultiEnvWrapper(gym.Wrapper):
    def __init__(self, list_envs) -> None:
        super().__init__(list_envs[0])
        self.list_envs = list_envs
        self.current_env = 0

    def reset(self):
        self.current_env = (self.current_env + 1) % len(self.list_envs)
        self._make_observation_space()
        obs =  super().reset()
        obs = np.concatenate((obs[:-4], np.array([self.current_env]), obs[-4:]))
        return obs

    def step(self, action):

        obs, rew, done, info = super().step(action)
        obs = np.append(obs, self.current_env)
        obs = np.concatenate((obs[:-4], np.array([self.current_env]), obs[-4:]))

        return obs, rew, done, info

    def _make_observation_space(self):
        low = self.list_envs[self.current_env].observation_space.low
        high = self.list_envs[self.current_env].observation_space.high
        dtype = self.list_envs[self.current_env].observation_space.dtype
        new_low = np.concatenate((low[:-4], np.array([0]), low[-4:]))
        new_high = np.concatenate((high[:-4], np.array([len(self.list_envs)]), high[-4:]))
        new_obs_space = Box(new_low, new_high, dtype=dtype)
        self.observation_space = new_obs_space
        self.action_space = self.list_envs[self.current_env].action_space

def make_env_list(env_ids):
    policy_dict = make_policy_dict()
    list_envs = []
    list_experts = []
    for i in range(len(env_ids)):
        exp, env_id = policy_dict[env_ids[i]]
        env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id]()
        env._freeze_rand_vec = False
        list_envs.append(env)
        list_experts.append(exp)
    return list_envs, list_experts
    
class MultiImitationLearningWrapper:
    def __init__(self, policies, env: GymEnv):
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.policies = policies

    def predict(self, obsv, deterministic=None):
        actions = []
        for obs in obsv:
            current_env = int(obs[-4])
            actions.append(self.policies[current_env].get_action(obs))
        return actions

In [46]:
policy_dict = make_policy_dict()

env_tag = 'reach'
reach = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[policy_dict[env_tag][1]]()

env_tag = 'push'
push = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[policy_dict[env_tag][1]]()

  logger.warn(


In [47]:
def make_multi_vec_env(env_ids, num_cpu, seq_len, sparse):

    def make_env(env_ids, rank, seed=0):
        def _init():
            max_episode_steps = seq_len
            list_envs, list_experts = make_env_list(env_ids)
            #multi_env = MultiEnvWrapper(list_envs=list_envs)
            multi_env = list_envs[0]
            timelimit = TimeLimit(env=multi_env, max_episode_steps=max_episode_steps)
            strict_time = StrictSeqLenWrapper(timelimit, seq_len=seq_len + 1, sparse=sparse)
            riw = RolloutInfoWrapper(strict_time)
            return riw
        return _init
        
    env = SubprocVecEnv([make_env(env_ids, i) for i in range(num_cpu)])

    list_envs, list_experts = make_env_list(env_ids)
    vec_expert = MultiImitationLearningWrapper(
        policies=list_experts, env=env)
    return env, vec_expert

In [49]:
multi_vec, multi_exp = make_multi_vec_env(env_ids=['reach'], num_cpu=2, seq_len=100, sparse=False)

  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(


In [50]:
obs = multi_vec.reset()
done = False
while not done:
    actions = multi_exp.predict(obs)
    ons, rew, dones, info = multi_vec.step(actions)
    done = dones[0]
    print(rew)
#print(rew)

[1.63983315 1.62969233]
[1.6666043 1.6545956]
[1.72306106 1.70609198]
[1.81173919 1.78664716]
[1.93444149 1.89792339]
[2.09326614 2.04161068]
[2.29097137 2.21969374]
[2.53009393 2.43166739]
[2.7978799 2.6652496]
[3.09596583 2.91782885]
[3.43434156 3.19500362]
[3.82297247 3.50417463]
[4.2700624  3.85171766]
[4.78077234 4.24173759]
[5.35545633 4.67524409]
[5.98742219 5.14946804]
[6.6608195 5.6574065]
[7.34964935 6.18779138]
[8.01906555 6.72566649]
[8.62977985 7.25365942]
[9.14524937 7.75386073]
[9.53974859 8.21001678]
[9.80442921 8.60958516]
[9.94900149 8.94517708]
[9.99855636 9.21505294]
[10.          9.42259959]
[10.        9.574993]
[10.          9.68143395]
[10.          9.75138841]
[10.          9.79318441]
[10.          9.81316577]
[10.          9.81935918]
[10.         9.8253385]
[9.99612303 9.83953393]
[9.96038556 9.85907395]
[9.88194299 9.87850302]
[9.75792938 9.89394666]
[9.58903683 9.90371983]
[9.37869386 9.90744784]
[9.13219424 9.90512339]
[8.86484456 9.89599956]
[8.59902351 



In [32]:
obs = multi_vec.reset()

In [33]:
obs

array([[ 0.00615235,  0.6001898 ,  0.19430117,  1.        ,  0.00944557,
         0.68723255,  0.02      ,  0.        ,  0.        ,  0.        ,
         1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.00615235,  0.6001898 ,
         0.19430117,  1.        ,  0.01988441,  0.6219765 ,  0.02      ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        , -0.07862544,  0.88846986,  0.13023407],
       [ 0.00615235,  0.6001898 ,  0.19430117,  1.        ,  0.08263704,
         0.67112505,  0.02      ,  0.        ,  0.        ,  0.        ,
         1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.00615235,  0.6001898 ,
         0.19430117,  1.        ,  0.04957775,  0.67235942,  0.02      ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.

In [51]:
multi_vec.reset()

array([[ 0.00615235,  0.6001898 ,  0.19430117,  1.        , -0.07985431,
         0.63109425,  0.02      ,  0.        ,  0.        ,  0.        ,
         1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.00615235,  0.6001898 ,
         0.19430117,  1.        ,  0.06500112,  0.65404595,  0.02      ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        , -0.02662927,  0.85583629,  0.0702654 ],
       [ 0.00615235,  0.6001898 ,  0.19430117,  1.        , -0.00392577,
         0.65428067,  0.02      ,  0.        ,  0.        ,  0.        ,
         1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.00615235,  0.6001898 ,
         0.19430117,  1.        ,  0.03766309,  0.66409218,  0.02      ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.

In [43]:
b = np.concatenate((a, np.array([100]), a[10:]))

In [44]:
b

array([ 6.15235164e-03,  6.00189803e-01,  1.94301175e-01,  1.00000000e+00,
        9.44556741e-03,  6.87232548e-01,  1.99999996e-02,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  1.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  6.15235164e-03,  6.00189803e-01,
        1.94301175e-01,  1.00000000e+00,  1.98844086e-02,  6.21976500e-01,
        1.99999996e-02,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        1.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
       -7.86254363e-02,  8.88469862e-01,  1.30234065e-01,  1.00000000e+02,
        1.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        6.15235164e-03,  6.00189803e-01,  1.94301175e-01,  1.00000000e+00,
        1.98844086e-02,  