In [1]:
from active_critic.utils.gym_utils import *
from gym.spaces.box import Box

class MultiEnvWrapper(gym.Wrapper):
    def __init__(self, list_envs) -> None:
        super().__init__(list_envs[0])
        self.list_envs = list_envs
        self.current_env = 0
        self._make_observation_space()

    def set_current_env(self, env_num):
        self.current_env = int(env_num)
        self._make_observation_space()
        print(f'self.current_enc = {self.current_env}')
        
    def reset(self):
        obs =  super().reset()
        obs = self.list_envs[self.current_env].reset()
        obs = np.append(obs, self.current_env)
        print(f'obs: {obs}')

        return obs

    def step(self, action):

        obs, rew, done, info = self.list_envs[self.current_env].step(action)
        obs = np.append(obs, self.current_env)
        return obs, rew, done, info

    def _make_observation_space(self):
        low = self.list_envs[self.current_env].observation_space.low
        high = self.list_envs[self.current_env].observation_space.high
        dtype = self.list_envs[self.current_env].observation_space.dtype
        new_low = np.append(low, 0)
        new_high = np.append(high, len(self.list_envs))
        new_obs_space = Box(new_low, new_high, dtype=dtype)
        self.observation_space = new_obs_space
        self.action_space = self.list_envs[self.current_env].action_space

def make_env_list(env_ids):
    policy_dict = make_policy_dict()
    list_envs = []
    list_experts = []
    for i in range(len(env_ids)):
        exp, env_id = policy_dict[env_ids[i]]
        env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id]()
        env._freeze_rand_vec = False
        list_envs.append(env)
        list_experts.append(exp)
    return list_envs, list_experts
    
class MultiImitationLearningWrapper:
    def __init__(self, policies, env: GymEnv):
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.policies = policies

    def predict(self, obsv, deterministic=None):
        actions = []
        for obs in obsv:
            current_env = int(obs[-1])
            print(f'current_env: {current_env}')
            inpt_obsv = obs[:-1]
            print(inpt_obsv.shape)
            actions.append(self.policies[current_env].get_action(inpt_obsv))
        return actions

  from .autonotebook import tqdm as notebook_tqdm
  logger.warn(


In [2]:
def make_multi_vec_env(env_ids, num_cpu, seq_len, sparse):

    def make_env(env_ids, rank, seed=0):
        def _init():
            max_episode_steps = seq_len
            list_envs, list_experts = make_env_list(env_ids)
            multi_env = MultiEnvWrapper(list_envs=list_envs)
            timelimit = TimeLimit(env=multi_env, max_episode_steps=max_episode_steps)
            strict_time = StrictSeqLenWrapper(timelimit, seq_len=seq_len + 1, sparse=sparse)
            riw = RolloutInfoWrapper(strict_time)
            return riw
        return _init
        
    env = SubprocVecEnv([make_env(env_ids, i) for i in range(num_cpu)])

    list_envs, list_experts = make_env_list(env_ids)
    vec_expert = MultiImitationLearningWrapper(
        policies=list_experts, env=env)
    return env, vec_expert

In [3]:
multi_vec, multi_exp = make_multi_vec_env(env_ids=['reach', 'pickplace'], num_cpu=2, seq_len=100, sparse=False)

  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(


In [4]:
multi_vec.set_env_ids(0)

self.current_enc = 0
self.current_enc = 0


In [11]:
multi_vec.set_env_ids(1)

self.current_enc = 1
self.current_enc = 1


In [14]:
obs = multi_vec.reset()


::::::::::::::::::::::
<PickPlaceV2GoalObservable instance>
[ 6.15235164e-03  6.00189803e-01  1.94301175e-01  1.00000000e+00
 -3.81644812e-02  6.25547348e-01  1.99999996e-02  0.00000000e+00
  0.00000000e+00  0.00000000e+00  1.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  5.71231149e-03  6.00692459e-01
  1.94438913e-01  1.00000000e+00 -2.55433033e-02  6.75508668e-01
  1.99147583e-02  3.77197318e-04 -1.79747518e-04 -4.45427324e-09
  9.99999913e-01  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
 -7.60597992e-02  8.54389480e-01  2.62112178e-01]
obs: [ 6.15235164e-03  6.00189803e-01  1.94301175e-01  1.00000000e+00
 -3.81644812e-02  6.25547348e-01  1.99999996e-02  0.00000000e+00
  0.00000000e+00  0.00000000e+00  1.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  5.7123

In [16]:
actions = multi_exp.predict(obs)
obs, rew, dones, info = multi_vec.step(actions)

current_env: 1
(39,)
current_env: 1
(39,)
current step: 3
current step: 3


ValueError: too many values to unpack (expected 4)

In [None]:
actions

[array([-0.13961927,  1.44733373,  0.01294375,  0.        ]),
 array([0.34571412, 1.27925898, 0.16247949, 0.        ])]

In [None]:
multi_vec, multi_exp = make_multi_vec_env(env_ids=['reach', 'pickplace'], num_cpu=2, seq_len=100, sparse=False)
multi_vec.set_env_ids(0)
obs = multi_vec.reset()
print(obs)
done = False
while not done:
    actions = multi_exp.predict(obs)
    obs, rew, dones, info = multi_vec.step(actions)
    done = dones[0]
    print(rew)


  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(


init__________________
init__________________
self.current_enc = 0
[None None]


Process ForkServerProcess-5:
Process ForkServerProcess-6:
Traceback (most recent call last):
  File "/home/hendrik/anaconda3/envs/ac/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/hendrik/anaconda3/envs/ac/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/hendrik/anaconda3/envs/ac/lib/python3.10/site-packages/stable_baselines3/common/vec_env/subproc_vec_env.py", line 27, in _worker
    cmd, data = remote.recv()
  File "/home/hendrik/anaconda3/envs/ac/lib/python3.10/multiprocessing/connection.py", line 255, in recv
    buf = self._recv_bytes()
  File "/home/hendrik/anaconda3/envs/ac/lib/python3.10/multiprocessing/connection.py", line 419, in _recv_bytes
    buf = self._recv(4)
  File "/home/hendrik/anaconda3/envs/ac/lib/python3.10/multiprocessing/connection.py", line 384, in _recv
    chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset b

TypeError: 'NoneType' object is not subscriptable

self.current_enc = 0


In [None]:
obs = multi_vec.reset()

In [None]:
obs

In [None]:
multi_vec.reset()

In [None]:
b = np.concatenate((a, np.array([100]), a[10:]))

In [None]:
b