In [1]:
from gymnasium.spaces import MultiDiscrete
import numpy as np
action_space = MultiDiscrete(np.array([20,25]), seed=42)
action_space.sample()

array([15, 10])

In [2]:
action_space.sample()

array([17, 17])

In [2]:
from nocturne.envs.nocturne_gymnasium import NocturneGymnasium
import yaml
from nocturne.envs.base_env import BaseEnv
from nocturne.envs.vec_env_ma import MultiAgentAsVecEnv

# Load environment settings
with open(f"../configs/env_config.yaml", "r") as stream:
    env_config = yaml.safe_load(stream)

# Initialize environment
env = BaseEnv(config=env_config)

In [4]:
gymnasiumEnv = NocturneGymnasium(env)

In [5]:
gymnasiumEnv.action_space

MultiDiscrete([20 25])

In [6]:
gymnasiumEnv.reset()

{6: array([0.30362597, 0.54050583, 0.16309013, ..., 0.        , 0.        ,
        0.        ]),
 23: array([0.33037499, 0.583     , 0.15996636, ..., 0.        , 0.        ,
        0.        ]),
 2: array([0.29274434, 0.50419724, 0.15200901, ..., 0.        , 0.        ,
        0.        ]),
 8: array([0.30145869, 0.53301573, 0.16035738, ..., 0.        , 0.        ,
        0.        ]),
 9: array([0.27850392, 0.50491506, 0.16618462, ..., 0.        , 0.        ,
        0.        ]),
 1: array([0.28742164, 0.52419186, 0.17056067, ..., 0.        , 0.        ,
        0.        ]),
 14: array([0.27822894, 0.51315653, 0.1849147 , ..., 0.        , 0.        ,
        0.        ]),
 18: array([0.28400436, 0.51011646, 0.13745898, ..., 0.        , 0.        ,
        0.        ]),
 4: array([0.31240156, 0.52112043, 0.16890202, ..., 0.        , 0.        ,
        0.        ]),
 0: array([0.27941427, 0.51273805, 0.16114239, ..., 0.        , 0.        ,
        0.        ])}

In [None]:
# Reset
obs_dict = gymnasiumEnv.reset()

# Get info
agent_ids = [agent_id for agent_id in obs_dict.keys()]
dead_agent_ids = []
num_agents = len(agent_ids)
rewards = {agent_id: 0 for agent_id in agent_ids}

for step in range(1000):

    # Sample actions
    action_dict = {
        agent_id: env.action_space.sample() 
        for agent_id in agent_ids
        if agent_id not in dead_agent_ids
    }
    # Step in env
    obs_dict, rew_dict, done_dict, info_dict = gymnasiumEnv.step(action_dict)

    for agent_id in action_dict.keys():
        rewards[agent_id] += rew_dict[agent_id]

    # Update dead agents
    for agent_id, is_done in done_dict.items():
        if is_done and agent_id not in dead_agent_ids:
            dead_agent_ids.append(agent_id)

    # Reset if all agents are done
    if done_dict["__all__"]:
        print(f'Done after {env.step_num} steps -- total return in episode: {rewards}')
        obs_dict = gymnasiumEnv.reset()
        agent_ids = [agent_id for agent_id in obs_dict.keys()]
        dead_agent_ids = []
        rewards = {agent_id: 0 for agent_id in agent_ids}

# Close environment
env.close()

In [3]:
from stable_baselines3.common.vec_env import SubprocVecEnv

In [4]:
def make_env(env_config):
    return NocturneGymnasium(BaseEnv(config=env_config)) 

In [5]:
envs = SubprocVecEnv([lambda: make_env(env_config) for _ in range(4)])

In [6]:
# Reset
obs_dicts = envs.reset()

In [7]:
agent_ids_batch = []
dead_agent_ids_batch = []
num_agents_batch = []
rewards_batch = []
for obs_dict in obs_dicts:
    agent_ids = [agent_id for agent_id in obs_dict.keys()]
    dead_agent_ids = []
    num_agents = len(agent_ids)
    rewards = {agent_id: 0 for agent_id in agent_ids}
    agent_ids_batch.append(agent_ids)
    dead_agent_ids_batch.append(dead_agent_ids)
    num_agents_batch.append(num_agents)
    rewards_batch.append(rewards)


In [8]:
action_dicts = [
        {
            agent_id: env.action_space.sample() 
            for agent_id in agent_ids
            if agent_id not in dead_agent_ids
        }
        for agent_ids, dead_agent_ids in zip(agent_ids_batch, dead_agent_ids_batch)
    ]
action_dicts

[{6: 6, 23: 18, 2: 24, 8: 1, 9: 2, 1: 18, 14: 1, 18: 13, 4: 11, 0: 18},
 {6: 0, 23: 14, 2: 8, 8: 22, 9: 1, 1: 4, 14: 18, 18: 18, 4: 1, 0: 23},
 {6: 9, 23: 22, 2: 1, 8: 2, 9: 0, 1: 20, 14: 5, 18: 14, 4: 7, 0: 13},
 {6: 13, 23: 1, 2: 8, 8: 13, 9: 9, 1: 16, 14: 9, 18: 9, 4: 12, 0: 23}]

In [9]:
for step in range(1000):

    # Sample actions
    action_dicts = [
        {
            agent_id: env.action_space.sample() 
            for agent_id in agent_ids
            if agent_id not in dead_agent_ids
        }
        for agent_ids, dead_agent_ids in zip(agent_ids_batch, dead_agent_ids_batch)
    ]
    # Step in env
    obs_dicts, rew_dicts, done_dicts, info_dicts = envs.step(action_dicts)

    for rew_dict, rewards in zip(rew_dicts, rewards_batch):
        for agent_id in rew_dict.keys():
            rewards[agent_id] += rew_dict[agent_id] 
    
    # Update dead agents
    for done_dict, dead_agent_ids in zip(done_dicts, dead_agent_ids_batch):
        for agent_id, is_done in done_dict.items():
            if is_done and agent_id not in dead_agent_ids:
                dead_agent_ids.append(agent_id)

    # Reset if all agents are done
    if all([done_dict["__all__"] for done_dict in done_dicts]):
        print(f'Done after {env.step_num} steps -- total return in episode: {rewards}')
        obs_dicts = envs.reset()
        agent_ids_batch = []
        dead_agent_ids_batch = []
        num_agents_batch = []
        rewards_batch = []
        for obs_dict in obs_dicts:
            agent_ids = [agent_id for agent_id in obs_dict.keys()]
            dead_agent_ids = []
            num_agents = len(agent_ids)
            rewards = {agent_id: 0 for agent_id in agent_ids}
            agent_ids_batch.append(agent_ids)
            dead_agent_ids_batch.append(dead_agent_ids)
            num_agents_batch.append(num_agents)
            rewards_batch.append(rewards)

    # # Sample actions
    # action_dict = {
    #     agent_id: env.action_space.sample() 
    #     for agent_id in agent_ids
    #     if agent_id not in dead_agent_ids
    # }
    # # Step in env
    # obs_dict, rew_dict, done_dict, info_dict = gymnasiumEnv.step(action_dict)

    # for agent_id in action_dict.keys():
    #     rewards[agent_id] += rew_dict[agent_id]

    # # Update dead agents
    # for agent_id, is_done in done_dict.items():
    #     if is_done and agent_id not in dead_agent_ids:
    #         dead_agent_ids.append(agent_id)

    # # Reset if all agents are done
    # if done_dict["__all__"]:
    #     print(f'Done after {env.step_num} steps -- total return in episode: {rewards}')
    #     obs_dict = gymnasiumEnv.reset()
    #     agent_ids = [agent_id for agent_id in obs_dict.keys()]
    #     dead_agent_ids = []
    #     rewards = {agent_id: 0 for agent_id in agent_ids}

KeyError: 7