# Action Masking

© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off.

In [None]:
from primaite.session.environment import PrimaiteGymEnv
from primaite.config.load import data_manipulation_config_path
from prettytable import PrettyTable


In [None]:
env = PrimaiteGymEnv(data_manipulation_config_path())
env.action_masking = True

The action mask is a list of booleans that specifies whether each action in the agent's action map is currently possible. Demonstrated here:

In [None]:
act_table = PrettyTable(("number", "action", "parameters", "mask"))
mask = env.action_masks()
actions = env.agent.action_manager.action_map
max_str_len = 70
for act,mask in zip(actions.items(), mask):
    act_num, act_data = act
    act_type, act_params = act_data
    act_params = s if len(s:=str(act_params))<max_str_len else f"{s[:max_str_len-3]}..."
    act_table.add_row((act_num, act_type, act_params, mask))
print(act_table)

## Action masking for Stable Baselines3 agents
SB3 agents automatically use the action_masks method during the training loop

In [None]:
from sb3_contrib import MaskablePPO


In [None]:
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32)
model.learn(1024)

## Action masking for Ray RLLib agents
Ray uses a different API to obtain action masks, but this is handled by the PrimaiteRayEnv and PrimaiteRayMarlEnv classes

In [None]:
from primaite.session.ray_envs import PrimaiteRayEnv
from ray.rllib.algorithms.ppo import PPOConfig
import yaml
from ray import air, tune
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec


In [None]:
with open(data_manipulation_config_path(), 'r') as f:
    cfg = yaml.safe_load(f)
for agent in cfg['agents']:
    if agent["ref"] == "defender":
        agent['agent_settings']['flatten_obs'] = True
env_config = cfg


In [None]:
config = (
    PPOConfig()
    .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
    .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key="action_mask")
    .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class = ActionMaskingTorchRLModule))
    .env_runners(num_env_runners=0)
    .training(train_batch_size=128)
)
algo = config.build()
for i in range(2):
    results = algo.train()

## Action masking with MARL in Ray RLLib
Each agent has their own action mask, this is useful if the agents have different action spaces.

In [None]:
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from primaite.session.ray_envs import PrimaiteRayMARLEnv
from primaite.config.load import data_manipulation_marl_config_path

In [None]:
with open(data_manipulation_marl_config_path(), 'r') as f:
    cfg = yaml.safe_load(f)
env_config = cfg


In [None]:
config = (
    PPOConfig()
    .multi_agent(
        policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.
        policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,
        )
    .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
    .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key="action_mask")
    .rl_module(rl_module_spec=MultiAgentRLModuleSpec(module_specs={
        "defender_1":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),
        "defender_2":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),
        }))
    .env_runners(num_env_runners=0)
    .training(train_batch_size=128)
)
algo = config.build()
for i in range(2):
    results = algo.train()