In [1]:
import argparse
import os
import sys
import time
from typing import TypeVar

import joblib
import numpy as np
import plotly
import plotly.express as px
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from gymnasium import spaces
from plotly.subplots import make_subplots
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from lasertag_dr import LasertagParallelWrapper
sys.path.append("../../..")
from lasertag import LasertagAdversarial  # noqa: E402
from syllabus.core import (  # noqa: E402
    DualCurriculumWrapper,
    TaskWrapper,
    make_multiprocessing_curriculum,
)

# noqa: E402
from syllabus.curricula import (  # noqa: E402
    DomainRandomization,
    PrioritizedFictitiousSelfPlay,
)
from syllabus.task_space import TaskSpace  # noqa: E402

ActionType = TypeVar("ActionType")
AgentID = TypeVar("AgentID")
AgentType = TypeVar("AgentType")
EnvTask = TypeVar("EnvTask")
AgentTask = TypeVar("AgentTask")
ObsType = TypeVar("ObsType")


def batchify(x, device):
    """Converts PZ style returns to batch of torch arrays."""
    # convert to list of np arrays
    x = np.stack([x[a] for a in x], axis=0)
    # convert to torch
    x = torch.tensor(x).to(device)

    return x


def unbatchify(x, possible_agents: np.ndarray):
    """Converts np array to PZ style arguments."""
    x = x.cpu().numpy()
    x = {agent: x[idx] for idx, agent in enumerate(possible_agents)}

    return x


class Agent(nn.Module):
    def __init__(self, num_actions):
        super().__init__()

        self.network = nn.Sequential(
            self._layer_init(nn.Linear(3 * 5 * 5, 512)),
            nn.ReLU(),
        )
        self.actor = self._layer_init(nn.Linear(512, num_actions), std=0.01)
        self.critic = self._layer_init(nn.Linear(512, 1))

    def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def get_value(self, x, flatten_start_dim=1):
        x = torch.flatten(x, start_dim=flatten_start_dim)
        return self.critic(self.network(x / 255.0))

    def get_action_and_value(self, x, action=None, flatten_start_dim=1):
        x = torch.flatten(x, start_dim=flatten_start_dim)
        hidden = self.network(x / 255.0)
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)


"""ALGO PARAMS"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ent_coef = 0.0
vf_coef = 0.5
clip_coef = 0.2
learning_rate = 1e-4
epsilon = 1e-5
gamma = 0.995
gae_lambda = 0.95
epochs = 5
batch_size = 32
stack_size = 3
frame_size = (5, 5)
max_cycles = 201  # lasertag has 200 maximum steps by default
total_episodes = 500
n_agents = 2
num_actions = 5
fsp_update_frequency = 50

save_agent_checkpoints = True
checkpoint_frequency = total_episodes / 10
logging_dir = "./pfsp_checkpoints"

""" LEARNER SETUP """
agent = Agent(num_actions=num_actions).to(device)
optimizer = optim.Adam(agent.parameters(), lr=learning_rate, eps=epsilon)

In [2]:
""" ENV SETUP """

env = LasertagAdversarial(record_video=False)  # 2 agents by default
env = LasertagParallelWrapper(env=env, n_agents=n_agents)
env_curriculum = DomainRandomization(TaskSpace(spaces.Discrete(200)))
agent_curriculum = PrioritizedFictitiousSelfPlay(
    agent=agent, device=device, storage_path="pfsp_agents", max_agents=10
)
dual_curriculum = DualCurriculumWrapper(
    env=env,
    agent_curriculum=agent_curriculum,
    env_curriculum=env_curriculum,
)
mp_curriculum = make_multiprocessing_curriculum(dual_curriculum)

""" ALGO LOGIC: EPISODE STORAGE"""
end_step = 0
total_episodic_return = 0
rb_obs = torch.zeros((max_cycles, n_agents, stack_size, *frame_size)).to(device)
rb_actions = torch.zeros((max_cycles, n_agents)).to(device)
rb_logprobs = torch.zeros((max_cycles, n_agents)).to(device)
rb_rewards = torch.zeros((max_cycles, n_agents)).to(device)
rb_terms = torch.zeros((max_cycles, n_agents)).to(device)
rb_values = torch.zeros((max_cycles, n_agents)).to(device)

agent_tasks, env_tasks = [], []
agent_c_rew, opp_c_rew = 0, 0
n_ends, n_learner_wins = 0, 0
info = {}

In [3]:
from dataclasses import dataclass

In [6]:
@dataclass
class args:
    total_episodes = 500
    agent_curriculum = "PFSP"
    logging_dir = "."


# train for n number of episodes
for episode in tqdm(range(args.total_episodes)):
    # collect an episode
    with torch.no_grad():
        # collect observations and convert to batch of torch tensors
        env_task, agent_task = mp_curriculum.sample()

        env_tasks.append(env_task)
        agent_tasks.append(agent_task)

        next_obs = env.reset(env_task)
        # reset the episodic return
        total_episodic_return = 0

        # each episode has num_steps
        for step in range(0, max_cycles):
            # rollover the observation
            joint_obs = batchify(next_obs, device).squeeze()
            agent_obs, opponent_obs = joint_obs

            # get action from the agent and the opponent
            actions, logprobs, _, values = agent.get_action_and_value(
                agent_obs, flatten_start_dim=0
            )

            opponent = mp_curriculum.get_opponent(info.get("agent_id", 0)).to(device)
            opponent_action, *_ = opponent.get_action_and_value(
                opponent_obs, flatten_start_dim=0
            )
            # execute the environment and log data
            joint_actions = torch.tensor((actions, opponent_action))
            next_obs, rewards, terms, truncs, info = env.step(
                unbatchify(joint_actions, env.possible_agents), device, agent_task
            )

            opp_reward = rewards["agent_1"]
            if opp_reward != 0:
                n_ends += 1
                if args.agent_curriculum in ["FSP", "PFSP"]:
                    mp_curriculum.update_winrate(info["agent_id"], opp_reward)
                if opp_reward == -1:
                    n_learner_wins += 1

            # add to episode storage
            rb_obs[step] = batchify(next_obs, device)
            rb_rewards[step] = batchify(rewards, device)
            rb_terms[step] = batchify(terms, device)
            rb_actions[step] = joint_actions
            rb_logprobs[step] = logprobs
            rb_values[step] = values.flatten()

            # compute episodic return
            total_episodic_return += rb_rewards[step].cpu().numpy()

            # store learner checkpoints
            if env.n_steps % 2000 == 0:
                print(f"saving checkpoint --{env.n_steps}")
                joblib.dump(
                    agent,
                    filename=(
                        f"{args.logging_dir}/test_checkpoints/"
                        f"{mp_curriculum.curriculum.env_curriculum.name}_"
                        f"{mp_curriculum.curriculum.agent_curriculum.name}_{env.n_steps}"
                        f"_seed_{0}.pkl"
                    ),
                )

            # if we reach termination or truncation, end
            if any([terms[a] for a in terms]) or any([truncs[a] for a in truncs]):
                end_step = step
                break

    with torch.no_grad():
        next_value = agent.get_value(
            torch.tensor(next_obs["agent_0"]), flatten_start_dim=0
        )
        rb_advantages = torch.zeros_like(rb_rewards).to(device)
        last_gae_lam = 0
        for t in reversed(range(end_step)):
            if t == end_step - 1:
                next_non_terminal = 1.0 - rb_terms[t + 1]
                next_values = next_value
            else:
                next_non_terminal = 1.0 - rb_terms[t + 1]
                next_values = rb_values[t + 1]
            delta = (
                rb_rewards[t] + gamma * next_values * next_non_terminal - rb_values[t]
            )
            rb_advantages[t] = last_gae_lam = (
                delta + gamma * gae_lambda * next_non_terminal * last_gae_lam
            )
        rb_returns = rb_advantages + rb_values
    # convert our episodes to batch of individual transitions
    b_obs = torch.flatten(rb_obs[:end_step], start_dim=0, end_dim=1)
    b_logprobs = torch.flatten(rb_logprobs[:end_step], start_dim=0, end_dim=1)
    b_actions = torch.flatten(rb_actions[:end_step], start_dim=0, end_dim=1)
    b_returns = torch.flatten(rb_returns[:end_step], start_dim=0, end_dim=1)
    b_values = torch.flatten(rb_values[:end_step], start_dim=0, end_dim=1)
    b_advantages = torch.flatten(rb_advantages[:end_step], start_dim=0, end_dim=1)

    # Optimizing the policy and value network
    b_index = np.arange(len(b_obs))
    clip_fracs = []
    for repeat in range(epochs):
        # shuffle the indices we use to access the data
        np.random.shuffle(b_index)
        for start in range(0, len(b_obs), batch_size):
            # select the indices we want to train on
            end = start + batch_size
            batch_index = b_index[start:end]

            _, newlogprob, entropy, value = agent.get_action_and_value(
                b_obs[batch_index], b_actions.long()[batch_index]
            )
            logratio = newlogprob - b_logprobs[batch_index]
            ratio = logratio.exp()

            with torch.no_grad():
                # calculate approx_kl http://joschu.net/blog/kl-approx.html
                old_approx_kl = (-logratio).mean()
                approx_kl = ((ratio - 1) - logratio).mean()
                clip_fracs += [((ratio - 1.0).abs() > clip_coef).float().mean().item()]

            # normalize advantages
            rb_advantages = b_advantages[batch_index]
            rb_advantages = (rb_advantages - rb_advantages.mean()) / (
                rb_advantages.std() + 1e-8
            )

            # Policy loss
            pg_loss1 = -b_advantages[batch_index] * ratio
            pg_loss2 = -b_advantages[batch_index] * torch.clamp(
                ratio, 1 - clip_coef, 1 + clip_coef
            )
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            # Value loss
            value = value.flatten()
            v_loss_unclipped = (value - b_returns[batch_index]) ** 2
            v_clipped = b_values[batch_index] + torch.clamp(
                value - b_values[batch_index],
                -clip_coef,
                clip_coef,
            )
            v_loss_clipped = (v_clipped - b_returns[batch_index]) ** 2
            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
            v_loss = 0.5 * v_loss_max.mean()

            entropy_loss = entropy.mean()
            loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
    var_y = np.var(y_true)
    explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

    # update opponent
    if args.agent_curriculum in ["FSP", "PFSP"]:
        if episode % 500 == 0 and episode != 0:
            mp_curriculum.update_agent(agent)

  0%|          | 0/500 [00:00<?, ?it/s]

saving checkpoint --6000
saving checkpoint --8000
saving checkpoint --10000
saving checkpoint --12000
saving checkpoint --14000
saving checkpoint --16000
saving checkpoint --18000
saving checkpoint --20000
saving checkpoint --22000
saving checkpoint --24000
saving checkpoint --26000
saving checkpoint --28000


In [11]:
joblib.load("test_checkpoints/DR_PFSP_6000_seed_0.pkl")

Agent(
  (network): Sequential(
    (0): Linear(in_features=75, out_features=512, bias=True)
    (1): ReLU()
  )
  (actor): Linear(in_features=512, out_features=5, bias=True)
  (critic): Linear(in_features=512, out_features=1, bias=True)
)

In [10]:
joblib.load("lasertag_DR_SP_checkpoints/DR_SP_2000_seed_0.pkl")

Agent(
  (network): Sequential(
    (0): Linear(in_features=75, out_features=512, bias=True)
    (1): ReLU()
  )
  (actor): Linear(in_features=512, out_features=5, bias=True)
  (critic): Linear(in_features=512, out_features=1, bias=True)
)