In [None]:
from active_critic.utils.gym_utils import *
from gym.spaces.box import Box

In [35]:
class MultiEnvWrapper(gym.Wrapper):
    def __init__(self, list_envs) -> None:
        super().__init__(list_envs[0])
        self.list_envs = list_envs
        self.current_env = 0
        #self._make_observation_space()


    def reset(self):
        self.current_env = (self.current_env + 1) % len(self.list_envs)
        #self._make_observation_space()
        obs =  super().reset()
        #obs = np.concatenate((obs[:-4], np.array([self.current_env]), obs[-4:]))
        return obs

    def step(self, action):

        obs, rew, done, info = super().step(action)
        #obs = np.append(obs, self.current_env)
        #obs = np.concatenate((obs[:-4], np.array([self.current_env]), obs[-4:]))

        return obs, rew, done, info

    def _make_observation_space(self):
        low = self.list_envs[self.current_env].observation_space.low
        high = self.list_envs[self.current_env].observation_space.high
        dtype = self.list_envs[self.current_env].observation_space.dtype
        new_low = np.concatenate((low[:-4], np.array([0]), low[-4:]))
        new_high = np.concatenate((high[:-4], np.array([len(self.list_envs)]), high[-4:]))
        new_obs_space = Box(new_low, new_high, dtype=dtype)
        self.observation_space = new_obs_space
        self.action_space = self.list_envs[self.current_env].action_space

def make_env_list(env_ids):
    policy_dict = make_policy_dict()
    list_envs = []
    list_experts = []
    for i in range(len(env_ids)):
        exp, env_id = policy_dict[env_ids[i]]
        env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id]()
        env._freeze_rand_vec = False
        list_envs.append(env)
        list_experts.append(exp)
    return list_envs, list_experts
    
class MultiImitationLearningWrapper:
    def __init__(self, policies, env: GymEnv):
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.policies = policies

    def predict(self, obsv, deterministic=None):
        actions = []
        for obs in obsv:
            current_env = int(obs[-4])
            actions.append(self.policies[current_env].get_action(obs))
        return actions

In [36]:
def make_multi_vec_env(env_ids, num_cpu, seq_len, sparse):

    def make_env(env_ids, rank, seed=0):
        def _init():
            max_episode_steps = seq_len
            list_envs, list_experts = make_env_list(env_ids)
            #multi_env = MultiEnvWrapper(list_envs=list_envs)
            multi_env = list_envs[0]
            timelimit = TimeLimit(env=multi_env, max_episode_steps=max_episode_steps)
            strict_time = StrictSeqLenWrapper(timelimit, seq_len=seq_len + 1, sparse=sparse)
            riw = RolloutInfoWrapper(strict_time)
            return riw
        return _init
        
    env = SubprocVecEnv([make_env(env_ids, i) for i in range(num_cpu)])

    list_envs, list_experts = make_env_list(env_ids)
    vec_expert = MultiImitationLearningWrapper(
        policies=list_experts, env=env)
    return env, vec_expert

In [39]:
vec_env, exp = make_vec_env(env_id='reach', num_cpu=2, seq_len=100, sparse=False)

  logger.warn(
  logger.warn(
  logger.warn(


  logger.warn(


In [46]:
obs = vec_env.reset()
done = False
while not done:
    actions = exp.predict(obs)
    ons, rew, dones, info = vec_env.step(actions)
    done = dones[0]
    print(rew)
#print(rew)

[1.55894833 1.53159285]
[1.58703984 1.55595658]
[1.6415646  1.60648643]
[1.72414922 1.68458073]
[1.8359031  1.79093174]
[1.97819434 1.92649214]
[2.15281861 2.09282307]
[2.36290938 2.29219682]
[2.59868941 2.51758656]
[2.84754165 2.76801682]
[3.11162683 3.05111536]
[3.39802752 3.3748716 ]
[3.71354096 3.74664321]
[4.06250615 4.17247163]
[4.44588242 4.65610494]
[4.86085653 5.1976238 ]
[5.30082664 5.79184953]
[5.75581184 6.42686222]
[6.21333788 7.08312245]
[6.65974702 7.73383596]
[7.08174318 8.34715776]
[7.48266593 8.8904181 ]
[7.88052648 9.33575212]
[8.27754587 9.6656564 ]
[8.65809009 9.87665889]
[9.0018782  9.97984908]
[ 9.29316904 10.        ]
[ 9.52471582 10.        ]
[ 9.69753549 10.        ]
[ 9.8185174 10.       ]
[ 9.8976044 10.       ]
[ 9.94546901 10.        ]
[ 9.97192564 10.        ]
[ 9.98496975 10.        ]
[ 9.99025301 10.        ]
[ 9.99086719 10.        ]
[ 9.98740269 10.        ]
[9.97828786 9.99117516]
[9.96037942 9.94748224]
[9.92970538 9.8643819 ]
[9.8829933  9.74521425

In [37]:
multi_vec, multi_exp = make_multi_vec_env(env_ids=['reach'], num_cpu=2, seq_len=100, sparse=False)

  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(


In [38]:
obs = multi_vec.reset()
done = False
while not done:
    actions = multi_exp.predict(obs)
    ons, rew, dones, info = multi_vec.step(actions)
    done = dones[0]
    print(rew)
#print(rew)

[1.33413614 1.36953196]
[1.36157165 1.39036371]
[1.41306213 1.43107989]
[1.48774436 1.4912391 ]
[1.58443556 1.56987766]
[1.70256555 1.66631762]
[1.84244509 1.78041685]
[2.00528003 1.91260036]
[2.19309415 2.06404861]
[2.41595215 2.23195083]
[2.66528375 2.41323278]
[2.93587565 2.61036733]
[3.23179822 2.82786478]
[3.56027636 3.07042932]
[3.92847458 3.34214131]
[4.34176951 3.64608275]
[4.80265849 3.98415242]
[5.30988232 4.35695084]
[5.85763262 4.76366699]
[6.43493326 5.20193036]
[7.02546329 5.66762212]
[7.60820084 6.15467538]
[8.15923193 6.65493274]
[8.65479586 7.15816162]
[9.07513532 7.65233835]
[9.40818036 8.12428499]
[9.65186027 8.5606693 ]
[9.81414712 8.94926957]
[9.91069541 9.2802965 ]
[9.9607659  9.54750439]
[9.98264841 9.74885091]
[9.98989988 9.88657886]
[9.98939963 9.96674797]
[9.98153409 9.99837479]
[ 9.96200317 10.        ]
[ 9.9242536 10.       ]
[ 9.86164175 10.        ]
[ 9.76888137 10.        ]
[ 9.64275007 10.        ]
[ 9.48223402 10.        ]
[ 9.29233646 10.        ]
[ 9.

In [None]:
obs = multi_vec.reset()

In [None]:
obs

In [None]:
multi_vec.reset()

In [None]:
b = np.concatenate((a, np.array([100]), a[10:]))

In [None]:
b