## PPO design notes

### Actor/critic networks design
### Env
-convert obs to proper obs

-reward calculation

In [1]:
import json
from IPython.display import display, Javascript
from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
import os
from my_agent.lux.utils import direction_to, direction_to_change
import matplotlib.pyplot as plt
import numpy as np
import random
from maps import EnergyMap, RelicMap, TileMap
from astar import *
import gymnasium as gym
from gymnasium.spaces import MultiDiscrete, Discrete, Tuple
from agent import Agent
from datetime import datetime

In [2]:
class Args:
    exp_name: str =""
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = False
    """if toggled, cuda will be enabled by default"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    # Algorithm specific arguments
    env_id: str = "CartPole-v1"
    """the id of the environment"""
    total_timesteps: int = 500000
    """total timesteps of the experiments"""
    learning_rate_actor: float = 2.5e-4
    """the learning rate of the actor optimizer"""
    learning_rate_critic: float = 2.5e-3
    """the learning rate of the critic optimizer"""
    num_envs: int = 1
    """the number of parallel game environments"""
    num_steps: int = 504
    """the number of steps to run in each environment per policy rollout"""
    anneal_lr: bool = True
    """Toggle learning rate annealing for policy and value networks"""
    gamma: float = 0.99
    """the discount factor gamma"""
    gae_lambda: float = 0.95
    """the lambda for the general advantage estimation"""
    minibatch_size: int = 32
    """the number of mini-batches"""
    update_epochs: int = 80
    """the K epochs to update the policy"""
    norm_adv: bool = True
    """Toggles advantages normalization"""
    clip_coef: float = 0.3
    """the surrogate clipping coefficient"""
    clip_vloss: bool = True
    """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
    ent_coef: float = 0.01
    """coefficient of the entropy"""
    vf_coef: float = 0.5
    """coefficient of the value function"""
    max_grad_norm: float = 0.5
    """the maximum norm for the gradient clipping"""
    target_kl: float = 0.01
    """the target KL divergence threshold"""

    # to be filled in runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    num_iterations: int = 0
    """the number of iterations (computed in runtime)"""

In [3]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class Critic(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.n_ens = env.observation_space[1].shape[0]
        self.n_maps = len(env.single_observation_space[0])
        self.n_state_params = env.single_observation_space[1].shape[0]
        self.n_action = env.single_action_space.shape[0]
        self.action_dim = env.single_action_space.nvec[-1,-1]
        self.n_unit_states = env.single_observation_space[2].shape[1]
        self.transformer_embedding_dim = env.get_attr("transformer_embedding_dim")[0]
        self.state_param_embedding_dim = env.get_attr("state_param_embedding_dim")[0]
        
        
        self.critic_net = nn.Sequential(
            layer_init(nn.Linear(self.n_unit_states, 32)),
            nn.ReLU(),
            layer_init(nn.Linear(32, 1)),
        )
    def get_value(self, x):
        maps, state_params, unit_params = x
        return torch.sum(self.critic_net(unit_params).squeeze(), dim=-1)

        
# TODO network design
class Actor(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.n_ens = env.observation_space[1].shape[0]
        self.n_maps = len(env.single_observation_space[0])
        self.n_state_params = env.single_observation_space[1].shape[0]
        self.n_action = env.single_action_space.shape[0]
        self.action_dim = env.single_action_space.nvec[-1,-1]
        self.n_unit_states = env.single_observation_space[2].shape[1]
        self.transformer_embedding_dim = env.get_attr("transformer_embedding_dim")[0]
        self.state_param_embedding_dim = env.get_attr("state_param_embedding_dim")[0]
        
        self.state_params_to_hidden = nn.Sequential(
            layer_init(nn.Linear(self.n_state_params, 32)),
            nn.ReLU(),
            layer_init(nn.Linear(32, self.state_param_embedding_dim)),
            nn.ReLU(),
        )
        
        self.embedding_maps = nn.Sequential(
            layer_init(nn.Linear(self.n_maps, 16)),
            layer_init(nn.Linear(16, self.transformer_embedding_dim)),
        )
        
        self.embedding_unit_params = nn.Sequential(
            layer_init(nn.Linear(self.n_unit_states, 16)),
            layer_init(nn.Linear(16, self.transformer_embedding_dim)),
        )
        
        self.actor_encoder = torch.nn.TransformerEncoderLayer(self.transformer_embedding_dim,4,64, batch_first=True)
        self.actor_decoder = torch.nn.TransformerDecoderLayer(self.transformer_embedding_dim,4,64, batch_first=True)

        self.out_to_logits = nn.Sequential(
            layer_init(nn.Linear(self.transformer_embedding_dim+self.state_param_embedding_dim, 32)),
            nn.ReLU(),
            layer_init(nn.Linear(32, 2+4*24)),
            nn.ReLU(),
        )

    def get_action(self, x, action=None):
        maps, state_params, unit_params = x
        maps = torch.flatten(maps,start_dim=-2).permute(0,2,1)
        batch_size, n_units = unit_params.shape[0], unit_params.shape[1] # B, N
        
        encoder_out = self.actor_encoder(self.embedding_maps(maps)) # B x 576 x 16
        decoder_out = self.actor_decoder(self.embedding_unit_params(unit_params), encoder_out) # B x N x 16
        
        state_params_hidden = self.state_params_to_hidden(state_params) # B x 8
        decoder_out_state_params_combined = torch.cat((decoder_out, torch.stack([state_params_hidden for i in range(n_units)],dim=1)),dim=-1) # B x N x 24
        all_logits = self.out_to_logits(decoder_out_state_params_combined) # B x N x 2+4*24
        
        move_type_logits = all_logits[:,:,:2].reshape(batch_size, n_units, 1, 2) # B x N x 1 x 2
        target_logits = all_logits[:,:,2:].reshape(batch_size, n_units, 4, self.action_dim) # B x N x 4 x 24
        move_type_probs = Categorical(logits=move_type_logits)
        target_probs = Categorical(target_logits)

        
        if action is None:
            action_type = move_type_probs.sample()
            action_target = target_probs.sample()
            action = torch.cat((action_type,action_target),dim=-1) # B x N x 5
        else:
            action_type = action[:,:,0].unsqueeze(dim=-1)
            action_target = action[:,:,1:]
        probs = torch.cat((move_type_probs.log_prob(action_type), target_probs.log_prob(action_target)),dim=-1) # B x N x 5
        return action, probs, move_type_probs.entropy() + target_probs.entropy()

In [4]:
class ActorCritic(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.n_ens = env.observation_space[1].shape[0]
        self.n_maps = len(env.single_observation_space[0])
        self.n_state_params = env.single_observation_space[1].shape[0]
        self.n_action = env.single_action_space.shape[0]
        self.action_dim = env.single_action_space.nvec[-1,-1]
        self.n_unit_states = env.single_observation_space[2].shape[1]
        self.transformer_embedding_dim = env.get_attr("transformer_embedding_dim")[0]
        self.state_param_embedding_dim = env.get_attr("state_param_embedding_dim")[0]
        #self.transformer_embedding_dim = env[0].transformer_embedding_dim
        
        self.state_params_to_hidden = nn.Sequential(
            layer_init(nn.Linear(self.n_state_params, 32)),
            nn.ReLU(),
            layer_init(nn.Linear(32, self.state_param_embedding_dim)),
            nn.ReLU(),
        )
        
        self.embedding_maps = nn.Sequential(
            layer_init(nn.Linear(self.n_maps, 16)),
            layer_init(nn.Linear(16, self.transformer_embedding_dim)),
        )
        
        self.embedding_unit_params = nn.Sequential(
            layer_init(nn.Linear(self.n_unit_states, 16)),
            layer_init(nn.Linear(16, self.transformer_embedding_dim)),
        )
        
        self.actor_encoder = torch.nn.TransformerEncoderLayer(self.transformer_embedding_dim,4,64, batch_first=True)
        self.actor_decoder = torch.nn.TransformerDecoderLayer(self.transformer_embedding_dim,4,64, batch_first=True)
        
        self.critic_encoder = torch.nn.TransformerEncoderLayer(self.transformer_embedding_dim,4,64, batch_first=True)
        self.critic_decoder = torch.nn.TransformerDecoderLayer(self.transformer_embedding_dim,4,64, batch_first=True)
        self.out_to_logits = nn.Sequential(
            layer_init(nn.Linear(self.transformer_embedding_dim+self.state_param_embedding_dim, 32)),
            nn.ReLU(),
            layer_init(nn.Linear(32, 2+4*24)),
            nn.ReLU(),
        )
        
        self.encoder_out_to_critic = nn.Sequential(
            layer_init(nn.Linear(8, 1)),
            nn.ReLU(),
        )

        self.critic_out_old = nn.Sequential(
            layer_init(nn.Linear(self.transformer_embedding_dim, 32)),
            nn.ReLU(),
            layer_init(nn.Linear(32, 1)),
        )
        self.critic_out = nn.Sequential(
            layer_init(nn.Linear(self.n_unit_states, 32)),
            nn.ReLU(),
            layer_init(nn.Linear(32, 1)),
        )
        

    def critic_old(self, maps, unit_params):
        encoder_out = self.critic_encoder(self.embedding_maps(maps)) # B x 576 x 8
        decoder_out = self.critic_decoder(self.embedding_unit_params(unit_params), encoder_out) # B x N x 8
        return torch.sum(self.critic_out_old(decoder_out).squeeze(), dim=-1)

    def critic(self, maps, unit_params):
        return torch.sum(self.critic_out(unit_params).squeeze(), dim=-1)
        
        
    def get_value(self, x):
        maps, state_params, unit_params = x
        maps = torch.flatten(maps,start_dim=-2).permute(0,2,1)
        value = self.critic(maps, unit_params)
        return value

    def get_action_and_value(self, x, action=None):
        maps, state_params, unit_params = x
        maps = torch.flatten(maps,start_dim=-2).permute(0,2,1)
        batch_size, n_units = unit_params.shape[0], unit_params.shape[1] # B, N
        
        encoder_out = self.actor_encoder(self.embedding_maps(maps)) # B x 576 x 16
        decoder_out = self.actor_decoder(self.embedding_unit_params(unit_params), encoder_out) # B x N x 16
        
        state_params_hidden = self.state_params_to_hidden(state_params) # B x 8
        decoder_out_state_params_combined = torch.cat((decoder_out, torch.stack([state_params_hidden for i in range(n_units)],dim=1)),dim=-1) # B x N x 24
        all_logits = self.out_to_logits(decoder_out_state_params_combined) # B x N x 2+4*24
        
        move_type_logits = all_logits[:,:,:2].reshape(batch_size, n_units, 1, 2) # B x N x 1 x 2
        target_logits = all_logits[:,:,2:].reshape(batch_size, n_units, 4, self.action_dim) # B x N x 4 x 24
        move_type_probs = Categorical(logits=move_type_logits)
        target_probs = Categorical(target_logits)

        value = self.critic(maps, unit_params)
        
        if action is None:
            action_type = move_type_probs.sample()
            action_target = target_probs.sample()
            action = torch.cat((action_type,action_target),dim=-1) # B x N x 5
        else:
            action_type = action[:,:,0].unsqueeze(dim=-1)
            action_target = action[:,:,1:]
        probs = torch.cat((move_type_probs.log_prob(action_type), target_probs.log_prob(action_target)),dim=-1) # B x N x 5
        return action, probs, move_type_probs.entropy() + target_probs.entropy(), value

In [5]:
def train(exp_name, args):
    args.exp_name = exp_name
    args.batch_size = int(args.num_envs * args.num_steps)
    args.num_iterations = args.total_timesteps // args.batch_size
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    run_name = f"{args.exp_name}__{args.seed}__{timestamp}"
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    envs = gym.vector.SyncVectorEnv(
        [env_fn for i in range(args.num_envs)],
    )
    

    actor = Actor(envs).to(device)
    critic = Critic(envs).to(device)
    optimizer_actor = optim.Adam(actor.parameters(), lr=args.learning_rate_actor, eps=1e-5)
    optimizer_critic = optim.Adam(critic.parameters(), lr=args.learning_rate_critic, eps=1e-5)
    # ALGO Logic: Storage setup
    obs = (torch.zeros((args.num_steps, args.num_envs) + np.array(envs.single_observation_space[0]).shape).to(device),
           torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space[1].shape).to(device),
          torch.zeros((args.num_steps, args.num_envs) + np.array(envs.single_observation_space[2]).shape).to(device))
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    logprobs = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs)).to(device)

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    save_thresh = 100000
    start_time = time.time()
    next_obs, _ = envs.reset(seed=args.seed)
    next_obs = (torch.Tensor(np.array(next_obs[0])).to(device).reshape((args.num_envs,) + np.array(envs.single_observation_space[0]).shape),
                torch.Tensor(np.array(next_obs[1])).to(device).reshape((args.num_envs,)+envs.single_observation_space[1].shape),
                torch.Tensor(np.array(next_obs[2])).to(device).reshape((args.num_envs,) + np.array(envs.single_observation_space[2]).shape),
               )
    next_done = torch.zeros(args.num_envs).to(device)
    t = 0
    for iteration in range(1, args.num_iterations + 1):
        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            frac = 1.0 - (iteration - 1.0) / args.num_iterations
            lrnow_actor = frac * args.learning_rate_actor
            optimizer_actor.param_groups[0]["lr"] = lrnow_actor
            lrnow_critic = frac * args.learning_rate_critic
            optimizer_critic.param_groups[0]["lr"] = lrnow_critic
            
        for step in range(0, args.num_steps):
            global_step += args.num_envs
            obs[0][step] = next_obs[0]
            obs[1][step] = next_obs[1]
            obs[2][step] = next_obs[2]
            dones[step] = next_done
            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, logprob, _ = actor.get_action(next_obs)
                value = critic.get_value(next_obs)
                values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
            next_done = np.logical_or(terminations, truncations)
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs = (torch.Tensor(np.array(next_obs[0])).to(device).swapaxes(0,1),
                torch.Tensor(np.array(next_obs[1])).to(device),
                torch.Tensor(np.array(next_obs[2])).to(device),
               )
            next_done = torch.Tensor(next_done).to(device)
            if "final_info" in infos:
                for info in infos["final_info"]:
                    if info and "episode" in info:
                        print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                        writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                        writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
        writer.add_scalar("charts/reward", torch.sum(rewards), global_step)
        # bootstrap value if not done
        with torch.no_grad():
            next_value = critic.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values
        # flatten the batch
        b_obs = (obs[0].reshape((-1,) + np.array(envs.single_observation_space[0]).shape), 
                 obs[1].reshape((-1,) + envs.single_observation_space[1].shape),
                 obs[2].reshape((-1,) + np.array(envs.single_observation_space[2]).shape), )
        b_logprobs = logprobs.reshape((-1,) + envs.single_action_space.shape)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # Optimizing the policy and value network
        b_inds = np.arange(args.batch_size)
        clipfracs = []
        approx_kl = 0
        actor_updates = 0
        for epoch in range(args.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                # Policy update
                if approx_kl is not None and approx_kl<args.target_kl:
                    actor_updates = epoch+1
                    _, newlogprob, entropy = actor.get_action((b_obs[0][mb_inds],b_obs[1][mb_inds],b_obs[2][mb_inds]), b_actions.long()[mb_inds])
                    logratio = newlogprob - b_logprobs[mb_inds]
                    ratio = logratio.exp()
                    entropy_loss = entropy.mean()
            
                    with torch.no_grad():
                        old_approx_kl = (-logratio).mean()
                        approx_kl = ((ratio - 1) - logratio).mean()
                        clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
    
                    mb_advantages = b_advantages[mb_inds].view(-1,1,1)
                    if args.norm_adv:
                        mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
      
                    pg_loss1 = -mb_advantages * ratio
                    pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                    pg_loss = torch.max(pg_loss1, pg_loss2).sum(dim=-1).sum(dim=-1).mean()                
                    
                    loss_actor = pg_loss - args.ent_coef * entropy_loss
    
                    optimizer_actor.zero_grad()
                    loss_actor.backward()
                    nn.utils.clip_grad_norm_(actor.parameters(), args.max_grad_norm)
                    optimizer_actor.step()
                
                # Value update
                newvalue = critic.get_value((b_obs[0][mb_inds],b_obs[1][mb_inds],b_obs[2][mb_inds]))
                newvalue = newvalue.view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
                loss_critic = v_loss * args.vf_coef
                optimizer_critic.zero_grad()
                loss_critic.backward()
                nn.utils.clip_grad_norm_(critic.parameters(), args.max_grad_norm)
                optimizer_critic.step()

            
        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
        # TRY NOT TO MODIFY: record rewards for plotting purposes
        writer.add_scalar("charts/learning_rate_actor", optimizer_actor.param_groups[0]["lr"], global_step)
        writer.add_scalar("charts/learning_rate_critic", optimizer_critic.param_groups[0]["lr"], global_step)
        writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
        writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
        writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
        writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.add_scalar("losses/epoch_to_kl", actor_updates, global_step)
        writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
        writer.add_scalar("losses/explained_variance", explained_var, global_step)
        #print("SPS:", int(global_step / (time.time() - start_time)))
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
        if global_step>save_thresh:
            torch.save({"actor" : actor.state_dict(),
                        "critic" : critic.state_dict()}, "models/" + run_name + f"_step_{global_step}")
            save_thresh += 100000

    envs.close()
    writer.close()
    torch.save(agent.state_dict(), "models/" + run_name + "_step_" + iterations)

In [6]:
def env_fn():
    return ProxyEnvironment()

class ProxyEnvironment(gym.Env):
    def __init__(self):
        self.n_maps = 6
        self.n_state_params = 3
        self.transformer_embedding_dim = 8
        self.state_param_embedding_dim = 8
        self.map_space = Tuple((
            MultiDiscrete(np.full((24,24),24)),
            MultiDiscrete(np.full((24,24),24)),
            MultiDiscrete(np.full((24,24),24)),
            MultiDiscrete(np.full((24,24),24)),
            MultiDiscrete(np.full((24,24),24)),
            MultiDiscrete(np.full((24,24),24)),
        ))
        self.unit_param_space = MultiDiscrete(np.repeat(np.expand_dims(np.array([24,24,24,24,401,11,2]),0),16,axis=0),
                                              start=np.repeat(np.expand_dims(np.array([0,0,0,0,0,-10,0]),0),16,axis=0))
        self.param_space = MultiDiscrete(np.array([505, 1000, 16*400]))
        self.observation_space = Tuple((self.map_space, self.param_space, self.unit_param_space))
        #print(self.observation_space)
        self.action_space = MultiDiscrete(np.repeat(np.expand_dims(np.array([2,24,24,24,24]),0),16,axis=0))
        self.env = RecordEpisode(LuxAIS3GymEnv(numpy_output=True), save_on_close=False, save_on_reset=False, save_dir="replays")
        self.obs, info  = self.env.reset()
        self.agent1 = ProxyAgent("player_0", info["params"])
        self.agent2 = Agent("player_1", info["params"])
        self.current_step = 0

    def close(self):
        self.env.close()
        
    def reset(self, seed=42, options=0):
        self.current_step = 0
        self.obs, info = self.env.reset(seed=seed)
        self.agent1 = ProxyAgent("player_0", info["params"])
        self.agent2 = Agent("player_1", info["params"])
        self.proxy_obs = self.agent1.get_init_proxy_obs(self.obs)
        return self.proxy_obs, info

    def step(self, proxy_action):
        self.current_step += 1
        actions = dict()
        actions["player_0"] = self.agent1.proxy_to_act(proxy_action)
        actions["player_1"] = self.agent2.act(step=self.current_step, obs=self.obs[self.agent2.player])
        #print(self.obs[self.agent1.player])
        self.obs, reward, terminated, truncated, info = self.env.step(actions)
        terminated = terminated["player_0"]
        truncated = truncated["player_0"]
        #print(self.obs[self.agent1.player]["units_mask"])
        self.proxy_obs, self.proxy_reward = self.agent1.step(self.obs[self.agent1.player], self.current_step)
        return self.proxy_obs, self.proxy_reward, terminated, truncated, info

In [7]:
class ProxyAgent():
    def __init__(self, player: str, env_cfg, model_name=None) -> None:
        self.player = player
        self.opp_player = "player_1" if self.player == "player_0" else "player_0"
        self.team_id = 0 if self.player == "player_0" else 1
        self.opp_team_id = 1 if self.team_id == 0 else 0
        np.random.seed(0)
        self.env_cfg = env_cfg
        if self.player=="player_0":
            self.start_pos = [0,0]
            self.pnum = 1
        else:
            self.start_pos = [23,23]
            self.pnum = 0
        self.unit_explore_locations = dict()
        self.relic_node_positions = []
        self.discovered_relic_nodes_ids = set()
        self.n_units = self.env_cfg["max_units"]
        self.match_num = 1
        self.relic_map = RelicMap(self.n_units)
        self.tile_map = TileMap()
        self.energy_map = EnergyMap()
        self.move_cost = 3.0
        self.nebula_drain = 5.0
        self.move_check = 0
        self.nebula_check = 0
        
        self.range = self.env_cfg["unit_sensor_range"]
        self.sap_range = self.env_cfg["unit_sap_range"]
        self.sap_cost = self.env_cfg["unit_sap_cost"]
        self.width = self.env_cfg["map_width"]
        self.height = self.env_cfg["map_height"]
        
        self.unit_has_target = -np.ones((self.n_units))
        self.unit_targets = dict(zip(range(0,self.n_units), np.zeros((self.n_units,2))))
        self.unit_targets_previous = dict(zip(range(0,self.n_units), np.zeros((self.n_units,2))))
        self.unit_path = dict(zip(range(0,self.n_units), [[] for i in range(0,self.n_units)]))
        self.unit_energys = np.full((self.n_units),100)
        self.unit_positions = -np.ones((self.n_units,2))
        self.available_unit_ids = []
        self.unit_moved = np.zeros((self.n_units))
        self.prev_points = 0
        self.prev_point_diff = 0
        self.prev_points_increase = 0
        self.wins = 0
        self.losses = 0
        self.prev_actions = None
        self.previous_energys = 100*np.zeros((self.n_units))
        self.previous_positions = -np.ones((self.n_units,2))
        self.model = None
        if model_name:
            self.model = torch.load(model_name)
        
    def reset(self):
        self.match_num += 1
        self.unit_has_target = -np.ones((self.n_units))
        self.unit_targets = dict(zip(range(0,self.n_units), np.zeros((self.n_units,2))))
        self.unit_targets_previous = dict(zip(range(0,self.n_units), np.zeros((self.n_units,2))))
        self.unit_path = dict(zip(range(0,self.n_units), [[] for i in range(0,self.n_units)]))
        self.available_unit_ids = []
        self.unit_moved = np.zeros((self.n_units))
        self.prev_points = 0
        self.prev_point_diff = 0
        self.prev_points_increase = 0
        self.prev_actions = np.zeros((self.env_cfg["max_units"]), dtype=int)
        self.prev_energys = 100*np.ones((self.n_units))
        self.previous_positions = -np.ones((self.n_units,2))

    def compare_positions(self, pos1, pos2):
        return pos1[0]==pos2[0] and pos1[1]==pos2[1]
        
    # bunnyhop mechanic (maximize points by avoiding doubling on fragment)
    def bunnyhop(self, unit, unit_positions):
        counter = 0
        unit_pos = unit_positions[unit]
        for unit2 in range(self.n_units):            
            if self.unit_has_target[unit2]==2 and self.tile_map.map[unit_positions[unit2][0],unit_positions[unit2][1]]!=2 and len(self.unit_path[unit])>1 and self.compare_positions(self.unit_path[unit][0],unit_positions[unit2]):
                self.unit_path[unit2] = self.unit_path[unit][1:]
                self.unit_targets[unit2] = self.unit_targets[unit]
                self.unit_has_target[unit2] = 1#self.unit_has_target[unit]
                self.unit_path[unit] = [unit_positions[unit2]]
                self.unit_targets[unit] = unit_positions[unit2]
                self.unit_has_target[unit] = 1
                counter +=1
                if counter<10:
                    self.bunnyhop(unit2, unit_positions)

    def positions_to_map(self, unit_positions):
        unit_map = np.zeros((24,24))
        for unit in unit_positions:
            if unit[0]!=-1 and unit[1]!=-1:
                unit_map[unit[0],unit[1]] = 1
        return unit_map

    # adjust for not only direct hits, but adjacent hits
    def check_hit(self, target):
        for pos in self.enemy_positions:
            if pos[0]!=-1 and pos[1]!=-1:
                if pos[0]==target[0] and pos[1]==target[1]:
                    return 1
        else:
            return 0

    def get_init_proxy_obs(self, obs):
         return (np.array([np.zeros((24,24),dtype=int) for i in range(6)]),np.array([0,0,0]), np.zeros((self.n_units,7),dtype=int))
     
    def step(self, obs, step):
        reward = 0
        unit_mask = np.array(obs["units_mask"][self.team_id]) # shape (max_units, )
        #print(step, unit_mask)
        self.unit_positions = np.array(obs["units"]["position"][self.team_id]) # shape (max_units, 2)
        self.enemy_positions = np.array(obs["units"]["position"][abs(self.team_id-1)]).tolist()
        self.unit_energys = np.array(obs["units"]["energy"][self.team_id]) # shape (max_units, 1)
        observed_relic_node_positions = np.array(obs["relic_nodes"]) # shape (max_relic_nodes, 2)
        observed_relic_nodes_mask = np.array(obs["relic_nodes_mask"]) # shape (max_relic_nodes, )
        team_points = np.array(obs["team_points"]) # points of each team, team_points[self.team_id] is the points of the your team
        increase = team_points[self.team_id]-self.prev_points
        diff = team_points[self.team_id] - team_points[abs(self.team_id-1)]
        diff_change = diff-self.prev_point_diff
        self.prev_point_diff = diff
        # ids of units you can control at this timestep
        current_tile_map = obs["map_features"]["tile_type"]
        current_energy_map = obs["map_features"]["energy"]
        ### proxy reward calculation ###
        # change in point difference 
        '''reward += 10*increase
        # units on known fragment tiles
        if self.obs["team_points"][self.team_id]>self.wins:
            self.wins = self.obs["team_points"][self.team_id]
            retard += 1000
        if self.obs["team_points"][abs(self.team_id-1)]>self.wins:
            self.wins = self.obs["team_points"][abs(self.team_id-1)]
            retard += -1000'''
            
        for unit in range(self.n_units):
            pos = self.unit_positions[unit]
            reward -= abs(pos[0]-4)
            if pos[0]==4:
                reward +=10
            t = self.unit_targets[unit]
            if t[0]==4:
                reward +=10
            if self.prev_actions[unit,0]==5:
                reward -=10
            '''if pos[0]!=-1 and pos[1]!=-1:
                if self.relic_map.map_knowns[pos[0],pos[1]]==1:
                    reward += 10
                # units targeting possibles/known fragments
                t = self.unit_targets[unit]
                #print(t[0], t[1])
                if self.relic_map.map_knowns[int(t[0]),int(t[1])]==1 or self.relic_map.map_possibles[int(t[0]),int(t[1])]==1:
                    reward += 10
                if self.tile_map.map[int(t[0]),int(t[1])]==-1:
                    reward += 1'''
            # unit dies (negative reward)
            #else: 
            #    if self.unit_moved[unit]:
            #        reward += -1
            # hit enemy
            #action = self.prev_actions[unit]
            #print(action)
            #if action[0]==5:
            #    reward += self.check_hit(action[1:])
            # collision
            #if action[0]>0 and pos[0]==self.previous_positions[unit][0] and pos[0]==self.previous_positions[unit][1]:
            #    reward += -10
            #f self.compare_positions
                
            
            
        
        if step in [102,203,304,405]:
            self.reset()
            
        # visible relic nodes
        visible_relic_node_ids = set(np.where(observed_relic_nodes_mask)[0])
        # save any new relic nodes that we discover for the rest of the game.
        for ii in visible_relic_node_ids:
            if ii not in self.discovered_relic_nodes_ids:
                # explore units switch to relic collection
                self.relic_map.new_relic(observed_relic_node_positions[ii])
                self.discovered_relic_nodes_ids.add(ii)
                self.discovered_relic_nodes_ids.add((ii+3)%6)
                self.relic_node_positions.append(observed_relic_node_positions[ii])
        # update maps
        self.available_unit_ids = np.where(unit_mask)[0].tolist()
        self.relic_map.step(self.unit_positions, increase)
        tile_shift = self.tile_map.update(current_tile_map)
        energy_shift = self.energy_map.update(current_energy_map)        

        # find out move cost
        if step>2 and not self.move_check and self.tile_map.map[self.unit_positions[0][0],self.unit_positions[0][1]]!=1 and self.unit_moved[0]:
            self.move_cost=self.previous_energys[0]-self.unit_energys[0]+self.energy_map.map[self.unit_positions[0][0],self.unit_positions[0][1]]
            self.move_check=1
        # find out nebula drain
        if not self.nebula_check and self.move_check:
            for unit in self.available_unit_ids:
                if self.unit_moved[unit] and  self.tile_map.map[self.unit_positions[unit][0],self.unit_positions[unit][1]]==1:
                    self.nebula_check=1
                    self.nebula_drain = -(self.unit_energys[unit]-self.previous_energys[unit]-self.energy_map.map[self.unit_positions[unit][0],self.unit_positions[unit][1]]+self.move_cost)
                    break

        
        self.previous_energys = self.unit_energys
        self.prev_points = team_points[self.team_id]
        self.prev_points_increase = increase
        self.previous_positions = self.unit_positions

        # TODO explore map
        tiles = np.ones((24,24))
        tiles[self.tile_map.map==2] = 0
        energy = self.energy_map.map.copy()
        energy[self.tile_map.map==1] = energy[self.tile_map.map==1] - self.nebula_drain
        my_unit_map = self.positions_to_map(self.unit_positions)
        enemy_unit_map = self.positions_to_map(self.enemy_positions)
        on_known = np.zeros((self.n_units,1))
        tile_energys = np.zeros((self.n_units,1))
        for ii, p in enumerate(self.unit_positions):
            if self.relic_map.map_knowns[p[0],p[1]]==1:
                on_known[ii] = 1
            tile_energys[ii] = energy[p[0],p[1]]
                

        proxy_obs = (np.array([tiles.astype(int), energy.astype(int), self.relic_map.map_possibles.astype(int), self.relic_map.map_knowns.astype(int), my_unit_map.astype(int), enemy_unit_map.astype(int)]), 
                     np.array([step, diff, np.sum(self.unit_energys)]), 
                     np.concatenate((np.array(self.unit_positions).astype(int), np.array(list(self.unit_targets.values())).astype(int), np.expand_dims(self.unit_energys,-1).astype(int), tile_energys.astype(int), on_known.astype(int)), axis=-1))
        return proxy_obs, reward
        
    def act(self, obs, step):
        proxy_obs, _ = self.step(obs, step)
        proxy_action,_ = self.model.get_value_and_action(proxy_obs)
        return proxy_to_act(proxy_action)
        
        
        
    def proxy_to_act(self, proxy_action):
        if torch.is_tensor(proxy_action):
            proxy_action = proxy_action.squeeze().cpu().detach().numpy()
        actions = np.zeros((self.n_units, 3), dtype=int)
        for unit in self.available_unit_ids:
            if proxy_action[unit,0]==1:
                actions[unit] = [5, proxy_action[unit,3], proxy_action[unit,4]]
            else:
                self.unit_targets[unit] = [proxy_action[unit,1],proxy_action[unit,2]]
                '''if not self.compare_positions(self.unit_targets[unit], self.unit_targets_previous[unit]):
                    path, _ = a_star(unit_positions[unit], self.unit_targets[unit], self.tile_map.map, self.energy_map.map, self.relic_map.map_knowns, self.move_cost, self.nebula_drain, use_energy=False)
                    self.unit_path[unit] = path[1:]'''
                direction = direction_to(self.unit_positions[unit], self.unit_targets[unit])
                change = direction_to_change(direction)
                self.unit_path[unit] = [self.unit_positions[unit][0]+change[0],self.unit_positions[unit][1]+change[1]]
                actions[unit] = [direction, 0, 0]

        self.prev_actions = actions
        self.unit_targets_previous = self.unit_targets
        return actions

In [8]:
name = "best_more_ep"
train(name, Args)

  gym.logger.warn(
  gym.logger.warn(

KeyboardInterrupt

