In [1]:
from luxai_s3.wrappers import LuxAIS3GymEnv
from enum import IntEnum

# taken from https://www.kaggle.com/code/yizhewang3/ppo-stable-baselines3

SPACE_SIZE = 24
NUM_TEAMS = 2
MAX_UNITS = 16
RELIC_REWARD_RANGE = 2
MAX_STEPS_IN_MATCH = 100
MAX_ENERGY_PER_TILE = 20
MAX_RELIC_NODES = 6
LAST_MATCH_STEP_WHEN_RELIC_CAN_APPEAR = 50
LAST_MATCH_WHEN_RELIC_CAN_APPEAR = 2

# We will find the exact value of these constants during the game
UNIT_MOVE_COST = 1  # OPTIONS: list(range(1, 6))
UNIT_SAP_COST = 30  # OPTIONS: list(range(30, 51))
UNIT_SAP_RANGE = 3  # OPTIONS: list(range(3, 8))
UNIT_SENSOR_RANGE = 2  # OPTIONS: [1, 2, 3, 4]
OBSTACLE_MOVEMENT_PERIOD = 20  # OPTIONS: 6.67, 10, 20, 40
OBSTACLE_MOVEMENT_DIRECTION = (0, 0)  # OPTIONS: [(1, -1), (-1, 1)]

# We will NOT find the exact value of these constants during the game
NEBULA_ENERGY_REDUCTION = 5  # OPTIONS: [0, 1, 2, 3, 5, 25]

# Exploration flags:

ALL_RELICS_FOUND = False
ALL_REWARDS_FOUND = False
OBSTACLE_MOVEMENT_PERIOD_FOUND = False
OBSTACLE_MOVEMENT_DIRECTION_FOUND = False

# Game logs:

# REWARD_RESULTS: [{"nodes": Set[Node], "points": int}, ...]
# A history of reward events, where each entry contains:
# - "nodes": A set of nodes where our ships were located.
# - "points": The number of points scored at that location.
# This data will help identify which nodes yield points.
REWARD_RESULTS = []

# obstacles_movement_status: list of bool
# A history log of obstacle (asteroids and nebulae) movement events.
# - `True`: The ships' sensors detected a change in the obstacles' positions at this step.
# - `False`: The sensors did not detect any changes.
# This information will be used to determine the speed and direction of obstacle movement.
OBSTACLES_MOVEMENT_STATUS = []

# Others:

# The energy on the unknown tiles will be used in the pathfinding
HIDDEN_NODE_ENERGY = 0



class NodeType(IntEnum):
    unknown = -1
    empty = 0
    nebula = 1
    asteroid = 2

    def __str__(self):
        return self.name

    def __repr__(self):
        return self.name


_DIRECTIONS = [
    (0, 0),  # center
    (0, -1),  # up
    (1, 0),  # right
    (0, 1),  #  down
    (-1, 0),  # left
    (0, 0),  # sap
]


class ActionType(IntEnum):
    center = 0
    up = 1
    right = 2
    down = 3
    left = 4
    sap = 5

    def __str__(self):
        return self.name

    def __repr__(self):
        return self.name

    @classmethod
    def from_coordinates(cls, current_position, next_position):
        dx = next_position[0] - current_position[0]
        dy = next_position[1] - current_position[1]

        if dx < 0:
            return ActionType.left
        elif dx > 0:
            return ActionType.right
        elif dy < 0:
            return ActionType.up
        elif dy > 0:
            return ActionType.down
        else:
            return ActionType.center

    def to_direction(self):
        return _DIRECTIONS[self]


def get_match_step(step: int) -> int:
    return step % (MAX_STEPS_IN_MATCH + 1)


def get_match_number(step: int) -> int:
    return step // (MAX_STEPS_IN_MATCH + 1)


# def warp_int(x):
#     if x >= SPACE_SIZE:
#         x -= SPACE_SIZE
#     elif x < 0:
#         x += SPACE_SIZE
#     return x


# def warp_point(x, y) -> tuple:
#     return warp_int(x), warp_int(y)


def get_opposite(x, y) -> tuple:
    # Returns the mirrored point across the diagonal
    return SPACE_SIZE - y - 1, SPACE_SIZE - x - 1


def is_upper_sector(x, y) -> bool:
    return SPACE_SIZE - x - 1 >= y


def is_lower_sector(x, y) -> bool:
    return SPACE_SIZE - x - 1 <= y


def is_team_sector(team_id, x, y) -> bool:
    return is_upper_sector(x, y) if team_id == 0 else is_lower_sector(x, y)



In [2]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces

# modified from https://www.kaggle.com/code/yizhewang3/ppo-stable-baselines3

def flatten_obs(base_obs, env_cfg, pid: int):
    """
    Convert the multi-agent observation dictionary for the *current* player
    into the flattened dict structure you had.
    """

    # Sometimes there's an extra "obs" key; adapt as needed
    if "obs" in base_obs:
        base_obs = base_obs["obs"]

    flat_obs = {}
    flat_obs['pid'] = np.array([pid])
        # 处理 units 数据
    if "units" in base_obs:
        flat_obs["units_position"] = base_obs["units"]["position"]
        flat_obs["units_energy"] = base_obs["units"]["energy"]
        # 如果 units_energy 的 shape 为 (NUM_TEAMS, MAX_UNITS) 则扩展一个维度
        if flat_obs["units_energy"].ndim == 2:
            flat_obs["units_energy"] = np.expand_dims(flat_obs["units_energy"], axis=-1)
    else:
        flat_obs["units_position"] = np.array(base_obs["units_position"], dtype=np.int32)
        flat_obs["units_energy"] = np.array(base_obs["units_energy"], dtype=np.int32)
        if flat_obs["units_energy"].ndim == 2:
            flat_obs["units_energy"] = np.expand_dims(flat_obs["units_energy"], axis=-1)
    
    # 处理 units_mask
    if "units_mask" in base_obs:
        flat_obs["units_mask"] = np.array(base_obs["units_mask"], dtype=np.int8)
    else:
        flat_obs["units_mask"] = np.zeros(flat_obs["units_position"].shape[:2], dtype=np.int8)
    
    # 处理 sensor_mask：若返回的是 3D 数组，则取逻辑 or 得到全局 mask
    sensor_mask_arr = np.array(base_obs["sensor_mask"], dtype=np.int8)
    if sensor_mask_arr.ndim == 3:
        sensor_mask = np.any(sensor_mask_arr, axis=0).astype(np.int8)
    else:
        sensor_mask = sensor_mask_arr
    flat_obs["sensor_mask"] = sensor_mask

    # 处理 map_features（tile_type 与 energy）
    if "map_features" in base_obs:
        mf = base_obs["map_features"]
        flat_obs["map_features_tile_type"] = np.array(mf["tile_type"], dtype=np.int8)
        flat_obs["map_features_energy"] = np.array(mf["energy"], dtype=np.int8)
    else:
        flat_obs["map_features_tile_type"] = np.array(base_obs["map_features_tile_type"], dtype=np.int8)
        flat_obs["map_features_energy"] = np.array(base_obs["map_features_energy"], dtype=np.int8)

    # 处理 relic 节点信息
    if "relic_nodes_mask" in base_obs:
        flat_obs["relic_nodes_mask"] = np.array(base_obs["relic_nodes_mask"], dtype=np.int8)
    else:
        max_relic = env_cfg.get("max_relic_nodes", 6) if env_cfg is not None else 6
        flat_obs["relic_nodes_mask"] = np.zeros((max_relic,), dtype=np.int8)
    if "relic_nodes" in base_obs:
        flat_obs["relic_nodes"] = np.array(base_obs["relic_nodes"], dtype=np.int32)
    else:
        max_relic = env_cfg.get("max_relic_nodes", 6) if env_cfg is not None else 6
        flat_obs["relic_nodes"] = np.full((max_relic, 2), -1, dtype=np.int32)

    # 处理团队得分与胜局
    if "team_points" in base_obs:
        flat_obs["team_points"] = np.array(base_obs["team_points"], dtype=np.int32)
    else:
        flat_obs["team_points"] = np.zeros(2, dtype=np.int32)
    if "team_wins" in base_obs:
        flat_obs["team_wins"] = np.array(base_obs["team_wins"], dtype=np.int32)
    else:
        flat_obs["team_wins"] = np.zeros(2, dtype=np.int32)

    # 处理步数信息
    if "steps" in base_obs:
        flat_obs["steps"] = np.array([base_obs["steps"]], dtype=np.int32)
    else:
        flat_obs["steps"] = np.array([0], dtype=np.int32)
    if "match_steps" in base_obs:
        flat_obs["match_steps"] = np.array([base_obs["match_steps"]], dtype=np.int32)
    else:
        flat_obs["match_steps"] = np.array([0], dtype=np.int32)

    # 注意：不在此处处理 remainingOverageTime，
    # 将在 Agent.act 中利用传入的参数添加

    # 补全环境配置信息
    assert env_cfg is not None
    flat_obs["env_cfg_map_width"] = np.array([env_cfg["map_width"]], dtype=np.int32)
    flat_obs["env_cfg_map_height"] = np.array([env_cfg["map_height"]], dtype=np.int32)
    flat_obs["env_cfg_max_steps_in_match"] = np.array([env_cfg["max_steps_in_match"]], dtype=np.int32)
    flat_obs["env_cfg_unit_move_cost"] = np.array([env_cfg["unit_move_cost"]], dtype=np.int32)
    flat_obs["env_cfg_unit_sap_cost"] = np.array([env_cfg["unit_sap_cost"]], dtype=np.int32)
    flat_obs["env_cfg_unit_sap_range"] = np.array([env_cfg["unit_sap_range"]], dtype=np.int32)


    return flat_obs


class LuxCustomWrapper(gym.Wrapper):

    def __init__(self, base_env):
        """
        :param base_env: Your two-agent environment, e.g. LuxAIS3GymEnv, that returns:
                         obs: {"player_0": {...}, "player_1": {...}}
                         reward: {"player_0": float, "player_1": float}
                         done/trunc: {"player_0": bool, "player_1": bool}, etc.
        """
        super().__init__(base_env)
        self.agents = ["player_0", "player_1"]
        self.obs_space_single = spaces.Dict({
            "pid": spaces.Box(low=0, high=1, shape=(1,), dtype=np.int8),
            "units_position": spaces.Box(low=0, high=SPACE_SIZE - 1,
                                         shape=(NUM_TEAMS, MAX_UNITS, 2), dtype=np.int32),
            "units_energy": spaces.Box(low=0, high=400, shape=(NUM_TEAMS, MAX_UNITS, 1), dtype=np.int32),
            "units_mask": spaces.Box(low=0, high=1, shape=(NUM_TEAMS, MAX_UNITS), dtype=np.int8),
            "sensor_mask": spaces.Box(low=0, high=1, shape=(SPACE_SIZE, SPACE_SIZE), dtype=np.int8),
            "map_features_tile_type": spaces.Box(low=-1, high=2, shape=(SPACE_SIZE, SPACE_SIZE), dtype=np.int8),
            "map_features_energy": spaces.Box(low=-1, high=MAX_ENERGY_PER_TILE, shape=(SPACE_SIZE, SPACE_SIZE), dtype=np.int8),
            "relic_nodes_mask": spaces.Box(low=0, high=1, shape=(MAX_RELIC_NODES,), dtype=np.int8),
            "relic_nodes": spaces.Box(low=-1, high=SPACE_SIZE - 1, shape=(MAX_RELIC_NODES, 2), dtype=np.int32),
            "team_points": spaces.Box(low=0, high=1000, shape=(NUM_TEAMS,), dtype=np.int32),
            "team_wins": spaces.Box(low=0, high=1000, shape=(NUM_TEAMS,), dtype=np.int32),
            "steps": spaces.Box(low=0, high=MAX_STEPS_IN_MATCH, shape=(1,), dtype=np.int32),
            "match_steps": spaces.Box(low=0, high=MAX_STEPS_IN_MATCH, shape=(1,), dtype=np.int32),
            "env_cfg_map_width": spaces.Box(low=0, high=SPACE_SIZE, shape=(1,), dtype=np.int32),
            "env_cfg_map_height": spaces.Box(low=0, high=SPACE_SIZE, shape=(1,), dtype=np.int32),
            "env_cfg_max_steps_in_match": spaces.Box(low=0, high=MAX_STEPS_IN_MATCH, shape=(1,), dtype=np.int32),
            "env_cfg_unit_move_cost": spaces.Box(low=0, high=100, shape=(1,), dtype=np.int32),
            "env_cfg_unit_sap_cost": spaces.Box(low=0, high=100, shape=(1,), dtype=np.int32),
            "env_cfg_unit_sap_range": spaces.Box(low=0, high=100, shape=(1,), dtype=np.int32),
        })

        self.act_space_single =  spaces.MultiDiscrete(np.array([[len(ActionType), SPACE_SIZE, SPACE_SIZE]] * MAX_UNITS).flatten())
        self.action_space = spaces.Dict(
            {player: self.act_space_single for player in self.agents}
        )
        self.observation_space = spaces.Dict(
            {player: self.obs_space_single for player in self.agents}
        )

    @staticmethod
    def get_agent_id(name: str) -> int:
        return int(name.split("_", 1)[-1])

    def reset(self, seed=None, return_info=True, options=None):
        obs_dict, info = super().reset(seed=seed, options=options)
        self.env_cfg = info.get("params", {})
        print(f"ENV CFG: {self.env_cfg}")
        obs = {ag: flatten_obs(obs_dict[ag], self.env_cfg, self.get_agent_id(ag)) for ag in self.agents}
        return obs, info
        

    def step(self, simplified_actions_dict):
        action_dict = {}
        for ag in self.agents:
            actions = simplified_actions_dict[ag].reshape(-1, 3)
            actions[:,1:][actions[:, 0] != ActionType.sap.value] = 0

            # for i, (action_type, aoe_x, aoe_y) in enumerate(simplified_actions_dict[ag]):
            #     actions[i, 0] = action_type  # ActionType (0 to 5)
            #     actions[i, 1] = 0 if action_type != ActionType.sap else aoe_x # delta_x
            #     actions[i, 2] = 0 if action_type != ActionType.sap else aoe_y # delta_y
            action_dict[ag] = actions
        obs, rew, terminated_dict, truncated_dict, info = super().step(
            action_dict
        )
        observations = {}
        rewards = {}
        terminations = {}
        truncations = {}
        infos = {}

        for ag in self.agents:
            observations[ag] = flatten_obs(obs[ag], self.env_cfg,  0 if ag == 'player_0' else 1)
            rewards[ag] = float(rew[ag])  # ensure it's float
            terminations[ag] = bool(terminated_dict[ag])
            truncations[ag] = bool(truncated_dict[ag])
            infos[ag] = info.get(ag, {})

            # If not done, we keep them in the list
            # if not (terminations[ag] or truncations[ag]):
            #     new_active_agents.append(ag)

        return observations, rewards, terminations, truncations, infos

    def render(self):
        return self.base_env.render()

    def close(self):
        pass


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from contextlib import nullcontext


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from stable_baselines3.common.policies import MultiInputActorCriticPolicy

class Agent(nn.Module):
    def __init__(self, env, obs_space, action_space):
        super().__init__()
        self.observation_space = obs_space
        self.action_space = action_space
        # We assume env.observation_spaces["player_0"] and env.action_spaces["player_0"] exist
        self.net = MultiInputActorCriticPolicy(
            observation_space=self.observation_space,
            action_space=self.action_space,
            lr_schedule=MultiInputActorCriticPolicy._dummy_schedule
        )
        self.env = env

    def get_value(self, obs):
        features_vf = self.net.vf_features_extractor(obs).float()
        latent_vf = self.net.mlp_extractor.value_net(features_vf)
        value = self.net.value_net(latent_vf)
        return value.squeeze(-1)

    def get_action_and_value(self, obs, action=None):
        features_pi = self.net.pi_features_extractor(obs).float()
        latent_pi = self.net.mlp_extractor.policy_net(features_pi)
        dist = self.net._get_action_dist_from_latent(latent_pi)

        if action is None:
            action = dist.sample()
        log_prob = dist.log_prob(action)
        if log_prob.ndim > 1:
            # MultiDiscrete: sum log_probs across action dimensions
            log_prob = log_prob.sum(-1)
        entropy = dist.entropy()
        if entropy.ndim > 1:
            entropy = entropy.sum(-1)

        # Compute value using separate extractor and head
        features_vf = self.net.vf_features_extractor(obs).float()
        latent_vf = self.net.mlp_extractor.value_net(features_vf)
        value = self.net.value_net(latent_vf).squeeze(-1)
        return action, log_prob, entropy, value

    def set_env_cfg(self, first_step_info):
        self.env_cfg = first_step_info["params"]


    def batchify_obs(self, obs):
        """
        obs is like:
        {
            'player_0': { 'units_position': ..., 'units_energy': ..., ... },
            'player_1': { 'units_position': ..., 'units_energy': ..., ... }
        }
        Returns a dict of Tensors with shape [2, ...].
        """
        batched_obs = {}
        for key in obs['player_0']:
            # For each key, convert each player's observation to a tensor and stack
            stacked = []
            for agent in ('player_0', 'player_1'):
                # TODO: maybe flip for player 1
                stacked.append(torch.tensor(obs[agent][key], device=device))
            batched_obs[key] = torch.stack(stacked, dim=0)
        return batched_obs

    @staticmethod
    def unbatchify_actions(action_tensor):
        """
        Converts a [2, ...] torch tensor of actions into:
        { 'player_0': action_for_player_0, 'player_1': action_for_player_1 }
        For a MultiDiscrete action of shape [16], each row is the 16-dim action.
        """
        action_np = action_tensor.cpu().numpy()
        # bsize = action_tensor.shape[0]
        # print(f"dbg: bsize was {bsize}")
        return {
            "player_0": action_np[0],
            "player_1": action_np[1],
        }
    
    def step(self, obs_tensor: dict[torch.Tensor], no_grad=True):
        cm = torch.no_grad() if no_grad else nullcontext()
        with cm:
            actions, logprobs, entropy, values = self.get_action_and_value(obs_tensor)
        action_dict = self.unbatchify_actions(actions)
        return actions, logprobs, entropy, values, action_dict


In [4]:
base_env = LuxAIS3GymEnv(numpy_output=True)
wrapped_env = LuxCustomWrapper(base_env)

obs, info = wrapped_env.reset()
agent = Agent(wrapped_env, wrapped_env.obs_space_single, wrapped_env.act_space_single)
agent.set_env_cfg(info)

batched_obs = agent.batchify_obs(obs)
print(f"batched obs {batched_obs}")
actions, logprobs, entropy, values, action_dict = agent.step(batched_obs)
print(f"forward actions {action_dict}")
obs, _, _, _, _ = wrapped_env.step(action_dict)


ENV CFG: {'max_units': 16, 'match_count_per_episode': 5, 'max_steps_in_match': 100, 'map_height': 24, 'map_width': 24, 'num_teams': 2, 'unit_move_cost': 4, 'unit_sap_cost': 45, 'unit_sap_range': 3, 'unit_sensor_range': 2}
batched obs {'pid': tensor([[0],
        [1]]), 'units_position': tensor([[[[-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1]],

         [[-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1],
          [-1, -1]]],


        [[[-1, -1],
          [-1, -1],
          [-1, -1],
      

In [24]:
# modified from https://pettingzoo.farama.org/tutorials/cleanrl/implementing_PPO/
import wandb
from datetime import datetime

def train_ppo(env, total_episodes=2, rollout_num_steps=200, checkpoint_freq: int=100):
    ent_coef = 0.1
    vf_coef = 0.1
    clip_coef = 0.1
    gamma = 0.99
    batch_size = 32
    train_epochs = 4
    gae_lambda = 0.95
    max_grad_norm = 0.5
    lr = 2.5e-4
    anneal_lr = True
    seed = 2025
    run = wandb.init(
    # mode="disabled",
    entity="ay2425s2-cs3263-group-13",
    project="lux-ppo",
    config={
        "ent_coef": ent_coef,
        "vf_coef": vf_coef,
        "clip_coef": clip_coef,
        "gamma": gamma,
        "learning_rate": lr,
        "batch_size": batch_size,
        "train_epochs": train_epochs,
        "total_episodes": total_episodes,
        "max_steps_per_episode": rollout_num_steps
        },
        save_code=True,
    )
    agent = Agent(env, env.obs_space_single, env.act_space_single).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=lr, eps=1e-5)
    for episode in range(1, total_episodes + 1):
        next_obs, info = env.reset(seed=seed + episode)
        total_episodic_return = torch.zeros(2, dtype=float)
        
        rb_obs = []
        rb_actions = []
        rb_logprobs = []
        rb_rewards = []
        rb_dones = []
        rb_values = []

        end_step = 0  # track how many steps actually took place
         
        if anneal_lr:
            frac = 1.0 - (episode - 1.0) / total_episodes
            lrnow = frac * lr
            optimizer.param_groups[0]["lr"] = lrnow

        # 1. Collect experience
        for step in range(rollout_num_steps):
            obs_tensor = agent.batchify_obs(next_obs)
            
            actions, logprobs, entropy, values, action_dict = agent.step(obs_tensor, no_grad=True)
            next_obs, rewards, terms, truncs, infos = env.step(action_dict)

            rb_obs.append(obs_tensor)
            rb_actions.append(actions)
            rb_logprobs.append(logprobs)
            rb_values.append(values)
            
            # env rewards IS total reward, subtract prev
            new_total_return = torch.tensor([rewards["player_0"], rewards["player_1"]], dtype=float)
            reward = new_total_return - total_episodic_return
            total_episodic_return = new_total_return
            trunc0, trunc1 = truncs["player_0"], truncs["player_1"]
            term0, term1 = terms["player_0"], terms["player_1"]
            next_done = torch.tensor([np.logical_or(trunc0, term0), np.logical_or(trunc1, term1)])
            rb_rewards.append(torch.tensor(reward, device=device))
            rb_dones.append(next_done)
            end_step = step + 1

            if all(terms.values()) or all(truncs.values()):
                break

        # 2. Bootstrapping if not done
        with torch.no_grad():
            if not all(terms.values()):
                final_obs_tensor = agent.batchify_obs(next_obs)
                _, _, _, next_values = agent.get_action_and_value(final_obs_tensor)
            else:
                next_values = torch.zeros(2, device=device)

        # 3. Convert lists -> stacked Tensors

        num_steps = len(rb_obs)
        stacked_obs = {}
        for key in rb_obs[0].keys():
            cat_list = [step_dict[key] for step_dict in rb_obs]
            stacked_obs[key] = torch.stack(cat_list, dim=0)

        rb_actions  = torch.stack(rb_actions, dim=0)   # [num_steps, 2, (action_dim)]
        rb_logprobs = torch.stack(rb_logprobs, dim=0) # [num_steps, 2]
        rb_values   = torch.stack(rb_values, dim=0)   # [num_steps, 2]
        rb_rewards  = torch.stack(rb_rewards, dim=0)  # [num_steps, 2]
        rb_dones    = torch.stack(rb_dones, dim=0)    # [num_steps, 2]

        # 4. GAE or simple advantage
        rb_advantages = torch.zeros_like(rb_rewards)
        rb_returns = torch.zeros_like(rb_rewards)
        gae = torch.zeros(2, device=device)

        for t in reversed(range(num_steps)):
            if t == num_steps - 1:
                next_val = next_values
                done_mask = 1.0 - rb_dones[t].float()
            else:
                next_val = rb_values[t + 1]
                done_mask = 1.0 - rb_dones[t + 1].float()

            delta = rb_rewards[t] + gamma * next_val * done_mask - rb_values[t]
            gae = delta + gamma * gae_lambda * gae * done_mask
            rb_advantages[t] = gae
            rb_returns[t] = gae + rb_values[t]

        # 5. Flatten time & agent
        b_obs = {}
        for key, val in stacked_obs.items():
            b_obs[key] = val.view(num_steps * 2, *val.shape[2:])

        b_actions    = rb_actions.view(num_steps * 2, -1)
        b_logprobs   = rb_logprobs.view(num_steps * 2)
        b_values     = rb_values.view(num_steps * 2)
        b_advantages = rb_advantages.view(num_steps * 2)
        b_returns    = rb_returns.view(num_steps * 2)

        # We'll track these for logging outside the minibatch loop
        clip_fracs = []
        old_approx_kls = []
        approx_kls = []
        last_v_loss = 0.0
        last_pg_loss = 0.0

        # 6. PPO update
        total_batch = num_steps * 2
        indices = np.arange(total_batch)
        for _ in range(train_epochs):
            np.random.shuffle(indices)
            for start in range(0, total_batch, batch_size):
                end = start + batch_size
                batch_inds = indices[start:end]

                mb_obs = {k: v[batch_inds] for k, v in b_obs.items()}
                mb_actions = b_actions[batch_inds]
                mb_old_logprob = b_logprobs[batch_inds]
                mb_adv = b_advantages[batch_inds]
                mb_returns = b_returns[batch_inds]
                mb_values = b_values[batch_inds]
                 # Evaluate new logprob
                _, new_logprob, entropy, value = agent.get_action_and_value(
                    mb_obs, action=mb_actions
                )
                logratio = new_logprob - mb_old_logprob
                ratio = logratio.exp()
                with torch.no_grad():
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()

                # Record them
                old_approx_kls.append(old_approx_kl.item())
                approx_kls.append(approx_kl.item())

                # Compute clip fraction
                clip_fraction = ((ratio - 1.0).abs() > clip_coef).float().mean().item()
                clip_fracs.append(clip_fraction)

                # Normalize advantages
                mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_adv * ratio
                pg_loss2 = -mb_adv * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss (clipped)
                value = value.view(-1)
                v_loss_unclipped = (value - mb_returns) ** 2
                v_clipped = mb_values + torch.clamp(value - mb_values, -clip_coef, clip_coef)
                v_loss_clipped = (v_clipped - mb_returns) ** 2
                v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()

                # Entropy
                entropy_loss = entropy.mean()

                loss = pg_loss + vf_coef * v_loss - ent_coef * entropy_loss

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
                optimizer.step()

                # We'll keep track of the last minibatch losses for logging
                last_v_loss = v_loss.item()
                last_pg_loss = pg_loss.item()

        # 7. Explained Variance
        y_pred = b_values.detach().cpu().numpy()
        y_true = b_returns.detach().cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
        run.log({
            "episode": episode,
            "episode_length": end_step,
            "player0_return": total_episodic_return[0],
            "player1_return": total_episodic_return[1],
            "policy_loss": -last_pg_loss,
            "value_loss": last_v_loss,
            "old_approx_kl": np.mean(old_approx_kls),
            "approx_kl": np.mean(approx_kls),
            "clip_fraction": np.mean(clip_fracs),
            "explained_variance": explained_var
        })

        # 8. Logging (similar structure to your snippet)
        print(f"Training episode {episode}")
        print(f"player0 Return: {total_episodic_return[0]}")
        print(f"player1 return: {total_episodic_return[1]}")
        print(f"Episode Length: {end_step}")
        print("")
        print(f"Value Loss: {last_v_loss}")
        print(f"Policy Loss: {-last_pg_loss}")
        print(f"Old Approx KL: {np.mean(old_approx_kls)}")
        print(f"Approx KL: {np.mean(approx_kls)}")
        print(f"Clip Fraction: {np.mean(clip_fracs)}")
        print(f"Explained Variance: {explained_var}")
        print("\n-------------------------------------------\n")
        
        if episode % checkpoint_freq == 0:
            timestamp = datetime.now().strftime("%m-%d_%H-%M")
            torch.save(agent.state_dict(), f"./checkpoints/ppo/model_ep{episode}_{timestamp}.pt")
            # Upload to wandb
            wandb.save(f"checkpoints/ppo/model_ep{episode}_{timestamp}.pt", base_path="checkpoints/ppo/")
    run.finish()
    print("Training complete.")


train_ppo(wrapped_env, total_episodes=1000, rollout_num_steps=512)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


ENV CFG: {'max_units': 16, 'match_count_per_episode': 5, 'max_steps_in_match': 100, 'map_height': 24, 'map_width': 24, 'num_teams': 2, 'unit_move_cost': 4, 'unit_sap_cost': 43, 'unit_sap_range': 5, 'unit_sensor_range': 4}


  next_done = torch.tensor([np.logical_or(trunc0, term0), np.logical_or(trunc1, term1)])
  rb_rewards.append(torch.tensor(reward, device=device))


Training episode 1
player0 Return: 3.0
player1 return: 2.0
Episode Length: 505

Value Loss: 0.04757641107576102
Policy Loss: 0.05148946555455953
Old Approx KL: 0.017166604258818552
Approx KL: 0.06813003428396769
Clip Fraction: 0.7766655818559229
Explained Variance: -0.6805361572896931

-------------------------------------------

ENV CFG: {'max_units': 16, 'match_count_per_episode': 5, 'max_steps_in_match': 100, 'map_height': 24, 'map_width': 24, 'num_teams': 2, 'unit_move_cost': 2, 'unit_sap_cost': 34, 'unit_sap_range': 6, 'unit_sensor_range': 2}
Training episode 2
player0 Return: 2.0
player1 return: 3.0
Episode Length: 505

Value Loss: 0.03998905524971922
Policy Loss: 0.029482478944158095
Old Approx KL: 0.0690812116372399
Approx KL: 0.11851569298596587
Clip Fraction: 0.8646375862881541
Explained Variance: 0.03156617056660982

-------------------------------------------

ENV CFG: {'max_units': 16, 'match_count_per_episode': 5, 'max_steps_in_match': 100, 'map_height': 24, 'map_width': 

0,1
approx_kl,▅█▇▇▆▆▆▅▅▄▅▅▅▃▅▄▂▃▆▄▃▃▂▂▂▂▂▃▃▃▂▂▂▁▁▁▁▁▁▁
clip_fraction,▆████▇██▇▇█▇▆██▆▅▆▆▇▇▆▆▆▆▅▅▇▅▅▆▄▆▅▅▃▄▃▂▁
episode,▁▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇███
episode_length,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
explained_variance,▁▂▂▄▅▄▄▂▇▄▆▄▅▄▃▇▇▆▆▇▅▄▆▆▇▄▇█▄▇▆▇▇██▆▇▆▇▇
old_approx_kl,▇█▅▄▄▄▃▄▂▃▁▁▃▃▃▂▃▅▁▃▂▃▂▃▂▂▂▂▁▂▂▁▁▁▂▂▂▂▁▁
player0_return,▃▅▆▃▃▁▃▅▅▃▃▆▆█▃▅▆▃▅▅▅▅▃▅▃▅▁▃▆▅▃▅▆█▃▅▆▃▁▅
player1_return,▅▄▄▇▄▅▅▅█▂▄▇▇▂▅▄▅▂▇▁▇▅▄▄▅▇▇▂▅▄▅▅▇▇▂▇▅▅█▂
policy_loss,█▇▅▁▆█▇▄▅▇█▆▆▆▇▄▆▅▆▆▃▅▃▅▆▄▇▄▂▁▂▂▅▅▄▅▄▂▂▄
value_loss,▁▂▁▆▂▃▁▄▂▂▁▂▁▁▁▃█▁▁▅▂▄▂▂▄▂▁▄▃▁▂▂▁▄▃▃▂▂▄▂

0,1
approx_kl,0.01186
clip_fraction,0.43994
episode,1000.0
episode_length,505.0
explained_variance,0.6522
old_approx_kl,0.05701
player0_return,2.0
player1_return,3.0
policy_loss,0.00617
value_loss,0.00362


Training complete.


In [6]:

from datetime import datetime
from pathlib import Path
import random
import os
import copy
import wandb
from collections import deque

def train_with_league(env, total_episodes=1000, rollout_num_steps=512, 
                      snapshot_freq=10, num_exploiters=2, main_checkpoint: str | None = None, exploiter_checkpoints: list[str] | None = None):
    """
    Train PPO agents using a simplified AlphaStar-like league system
    
    Args:
        env: Environment to train in
        total_episodes: Total episodes to train for
        rollout_num_steps: Number of steps per rollout
        snapshot_freq: How often to take snapshots of the main agent
        num_exploiters: Number of exploiter agents to maintain
    """
    # Hyperparameters
    ent_coef = 0.1
    vf_coef = 0.1
    clip_coef = 0.1
    gamma = 0.99
    batch_size = 32
    train_epochs = 4
    gae_lambda = 0.95
    max_grad_norm = 0.5
    lr = 2.5e-4
    seed = 2025
    anneal_lr = True
    
    # Initialize wandb
    run = wandb.init(
        # mode="disabled",
        entity="ay2425s2-cs3263-group-13",
        project="lux-alphastar-league",
        config={
            "ent_coef": ent_coef,
            "vf_coef": vf_coef,
            "clip_coef": clip_coef,
            "gamma": gamma,
            "learning_rate": lr,
            "batch_size": batch_size,
            "train_epochs": train_epochs,
            "total_episodes": total_episodes,
            "max_steps_per_episode": rollout_num_steps,
            "num_exploiters": num_exploiters,
        },
        save_code=True,
    )
    
    # Create checkpoints directory
    Path("./checkpoints").mkdir(exist_ok=True)
    
    # Initialize league components
    
    # 1. Main agent - our best agent that trains against everyone
    main_agent = Agent(env, env.obs_space_single, env.act_space_single).to(device)
    if main_checkpoint:
        main_agent.load_state_dict(torch.load(main_checkpoint))
    main_optimizer = optim.Adam(main_agent.parameters(), lr=lr, eps=1e-5)
    
    # 2. Exploiter agents - specialize in exploiting the main agent
    exploiters = []
    for i in range(num_exploiters):
        exploiter_agent = Agent(env, env.obs_space_single, env.act_space_single)
        if exploiter_checkpoints:
            exploiter_agent.load_state_dict(torch.load(exploiter_checkpoints[i]))
        exploiter_agent.to(device)
        exploiter = {
            "agent": exploiter_agent,
            "optimizer": optim.Adam(exploiter_agent.parameters(), lr=lr, eps=1e-5),
            "id": i,
            "episodes_trained": 0,
            "wins_against_main": 0,
            "matches_against_main": 0
        }
        exploiters.append(exploiter)
    
    # 3. Historical snapshots of the main agent
    snapshots = []  # Will store (episode_number, model_path) tuples
    
    # Track main agent's performance against exploiters and snapshots
    main_agent_stats = {
        "wins_vs_exploiters": 0,
        "matches_vs_exploiters": 0,
        "wins_vs_snapshots": 0,
        "matches_vs_snapshots": 0
    }
    
    # Main training loop
    for episode in range(1, total_episodes + 1):
        # Determine which agent to train this episode
        if episode % (num_exploiters + 1) == 0:
            # Train main agent
            current_agent = main_agent
            current_optimizer = main_optimizer
            agent_type = "main"
            agent_id = 0
        else:
            # Train exploiter agent
            exploiter_idx = (episode % (num_exploiters + 1)) - 1
            current_agent = exploiters[exploiter_idx]["agent"]
            current_optimizer = exploiters[exploiter_idx]["optimizer"]
            exploiters[exploiter_idx]["episodes_trained"] += 1
            agent_type = "exploiter"
            agent_id = exploiter_idx
        
        # Apply learning rate annealing if enabled
        if anneal_lr:
            frac = 1.0 - (episode - 1.0) / total_episodes
            lrnow = frac * lr
            current_optimizer.param_groups[0]["lr"] = lrnow
        
        # Select opponent based on agent type
        if agent_type == "main":
            # Main agent trains against a mix of exploiters and snapshots
            if random.random() < 0.7 and exploiters:  # 70% chance to play against exploiters
                # Select exploiter, preferring those that win more often
                if random.random() < 0.7:  # 80% chance to select strongest exploiter
                    # Sort exploiters by win rate against main agent
                    sorted_exploiters = sorted(
                        exploiters, 
                        key=lambda x: x["wins_against_main"] / max(1, x["matches_against_main"]),
                        reverse=True
                    )
                    opponent = sorted_exploiters[0]["agent"]
                    opponent_type = "strongest_exploiter"
                else:
                    # Random exploiter
                    opponent = random.choice(exploiters)["agent"]
                    opponent_type = "random_exploiter"
            elif snapshots:  # Otherwise use historical snapshot if available
                # 50% chance to play against most recent snapshot, otherwise random snapshot
                if random.random() < 0.5:
                    opponent_path = snapshots[-1][1]  # Most recent snapshot
                    opponent_type = "recent_snapshot"
                else:
                    opponent_path = random.choice(snapshots)[1]  # Random snapshot
                    opponent_type = "random_snapshot"
                
                # Load snapshot weights
                opponent = Agent(env, env.obs_space_single, env.act_space_single).to(device)
                opponent.load_state_dict(torch.load(opponent_path))
            else:
                # If no snapshots yet, self-play
                opponent = current_agent
                opponent_type = "self"
        else:
            # Exploiters only train against the main agent
            opponent = main_agent
            opponent_type = "main"

        # Run a single training episode
        metrics = train_league_episode(
            env,
            current_agent,
            opponent,
            current_optimizer,
            rollout_num_steps=rollout_num_steps,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            clip_coef=clip_coef,
            gamma=gamma,
            gae_lambda=gae_lambda,
            max_grad_norm=max_grad_norm,
            train_epochs=train_epochs,
            batch_size=batch_size,
            seed=seed + episode
        )
        
        # Update win statistics
        if agent_type == "main":
            if opponent_type in ["strongest_exploiter", "random_exploiter"]:
                main_agent_stats["matches_vs_exploiters"] += 1
                if metrics["win"]:
                    main_agent_stats["wins_vs_exploiters"] += 1
            elif opponent_type in ["recent_snapshot", "random_snapshot"]:
                main_agent_stats["matches_vs_snapshots"] += 1
                if metrics["win"]:
                    main_agent_stats["wins_vs_snapshots"] += 1
        else:  # exploiter
            exploiters[agent_id]["matches_against_main"] += 1
            if metrics["win"]:
                exploiters[agent_id]["wins_against_main"] += 1
        
        # Take snapshot of main agent at regular intervals
        timestamp = datetime.now().strftime("%m-%d_%H-%M")
        if episode % snapshot_freq == 0:
            snapshot_path = f"./checkpoints/league/main_snapshot_ep{episode}_{timestamp}.pt"
            torch.save(main_agent.state_dict(), snapshot_path)
            snapshots.append((episode, snapshot_path))
            print(f"Created snapshot at episode {episode}")
            
            # Keep only the most recent 10 snapshots to manage storage
            if len(snapshots) > 10:
                oldest = snapshots.pop(0)
                # Optionally delete the file to save space
                os.remove(oldest[1])
        
        # Log metrics
        log_data = {
            "episode": episode,
            "agent_type": agent_type,
            "agent_id": agent_id,
            "opponent_type": opponent_type,
            "episode_length": metrics["episode_length"],
            "player0_return": metrics["player0_return"],
            "policy_loss": metrics["policy_loss"],
            "value_loss": metrics["value_loss"],
            "win": metrics["win"],
            "lr": current_optimizer.param_groups[0]["lr"],
        }
        
        # Add league statistics
        if main_agent_stats["matches_vs_exploiters"] > 0:
            log_data["main_win_rate_vs_exploiters"] = main_agent_stats["wins_vs_exploiters"] / main_agent_stats["matches_vs_exploiters"]
        
        if main_agent_stats["matches_vs_snapshots"] > 0:
            log_data["main_win_rate_vs_snapshots"] = main_agent_stats["wins_vs_snapshots"] / main_agent_stats["matches_vs_snapshots"]
        
        for i, exploiter in enumerate(exploiters):
            if exploiter["matches_against_main"] > 0:
                log_data[f"exploiter{i}_win_rate"] = exploiter["wins_against_main"] / exploiter["matches_against_main"]
        
        run.log(log_data)
        
        # Print summary every 10 episodes
        if episode % 10 == 0:
            print(f"\n--- Episode {episode} Summary ---")
            print(f"Agent type: {agent_type}, Agent ID: {agent_id}")
            print(f"Opponent type: {opponent_type}")
            print(f"Episode return: {metrics['player0_return']:.4f}")
            print(f"Episode length: {metrics['episode_length']}")
            print(f"Win: {metrics['win']}")
            
            if main_agent_stats["matches_vs_exploiters"] > 0:
                main_vs_exp = main_agent_stats["wins_vs_exploiters"] / main_agent_stats["matches_vs_exploiters"]
                print(f"Main win rate vs exploiters: {main_vs_exp:.4f}")
            
            if main_agent_stats["matches_vs_snapshots"] > 0:
                main_vs_snap = main_agent_stats["wins_vs_snapshots"] / main_agent_stats["matches_vs_snapshots"]
                print(f"Main win rate vs snapshots: {main_vs_snap:.4f}")
            
            for i, exploiter in enumerate(exploiters):
                if exploiter["matches_against_main"] > 0:
                    exp_win_rate = exploiter["wins_against_main"] / exploiter["matches_against_main"]
                    print(f"Exploiter {i} win rate: {exp_win_rate:.4f}")
            
            print("----------------------------\n")
    
    # Save final models
    timestamp = datetime.now().strftime("%m-%d_%H-%M")
    
    # Save main agent
    main_path = f"./checkpoints/main_final_{timestamp}.pt"
    torch.save(main_agent.state_dict(), main_path)
    wandb.save(main_path)
    
    # Save exploiters
    for i, exploiter in enumerate(exploiters):
        exploiter_path = f"./checkpoints/exploiter{i}_final_{timestamp}.pt"
        torch.save(exploiter["agent"].state_dict(), exploiter_path)
    
    run.finish()
    print("League training complete.")
    return main_agent

def train_league_episode(env: LuxCustomWrapper, agent: Agent, opponent: Agent, optimizer, rollout_num_steps=512, 
                        ent_coef=0.1, vf_coef=0.1, clip_coef=0.1,
                        gamma=0.99, gae_lambda=0.95, max_grad_norm=0.5,
                        train_epochs=4, batch_size=32, seed=2025):
    """
    Run a single training episode for the league training system
    
    Args:
        env: The environment
        agent: The agent being trained (randomly assigned to player 0 or player 1)
        opponent: The opponent agent (plays the opposite side from agent)
        optimizer: The optimizer for the agent
        Other args: Standard PPO parameters
        
    Returns:
        Dictionary of metrics from the episode
    """
    # Randomly determine which side the training agent plays
    opponent.eval()
    agent.train()
    agent_side = int(random.random() < 0.5)  # 0 or 1
    opponent_side = 1 - agent_side  # opposite side
    
    next_obs, info = env.reset(seed=seed)
    total_episodic_return = torch.zeros(2, dtype=float)

    rb_obs = []
    rb_actions = []
    rb_logprobs = []
    rb_rewards = []
    rb_dones = []
    rb_values = []

    end_step = 0

    # 1. Collect experience
    for step in range(rollout_num_steps):
        obs_tensor = agent.batchify_obs(next_obs)
        
        # Get actions for both agents
        with torch.no_grad():
            # Get actions from both models
            agent_actions, agent_logprobs, agent_entropy, agent_values = agent.get_action_and_value(obs_tensor)
            opponent_actions, _, _, _ = opponent.get_action_and_value(obs_tensor)
        
        # Combine actions based on which side each agent is playing
        actions = torch.zeros_like(agent_actions)
        actions[agent_side] = agent_actions[agent_side]  # Agent's action
        actions[opponent_side] = opponent_actions[opponent_side]  # Opponent's action
        
        # Store logprobs and values only for the training agent's side
        logprobs = torch.zeros_like(agent_logprobs)
        logprobs[agent_side] = agent_logprobs[agent_side]
        
        values = torch.zeros_like(agent_values)
        values[agent_side] = agent_values[agent_side]
        
        # Step environment
        action_dict = agent.unbatchify_actions(actions)
        next_obs, rewards, terms, truncs, infos = env.step(action_dict)

        # Store the experience
        rb_obs.append(obs_tensor)
        rb_actions.append(actions)
        rb_logprobs.append(logprobs)
        rb_values.append(values)

        new_total_return = torch.tensor([rewards["player_0"], rewards["player_1"]], dtype=float)
        reward = total_episodic_return - new_total_return
        total_episodic_return = new_total_return
        trunc0, trunc1 = truncs["player_0"], truncs["player_1"]
        term0, term1 = terms["player_0"], terms["player_1"]
        next_done = torch.tensor([
            np.logical_or(trunc0, term0),
            np.logical_or(trunc1, term1),
        ])
        rb_rewards.append(torch.tensor(reward, device=device))
        rb_dones.append(next_done)

        end_step = step + 1

        if all(terms.values()) or all(truncs.values()):
            break

    # Determine if the agent won (based on agent_side)
    won = total_episodic_return[agent_side] > total_episodic_return[opponent_side]

    # 2. Bootstrap if not done
    with torch.no_grad():
        if not all(terms.values()):
            final_obs_tensor = agent.batchify_obs(next_obs)
            _, _, _, next_values = agent.get_action_and_value(final_obs_tensor)
        else:
            next_values = torch.zeros(2, device=device)

    # 3. Convert lists -> Tensors
    num_steps = len(rb_obs)
    stacked_obs = {}
    for key in rb_obs[0].keys():
        cat_list = [step_dict[key] for step_dict in rb_obs]
        stacked_obs[key] = torch.stack(cat_list, dim=0)

    rb_actions = torch.stack(rb_actions, dim=0)
    rb_logprobs = torch.stack(rb_logprobs, dim=0)
    rb_values = torch.stack(rb_values, dim=0)
    rb_rewards = torch.stack(rb_rewards, dim=0)
    rb_dones = torch.stack(rb_dones, dim=0)

    # 4. GAE calculation - Only compute for the agent_side (the side we're training)
    rb_advantages = torch.zeros_like(rb_rewards)
    rb_returns = torch.zeros_like(rb_rewards)
    gae = torch.zeros(2, device=device)

    for t in reversed(range(num_steps)):
        if t == num_steps - 1:
            next_val = next_values
            done_mask = 1.0 - rb_dones[t].float().to(device)
        else:
            next_val = rb_values[t + 1]
            done_mask = 1.0 - rb_dones[t + 1].float().to(device)

        delta = rb_rewards[t] + gamma * next_val * done_mask - rb_values[t]
        gae = delta + gamma * gae_lambda * gae * done_mask
        rb_advantages[t] = gae
        rb_returns[t] = gae + rb_values[t]

    # 5. Flatten batch - but only use the agent's side data (the side we're training)
    b_obs = {}
    for key, val in stacked_obs.items():
        # Extract only the agent side's observations
        b_obs[key] = val[:, agent_side].unsqueeze(1)  # Keep dimension for compatibility

    # Extract only the agent side's data
    b_actions = rb_actions[:, agent_side]
    b_logprobs = rb_logprobs[:, agent_side]
    b_values = rb_values[:, agent_side]
    b_advantages = rb_advantages[:, agent_side]
    b_returns = rb_returns[:, agent_side]

    # 6. PPO update
    clip_fracs = []
    pg_losses = []
    v_losses = []
    
    total_batch = num_steps
    indices = np.arange(total_batch)
    
    for _ in range(train_epochs):
        np.random.shuffle(indices)
        for start in range(0, total_batch, batch_size):
            end = start + batch_size
            if end > total_batch:
                continue  # Skip incomplete batches
                
            batch_inds = indices[start:end]

            mb_obs = {k: v[batch_inds] for k, v in b_obs.items()}
            mb_actions = b_actions[batch_inds]
            mb_old_logprob = b_logprobs[batch_inds]
            mb_adv = b_advantages[batch_inds]
            mb_returns = b_returns[batch_inds]
            mb_values = b_values[batch_inds]
            
            # Forward pass with current parameters
            _, new_logprob, entropy, value = agent.get_action_and_value(mb_obs, action=mb_actions)
            
            # Important: viewing as -1 since we're only updating the agent's policy
            new_logprob = new_logprob.view(-1)
            value = value.view(-1)
            
            logratio = new_logprob - mb_old_logprob
            ratio = logratio.exp()
            
            # Normalize advantages
            mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)

            # Policy loss
            pg_loss1 = -mb_adv * ratio
            pg_loss2 = -mb_adv * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()
            pg_losses.append(pg_loss.item())

            # Value loss
            v_loss_unclipped = (value - mb_returns) ** 2
            v_clipped = mb_values + torch.clamp(
                value - mb_values, -clip_coef, clip_coef
            )
            v_loss_clipped = (v_clipped - mb_returns) ** 2
            v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
            v_losses.append(v_loss.item())

            # Entropy loss
            entropy_loss = entropy.mean()

            # Total loss
            loss = pg_loss + vf_coef * v_loss - ent_coef * entropy_loss

            # Gradient step
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
            optimizer.step()

            # Calculate clip fraction
            clip_fraction = ((ratio - 1.0).abs() > clip_coef).float().mean().item()
            clip_fracs.append(clip_fraction)

    # Return episode metrics
    return {
        "episode_length": end_step,
        "player0_return": total_episodic_return[agent_side].item(),
        "agent_side": agent_side,
        "agent_return": total_episodic_return[agent_side].item(),
        "opponent_return": total_episodic_return[opponent_side].item(),
        "policy_loss": -np.mean(pg_losses),
        "value_loss": np.mean(v_losses),
        "clip_fraction": np.mean(clip_fracs),
        "win": won
    }



train_with_league(
    LuxCustomWrapper(LuxAIS3GymEnv(numpy_output=True)),
    total_episodes=10000,
    rollout_num_steps=512,
    snapshot_freq=50,
    num_exploiters=2,
    main_checkpoint="/Users/jayanth/Desktop/AY2425S2/CS3263/Lux-Design-S3/kits/python/checkpoints/main_final_04-20_02-01.pt",
    exploiter_checkpoints=["/Users/jayanth/Desktop/AY2425S2/CS3263/Lux-Design-S3/kits/python/checkpoints/exploiter0_final_04-20_02-01.pt", "/Users/jayanth/Desktop/AY2425S2/CS3263/Lux-Design-S3/kits/python/checkpoints/exploiter1_final_04-20_02-01.pt"]
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjayanth-b[0m ([33may2425s2-cs3263-group-13[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


ENV CFG: {'max_units': 16, 'match_count_per_episode': 5, 'max_steps_in_match': 100, 'map_height': 24, 'map_width': 24, 'num_teams': 2, 'unit_move_cost': 4, 'unit_sap_cost': 43, 'unit_sap_range': 5, 'unit_sensor_range': 4}


  next_done = torch.tensor([
  rb_rewards.append(torch.tensor(reward, device=device))


ENV CFG: {'max_units': 16, 'match_count_per_episode': 5, 'max_steps_in_match': 100, 'map_height': 24, 'map_width': 24, 'num_teams': 2, 'unit_move_cost': 2, 'unit_sap_cost': 34, 'unit_sap_range': 6, 'unit_sensor_range': 2}
ENV CFG: {'max_units': 16, 'match_count_per_episode': 5, 'max_steps_in_match': 100, 'map_height': 24, 'map_width': 24, 'num_teams': 2, 'unit_move_cost': 1, 'unit_sap_cost': 39, 'unit_sap_range': 6, 'unit_sensor_range': 2}
ENV CFG: {'max_units': 16, 'match_count_per_episode': 5, 'max_steps_in_match': 100, 'map_height': 24, 'map_width': 24, 'num_teams': 2, 'unit_move_cost': 4, 'unit_sap_cost': 33, 'unit_sap_range': 3, 'unit_sensor_range': 2}
ENV CFG: {'max_units': 16, 'match_count_per_episode': 5, 'max_steps_in_match': 100, 'map_height': 24, 'map_width': 24, 'num_teams': 2, 'unit_move_cost': 5, 'unit_sap_cost': 46, 'unit_sap_range': 4, 'unit_sensor_range': 4}
ENV CFG: {'max_units': 16, 'match_count_per_episode': 5, 'max_steps_in_match': 100, 'map_height': 24, 'map_width

KeyboardInterrupt: 

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x329ac9c10>> (for post_run_cell), with arguments args (<ExecutionResult object at 153e66bd0, execution_count=6 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 165a171d0, raw_cell="
from datetime import datetime
from pathlib import.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/jayanth/Desktop/AY2425S2/CS3263/Lux-Design-S3/kits/python/ppo.ipynb#W5sZmlsZQ%3D%3D> result=None>,),kwargs {}:


MailboxClosedError: 