# ***Lasertag Vecenv***

In [35]:
import argparse
import os
import sys
import time
from typing import Dict, Tuple, TypeVar

import joblib
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from gymnasium import spaces
from plotly.subplots import make_subplots
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import supersuit as ss
sys.path.append("../../..")
from lasertag import LasertagAdversarial  # noqa: E402
from syllabus.core import (  # noqa: E402
    DualCurriculumWrapper,
    TaskWrapper,
    make_multiprocessing_curriculum,
)

# noqa: E402
from syllabus.curricula import (  # noqa: E402
    CentralizedPrioritizedLevelReplay,
    DomainRandomization,
    FictitiousSelfPlay,
    PrioritizedFictitiousSelfPlay,
    SelfPlay,
)
from syllabus.task_space import TaskSpace  # noqa: E402

ActionType = TypeVar("ActionType")
AgentID = TypeVar("AgentID")
AgentType = TypeVar("AgentType")
EnvTask = TypeVar("EnvTask")
AgentTask = TypeVar("AgentTask")
ObsType = TypeVar("ObsType")

In [36]:
from pettingzoo import ParallelEnv

In [37]:
def batchify(x, device):
    """Converts PZ style returns to batch of torch arrays."""
    # convert to list of np arrays
    x = np.stack([x[a] for a in x], axis=0)
    # convert to torch
    x = torch.tensor(x).to(device)

    return x


def unbatchify(x, possible_agents: np.ndarray):
    """Converts np array to PZ style arguments."""
    x = x.cpu().numpy()
    x = {agent: x[idx] for idx, agent in enumerate(possible_agents)}

    return x


class LasertagParallelWrapper(ParallelEnv, TaskWrapper):
    """
    Wrapper ensuring compatibility with the PettingZoo Parallel API.

    Lasertag Environment:
        * Action shape:  `n_agents` * `Discrete(5)`
        * Observation shape: Dict('image': Box(0, 255, (`n_agents`, 3, 5, 5), uint8))
    """

    def __init__(self, n_agents, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_agents = n_agents
        self.task = None
        self.episode_return = 0
        self.possible_agents = [f"agent_{i}" for i in range(self.n_agents)]
        self.n_steps = 0
        self.observation_spaces = {
            f"{agent}": spaces.Box(
                low=0,
                high=255,
                shape=(
                    self.n_agents,
                    3,
                    self.agent_view_size,
                    self.agent_view_size,
                ),
                dtype="uint8",
            )
            for agent in self.possible_agents
        }
        self.action_spaces = {
            f"{agent}": spaces.Discrete(5) for agent in self.possible_agents
        }

    def __getattr__(self, name):
        """
        Delegate attribute lookup to the wrapped environment if the attribute
        is not found in the LasertagParallelWrapper instance.
        """
        return getattr(self.env, name)

    def _np_array_to_pz_dict(self, array: np.ndarray) -> Dict[str, np.ndarray]:
        """
        Returns a dictionary containing individual observations for each agent.
        Assumes that the batch dimension represents individual agents.
        """
        out = {}
        for idx, value in enumerate(array):
            out[self.possible_agents[idx]] = value
        return out

    def _singleton_to_pz_dict(self, value: bool) -> Dict[str, bool]:
        """
        Broadcasts the `done` and `trunc` flags to dictionaries keyed by agent id.
        """
        return {str(agent_index): value for agent_index in range(self.n_agents)}

    def reset(
        self, env_task: int
    ) -> Tuple[Dict[AgentID, ObsType], Dict[AgentID, dict]]:
        """
        Resets the environment and returns a dictionary of observations
        keyed by agent ID.
        """
        self.env.seed(env_task)
        obs = self.env.reset_random()  # random level generation
        pz_obs = self._np_array_to_pz_dict(obs["image"])

        return pz_obs

    def step(
        self, action: Dict[AgentID, ActionType], device: str, agent_task: int
    ) -> Tuple[
        Dict[AgentID, ObsType],
        Dict[AgentID, float],
        Dict[AgentID, bool],
        Dict[AgentID, bool],
        Dict[AgentID, dict],
    ]:
        """
        Takes inputs in the PettingZoo (PZ) Parallel API format, performs a step and
        returns outputs in PZ format.
        """
        action = batchify(action, device)
        obs, rew, done, info = self.env.step(action)
        obs = obs["image"]
        trunc = False  # there is no `truncated` flag in this environment
        self.task_completion = self._task_completion(obs, rew, done, trunc, info)
        # convert outputs back to PZ format
        obs, rew = map(self._np_array_to_pz_dict, [obs, rew])
        done, trunc, info = map(
            self._singleton_to_pz_dict, [done, trunc, self.task_completion]
        )
        info["agent_id"] = agent_task
        self.n_steps += 1

        return self.observation(obs), rew, done, trunc, info

In [38]:
env = LasertagAdversarial(record_video=False)  # 2 agents by default
env = LasertagParallelWrapper(env=env, n_agents=2)

In [39]:
num_envs = 2

env = ss.pettingzoo_env_to_vec_env_v1(env)
envs = ss.concat_vec_envs_v1(
    env, num_envs // 2, num_cpus=0, base_class="gymnasium"
)



TypeError: cannot pickle 'python_griddly.GDY' object