### Implementation details MultiDiscrete action space

The multi-discrete action space consists of a series of discrete action spaces with different number of actions in each.

In MultiDiscrete action spaces, the actions are independent action components, that means the agent can take multiple discrete actions simultaneously.

To account for this difference, PPO treats [a1,a2] as probabilistically independent action components, therefore calculating prob(a)=prob(a1)⋅prob(a2)


In [None]:
!pip install gym
!pip install gym_microrts==0.3.2
!pip install pyvirtualdisplay
!pip install -y xvfb ffmpeg

In [None]:
import argparse
import random
import gym
import gym_microrts
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical

  logger.warn(


In [None]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    nn.init.orthogonal_(layer.weight, std)
    nn.init.constant_(layer.bias, bias_const)
    return layer

def conv2d_output_size(W_in, kernel_size, stride=1):
    """Compute the output size of a 2D convolution operation."""
    return (W_in - (kernel_size - 1) - 1) // stride + 1

  and should_run_async(code)


In [None]:
class Transpose(nn.Module):
    def __init__(self, permutation):
        super().__init__()
        self.permutation = permutation

    def forward(self, x):
        return x.permute(self.permutation)

In [None]:
class PPOAgent(nn.Module):
    def __init__(self, state_shape, action_nvec):
        super(PPOAgent, self).__init__()

        self.action_nvec = action_nvec

        conv1_size = conv2d_output_size(state_shape[0], kernel_size=3, stride=2)
        conv2_size = conv2d_output_size(conv1_size, kernel_size=2, stride=1)

        self.network = nn.Sequential(
            Transpose((0, 3, 1, 2)), # transpose: batch_size, channels in, height, weight. 8 x 27x16x16
            layer_init(nn.Conv2d(state_shape[2], 16, kernel_size=3, stride=2)), # 8 x 16x7x7
            nn.ReLU(),
            layer_init(nn.Conv2d(16, 32, kernel_size=2, stride=1)), # 8 x 32x6x6
            nn.ReLU(),
            nn.Flatten(), # 8 x 1152
            layer_init(nn.Linear(32*conv2_size*conv2_size, 128)), # 8 x 128
            nn.ReLU()
        )

        self.actor = layer_init(nn.Linear(128, action_nvec.sum()), std=0.01)
        self.critic = layer_init(nn.Linear(128, 1), std=1.0)

    def get_value(self, x):
        hidden = self.network(x)
        return self.critic(hidden)

    def get_action_and_value(self, x, action=None):
        hidden = self.network(x)
        logits = self.actor(hidden)
        split_logits = torch.split(logits, self.action_nvec.tolist(), dim=1)
        multi_probs = [Categorical(logits=logits) for logits in split_logits]

        if action is None:
            action = torch.stack([probs.sample() for probs in multi_probs])

        log_prob = torch.stack([probs.log_prob(a) for a, probs in zip(action, multi_probs)]).sum(0)
        entropy = torch.stack([probs.entropy() for probs in multi_probs]).sum(0)
        value = self.critic(hidden)

        return action.T, log_prob, entropy, value

In [None]:
config = argparse.Namespace()

config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config.seed = 1

config.env_id = 'MicrortsMining-v4'
config.num_envs = 8
config.num_steps = 128     # The number of steps to run in each env to policy rollout

config.batch_size = config.num_envs * config.num_steps
config.num_minibatches = 4 # the number of mini-batches
config.minibatch_size = config.batch_size // config.num_minibatches

config.total_timesteps = 2000000
config.num_updates = config.total_timesteps // config.batch_size
config.num_epochs = 4

config.learning_rate = 3e-4
config.gamma = 0.99
config.anneal_lr = True    # Learning rate annealing for policy and value networks
config.gae = True         # Generalized Advantage Estimation
config.gae_lambda = 0.95

config.norm_adv = True     # Advantages normalization
config.clip_coef = 0.1     # The surrogate clipping coefficient (policy and value function)
config.clip_vloss = True   # Use clip_coef to clip value function

config.vf_coef = 0.5       # The value coefficient to calculate loss
config.ent_coef = 0.01     # Encourages the policy to explore a diverse set of actions
config.max_grad_norm = 0.5 # The maximum normalization for the gradient clipping
config.target_kl = 0.015   # The target KL divergence threshold

config.track = True
config.record_video = False

Connect to Weights and Biases

In [None]:
if config.track:
    !pip install wandb

In [None]:
if config.track:
    import wandb
    wandb.login()

In [None]:
if config.track:
    wandb.init(
        project='ppo-implementation-details',
        config=config,
        name=config.env_id,
        monitor_gym=True,
        save_code=True,
    )

In [None]:
def make_env(env_id, seed, idx, record_video=False):
    def thunk():
        env = gym.make(env_id)
        if record_video and idx == 0:
            env = gym.wrappers.RecordVideo(env, f"videos/{env_id}")

        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)

        return env

    return thunk

In [None]:
# Seeding
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)

envs = gym.vector.SyncVectorEnv(
    [make_env(config.env_id, config.seed + i, i, record_video=config.record_video) for i in range(config.num_envs)]
)

assert isinstance(envs.single_action_space, gym.spaces.MultiDiscrete), "only milti discrete action space is supported"

print("Observation space:", envs.single_observation_space)
print("Action space:", envs.single_action_space.nvec)

agent = PPOAgent(envs.single_observation_space.shape, envs.single_action_space.nvec).to(config.device)
optimizer = optim.Adam(agent.parameters(), lr=config.learning_rate, eps=1e-5)

Observation space: Box(0, 1, (16, 16, 27), int32)
Action space: [256   6   4   4   4   4   7 256]


In [None]:
def compute_gae_advantages(rewards, values, next_return, dones, num_steps, gamma=0.99, gae_lambda=0.95):
    gae_advantages = torch.zeros_like(rewards)

    gae = 0
    for t in reversed(range(num_steps)):
        delta = rewards[t] + gamma * next_return * (1 - dones[t]) - values[t]
        gae = delta + gamma * gae_lambda * gae * (1 - dones[t])
        gae_advantages[t] = gae
        next_return = values[t]

    returns = gae_advantages + values
    return returns, gae_advantages

In [None]:
def compute_n_step_return(rewards, values, next_return, dones, num_steps, gamma=0.99):
    returns = torch.zeros_like(rewards)

    for t in reversed(range(num_steps)):
        returns[t] = rewards[t] + gamma * next_return * (1 - dones[t])
        next_return = returns[t]

    advantages = returns - values
    return returns, advantages

In [None]:
def PPOTrain(envs, optimizer, config):
  global global_step

  # Storage setup
  observations = torch.zeros((config.num_steps, config.num_envs) + envs.single_observation_space.shape).to(config.device)
  actions = torch.zeros((config.num_steps, config.num_envs) + envs.single_action_space.shape, dtype=torch.long).to(config.device)
  rewards = torch.zeros((config.num_steps, config.num_envs)).to(config.device)
  dones = torch.zeros((config.num_steps, config.num_envs)).to(config.device)
  values = torch.zeros((config.num_steps, config.num_envs)).to(config.device)
  logprobs = torch.zeros((config.num_steps, config.num_envs)).to(config.device)

  obs = envs.reset()

  for update in range(config.num_updates):
      # Update learning rate
      if config.anneal_lr:
          lr_frac = 1.0 - update / config.num_updates
          optimizer.param_groups[0]["lr"] = lr_frac * config.learning_rate

      # Do n-steps
      episodic_return = 0
      for step in range(config.num_steps):
          global_step += config.num_envs

          with torch.no_grad():
              obs_tensor = torch.Tensor(obs).to(config.device)
              action, logprob, _, value = agent.get_action_and_value(obs_tensor)

          next_obs, reward, done, _, = envs.step(action.cpu().numpy())

          # Save batch
          observations[step] = obs_tensor
          actions[step] = torch.tensor(action, dtype=torch.long).to(config.device)
          rewards[step] = torch.Tensor(reward).to(config.device)
          dones[step] = torch.Tensor(done).to(config.device)
          values[step] = torch.Tensor(value.flatten()).to(config.device)
          logprobs[step] = torch.Tensor(logprob).to(config.device)

          episodic_return += reward.mean()
          obs = next_obs

      print(f"global_step={global_step}, episodic_return={episodic_return}")
      if config.track:
          wandb.log({'episodic_return': episodic_return})

      # Generalized Advantage Estimation
      with torch.no_grad():
          next_obs_tensor = torch.Tensor(next_obs).to(config.device)
          next_value = agent.get_value(next_obs_tensor).reshape(1, -1)

          if config.gae:
              returns, advantages = compute_gae_advantages(rewards, values, next_value, dones, config.num_steps, config.gamma, config.gae_lambda)
          else:
              returns, advantages = compute_n_step_return(rewards, values, next_value, dones, config.num_steps, config.gamma)

      # Flatten the batch: num_steps * num_envs
      b_observations = observations.reshape((-1,) + envs.single_observation_space.shape)
      b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
      b_returns = returns.reshape(-1)
      b_values = values.reshape(-1)
      b_logprobs = logprobs.reshape(-1)
      b_advantages = advantages.reshape(-1)

      clip_fracs = []

      # Shuffles the indices of the batch and breaks it into mini-batches
      batch_inds = np.arange(config.batch_size)
      for epoch in range(config.num_epochs):
          np.random.shuffle(batch_inds)

          for start in range(0, config.batch_size, config.minibatch_size):
              end = start + config.minibatch_size
              minibatch_inds = batch_inds[start:end]

              # Mini-batches: targets
              mb_observations = b_observations[minibatch_inds]
              mb_actions = b_actions[minibatch_inds]
              mb_returns = b_returns[minibatch_inds]
              mb_values = b_values[minibatch_inds]
              mb_logprobs = b_logprobs[minibatch_inds]
              mb_advantages = b_advantages[minibatch_inds]

              # Advantages normalization
              if config.norm_adv:
                  mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

              # Predictions
              _, mb_logprobs_pred, mb_entropy_pred, mb_values_pred = agent.get_action_and_value(mb_observations, mb_actions.T)

              log_ratio = mb_logprobs_pred - mb_logprobs
              ratio = log_ratio.exp()

              with torch.no_grad():
                  clip_fracs += [((ratio - 1.0).abs() > config.clip_coef).float().mean().item()]

              # Policy loss
              pg_loss_unclipped = -mb_advantages * ratio
              pg_loss_clipped = -mb_advantages * torch.clamp(ratio, 1 - config.clip_coef, 1 + config.clip_coef)
              pg_loss = torch.max(pg_loss_unclipped, pg_loss_clipped).mean()

              # Value loss
              mb_values_pred = mb_values_pred.view(-1)

              v_loss_sqrt = (mb_values_pred - mb_returns) ** 2 # Unclipped
              if config.clip_vloss:
                  # Ensure that the value function updates do not deviate too far from the original values
                  v_clipped = torch.clamp(mb_values_pred, mb_values - config.clip_coef, mb_values + config.clip_coef)
                  v_loss_clipped = (v_clipped - mb_returns) ** 2
                  v_loss_sqrt = torch.max(v_loss_sqrt, v_loss_clipped)

              v_loss = 0.5 * v_loss_sqrt.mean()

              # Entropy Loss
              entropy_loss = mb_entropy_pred.mean()

              # Overall Loss
              loss = pg_loss - config.ent_coef * entropy_loss + config.vf_coef * v_loss

              optimizer.zero_grad()
              loss.backward()

              # Global Gradient Clipping
              nn.utils.clip_grad_norm_(agent.parameters(), config.max_grad_norm)

              optimizer.step()

          with torch.no_grad():
            # old_approx_kl = (-logratio).mean()
            approx_kl = ((ratio - 1) - log_ratio).mean()

          # Early stop using approx_kl
          if config.target_kl is not None:
                if approx_kl > config.target_kl:
                    break

      y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
      var_y = np.var(y_true)
      explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

      if config.track:
          metrics = {
              "GLOBAL STEP": global_step,
              "loss": loss.item(),
              "learning_rate": optimizer.param_groups[0]["lr"],
              "value_loss": v_loss.item(),
              "policy_loss": pg_loss.item(),
              "entropy": entropy_loss.item(),
              "approx_kl": approx_kl.item(),
              "clip_frac": np.mean(clip_fracs),
              "explained_variance": explained_var
          }
          wandb.log(metrics)
          print(metrics)

In [None]:
global_step = 0

In [None]:
PPOTrain(envs, optimizer, config)

In [None]:
def evaluate(num_steps):
    episodic_return = 0
    state = envs.reset()

    for step in range(num_steps):
        with torch.no_grad():
            state = torch.Tensor(state).to(config.device)
            action, _, _, _ = agent.get_action_and_value(state)

        state, reward, done, _, = envs.step(action.cpu().numpy())

        episodic_return += reward.mean()
        video_recorder.capture_frame()

        if done: break

    video_recorder.close()

In [None]:
from pyvirtualdisplay import Display
from gym.wrappers.monitoring.video_recorder import VideoRecorder

display = Display(visible=False, size=(1400, 900))
_ = display.start()

envs.render_mode = 'rgb_array'
video_path = f"./videos/{config.env_id}.mp4"
video_recorder = VideoRecorder(envs, video_path, enabled=True)

evaluate(2000)

In [None]:
from IPython.display import display, HTML
# Display the video
display(HTML(f"""<video src="{video_path}" width=400 controls></video>"""))

In [None]:
envs.close()
if config.track: wandb.finish()