In [2]:
from typing import TypeVar

import torch.nn as nn
import torch.optim as optim
import plotly.express as px
import pandas as pd

from gymnasium import spaces
import numpy as np
import torch
import gym
from copy import deepcopy
from torch.distributions.categorical import Categorical
from tqdm.auto import tqdm
import os
import joblib
from lasertag import LasertagAdversarial

import matplotlib.pyplot as plt
from syllabus.core import (
    Curriculum,
    TaskWrapper,
    PettingZooMultiProcessingSyncWrapper,
    make_multiprocessing_curriculum,
)
from syllabus.task_space import TaskSpace
from syllabus.curricula import DomainRandomization

ObsType = TypeVar("ObsType")
ActionType = TypeVar("ActionType")
AgentID = TypeVar("AgentID")

In [3]:
def batchify(x, device):
    """Converts PZ style returns to batch of torch arrays."""
    # convert to list of np arrays
    x = np.stack([x[a] for a in x], axis=0)
    # convert to torch
    x = torch.tensor(x).to(device)

    return x


def unbatchify(x, possible_agents:np.ndarray):
    """Converts np array to PZ style arguments."""
    x = x.cpu().numpy()
    x = {a: x[i] for i, a in enumerate(possible_agents)}

    return x

In [4]:
class LasertagParallelWrapper(TaskWrapper):
    """
    Wrapper ensuring compatibility with the PettingZoo Parallel API.

    Lasertag Environment:
        * Action shape:  `n_agents` * `Discrete(5)`
        * Observation shape: Dict('image': Box(0, 255, (`n_agents`, 3, 5, 5), uint8))
    """

    def __init__(self, n_agents, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_agents = n_agents
        self.task = None
        self.episode_return = 0
        self.task_space = TaskSpace(spaces.MultiDiscrete(np.array([[2], [5]])))
        self.possible_agents = np.arange(self.n_agents)

    def __getattr__(self, name):
        """
        Delegate attribute lookup to the wrapped environment if the attribute
        is not found in the LasertagParallelWrapper instance.
        """
        return getattr(self.env, name)

    def _np_array_to_pz_dict(self, array: np.ndarray) -> dict[str : np.ndarray]:
        """
        Returns a dictionary containing individual observations for each agent.
        Assumes that the batch dimension represents individual agents.
        """
        out = {}
        for idx, i in enumerate(array):
            out[str(idx)] = i
        return out

    def _singleton_to_pz_dict(self, value: bool) -> dict[str:bool]:
        """
        Broadcasts the `done` and `trunc` flags to dictionaries keyed by agent id.
        """
        return {str(idx): value for idx in range(self.n_agents)}

    def reset(self) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]:
        """
        Resets the environment and returns a dictionary of observations
        keyed by agent ID.
        """
        obs = self.env.reset()
        pz_obs = self._np_array_to_pz_dict(obs["image"])

        return pz_obs

    def step(self, action: dict[AgentID, ActionType], device: str) -> tuple[
        dict[AgentID, ObsType],
        dict[AgentID, float],
        dict[AgentID, bool],
        dict[AgentID, bool],
        dict[AgentID, dict],
    ]:
        """
        Takes inputs in the PettingZoo (PZ) Parallel API format, performs a step and
        returns outputs in PZ format.
        """
        action = batchify(action, device)
        obs, rew, done, info = self.env.step(action)
        obs = obs["image"]
        trunc = 0  # there is no `truncated` flag in this environment
        self.task_completion = self._task_completion(obs, rew, done, trunc, info)
        # convert outputs back to PZ format
        obs, rew = tuple(map(self._np_array_to_pz_dict, [obs, rew]))
        done, trunc, info = tuple(
            map(self._singleton_to_pz_dict, [done, trunc, self.task_completion])
        )

        return self.observation(obs), rew, done, trunc, info

In [5]:
class SelfPlay(Curriculum):
    def __init__(self, agent, device: str, store_agents_on_cpu: bool = False):
        self.store_agents_on_cpu = store_agents_on_cpu
        self.storage_device = "cpu" if self.store_agents_on_cpu else device
        self.agent = deepcopy(agent).to(self.storage_device)

    def update_agent(self, agent):
        self.agent = deepcopy(agent).to(self.storage_device)

    def get_opponent(self, agent_id):
        assert (
            agent_id == 0
        ), f"Self play only tracks the current agent. Expected agent id 0, got {agent_id}"

    def sample(self, k=1):
        return 0

In [6]:
class Agent(nn.Module):
    def __init__(self, num_actions):
        super().__init__()

        self.network = nn.Sequential(
            self._layer_init(nn.Linear(3 * 5 * 5, 512)),
            nn.ReLU(),
        )
        self.actor = self._layer_init(nn.Linear(512, num_actions), std=0.01)
        self.critic = self._layer_init(nn.Linear(512, 1))

    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def get_value(self, x, flatten_start_dim=1):
        x = torch.flatten(x, start_dim=flatten_start_dim)
        return self.critic(self.network(x / 255.0))

    def get_action_and_value(self, x, action=None, flatten_start_dim=1):
        x = torch.flatten(x, start_dim=flatten_start_dim)
        hidden = self.network(x / 255.0)
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)

In [7]:
"""ALGO PARAMS"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ent_coef = 0.1
vf_coef = 0.1
clip_coef = 0.1
gamma = 0.99
batch_size = 32
stack_size = 3
frame_size = (5, 5)
max_cycles = 201 # lasertag has 200 maximum steps by default
total_episodes = 500
n_agents = 2
num_actions = 5

""" LEARNER SETUP """
agent = Agent(num_actions=num_actions).to(device)
optimizer = optim.Adam(agent.parameters(), lr=0.001, eps=1e-5)

""" ENV SETUP """
env = LasertagAdversarial(record_video=False) # 2 agents by default
env = LasertagParallelWrapper(env=env, n_agents=n_agents)
curriculum = SelfPlay(agent=agent, device=device, store_agents_on_cpu=True)
observation_size = env.observation_space["image"].shape[1:]

""" ALGO LOGIC: EPISODE STORAGE"""
end_step = 0
total_episodic_return = 0
rb_obs = torch.zeros((max_cycles, n_agents, stack_size, *frame_size)).to(device)
rb_actions = torch.zeros((max_cycles, n_agents)).to(device)
rb_logprobs = torch.zeros((max_cycles, n_agents)).to(device)
rb_rewards = torch.zeros((max_cycles, n_agents)).to(device)
rb_terms = torch.zeros((max_cycles, n_agents)).to(device)
rb_values = torch.zeros((max_cycles, n_agents)).to(device)

losses, episode_rewards = [], []

In [14]:
env.reset_random()

{'image': array([[[[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 1],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]],
 
         [[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]],
 
         [[0, 0, 1, 0, 1],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 1, 0],
          [0, 0, 0, 1, 0],
          [0, 0, 1, 1, 1]]],
 
 
        [[[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 1],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]],
 
         [[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]],
 
         [[0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0],
          [0, 1, 0, 0, 0]]]], dtype=uint8)}

In [15]:
env.level

(b'.\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00w\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00w\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00.\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00.\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00.\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00.\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00w\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00w\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00.\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00w\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00.\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00.\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00