# Lux AI Deep Reinforcement Learning Environment Example
See https://github.com/glmcdona/LuxPythonEnvGym for environment project and updates.

This is a python replica of the Lux game engine to speed up training. It reformats the agent problem into making a action decision per-unit for the team.

- \+ v01_global_local_observation
- \+ v02_reward_shaping + mixed_reward + local_reward (x,y,road_level) + global_reward (worker/cart amount of Wood, Coal, Uranium)

In [56]:
# !git clone https://github.com/glmcdona/LuxPythonEnvGym.git
# gym==0.17.0 
# !pip install git+https://github.com/glmcdona/LuxPythonEnvGym.git kaggle-environments -U
# !pip install kaggle-environments -U
# !pip install torchinfo
### !pip install pytest stable_baselines3==1.2.1a2 numpy tensorboard gym==0.19.0 kaggle-environments stable_baselines3 typing_extensions
# !pip install trochinfo
# !pip install stable_baselines3==2.3.0a3 kaggle-environments

In [57]:
# import stable_baselines3 as sb3
# print(sb3.__version__) 
# 2.3.0a3

In [58]:
# -- coding: cp936 --
!CHCP   65001

Active code page: 65001


## Use GPU if available
Note: GPU provides very little speedup. I recommend using a CPU-only notebook usually.

In [59]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


# Define the RL agent logic
Edit this agent logic to implement your own observations, action space, and reward shaping.

In [60]:
#add absolute path of parent dir to path
import sys
import os
import importlib

parent_path = os.path.dirname(os.path.dirname(os.path.abspath("./")))
if parent_path in sys.path:
    print(f"'{parent_path}' already in sys.path", file=sys.stdout)

if parent_path not in sys.path:
    sys.path.append(parent_path)
    print(f"append {parent_path} to sys.path", file=sys.stdout)

# import lib.self_lux_utils as self_lux_utils_file
# importlib.reload(self_lux_utils_file)
# import lib.replay_buffer as replay_buffer_file
# importlib.reload(replay_buffer_file)
# import lib.get_obs_utils as get_obs_utils_file
# importlib.reload(get_obs_utils_file)


'f:\rl' already in sys.path


In [61]:
%%writefile agent_policy.py
import torch
from luxai2021.game.match_controller import ActionSequence
import sys
import time
from functools import partial  # pip install functools

import numpy as np
from gym import spaces
import copy
import random

from luxai2021.env.agent import Agent
from luxai2021.game.actions import *
from luxai2021.game.game_constants import GAME_CONSTANTS
from luxai2021.game.position import Position


# https://codereview.stackexchange.com/questions/28207/finding-the-closest-point-to-a-list-of-points
def get_city_remaining_days(city):
    return city.fuel / city.get_light_upkeep()
def closest_node(node, nodes):
    dist_2 = np.sum((nodes - node) ** 2, axis=1)
    return np.argmin(dist_2)
def furthest_node(node, nodes):
    dist_2 = np.sum((nodes - node) ** 2, axis=1)
    return np.argmax(dist_2)

def smart_transfer_to_nearby(game, team, unit_id, unit, target_type_restriction=None, **kwarg):
    """
    Smart-transfers from the specified unit to a nearby neighbor. Prioritizes any
    nearby carts first, then any worker. Transfers the resource type which the unit
    has most of. Picks which cart/worker based on choosing a target that is most-full
    but able to take the most amount of resources.

    Args:
        team ([type]): [description]
        unit_id ([type]): [description]

    Returns:
        Action: Returns a TransferAction object, even if the request is an invalid
                transfer. Use TransferAction.is_valid() to check validity.
    """

    # Calculate how much resources could at-most be transferred
    resource_type = None
    resource_amount = 0
    target_unit = None

    if unit != None:
        for type, amount in unit.cargo.items():
            if amount > resource_amount:
                resource_type = type
                resource_amount = amount

        # Find the best nearby unit to transfer to
        unit_cell = game.map.get_cell_by_pos(unit.pos)
        adjacent_cells = game.map.get_adjacent_cells(unit_cell)

        
        for c in adjacent_cells:
            for id, u in c.units.items():
                # Apply the unit type target restriction
                if target_type_restriction == None or u.type == target_type_restriction:
                    if u.team == team:
                        # This unit belongs to our team, set it as the winning transfer target
                        # if it's the best match.
                        if target_unit is None:
                            target_unit = u
                        else:
                            # Compare this unit to the existing target
                            if target_unit.type == u.type:
                                # Transfer to the target with the least capacity, but can accept
                                # all of our resources
                                if( u.get_cargo_space_left() >= resource_amount and 
                                    target_unit.get_cargo_space_left() >= resource_amount ):
                                    # Both units can accept all our resources. Prioritize one that is most-full.
                                    if u.get_cargo_space_left() < target_unit.get_cargo_space_left():
                                        # This new target it better, it has less space left and can take all our
                                        # resources
                                        target_unit = u
                                    
                                elif( target_unit.get_cargo_space_left() >= resource_amount ):
                                    # Don't change targets. Current one is best since it can take all
                                    # the resources, but new target can't.
                                    pass
                                    
                                elif( u.get_cargo_space_left() > target_unit.get_cargo_space_left() ):
                                    # Change targets, because neither target can accept all our resources and 
                                    # this target can take more resources.
                                    target_unit = u
                            elif u.type == Constants.UNIT_TYPES.CART:
                                # Transfer to this cart instead of the current worker target
                                target_unit = u
    
    # Build the transfer action request
    target_unit_id = None
    if target_unit is not None:
        target_unit_id = target_unit.id

        # Update the transfer amount based on the room of the target
        if target_unit.get_cargo_space_left() < resource_amount:
            resource_amount = target_unit.get_cargo_space_left()
    
    return TransferAction(team, unit_id, target_unit_id, resource_type, resource_amount)

########################################################################################################################
# This is the Agent that you need to design for the competition
########################################################################################################################
class AgentPolicy(Agent):
    def __init__(self, mode="train", model=None) -> None:
        """
        Arguments:
            mode: "train" or "inference", which controls if this agent is for training or not.
            model: The pretrained model, or if None it will operate in training mode.
        """
        super().__init__()
        self.model = model
        self.mode = mode
        
        self.stats = None
        self.stats_last_game = None

        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.actionSpaceMapUnits = [
            partial(MoveAction, direction=Constants.DIRECTIONS.CENTER),  # This is the do-nothing action
            partial(MoveAction, direction=Constants.DIRECTIONS.NORTH),
            partial(MoveAction, direction=Constants.DIRECTIONS.WEST),
            partial(MoveAction, direction=Constants.DIRECTIONS.SOUTH),
            partial(MoveAction, direction=Constants.DIRECTIONS.EAST),
            smart_transfer_to_nearby, # Transfer to nearby
            SpawnCityAction,
            #PillageAction,
        ]
        self.actionSpaceMapCities = [
            SpawnWorkerAction,
            SpawnCartAction,
            ResearchAction,
        ]

        self.action_space = spaces.Discrete(max(len(self.actionSpaceMapUnits), len(self.actionSpaceMapCities)))
        

        # Observation space: (Basic minimum for a miner agent)
        # Object:
        #   1x is worker
        #   1x is cart
        #   1x is citytile
        #
        #   5x direction_nearest_wood
        #   1x distance_nearest_wood
        #   1x amount
        #
        #   5x direction_nearest_coal
        #   1x distance_nearest_coal
        #   1x amount
        #
        #   5x direction_nearest_uranium
        #   1x distance_nearest_uranium
        #   1x amount
        #
        #   5x direction_nearest_city
        #   1x distance_nearest_city
        #   1x amount of fuel
        #
        #   28x (the same as above, but direction, distance, and amount to the furthest of each)
        #
        #   5x direction_nearest_worker
        #   1x distance_nearest_worker
        #   1x amount of cargo
        # Unit:
        #   1x cargo size
        # State:
        #   1x is night
        #   1x percent of game done
        #   2x citytile counts [cur player, opponent]
        #   2x worker counts [cur player, opponent]
        #   2x cart counts [cur player, opponent]
        #   1x research points [cur player]
        #   1x researched coal [cur player]
        #   1x researched uranium [cur player]
        # self.observation_shape = (3 + 7 * 5 * 2 + 1 + 1 + 1 + 2 + 2 + 2 + 3,)
        self.observation_shape_map = (18,32,32)
        self.observation_shape_player_status = (85+3,)


        ### YCHUANG ADDED START
        self.observation_space = spaces.Dict({
            'map': spaces.Box(low=0, high=255, shape=self.observation_shape_map, dtype=np.uint8),  # Example for a 64x64 RGB image
            'vector': spaces.Box(low=-np.inf, high=np.inf, shape=self.observation_shape_player_status, dtype=np.float32)  # Example for a vector of size 10
        })

        # self.observation_shape = self.observation_shape_map
        # self.observation_space = spaces.Box(low=0, high=1, shape=
        # self.observation_shape, dtype=np.float16)

        self.object_nodes = {}

        self.total_turns = 0
        ### YCHUANG ADDED END

    def get_agent_type(self):
        """
        Returns the type of agent. Use AGENT for inference, and LEARNING for training a model.
        """
        if self.mode == "train":
            return Constants.AGENT_TYPE.LEARNING
        else:
            return Constants.AGENT_TYPE.AGENT


    '''''
    def get_observation(self, game, unit, city_tile, team, is_new_turn):
        map_height = game.map.height  # Actual map height
        map_width = game.map.width    # Actual map width
        map_info = torch.full((12, 32, 32), -1.0, dtype=torch.float)  # Initialize with -1 for padding
        
        # Determine the start index for padding to keep original positions
        # start_y = (32 - map_height) // 2
        # start_x = (32 - map_width) // 2
        map_info[:, :map_height, :map_width] = 0
        
        # Process resources
        for i, res_list in enumerate(game.map.resources_by_type.values()):
            for res in res_list:
                # Adjust indices with padding offsets
                if i == 0:
                    map_info[0][res.pos.y][res.pos.x] = res.resource.amount / 800
                elif i == 1:
                    map_info[1][res.pos.y][res.pos.x] = res.resource.amount / 425
                elif i == 2:
                    map_info[2][res.pos.y][res.pos.x] = res.resource.amount / 350
        
        # Process roads
        for cll in game.cells_with_roads:
            map_info[11][cll.pos.y][cll.pos.x] = cll.get_road()/6

        # Process city tiles
        for ct in game.cities.values():
            ct_remain_day = get_city_remaining_days(ct)  # This function needs to be defined
            for ctile in ct.city_cells:
                # Adjust indices with padding offsets and add city tile information
                map_info[3 + ct.team][ctile.pos.y][ctile.pos.x] = 1
                map_info[9 + ct.team][ctile.pos.y][ctile.pos.x] = ct_remain_day / 200
                map_info[11][ctile.pos.y][ctile.pos.x] = ctile.get_road() / 6
        
        # Process units
        opponent = 1 if team == 0 else 0
        for uNit in game.state['teamStates'][team]["units"].values():
            if uNit.type == Constants.UNIT_TYPES.WORKER:
                map_info[5][uNit.pos.y][uNit.pos.x] = 1
            elif uNit.type == Constants.UNIT_TYPES.CART:
                map_info[6][uNit.pos.y][uNit.pos.x] = 1

        for uNit in game.state['teamStates'][opponent]["units"].values():
            if uNit.type == Constants.UNIT_TYPES.WORKER:
                map_info[7][uNit.pos.y][uNit.pos.x] = 1
            elif uNit.type == Constants.UNIT_TYPES.CART:
                map_info[8][uNit.pos.y][uNit.pos.x] = 1

        # Vector
        vector = torch.zeros(10,dtype=torch.float)
        vector[0] = 1 if not city_tile == None else 0
        if not unit == None:
            vector[1] = 1 if unit.type == Constants.UNIT_TYPES.CART else 0
            vector[2] = 1 if unit.type == Constants.UNIT_TYPES.WORKER else 0
        rem_day = []
        for ct in game.cities.values():
            if ct.team == team: rem_day.append(get_city_remaining_days(ct))
        vector[3] = -1 if rem_day == [] else min(rem_day)/200
        vector[4] = 1 if game.state["teamStates"][team]["researched"]["coal"] else 0
        vector[5] = 1 if game.state["teamStates"][team]["researched"]["uranium"] else 0
        if not unit == None:
            if unit.type == Constants.UNIT_TYPES.WORKER:
                vector[6] = unit.cooldown/2
            elif unit.type == Constants.UNIT_TYPES.CART:
                vector[6] = unit.cooldown/3
        elif not city_tile == None:
            vector[6] = city_tile.cooldown/10
        if not unit == None:
            vector[7] = unit.pos.x
            vector[8] = unit.pos.y
        if not city_tile == None:
            vector[7] = city_tile.pos.x
            vector[8] = city_tile.pos.y

        return {'map': map_info, 'vector': vector}

        '''''
    def get_observation(self, game, unit, city_tile, team, is_new_turn):
        """
        Implements getting a observation from the current game for this unit or city
        """
        observation_index = 0
        if is_new_turn:
            # It's a new turn this event. This flag is set True for only the first observation from each turn.
            # Update any per-turn fixed observation space that doesn't change per unit/city controlled.

            # Build a list of object nodes by type for quick distance-searches
            self.object_nodes = {}

            # Add resources
            for cell in game.map.resources:
                if cell.resource.type not in self.object_nodes:
                    self.object_nodes[cell.resource.type] = np.array([[cell.pos.x, cell.pos.y]])
                else:
                    self.object_nodes[cell.resource.type] = np.concatenate(
                        (
                            self.object_nodes[cell.resource.type],
                            [[cell.pos.x, cell.pos.y]]
                        ),
                        axis=0
                    )

            # Add your own and opponent units
            for t in [team, (team + 1) % 2]:
                for u in game.state["teamStates"][team]["units"].values():
                    key = str(u.type)
                    if t != team:
                        key = str(u.type) + "_opponent"

                    if key not in self.object_nodes:
                        self.object_nodes[key] = np.array([[u.pos.x, u.pos.y]])
                    else:
                        self.object_nodes[key] = np.concatenate(
                            (
                                self.object_nodes[key],
                                [[u.pos.x, u.pos.y]]
                            )
                            , axis=0
                        )

            # print(team)

            # Add your own and opponent cities
            for city in game.cities.values():
                for cells in city.city_cells:
                    key = "city"
                    if city.team != team:
                        key = "city_opponent"

                    if key not in self.object_nodes:
                        self.object_nodes[key] = np.array([[cells.pos.x, cells.pos.y]])
                    else:
                        self.object_nodes[key] = np.concatenate(
                            (
                                self.object_nodes[key],
                                [[cells.pos.x, cells.pos.y]]
                            )
                            , axis=0
                        )

        # Observation space: (Basic minimum for a miner agent)
        # Object:
        #   1x is worker
        #   1x is cart
        #   1x is citytile
        #
        #   YCHUANG ADDED START
        #   1x x coordinate / 32
        #   1x y coordinate / 32
        #   1x road level at (x,y) / 6
        #   YCHUANG ADDED END
        #
        #   5x direction_nearest_wood
        #   1x distance_nearest_wood
        #   1x amount
        #
        #   5x direction_nearest_coal
        #   1x distance_nearest_coal
        #   1x amount
        #
        #   5x direction_nearest_uranium
        #   1x distance_nearest_uranium
        #   1x amount
        #
        #   5x direction_nearest_city
        #   1x distance_nearest_city
        #   1x amount of fuel
        #
        #   5x direction_nearest_worker
        #   1x distance_nearest_worker
        #   1x amount of cargo
        #
        #   28x (the same as above, but direction, distance, and amount to the furthest of each)
        #
        # Unit:
        #   1x cargo size
        # State:
        #   1x is night
        #   1x percent of game done
        #   2x citytile counts [cur player, opponent]
        #   2x worker counts [cur player, opponent]
        #   2x cart counts [cur player, opponent]
        #   1x research points [cur player]
        #   1x researched coal [cur player]
        #   1x researched uranium [cur player]
        obs = np.zeros(self.observation_shape_player_status)
        
        # Update the type of this object
        #   1x is worker
        #   1x is cart
        #   1x is citytile
        observation_index = 0
        if unit is not None:
            if unit.type == Constants.UNIT_TYPES.WORKER:
                obs[observation_index] = 1.0 # Worker
            else:
                obs[observation_index+1] = 1.0 # Cart
        if city_tile is not None:
            obs[observation_index+2] = 1.0 # CityTile
        observation_index += 3
        
        pos = None
        if unit is not None:
            pos = unit.pos
        else:
            pos = city_tile.pos

        # YCHUANG ADDED START
        if pos is not None:
            obs[observation_index] = pos.x / 32
            obs[observation_index+1] = pos.y / 32
            observation_index += 2
            obs[observation_index] = game.map.get_cell_by_pos(pos).get_road() / 6
            observation_index += 1
        # YCHUANG ADDED END

        if pos is None:
            observation_index += 7 * 5 * 2
        else:
            # Encode the direction to the nearest objects
            #   5x direction_nearest
            #   1x distance
            for distance_function in [closest_node, furthest_node]:
                for key in [
                    Constants.RESOURCE_TYPES.WOOD,
                    Constants.RESOURCE_TYPES.COAL,
                    Constants.RESOURCE_TYPES.URANIUM,
                    "city",
                    str(Constants.UNIT_TYPES.WORKER)]:
                    # Process the direction to and distance to this object type

                    # Encode the direction to the nearest object (excluding itself)
                    #   5x direction
                    #   1x distance
                    if key in self.object_nodes:
                        if (
                                (key == "city" and city_tile is not None) or
                                (unit is not None and str(unit.type) == key and len(game.map.get_cell_by_pos(unit.pos).units) <= 1 )
                        ):
                            # Filter out the current unit from the closest-search
                            closest_index = closest_node((pos.x, pos.y), self.object_nodes[key])
                            filtered_nodes = np.delete(self.object_nodes[key], closest_index, axis=0)
                        else:
                            filtered_nodes = self.object_nodes[key]

                        if len(filtered_nodes) == 0:
                            # No other object of this type
                            obs[observation_index + 5] = 1.0
                        else:
                            # There is another object of this type
                            closest_index = distance_function((pos.x, pos.y), filtered_nodes)

                            if closest_index is not None and closest_index >= 0:
                                closest = filtered_nodes[closest_index]
                                closest_position = Position(closest[0], closest[1])
                                direction = pos.direction_to(closest_position)
                                mapping = {
                                    Constants.DIRECTIONS.CENTER: 0,
                                    Constants.DIRECTIONS.NORTH: 1,
                                    Constants.DIRECTIONS.WEST: 2,
                                    Constants.DIRECTIONS.SOUTH: 3,
                                    Constants.DIRECTIONS.EAST: 4,
                                }
                                obs[observation_index + mapping[direction]] = 1.0  # One-hot encoding direction

                                # 0 to 1 distance
                                distance = pos.distance_to(closest_position)
                                obs[observation_index + 5] = min(distance / 20.0, 1.0)

                                # 0 to 1 value (amount of resource, cargo for unit, or fuel for city)
                                if key == "city":
                                    # City fuel as % of upkeep for 200 turns
                                    c = game.cities[game.map.get_cell_by_pos(closest_position).city_tile.city_id]
                                    obs[observation_index + 6] = min(
                                        c.fuel / (c.get_light_upkeep() * 200.0),
                                        1.0
                                    )
                                elif key in [Constants.RESOURCE_TYPES.WOOD, Constants.RESOURCE_TYPES.COAL,
                                             Constants.RESOURCE_TYPES.URANIUM]:
                                    # Resource amount
                                    obs[observation_index + 6] = min(
                                        game.map.get_cell_by_pos(closest_position).resource.amount / 500,
                                        1.0
                                    )
                                else:
                                    # Unit cargo
                                    obs[observation_index + 6] = min(
                                        next(iter(game.map.get_cell_by_pos(
                                            closest_position).units.values())).get_cargo_space_left() / 100,
                                        1.0
                                    )

                    observation_index += 7

        if unit is not None:
            # Encode the cargo space
            #   1x cargo size
            obs[observation_index] = unit.get_cargo_space_left() / GAME_CONSTANTS["PARAMETERS"]["RESOURCE_CAPACITY"][
                "WORKER"]
            observation_index += 1
        else:
            observation_index += 1

        # Game state observations

        #   1x is night
        obs[observation_index] = game.is_night()
        observation_index += 1

        #   1x percent of game done
        obs[observation_index] = game.state["turn"] / GAME_CONSTANTS["PARAMETERS"]["MAX_DAYS"]
        observation_index += 1

        #   2x citytile counts [cur player, opponent]
        #   2x worker counts [cur player, opponent]
        #   2x cart counts [cur player, opponent]
        max_count = 30
        for key in ["city", str(Constants.UNIT_TYPES.WORKER), str(Constants.UNIT_TYPES.CART)]:
            if key in self.object_nodes:
                obs[observation_index] = len(self.object_nodes[key]) / max_count
            if (key + "_opponent") in self.object_nodes:
                obs[observation_index + 1] = len(self.object_nodes[(key + "_opponent")]) / max_count
            observation_index += 2

        #   1x research points [cur player]
        #   1x researched coal [cur player]
        #   1x researched uranium [cur player]
        obs[observation_index] = game.state["teamStates"][team]["researchPoints"] / 200.0
        obs[observation_index+1] = float(game.state["teamStates"][team]["researched"]["coal"])
        obs[observation_index+2] = float(game.state["teamStates"][team]["researched"]["uranium"])


        map_height = game.map.height  # Actual map height
        map_width = game.map.width    # Actual map width
        map_info = torch.full((18, 32, 32), -1.0, dtype=torch.float)  # Initialize with -1 for padding
        
        # Determine the start index for padding to keep original positions
        # start_y = (32 - map_height) // 2
        # start_x = (32 - map_width) // 2
        map_info[:, :map_height, :map_width] = 0
        
        # Process resources
        for i, res_list in enumerate(game.map.resources_by_type.values()):
            for res in res_list:
                # Adjust indices with padding offsets
                if i == 0:
                    map_info[0][res.pos.y][res.pos.x] = res.resource.amount / 800
                elif i == 1:
                    map_info[1][res.pos.y][res.pos.x] = res.resource.amount / 425
                elif i == 2:
                    map_info[2][res.pos.y][res.pos.x] = res.resource.amount / 350
        
        # Process city tiles
        for ct in game.cities.values():
            ct_remain_day = get_city_remaining_days(ct)  # This function needs to be defined
            for ctile in ct.city_cells:
                team_offset = 0 if ct.team == team else 1
                # Adjust indices with padding offsets and add city tile information
                map_info[3 + team_offset][ctile.pos.y][ctile.pos.x] = 1
                map_info[5 + team_offset][ctile.pos.y][ctile.pos.x] = ct_remain_day / 200
            # TODO: road of all the map
            map_info[7][ctile.pos.y][ctile.pos.x] = ctile.get_road() / 6

        # Process roads
        for cll in game.cells_with_roads:
            map_info[7][cll.pos.y][cll.pos.x] = cll.get_road()/6

        # Process units
        # opponent = 1 if team == 0 else 0
        opponent = (team + 1) % 2
        for uNit in game.state['teamStates'][team]["units"].values():
            if uNit.type == Constants.UNIT_TYPES.WORKER:
                map_info[8][uNit.pos.y][uNit.pos.x] = 1
                if uNit.cargo != None:
                    map_info[9][uNit.pos.y][uNit.pos.x] = uNit.cargo["wood"] / 100
                    map_info[10][uNit.pos.y][uNit.pos.x] = uNit.cargo["coal"] / 100
                    map_info[11][uNit.pos.y][uNit.pos.x] = uNit.cargo["uranium"] / 100
                
            elif uNit.type == Constants.UNIT_TYPES.CART:
                map_info[12][uNit.pos.y][uNit.pos.x] = 1
                if uNit.cargo != None:
                    map_info[13][uNit.pos.y][uNit.pos.x] = uNit.cargo["wood"] / 2000
                    map_info[14][uNit.pos.y][uNit.pos.x] = uNit.cargo["coal"] / 2000
                    map_info[15][uNit.pos.y][uNit.pos.x] = uNit.cargo["uranium"] / 2000

        for uNit in game.state['teamStates'][opponent]["units"].values():
            if uNit.type == Constants.UNIT_TYPES.WORKER:
                map_info[16][uNit.pos.y][uNit.pos.x] = 1
            elif uNit.type == Constants.UNIT_TYPES.CART:
                map_info[17][uNit.pos.y][uNit.pos.x] = 1
        
        return {'map': map_info, 'vector': obs}

    
    def action_code_to_action(self, action_code, game, unit=None, city_tile=None, team=None):
        """
        Takes an action in the environment according to actionCode:
            actionCode: Index of action to take into the action array.
        Returns: An action.
        """
        # Map actionCode index into to a constructed Action object
        try:
            x = None
            y = None
            if city_tile is not None:
                x = city_tile.pos.x
                y = city_tile.pos.y
            elif unit is not None:
                x = unit.pos.x
                y = unit.pos.y
            
            if city_tile != None:
                action =  self.actionSpaceMapCities[action_code%len(self.actionSpaceMapCities)](
                    game=game,
                    unit_id=unit.id if unit else None,
                    unit=unit,
                    city_id=city_tile.city_id if city_tile else None,
                    citytile=city_tile,
                    team=team,
                    x=x,
                    y=y
                )

                # If the city action is invalid, default to research action automatically
                if not action.is_valid(game, actions_validated=[]):
                    action = ResearchAction(
                        game=game,
                        unit_id=unit.id if unit else None,
                        unit=unit,
                        city_id=city_tile.city_id if city_tile else None,
                        citytile=city_tile,
                        team=team,
                        x=x,
                        y=y
                    )
            else:
                action =  self.actionSpaceMapUnits[action_code%len(self.actionSpaceMapUnits)](
                    game=game,
                    unit_id=unit.id if unit else None,
                    unit=unit,
                    city_id=city_tile.city_id if city_tile else None,
                    citytile=city_tile,
                    team=team,
                    x=x,
                    y=y
                )
            
            return action
        except Exception as e:
            # Not a valid action
            print(e)
            return None

    def take_action(self, action_code, game, unit=None, city_tile=None, team=None):
        """
        Takes an action in the environment according to actionCode:
            actionCode: Index of action to take into the action array.
        """
        action = self.action_code_to_action(action_code, game, unit, city_tile, team)
        self.match_controller.take_action(action)
    
    def game_start(self, game):
        """
        This funciton is called at the start of each game. Use this to
        reset and initialize per game. Note that self.team may have
        been changed since last game. The game map has been created
        and starting units placed.

        Args:
            game ([type]): Game.
        """
        self.last_generated_fuel = game.stats["teamStats"][self.team]["fuelGenerated"]
        self.last_resources_collected = copy.deepcopy(game.stats["teamStats"][self.team]["resourcesCollected"])
        if self.stats != None:
            self.stats_last_game =  self.stats
        self.stats = {
            "rew/result": 0,
            "rew/r_total": 0,
            "rew/r_wood": 0,
            "rew/r_coal": 0,
            "rew/r_uranium": 0,
            "rew/r_research": 0,
            "rew/r_city_tiles_end": 0,
            "rew/r_fuel_collected":0,
            "rew/r_units":0,
            "rew/r_city_tiles":0,
            "game/turns": 0,
            "game/research": 0,
            "game/unit_count": 0,
            "game/cart_count": 0,
            "game/city_count": 0,
            "game/city_tiles": 0,
            "game/wood_rate_mined": 0,
            "game/coal_rate_mined": 0,
            "game/uranium_rate_mined": 0,
        }
        self.is_last_turn = False

        # Calculate starting map resources
        type_map = {
            Constants.RESOURCE_TYPES.WOOD: "WOOD",
            Constants.RESOURCE_TYPES.COAL: "COAL",
            Constants.RESOURCE_TYPES.URANIUM: "URANIUM",
        }

        self.fuel_collected_last = 0
        self.fuel_start = {}
        self.fuel_last = {}
        for type, type_upper in type_map.items():
            self.fuel_start[type] = 0
            self.fuel_last[type] = 0
            for c in game.map.resources_by_type[type]:
                self.fuel_start[type] += c.resource.amount * game.configs["parameters"]["RESOURCE_TO_FUEL_RATE"][type_upper]

        self.research_last = 0
        self.units_last = 0
        self.city_tiles_last = 0

        # YCHUANG ADDED START
        self.episode_start = True
        self.total_turns += 1
        # YCHUANG ADDED END


    def get_reward(self, game, is_game_finished, is_new_turn, is_game_error):
        """
        Returns the reward function for this step of the game.
        """
        if is_game_error:
            # Game environment step failed, assign a game lost reward to not incentivise this
            print("Game failed due to error")
            return -1.0

        if not is_new_turn and not is_game_finished:
            # Only apply rewards at the start of each turn
            return 0

        # Get some basic stats
        unit_count = len(game.state["teamStates"][self.team % 2]["units"])
        cart_count = 0
        for id, u in game.state["teamStates"][self.team % 2]["units"].items():
            if u.type == Constants.UNIT_TYPES.CART:
                cart_count += 1

        unit_count_opponent = len(game.state["teamStates"][(self.team + 1) % 2]["units"])
        research = min(game.state["teamStates"][self.team]["researchPoints"], 200.0) # Cap research points at 200
        city_count = 0
        city_count_opponent = 0
        city_tile_count = 0
        city_tile_count_opponent = 0
        for city in game.cities.values():
            if city.team == self.team:
                city_count += 1
            else:
                city_count_opponent += 1

            for cell in city.city_cells:
                if city.team == self.team:
                    city_tile_count += 1
                else:
                    city_tile_count_opponent += 1
        
        # Basic stats
        self.stats["game/research"] = research
        self.stats["game/city_tiles"] = city_tile_count
        self.stats["game/city_count"] = city_count
        self.stats["game/unit_count"] = unit_count
        self.stats["game/cart_count"] = cart_count
        self.stats["game/turns"] = game.state["turn"]

        rewards = {}

        # Give up to 1.0 reward for each resource based on % of total mined.
        type_map = {
            Constants.RESOURCE_TYPES.WOOD: "WOOD",
            Constants.RESOURCE_TYPES.COAL: "COAL",
            Constants.RESOURCE_TYPES.URANIUM: "URANIUM",
        }
        fuel_now = {}
        for type, type_upper in type_map.items():
            fuel_now = game.stats["teamStats"][self.team]["resourcesCollected"][type] * game.configs["parameters"]["RESOURCE_TO_FUEL_RATE"][type_upper]
            rewards["rew/r_%s" % type] = (fuel_now - self.fuel_last[type]) / self.fuel_start[type]
            self.stats["game/%s_rate_mined" % type] = fuel_now / self.fuel_start[type]
            self.fuel_last[type] = fuel_now
        
        # Give more incentive for coal and uranium
        rewards["rew/r_%s" % Constants.RESOURCE_TYPES.COAL] *= 2
        rewards["rew/r_%s" % Constants.RESOURCE_TYPES.URANIUM] *= 4
        
        # Give a reward based on amount of fuel collected. 1.0 reward for each 20K fuel gathered.
        fuel_collected = game.stats["teamStats"][self.team]["fuelGenerated"]
        rewards["rew/r_fuel_collected"] = ( (fuel_collected - self.fuel_collected_last) / 20000 )
        self.fuel_collected_last = fuel_collected

        # Give a reward for unit creation/death. 0.05 reward per unit.
        rewards["rew/r_units"] = (unit_count - self.units_last) * 0.05
        self.units_last = unit_count

        # Give a reward for unit creation/death. 0.1 reward per city.
        rewards["rew/r_city_tiles"] = (city_tile_count - self.city_tiles_last) * 0.1
        self.city_tiles_last = city_tile_count

        # Tiny reward for research to help. Up to 0.5 reward for this.
        rewards["rew/r_research"] = (research - self.research_last) / (200 * 2)
        self.research_last = research
        
        # Give a reward up to around 50.0 based on number of city tiles at the end of the game
        rewards["rew/r_city_tiles_end"] = 0
        if is_game_finished:
            self.is_last_turn = True
            rewards["rew/r_city_tiles_end"] = city_tile_count
        

        if is_game_finished:
            if self.team > 1:
                raise Exception('Team should not exceed 2')
            if game.get_winning_team() == self.team:
                self.stats["rew/result"] = 1.
            else:
                self.stats["rew/result"] = -1.
                
        reward = 0
        for name, value in rewards.items():
            self.stats[name] += value
            reward += value
        self.stats["rew/r_total"] += reward
        
        # self.stats["rew/real_total"] = 0
        # self.stats["rew/real_total"] = rewards["rew/r_city_tiles"]*2
        # if is_game_finished:
        #     self.stats["rew/real_total"] += rewards["rew/r_city_tiles_end"]/2.5
        #     self.stats["rew/real_total"] += self.stats["rew/result"]

        # Print the final game stats sometimes
        if is_game_finished and random.random() <= 0.15:
            stats_string = []
            for key, value in self.stats.items():
                stats_string.append("%s=%.2f" % (key, value))
            print(",".join(stats_string))

        
        return                                      self.stats["rew/result"]
        # return self.stats["rew/real_total"]

    # def get_reward(self, game, is_game_finished, is_new_turn, is_game_error):
    #     """
    #     Returns the reward function for this step of the game.
    #     """
    #     if is_game_error:
    #         # Game environment step failed, assign a game lost reward to not incentivise this
    #         print("Game failed due to error")
    #         return -1.0

    #     if not is_new_turn and not is_game_finished:
    #         # Only apply rewards at the start of each turn
    #         return 0

    #     # Get some basic stats
    #     unit_count = len(game.state["teamStates"][self.team % 2]["units"])
    #     cart_count = 0
    #     for id, u in game.state["teamStates"][self.team % 2]["units"].items():
    #         if u.type == Constants.UNIT_TYPES.CART:
    #             cart_count += 1

    #     unit_count_opponent = len(game.state["teamStates"][(self.team + 1) % 2]["units"])
    #     research = min(game.state["teamStates"][self.team]["researchPoints"], 200.0) # Cap research points at 200
    #     city_count = 0
    #     city_count_opponent = 0
    #     city_tile_count = 0
    #     city_tile_count_opponent = 0
    #     for city in game.cities.values():
    #         if city.team == self.team:
    #             city_count += 1
    #         else:
    #             city_count_opponent += 1

    #         for cell in city.city_cells:
    #             if city.team == self.team:
    #                 city_tile_count += 1
    #             else:
    #                 city_tile_count_opponent += 1
        
    #     # Basic stats
    #     self.stats["game/research"] = research
    #     self.stats["game/city_tiles"] = city_tile_count
    #     self.stats["game/city_count"] = city_count
    #     self.stats["game/unit_count"] = unit_count
    #     self.stats["game/cart_count"] = cart_count
    #     self.stats["game/turns"] = game.state["turn"]


    #     rewards = {}

    #     # Give up to 1.0 reward for each resource based on % of total mined.
    #     type_map = {
    #         Constants.RESOURCE_TYPES.WOOD: "WOOD",
    #         Constants.RESOURCE_TYPES.COAL: "COAL",
    #         Constants.RESOURCE_TYPES.URANIUM: "URANIUM",
    #     }
    #     fuel_now = {}
    #     for type, type_upper in type_map.items():
    #         fuel_now = game.stats["teamStats"][self.team]["resourcesCollected"][type] * game.configs["parameters"]["RESOURCE_TO_FUEL_RATE"][type_upper]
    #         rewards["rew/r_%s" % type] = (fuel_now - self.fuel_last[type]) / self.fuel_start[type]
    #         self.stats["game/%s_rate_mined" % type] = fuel_now / self.fuel_start[type]
    #         self.fuel_last[type] = fuel_now
        
    #     # Give more incentive for coal and uranium
    #     rewards["rew/r_%s" % Constants.RESOURCE_TYPES.COAL] *= 2
    #     rewards["rew/r_%s" % Constants.RESOURCE_TYPES.URANIUM] *= 4
        
    #     # Give a reward based on amount of fuel collected. 1.0 reward for each 20K fuel gathered.
    #     fuel_collected = game.stats["teamStats"][self.team]["fuelGenerated"]
    #     rewards["rew/r_fuel_collected"] = ( (fuel_collected - self.fuel_collected_last) / 20000 )
    #     self.fuel_collected_last = fuel_collected

    #     # Give a reward for unit creation/death. 0.05 reward per unit.
    #     rewards["rew/r_units"] = (unit_count - self.units_last) * 0.05
    #     self.units_last = unit_count

    #     # Give a reward for unit creation/death. 0.1 reward per city.
    #     city_tile_count_prev = self.city_tiles_last
    #     rewards["rew/r_city_tiles"] = (city_tile_count - self.city_tiles_last) * 0.1
    #     self.city_tiles_last = city_tile_count

    #     # Tiny reward for research to help. Up to 0.5 reward for this.
    #     research_prev = self.research_last
    #     rewards["rew/r_research"] = (research - self.research_last) / (200 * 2)
    #     self.research_last = research
        
    #     # Give a reward up to around 50.0 based on number of city tiles at the end of the game
    #     rewards["rew/r_city_tiles_end"] = 0
    #     if is_game_finished:
    #         self.is_last_turn = True
    #         rewards["rew/r_city_tiles_end"] = city_tile_count
       
        
    #     # Update the stats and total reward
    #     reward_virtual = 0
    #     for name, value in rewards.items():
    #         self.stats[name] += value
    #         reward_virtual += value
    #     self.stats["rew/r_total"] += reward_virtual


    #     # Print the final game stats sometimes
    #     if is_game_finished and random.random() <= 0.15:
    #         stats_string = []
    #         for key, value in self.stats.items():
    #             stats_string.append("%s=%.2f" % (key, value))
    #         print(",".join(stats_string))

    #     # reward for agent
    #     reward = 0
    #     if self.episode_start == True:
    #         self.episode_start = False
    #         reward = 0
    #     else:
    #         delta_city_tile_count = city_tile_count - city_tile_count_prev
    #         if(delta_city_tile_count > 0):
    #             reward += 0.1
    #         else:
    #             reward = 0

    #         delta_research = research - research_prev
    #         if(delta_research > 0):
    #             reward += 0.1
    #         else:
    #             reward = 0

    #         if(is_game_finished):
    #             if(self.is_last_turn):
    #                 reward += 0.1
    #             else:
    #                 reward = 0

    #             info_string = ""
    #             info_string += f"game fisnished: {is_game_finished}, is_last_turn: {self.is_last_turn} | "
    #             if(city_tile_count > city_tile_count_opponent):
    #                 reward += 1
    #                 info_string += f"city_tile_count > city_tile_count_opponent | "
    #             elif(city_tile_count == city_tile_count_opponent):
    #                 reward += 0.5
    #                 info_string += f"city_tile_count == city_tile_count_opponent | "
    #             else:
    #                 reward -= 1
    #                 info_string += f"city_tile_count < city_tile_count_opponent | "
    #             info_string += f"reward: {reward} | "
    #             info_string += f"city_tile_count: {city_tile_count}, city_tile_count_opponent: {city_tile_count_opponent}"
                
    #             if random.random() <= 0.005:
    #                 print(info_string)

    #     # if self.total_turns >= 20:
    #     reward_mixed_up = np.exp(-self.total_turns/10000)/np.exp(0) * reward_virtual + reward
    #     if is_game_finished and random.random() <= 0.001:
    #                 print(f"unit_count = {unit_count}, cart_count = {cart_count}, city_count = {city_count}, city_tile_count = {city_tile_count}, reward_mixed_up = {reward_mixed_up}")
    #     # return reward
    #     return reward_mixed_up


    # def get_reward(self, game, is_game_finished, is_new_turn, is_game_error):
    #     """
    #     Returns the reward function for this step of the game.
    #     """
    #     if is_game_error:
    #         # Game environment step failed, assign a game lost reward to not incentivise this
    #         print("Game failed due to error")
    #         return -1.0

    #     if not is_new_turn and not is_game_finished:
    #         # Only apply rewards at the start of each turn
    #         return 0

    #     # Get some basic stats
    #     unit_count = len(game.state["teamStates"][self.team % 2]["units"])
    #     cart_count = 0
    #     for id, u in game.state["teamStates"][self.team % 2]["units"].items():
    #         if u.type == Constants.UNIT_TYPES.CART:
    #             cart_count += 1

    #     unit_count_opponent = len(game.state["teamStates"][(self.team + 1) % 2]["units"])
    #     research = min(game.state["teamStates"][self.team]["researchPoints"], 200.0) # Cap research points at 200
    #     city_count = 0
    #     city_count_opponent = 0
    #     city_tile_count = 0
    #     city_tile_count_opponent = 0
    #     for city in game.cities.values():
    #         if city.team == self.team:
    #             city_count += 1
    #         else:
    #             city_count_opponent += 1

    #         for cell in city.city_cells:
    #             if city.team == self.team:
    #                 city_tile_count += 1
    #             else:
    #                 city_tile_count_opponent += 1
        
    #     # Basic stats
    #     self.stats["game/research"] = research
    #     self.stats["game/city_tiles"] = city_tile_count
    #     self.stats["game/city_count"] = city_count
    #     self.stats["game/unit_count"] = unit_count
    #     self.stats["game/cart_count"] = cart_count
    #     self.stats["game/turns"] = game.state["turn"]

    #     rewards = {}

    #     # Give up to 1.0 reward for each resource based on % of total mined.
    #     type_map = {
    #         Constants.RESOURCE_TYPES.WOOD: "WOOD",
    #         Constants.RESOURCE_TYPES.COAL: "COAL",
    #         Constants.RESOURCE_TYPES.URANIUM: "URANIUM",
    #     }
    #     fuel_now = {}
    #     for type, type_upper in type_map.items():
    #         fuel_now = game.stats["teamStats"][self.team]["resourcesCollected"][type] * game.configs["parameters"]["RESOURCE_TO_FUEL_RATE"][type_upper]
    #         rewards["rew/r_%s" % type] = (fuel_now - self.fuel_last[type]) / self.fuel_start[type]
    #         self.stats["game/%s_rate_mined" % type] = fuel_now / self.fuel_start[type]
    #         self.fuel_last[type] = fuel_now
        
    #     # Give more incentive for coal and uranium
    #     rewards["rew/r_%s" % Constants.RESOURCE_TYPES.COAL] *= 2
    #     rewards["rew/r_%s" % Constants.RESOURCE_TYPES.URANIUM] *= 4
        
    #     # Give a reward based on amount of fuel collected. 1.0 reward for each 20K fuel gathered.
    #     fuel_collected = game.stats["teamStats"][self.team]["fuelGenerated"]
    #     rewards["rew/r_fuel_collected"] = ( (fuel_collected - self.fuel_collected_last) / 20000 )
    #     self.fuel_collected_last = fuel_collected

    #     # Give a reward for unit creation/death. 0.05 reward per unit.
    #     rewards["rew/r_units"] = (unit_count - self.units_last) * 0.05
    #     self.units_last = unit_count

    #     # Give a reward for unit creation/death. 0.1 reward per city.
    #     rewards["rew/r_city_tiles"] = (city_tile_count - self.city_tiles_last) * 0.1
    #     self.city_tiles_last = city_tile_count

    #     # Tiny reward for research to help. Up to 0.5 reward for this.
    #     rewards["rew/r_research"] = (research - self.research_last) / (200 * 2)
    #     self.research_last = research
        
    #     # Give a reward up to around 50.0 based on number of city tiles at the end of the game
    #     rewards["rew/r_city_tiles_end"] = 0
    #     if is_game_finished:
    #         self.is_last_turn = True
    #         rewards["rew/r_city_tiles_end"] = city_tile_count
        
        
    #     # Update the stats and total reward
    #     reward = 0
    #     for name, value in rewards.items():
    #         self.stats[name] += value
    #         reward += value
    #     self.stats["rew/r_total"] += reward

    #     # Print the final game stats sometimes
    #     if is_game_finished and random.random() <= 0.15:
    #         stats_string = []
    #         for key, value in self.stats.items():
    #             stats_string.append("%s=%.2f" % (key, value))
    #         print(",".join(stats_string))


    #     return reward
        
    

    def process_turn(self, game, team):
        """
        Decides on a set of actions for the current turn. Not used in training, only inference.
        Returns: Array of actions to perform.
        """
        start_time = time.time()
        actions = []
        new_turn = True

        # Inference the model per-unit
        units = game.state["teamStates"][team]["units"].values()
        for unit in units:
            if unit.can_act():
                obs = self.get_observation(game, unit, None, unit.team, new_turn)
                action_code, _states = self.model.predict(obs, deterministic=False)
                if action_code is not None:
                    actions.append(
                        self.action_code_to_action(action_code, game=game, unit=unit, city_tile=None, team=unit.team))
                new_turn = False

        # Inference the model per-city
        cities = game.cities.values()
        for city in cities:
            if city.team == team:
                for cell in city.city_cells:
                    city_tile = cell.city_tile
                    if city_tile.can_act():
                        obs = self.get_observation(game, None, city_tile, city.team, new_turn)
                        action_code, _states = self.model.predict(obs, deterministic=False)
                        if action_code is not None:
                            actions.append(
                                self.action_code_to_action(action_code, game=game, unit=None, city_tile=city_tile,
                                                           team=city.team))
                        new_turn = False

        time_taken = time.time() - start_time
        if time_taken > 0.5:  # Warn if larger than 0.5 seconds.
            print("WARNING: Inference took %.3f seconds for computing actions. Limit is 1 second." % time_taken,
                  file=sys.stderr)

        return actions



Overwriting agent_policy.py


In [62]:
%%writefile agent_policy2.py
import torch
from luxai2021.game.match_controller import ActionSequence
import sys
import time
from functools import partial  # pip install functools

import numpy as np
from gym import spaces
import copy
import random

from luxai2021.env.agent import Agent
from luxai2021.game.actions import *
from luxai2021.game.game_constants import GAME_CONSTANTS
from luxai2021.game.position import Position


# https://codereview.stackexchange.com/questions/28207/finding-the-closest-point-to-a-list-of-points
def get_city_remaining_days(city):
    return city.fuel / city.get_light_upkeep()
def closest_node(node, nodes):
    dist_2 = np.sum((nodes - node) ** 2, axis=1)
    return np.argmin(dist_2)
def furthest_node(node, nodes):
    dist_2 = np.sum((nodes - node) ** 2, axis=1)
    return np.argmax(dist_2)

def smart_transfer_to_nearby(game, team, unit_id, unit, target_type_restriction=None, **kwarg):
    """
    Smart-transfers from the specified unit to a nearby neighbor. Prioritizes any
    nearby carts first, then any worker. Transfers the resource type which the unit
    has most of. Picks which cart/worker based on choosing a target that is most-full
    but able to take the most amount of resources.

    Args:
        team ([type]): [description]
        unit_id ([type]): [description]

    Returns:
        Action: Returns a TransferAction object, even if the request is an invalid
                transfer. Use TransferAction.is_valid() to check validity.
    """

    # Calculate how much resources could at-most be transferred
    resource_type = None
    resource_amount = 0
    target_unit = None

    if unit != None:
        for type, amount in unit.cargo.items():
            if amount > resource_amount:
                resource_type = type
                resource_amount = amount

        # Find the best nearby unit to transfer to
        unit_cell = game.map.get_cell_by_pos(unit.pos)
        adjacent_cells = game.map.get_adjacent_cells(unit_cell)

        
        for c in adjacent_cells:
            for id, u in c.units.items():
                # Apply the unit type target restriction
                if target_type_restriction == None or u.type == target_type_restriction:
                    if u.team == team:
                        # This unit belongs to our team, set it as the winning transfer target
                        # if it's the best match.
                        if target_unit is None:
                            target_unit = u
                        else:
                            # Compare this unit to the existing target
                            if target_unit.type == u.type:
                                # Transfer to the target with the least capacity, but can accept
                                # all of our resources
                                if( u.get_cargo_space_left() >= resource_amount and 
                                    target_unit.get_cargo_space_left() >= resource_amount ):
                                    # Both units can accept all our resources. Prioritize one that is most-full.
                                    if u.get_cargo_space_left() < target_unit.get_cargo_space_left():
                                        # This new target it better, it has less space left and can take all our
                                        # resources
                                        target_unit = u
                                    
                                elif( target_unit.get_cargo_space_left() >= resource_amount ):
                                    # Don't change targets. Current one is best since it can take all
                                    # the resources, but new target can't.
                                    pass
                                    
                                elif( u.get_cargo_space_left() > target_unit.get_cargo_space_left() ):
                                    # Change targets, because neither target can accept all our resources and 
                                    # this target can take more resources.
                                    target_unit = u
                            elif u.type == Constants.UNIT_TYPES.CART:
                                # Transfer to this cart instead of the current worker target
                                target_unit = u
    
    # Build the transfer action request
    target_unit_id = None
    if target_unit is not None:
        target_unit_id = target_unit.id

        # Update the transfer amount based on the room of the target
        if target_unit.get_cargo_space_left() < resource_amount:
            resource_amount = target_unit.get_cargo_space_left()
    
    return TransferAction(team, unit_id, target_unit_id, resource_type, resource_amount)

########################################################################################################################
# This is the Agent that you need to design for the competition
########################################################################################################################
class AgentPolicy2(Agent):
    def __init__(self, mode="train", model=None, md_list = None) -> None:
        """
        Arguments:
            mode: "train" or "inference", which controls if this agent is for training or not.
            model: The pretrained model, or if None it will operate in training mode.
        """
        super().__init__()
        self.model = model
        self.mode = mode
        self.md_list = md_list
        
        self.stats = None
        self.stats_last_game = None

        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.actionSpaceMapUnits = [
            partial(MoveAction, direction=Constants.DIRECTIONS.CENTER),  # This is the do-nothing action
            partial(MoveAction, direction=Constants.DIRECTIONS.NORTH),
            partial(MoveAction, direction=Constants.DIRECTIONS.WEST),
            partial(MoveAction, direction=Constants.DIRECTIONS.SOUTH),
            partial(MoveAction, direction=Constants.DIRECTIONS.EAST),
            smart_transfer_to_nearby, # Transfer to nearby
            SpawnCityAction,
            #PillageAction,
        ]
        self.actionSpaceMapCities = [
            SpawnWorkerAction,
            SpawnCartAction,
            ResearchAction,
        ]

        self.action_space = spaces.Discrete(max(len(self.actionSpaceMapUnits), len(self.actionSpaceMapCities)))
        

        # Observation space: (Basic minimum for a miner agent)
        # Object:
        #   1x is worker
        #   1x is cart
        #   1x is citytile
        #
        #   5x direction_nearest_wood
        #   1x distance_nearest_wood
        #   1x amount
        #
        #   5x direction_nearest_coal
        #   1x distance_nearest_coal
        #   1x amount
        #
        #   5x direction_nearest_uranium
        #   1x distance_nearest_uranium
        #   1x amount
        #
        #   5x direction_nearest_city
        #   1x distance_nearest_city
        #   1x amount of fuel
        #
        #   28x (the same as above, but direction, distance, and amount to the furthest of each)
        #
        #   5x direction_nearest_worker
        #   1x distance_nearest_worker
        #   1x amount of cargo
        # Unit:
        #   1x cargo size
        # State:
        #   1x is night
        #   1x percent of game done
        #   2x citytile counts [cur player, opponent]
        #   2x worker counts [cur player, opponent]
        #   2x cart counts [cur player, opponent]
        #   1x research points [cur player]
        #   1x researched coal [cur player]
        #   1x researched uranium [cur player]
        # self.observation_shape = (3 + 7 * 5 * 2 + 1 + 1 + 1 + 2 + 2 + 2 + 3,)
        self.observation_shape_map = (18,32,32)
        self.observation_shape_player_status = (85+3,)


        ### YCHUANG ADDED START
        self.observation_space = spaces.Dict({
            'map': spaces.Box(low=0, high=255, shape=self.observation_shape_map, dtype=np.uint8),  # Example for a 64x64 RGB image
            'vector': spaces.Box(low=-np.inf, high=np.inf, shape=self.observation_shape_player_status, dtype=np.float32)  # Example for a vector of size 10
        })

        # self.observation_shape = self.observation_shape_map
        # self.observation_space = spaces.Box(low=0, high=1, shape=
        # self.observation_shape, dtype=np.float16)

        self.object_nodes = {}

        self.total_turns = 0
        ### YCHUANG ADDED END

    def get_agent_type(self):
        """
        Returns the type of agent. Use AGENT for inference, and LEARNING for training a model.
        """
        if self.mode == "train":
            return Constants.AGENT_TYPE.LEARNING
        else:
            return Constants.AGENT_TYPE.AGENT


    '''''
    def get_observation(self, game, unit, city_tile, team, is_new_turn):
        map_height = game.map.height  # Actual map height
        map_width = game.map.width    # Actual map width
        map_info = torch.full((12, 32, 32), -1.0, dtype=torch.float)  # Initialize with -1 for padding
        
        # Determine the start index for padding to keep original positions
        # start_y = (32 - map_height) // 2
        # start_x = (32 - map_width) // 2
        map_info[:, :map_height, :map_width] = 0
        
        # Process resources
        for i, res_list in enumerate(game.map.resources_by_type.values()):
            for res in res_list:
                # Adjust indices with padding offsets
                if i == 0:
                    map_info[0][res.pos.y][res.pos.x] = res.resource.amount / 800
                elif i == 1:
                    map_info[1][res.pos.y][res.pos.x] = res.resource.amount / 425
                elif i == 2:
                    map_info[2][res.pos.y][res.pos.x] = res.resource.amount / 350
        
        # Process roads
        for cll in game.cells_with_roads:
            map_info[11][cll.pos.y][cll.pos.x] = cll.get_road()/6

        # Process city tiles
        for ct in game.cities.values():
            ct_remain_day = get_city_remaining_days(ct)  # This function needs to be defined
            for ctile in ct.city_cells:
                # Adjust indices with padding offsets and add city tile information
                map_info[3 + ct.team][ctile.pos.y][ctile.pos.x] = 1
                map_info[9 + ct.team][ctile.pos.y][ctile.pos.x] = ct_remain_day / 200
                map_info[11][ctile.pos.y][ctile.pos.x] = ctile.get_road() / 6
        
        # Process units
        opponent = 1 if team == 0 else 0
        for uNit in game.state['teamStates'][team]["units"].values():
            if uNit.type == Constants.UNIT_TYPES.WORKER:
                map_info[5][uNit.pos.y][uNit.pos.x] = 1
            elif uNit.type == Constants.UNIT_TYPES.CART:
                map_info[6][uNit.pos.y][uNit.pos.x] = 1

        for uNit in game.state['teamStates'][opponent]["units"].values():
            if uNit.type == Constants.UNIT_TYPES.WORKER:
                map_info[7][uNit.pos.y][uNit.pos.x] = 1
            elif uNit.type == Constants.UNIT_TYPES.CART:
                map_info[8][uNit.pos.y][uNit.pos.x] = 1

        # Vector
        vector = torch.zeros(10,dtype=torch.float)
        vector[0] = 1 if not city_tile == None else 0
        if not unit == None:
            vector[1] = 1 if unit.type == Constants.UNIT_TYPES.CART else 0
            vector[2] = 1 if unit.type == Constants.UNIT_TYPES.WORKER else 0
        rem_day = []
        for ct in game.cities.values():
            if ct.team == team: rem_day.append(get_city_remaining_days(ct))
        vector[3] = -1 if rem_day == [] else min(rem_day)/200
        vector[4] = 1 if game.state["teamStates"][team]["researched"]["coal"] else 0
        vector[5] = 1 if game.state["teamStates"][team]["researched"]["uranium"] else 0
        if not unit == None:
            if unit.type == Constants.UNIT_TYPES.WORKER:
                vector[6] = unit.cooldown/2
            elif unit.type == Constants.UNIT_TYPES.CART:
                vector[6] = unit.cooldown/3
        elif not city_tile == None:
            vector[6] = city_tile.cooldown/10
        if not unit == None:
            vector[7] = unit.pos.x
            vector[8] = unit.pos.y
        if not city_tile == None:
            vector[7] = city_tile.pos.x
            vector[8] = city_tile.pos.y

        return {'map': map_info, 'vector': vector}

        '''''
    def get_observation(self, game, unit, city_tile, team, is_new_turn):
        """
        Implements getting a observation from the current game for this unit or city
        """
        observation_index = 0
        if is_new_turn:
            # It's a new turn this event. This flag is set True for only the first observation from each turn.
            # Update any per-turn fixed observation space that doesn't change per unit/city controlled.

            # Build a list of object nodes by type for quick distance-searches
            self.object_nodes = {}

            # Add resources
            for cell in game.map.resources:
                if cell.resource.type not in self.object_nodes:
                    self.object_nodes[cell.resource.type] = np.array([[cell.pos.x, cell.pos.y]])
                else:
                    self.object_nodes[cell.resource.type] = np.concatenate(
                        (
                            self.object_nodes[cell.resource.type],
                            [[cell.pos.x, cell.pos.y]]
                        ),
                        axis=0
                    )

            # Add your own and opponent units
            for t in [team, (team + 1) % 2]:
                for u in game.state["teamStates"][team]["units"].values():
                    key = str(u.type)
                    if t != team:
                        key = str(u.type) + "_opponent"

                    if key not in self.object_nodes:
                        self.object_nodes[key] = np.array([[u.pos.x, u.pos.y]])
                    else:
                        self.object_nodes[key] = np.concatenate(
                            (
                                self.object_nodes[key],
                                [[u.pos.x, u.pos.y]]
                            )
                            , axis=0
                        )

            # print(team)

            # Add your own and opponent cities
            for city in game.cities.values():
                for cells in city.city_cells:
                    key = "city"
                    if city.team != team:
                        key = "city_opponent"

                    if key not in self.object_nodes:
                        self.object_nodes[key] = np.array([[cells.pos.x, cells.pos.y]])
                    else:
                        self.object_nodes[key] = np.concatenate(
                            (
                                self.object_nodes[key],
                                [[cells.pos.x, cells.pos.y]]
                            )
                            , axis=0
                        )

        # Observation space: (Basic minimum for a miner agent)
        # Object:
        #   1x is worker
        #   1x is cart
        #   1x is citytile
        #
        #   YCHUANG ADDED START
        #   1x x coordinate / 32
        #   1x y coordinate / 32
        #   1x road level at (x,y) / 6
        #   YCHUANG ADDED END
        #
        #   5x direction_nearest_wood
        #   1x distance_nearest_wood
        #   1x amount
        #
        #   5x direction_nearest_coal
        #   1x distance_nearest_coal
        #   1x amount
        #
        #   5x direction_nearest_uranium
        #   1x distance_nearest_uranium
        #   1x amount
        #
        #   5x direction_nearest_city
        #   1x distance_nearest_city
        #   1x amount of fuel
        #
        #   5x direction_nearest_worker
        #   1x distance_nearest_worker
        #   1x amount of cargo
        #
        #   28x (the same as above, but direction, distance, and amount to the furthest of each)
        #
        # Unit:
        #   1x cargo size
        # State:
        #   1x is night
        #   1x percent of game done
        #   2x citytile counts [cur player, opponent]
        #   2x worker counts [cur player, opponent]
        #   2x cart counts [cur player, opponent]
        #   1x research points [cur player]
        #   1x researched coal [cur player]
        #   1x researched uranium [cur player]
        obs = np.zeros(self.observation_shape_player_status)
        
        # Update the type of this object
        #   1x is worker
        #   1x is cart
        #   1x is citytile
        observation_index = 0
        if unit is not None:
            if unit.type == Constants.UNIT_TYPES.WORKER:
                obs[observation_index] = 1.0 # Worker
            else:
                obs[observation_index+1] = 1.0 # Cart
        if city_tile is not None:
            obs[observation_index+2] = 1.0 # CityTile
        observation_index += 3
        
        pos = None
        if unit is not None:
            pos = unit.pos
        else:
            pos = city_tile.pos

        # YCHUANG ADDED START
        if pos is not None:
            obs[observation_index] = pos.x / 32
            obs[observation_index+1] = pos.y / 32
            observation_index += 2
            obs[observation_index] = game.map.get_cell_by_pos(pos).get_road() / 6
            observation_index += 1
        # YCHUANG ADDED END

        if pos is None:
            observation_index += 7 * 5 * 2
        else:
            # Encode the direction to the nearest objects
            #   5x direction_nearest
            #   1x distance
            for distance_function in [closest_node, furthest_node]:
                for key in [
                    Constants.RESOURCE_TYPES.WOOD,
                    Constants.RESOURCE_TYPES.COAL,
                    Constants.RESOURCE_TYPES.URANIUM,
                    "city",
                    str(Constants.UNIT_TYPES.WORKER)]:
                    # Process the direction to and distance to this object type

                    # Encode the direction to the nearest object (excluding itself)
                    #   5x direction
                    #   1x distance
                    if key in self.object_nodes:
                        if (
                                (key == "city" and city_tile is not None) or
                                (unit is not None and str(unit.type) == key and len(game.map.get_cell_by_pos(unit.pos).units) <= 1 )
                        ):
                            # Filter out the current unit from the closest-search
                            closest_index = closest_node((pos.x, pos.y), self.object_nodes[key])
                            filtered_nodes = np.delete(self.object_nodes[key], closest_index, axis=0)
                        else:
                            filtered_nodes = self.object_nodes[key]

                        if len(filtered_nodes) == 0:
                            # No other object of this type
                            obs[observation_index + 5] = 1.0
                        else:
                            # There is another object of this type
                            closest_index = distance_function((pos.x, pos.y), filtered_nodes)

                            if closest_index is not None and closest_index >= 0:
                                closest = filtered_nodes[closest_index]
                                closest_position = Position(closest[0], closest[1])
                                direction = pos.direction_to(closest_position)
                                mapping = {
                                    Constants.DIRECTIONS.CENTER: 0,
                                    Constants.DIRECTIONS.NORTH: 1,
                                    Constants.DIRECTIONS.WEST: 2,
                                    Constants.DIRECTIONS.SOUTH: 3,
                                    Constants.DIRECTIONS.EAST: 4,
                                }
                                obs[observation_index + mapping[direction]] = 1.0  # One-hot encoding direction

                                # 0 to 1 distance
                                distance = pos.distance_to(closest_position)
                                obs[observation_index + 5] = min(distance / 20.0, 1.0)

                                # 0 to 1 value (amount of resource, cargo for unit, or fuel for city)
                                if key == "city":
                                    # City fuel as % of upkeep for 200 turns
                                    c = game.cities[game.map.get_cell_by_pos(closest_position).city_tile.city_id]
                                    obs[observation_index + 6] = min(
                                        c.fuel / (c.get_light_upkeep() * 200.0),
                                        1.0
                                    )
                                elif key in [Constants.RESOURCE_TYPES.WOOD, Constants.RESOURCE_TYPES.COAL,
                                             Constants.RESOURCE_TYPES.URANIUM]:
                                    # Resource amount
                                    obs[observation_index + 6] = min(
                                        game.map.get_cell_by_pos(closest_position).resource.amount / 500,
                                        1.0
                                    )
                                else:
                                    # Unit cargo
                                    obs[observation_index + 6] = min(
                                        next(iter(game.map.get_cell_by_pos(
                                            closest_position).units.values())).get_cargo_space_left() / 100,
                                        1.0
                                    )

                    observation_index += 7

        if unit is not None:
            # Encode the cargo space
            #   1x cargo size
            obs[observation_index] = unit.get_cargo_space_left() / GAME_CONSTANTS["PARAMETERS"]["RESOURCE_CAPACITY"][
                "WORKER"]
            observation_index += 1
        else:
            observation_index += 1

        # Game state observations

        #   1x is night
        obs[observation_index] = game.is_night()
        observation_index += 1

        #   1x percent of game done
        obs[observation_index] = game.state["turn"] / GAME_CONSTANTS["PARAMETERS"]["MAX_DAYS"]
        observation_index += 1

        #   2x citytile counts [cur player, opponent]
        #   2x worker counts [cur player, opponent]
        #   2x cart counts [cur player, opponent]
        max_count = 30
        for key in ["city", str(Constants.UNIT_TYPES.WORKER), str(Constants.UNIT_TYPES.CART)]:
            if key in self.object_nodes:
                obs[observation_index] = len(self.object_nodes[key]) / max_count
            if (key + "_opponent") in self.object_nodes:
                obs[observation_index + 1] = len(self.object_nodes[(key + "_opponent")]) / max_count
            observation_index += 2

        #   1x research points [cur player]
        #   1x researched coal [cur player]
        #   1x researched uranium [cur player]
        obs[observation_index] = game.state["teamStates"][team]["researchPoints"] / 200.0
        obs[observation_index+1] = float(game.state["teamStates"][team]["researched"]["coal"])
        obs[observation_index+2] = float(game.state["teamStates"][team]["researched"]["uranium"])


        map_height = game.map.height  # Actual map height
        map_width = game.map.width    # Actual map width
        map_info = torch.full((18, 32, 32), -1.0, dtype=torch.float)  # Initialize with -1 for padding
        
        # Determine the start index for padding to keep original positions
        # start_y = (32 - map_height) // 2
        # start_x = (32 - map_width) // 2
        map_info[:, :map_height, :map_width] = 0
        
        # Process resources
        for i, res_list in enumerate(game.map.resources_by_type.values()):
            for res in res_list:
                # Adjust indices with padding offsets
                if i == 0:
                    map_info[0][res.pos.y][res.pos.x] = res.resource.amount / 800
                elif i == 1:
                    map_info[1][res.pos.y][res.pos.x] = res.resource.amount / 425
                elif i == 2:
                    map_info[2][res.pos.y][res.pos.x] = res.resource.amount / 350
        
        # Process city tiles
        for ct in game.cities.values():
            ct_remain_day = get_city_remaining_days(ct)  # This function needs to be defined
            for ctile in ct.city_cells:
                team_offset = 0 if ct.team == team else 1
                # Adjust indices with padding offsets and add city tile information
                map_info[3 + team_offset][ctile.pos.y][ctile.pos.x] = 1
                map_info[5 + team_offset][ctile.pos.y][ctile.pos.x] = ct_remain_day / 200
            # TODO: road of all the map
            map_info[7][ctile.pos.y][ctile.pos.x] = ctile.get_road() / 6

        # Process roads
        for cll in game.cells_with_roads:
            map_info[7][cll.pos.y][cll.pos.x] = cll.get_road()/6

        # Process units
        # opponent = 1 if team == 0 else 0
        opponent = (team + 1) % 2
        for uNit in game.state['teamStates'][team]["units"].values():
            if uNit.type == Constants.UNIT_TYPES.WORKER:
                map_info[8][uNit.pos.y][uNit.pos.x] = 1
                if uNit.cargo != None:
                    map_info[9][uNit.pos.y][uNit.pos.x] = uNit.cargo["wood"] / 100
                    map_info[10][uNit.pos.y][uNit.pos.x] = uNit.cargo["coal"] / 100
                    map_info[11][uNit.pos.y][uNit.pos.x] = uNit.cargo["uranium"] / 100
                
            elif uNit.type == Constants.UNIT_TYPES.CART:
                map_info[12][uNit.pos.y][uNit.pos.x] = 1
                if uNit.cargo != None:
                    map_info[13][uNit.pos.y][uNit.pos.x] = uNit.cargo["wood"] / 2000
                    map_info[14][uNit.pos.y][uNit.pos.x] = uNit.cargo["coal"] / 2000
                    map_info[15][uNit.pos.y][uNit.pos.x] = uNit.cargo["uranium"] / 2000

        for uNit in game.state['teamStates'][opponent]["units"].values():
            if uNit.type == Constants.UNIT_TYPES.WORKER:
                map_info[16][uNit.pos.y][uNit.pos.x] = 1
            elif uNit.type == Constants.UNIT_TYPES.CART:
                map_info[17][uNit.pos.y][uNit.pos.x] = 1
        
        return {'map': map_info, 'vector': obs}

    
    def action_code_to_action(self, action_code, game, unit=None, city_tile=None, team=None):
        """
        Takes an action in the environment according to actionCode:
            actionCode: Index of action to take into the action array.
        Returns: An action.
        """
        # Map actionCode index into to a constructed Action object
        try:
            x = None
            y = None
            if city_tile is not None:
                x = city_tile.pos.x
                y = city_tile.pos.y
            elif unit is not None:
                x = unit.pos.x
                y = unit.pos.y
            
            if city_tile != None:
                action =  self.actionSpaceMapCities[action_code%len(self.actionSpaceMapCities)](
                    game=game,
                    unit_id=unit.id if unit else None,
                    unit=unit,
                    city_id=city_tile.city_id if city_tile else None,
                    citytile=city_tile,
                    team=team,
                    x=x,
                    y=y
                )

                # If the city action is invalid, default to research action automatically
                if not action.is_valid(game, actions_validated=[]):
                    action = ResearchAction(
                        game=game,
                        unit_id=unit.id if unit else None,
                        unit=unit,
                        city_id=city_tile.city_id if city_tile else None,
                        citytile=city_tile,
                        team=team,
                        x=x,
                        y=y
                    )
            else:
                action =  self.actionSpaceMapUnits[action_code%len(self.actionSpaceMapUnits)](
                    game=game,
                    unit_id=unit.id if unit else None,
                    unit=unit,
                    city_id=city_tile.city_id if city_tile else None,
                    citytile=city_tile,
                    team=team,
                    x=x,
                    y=y
                )
            
            return action
        except Exception as e:
            # Not a valid action
            print(e)
            return None

    def take_action(self, action_code, game, unit=None, city_tile=None, team=None):
        """
        Takes an action in the environment according to actionCode:
            actionCode: Index of action to take into the action array.
        """
        action = self.action_code_to_action(action_code, game, unit, city_tile, team)
        self.match_controller.take_action(action)
    
    def game_start(self, game):
        """
        This funciton is called at the start of each game. Use this to
        reset and initialize per game. Note that self.team may have
        been changed since last game. The game map has been created
        and starting units placed.

        Args:
            game ([type]): Game.
        """
        if self.md_list != None:
            self.model = self.md_list[random.randint(0,len(self.md_list)-1)]
        self.last_generated_fuel = game.stats["teamStats"][self.team]["fuelGenerated"]
        self.last_resources_collected = copy.deepcopy(game.stats["teamStats"][self.team]["resourcesCollected"])
        if self.stats != None:
            self.stats_last_game =  self.stats
        self.stats = {
            "rew/result": 0,
            "rew/r_total": 0,
            "rew/r_wood": 0,
            "rew/r_coal": 0,
            "rew/r_uranium": 0,
            "rew/r_research": 0,
            "rew/r_city_tiles_end": 0,
            "rew/r_fuel_collected":0,
            "rew/r_units":0,
            "rew/r_city_tiles":0,
            "game/turns": 0,
            "game/research": 0,
            "game/unit_count": 0,
            "game/cart_count": 0,
            "game/city_count": 0,
            "game/city_tiles": 0,
            "game/wood_rate_mined": 0,
            "game/coal_rate_mined": 0,
            "game/uranium_rate_mined": 0,
        }
        self.is_last_turn = False

        # Calculate starting map resources
        type_map = {
            Constants.RESOURCE_TYPES.WOOD: "WOOD",
            Constants.RESOURCE_TYPES.COAL: "COAL",
            Constants.RESOURCE_TYPES.URANIUM: "URANIUM",
        }

        self.fuel_collected_last = 0
        self.fuel_start = {}
        self.fuel_last = {}
        for type, type_upper in type_map.items():
            self.fuel_start[type] = 0
            self.fuel_last[type] = 0
            for c in game.map.resources_by_type[type]:
                self.fuel_start[type] += c.resource.amount * game.configs["parameters"]["RESOURCE_TO_FUEL_RATE"][type_upper]

        self.research_last = 0
        self.units_last = 0
        self.city_tiles_last = 0

        # YCHUANG ADDED START
        self.episode_start = True
        self.total_turns += 1
        # YCHUANG ADDED END


    def get_reward(self, game, is_game_finished, is_new_turn, is_game_error):
        """
        Returns the reward function for this step of the game.
        """
        if is_game_error:
            # Game environment step failed, assign a game lost reward to not incentivise this
            print("Game failed due to error")
            return -1.0

        if not is_new_turn and not is_game_finished:
            # Only apply rewards at the start of each turn
            return 0

        # Get some basic stats
        unit_count = len(game.state["teamStates"][self.team % 2]["units"])
        cart_count = 0
        for id, u in game.state["teamStates"][self.team % 2]["units"].items():
            if u.type == Constants.UNIT_TYPES.CART:
                cart_count += 1

        unit_count_opponent = len(game.state["teamStates"][(self.team + 1) % 2]["units"])
        research = min(game.state["teamStates"][self.team]["researchPoints"], 200.0) # Cap research points at 200
        city_count = 0
        city_count_opponent = 0
        city_tile_count = 0
        city_tile_count_opponent = 0
        for city in game.cities.values():
            if city.team == self.team:
                city_count += 1
            else:
                city_count_opponent += 1

            for cell in city.city_cells:
                if city.team == self.team:
                    city_tile_count += 1
                else:
                    city_tile_count_opponent += 1
        
        # Basic stats
        self.stats["game/research"] = research
        self.stats["game/city_tiles"] = city_tile_count
        self.stats["game/city_count"] = city_count
        self.stats["game/unit_count"] = unit_count
        self.stats["game/cart_count"] = cart_count
        self.stats["game/turns"] = game.state["turn"]

        rewards = {}

        # Give up to 1.0 reward for each resource based on % of total mined.
        type_map = {
            Constants.RESOURCE_TYPES.WOOD: "WOOD",
            Constants.RESOURCE_TYPES.COAL: "COAL",
            Constants.RESOURCE_TYPES.URANIUM: "URANIUM",
        }
        fuel_now = {}
        for type, type_upper in type_map.items():
            fuel_now = game.stats["teamStats"][self.team]["resourcesCollected"][type] * game.configs["parameters"]["RESOURCE_TO_FUEL_RATE"][type_upper]
            rewards["rew/r_%s" % type] = (fuel_now - self.fuel_last[type]) / self.fuel_start[type]
            self.stats["game/%s_rate_mined" % type] = fuel_now / self.fuel_start[type]
            self.fuel_last[type] = fuel_now
        
        # Give more incentive for coal and uranium
        rewards["rew/r_%s" % Constants.RESOURCE_TYPES.COAL] *= 2
        rewards["rew/r_%s" % Constants.RESOURCE_TYPES.URANIUM] *= 4
        
        # Give a reward based on amount of fuel collected. 1.0 reward for each 20K fuel gathered.
        fuel_collected = game.stats["teamStats"][self.team]["fuelGenerated"]
        rewards["rew/r_fuel_collected"] = ( (fuel_collected - self.fuel_collected_last) / 20000 )
        self.fuel_collected_last = fuel_collected

        # Give a reward for unit creation/death. 0.05 reward per unit.
        rewards["rew/r_units"] = (unit_count - self.units_last) * 0.05
        self.units_last = unit_count

        # Give a reward for unit creation/death. 0.1 reward per city.
        rewards["rew/r_city_tiles"] = (city_tile_count - self.city_tiles_last) * 0.1
        self.city_tiles_last = city_tile_count

        # Tiny reward for research to help. Up to 0.5 reward for this.
        rewards["rew/r_research"] = (research - self.research_last) / (200 * 2)
        self.research_last = research
        
        # Give a reward up to around 50.0 based on number of city tiles at the end of the game
        rewards["rew/r_city_tiles_end"] = 0
        if is_game_finished:
            self.is_last_turn = True
            rewards["rew/r_city_tiles_end"] = city_tile_count
        

        
        if is_game_finished:
            if self.team > 1:
                raise Exception('Team should not exceed 2')
            if game.get_winning_team() == self.team:
                self.stats["rew/result"] = 1.
            else:
                self.stats["rew/result"] = -1.
                
        # Update the stats and total reward
        reward = 0
        for name, value in rewards.items():
            self.stats[name] += value
            reward += value
        self.stats["rew/r_total"] += reward

        # Print the final game stats sometimes
        if is_game_finished and random.random() <= 0.15:
            stats_string = []
            for key, value in self.stats.items():
                stats_string.append("%s=%.2f" % (key, value))
            print(",".join(stats_string))


        return self.stats["rew/result"]

    # def get_reward(self, game, is_game_finished, is_new_turn, is_game_error):
    #     """
    #     Returns the reward function for this step of the game.
    #     """
    #     if is_game_error:
    #         # Game environment step failed, assign a game lost reward to not incentivise this
    #         print("Game failed due to error")
    #         return -1.0

    #     if not is_new_turn and not is_game_finished:
    #         # Only apply rewards at the start of each turn
    #         return 0

    #     # Get some basic stats
    #     unit_count = len(game.state["teamStates"][self.team % 2]["units"])
    #     cart_count = 0
    #     for id, u in game.state["teamStates"][self.team % 2]["units"].items():
    #         if u.type == Constants.UNIT_TYPES.CART:
    #             cart_count += 1

    #     unit_count_opponent = len(game.state["teamStates"][(self.team + 1) % 2]["units"])
    #     research = min(game.state["teamStates"][self.team]["researchPoints"], 200.0) # Cap research points at 200
    #     city_count = 0
    #     city_count_opponent = 0
    #     city_tile_count = 0
    #     city_tile_count_opponent = 0
    #     for city in game.cities.values():
    #         if city.team == self.team:
    #             city_count += 1
    #         else:
    #             city_count_opponent += 1

    #         for cell in city.city_cells:
    #             if city.team == self.team:
    #                 city_tile_count += 1
    #             else:
    #                 city_tile_count_opponent += 1
        
    #     # Basic stats
    #     self.stats["game/research"] = research
    #     self.stats["game/city_tiles"] = city_tile_count
    #     self.stats["game/city_count"] = city_count
    #     self.stats["game/unit_count"] = unit_count
    #     self.stats["game/cart_count"] = cart_count
    #     self.stats["game/turns"] = game.state["turn"]


    #     rewards = {}

    #     # Give up to 1.0 reward for each resource based on % of total mined.
    #     type_map = {
    #         Constants.RESOURCE_TYPES.WOOD: "WOOD",
    #         Constants.RESOURCE_TYPES.COAL: "COAL",
    #         Constants.RESOURCE_TYPES.URANIUM: "URANIUM",
    #     }
    #     fuel_now = {}
    #     for type, type_upper in type_map.items():
    #         fuel_now = game.stats["teamStats"][self.team]["resourcesCollected"][type] * game.configs["parameters"]["RESOURCE_TO_FUEL_RATE"][type_upper]
    #         rewards["rew/r_%s" % type] = (fuel_now - self.fuel_last[type]) / self.fuel_start[type]
    #         self.stats["game/%s_rate_mined" % type] = fuel_now / self.fuel_start[type]
    #         self.fuel_last[type] = fuel_now
        
    #     # Give more incentive for coal and uranium
    #     rewards["rew/r_%s" % Constants.RESOURCE_TYPES.COAL] *= 2
    #     rewards["rew/r_%s" % Constants.RESOURCE_TYPES.URANIUM] *= 4
        
    #     # Give a reward based on amount of fuel collected. 1.0 reward for each 20K fuel gathered.
    #     fuel_collected = game.stats["teamStats"][self.team]["fuelGenerated"]
    #     rewards["rew/r_fuel_collected"] = ( (fuel_collected - self.fuel_collected_last) / 20000 )
    #     self.fuel_collected_last = fuel_collected

    #     # Give a reward for unit creation/death. 0.05 reward per unit.
    #     rewards["rew/r_units"] = (unit_count - self.units_last) * 0.05
    #     self.units_last = unit_count

    #     # Give a reward for unit creation/death. 0.1 reward per city.
    #     city_tile_count_prev = self.city_tiles_last
    #     rewards["rew/r_city_tiles"] = (city_tile_count - self.city_tiles_last) * 0.1
    #     self.city_tiles_last = city_tile_count

    #     # Tiny reward for research to help. Up to 0.5 reward for this.
    #     research_prev = self.research_last
    #     rewards["rew/r_research"] = (research - self.research_last) / (200 * 2)
    #     self.research_last = research
        
    #     # Give a reward up to around 50.0 based on number of city tiles at the end of the game
    #     rewards["rew/r_city_tiles_end"] = 0
    #     if is_game_finished:
    #         self.is_last_turn = True
    #         rewards["rew/r_city_tiles_end"] = city_tile_count
       
        
    #     # Update the stats and total reward
    #     reward_virtual = 0
    #     for name, value in rewards.items():
    #         self.stats[name] += value
    #         reward_virtual += value
    #     self.stats["rew/r_total"] += reward_virtual


    #     # Print the final game stats sometimes
    #     if is_game_finished and random.random() <= 0.15:
    #         stats_string = []
    #         for key, value in self.stats.items():
    #             stats_string.append("%s=%.2f" % (key, value))
    #         print(",".join(stats_string))

    #     # reward for agent
    #     reward = 0
    #     if self.episode_start == True:
    #         self.episode_start = False
    #         reward = 0
    #     else:
    #         delta_city_tile_count = city_tile_count - city_tile_count_prev
    #         if(delta_city_tile_count > 0):
    #             reward += 0.1
    #         else:
    #             reward = 0

    #         delta_research = research - research_prev
    #         if(delta_research > 0):
    #             reward += 0.1
    #         else:
    #             reward = 0

    #         if(is_game_finished):
    #             if(self.is_last_turn):
    #                 reward += 0.1
    #             else:
    #                 reward = 0

    #             info_string = ""
    #             info_string += f"game fisnished: {is_game_finished}, is_last_turn: {self.is_last_turn} | "
    #             if(city_tile_count > city_tile_count_opponent):
    #                 reward += 1
    #                 info_string += f"city_tile_count > city_tile_count_opponent | "
    #             elif(city_tile_count == city_tile_count_opponent):
    #                 reward += 0.5
    #                 info_string += f"city_tile_count == city_tile_count_opponent | "
    #             else:
    #                 reward -= 1
    #                 info_string += f"city_tile_count < city_tile_count_opponent | "
    #             info_string += f"reward: {reward} | "
    #             info_string += f"city_tile_count: {city_tile_count}, city_tile_count_opponent: {city_tile_count_opponent}"
                
    #             if random.random() <= 0.005:
    #                 print(info_string)

    #     # if self.total_turns >= 20:
    #     reward_mixed_up = np.exp(-self.total_turns/10000)/np.exp(0) * reward_virtual + reward
    #     if is_game_finished and random.random() <= 0.001:
    #                 print(f"unit_count = {unit_count}, cart_count = {cart_count}, city_count = {city_count}, city_tile_count = {city_tile_count}, reward_mixed_up = {reward_mixed_up}")
    #     # return reward
    #     return reward_mixed_up


    # def get_reward(self, game, is_game_finished, is_new_turn, is_game_error):
    #     """
    #     Returns the reward function for this step of the game.
    #     """
    #     if is_game_error:
    #         # Game environment step failed, assign a game lost reward to not incentivise this
    #         print("Game failed due to error")
    #         return -1.0

    #     if not is_new_turn and not is_game_finished:
    #         # Only apply rewards at the start of each turn
    #         return 0

    #     # Get some basic stats
    #     unit_count = len(game.state["teamStates"][self.team % 2]["units"])
    #     cart_count = 0
    #     for id, u in game.state["teamStates"][self.team % 2]["units"].items():
    #         if u.type == Constants.UNIT_TYPES.CART:
    #             cart_count += 1

    #     unit_count_opponent = len(game.state["teamStates"][(self.team + 1) % 2]["units"])
    #     research = min(game.state["teamStates"][self.team]["researchPoints"], 200.0) # Cap research points at 200
    #     city_count = 0
    #     city_count_opponent = 0
    #     city_tile_count = 0
    #     city_tile_count_opponent = 0
    #     for city in game.cities.values():
    #         if city.team == self.team:
    #             city_count += 1
    #         else:
    #             city_count_opponent += 1

    #         for cell in city.city_cells:
    #             if city.team == self.team:
    #                 city_tile_count += 1
    #             else:
    #                 city_tile_count_opponent += 1
        
    #     # Basic stats
    #     self.stats["game/research"] = research
    #     self.stats["game/city_tiles"] = city_tile_count
    #     self.stats["game/city_count"] = city_count
    #     self.stats["game/unit_count"] = unit_count
    #     self.stats["game/cart_count"] = cart_count
    #     self.stats["game/turns"] = game.state["turn"]

    #     rewards = {}

    #     # Give up to 1.0 reward for each resource based on % of total mined.
    #     type_map = {
    #         Constants.RESOURCE_TYPES.WOOD: "WOOD",
    #         Constants.RESOURCE_TYPES.COAL: "COAL",
    #         Constants.RESOURCE_TYPES.URANIUM: "URANIUM",
    #     }
    #     fuel_now = {}
    #     for type, type_upper in type_map.items():
    #         fuel_now = game.stats["teamStats"][self.team]["resourcesCollected"][type] * game.configs["parameters"]["RESOURCE_TO_FUEL_RATE"][type_upper]
    #         rewards["rew/r_%s" % type] = (fuel_now - self.fuel_last[type]) / self.fuel_start[type]
    #         self.stats["game/%s_rate_mined" % type] = fuel_now / self.fuel_start[type]
    #         self.fuel_last[type] = fuel_now
        
    #     # Give more incentive for coal and uranium
    #     rewards["rew/r_%s" % Constants.RESOURCE_TYPES.COAL] *= 2
    #     rewards["rew/r_%s" % Constants.RESOURCE_TYPES.URANIUM] *= 4
        
    #     # Give a reward based on amount of fuel collected. 1.0 reward for each 20K fuel gathered.
    #     fuel_collected = game.stats["teamStats"][self.team]["fuelGenerated"]
    #     rewards["rew/r_fuel_collected"] = ( (fuel_collected - self.fuel_collected_last) / 20000 )
    #     self.fuel_collected_last = fuel_collected

    #     # Give a reward for unit creation/death. 0.05 reward per unit.
    #     rewards["rew/r_units"] = (unit_count - self.units_last) * 0.05
    #     self.units_last = unit_count

    #     # Give a reward for unit creation/death. 0.1 reward per city.
    #     rewards["rew/r_city_tiles"] = (city_tile_count - self.city_tiles_last) * 0.1
    #     self.city_tiles_last = city_tile_count

    #     # Tiny reward for research to help. Up to 0.5 reward for this.
    #     rewards["rew/r_research"] = (research - self.research_last) / (200 * 2)
    #     self.research_last = research
        
    #     # Give a reward up to around 50.0 based on number of city tiles at the end of the game
    #     rewards["rew/r_city_tiles_end"] = 0
    #     if is_game_finished:
    #         self.is_last_turn = True
    #         rewards["rew/r_city_tiles_end"] = city_tile_count
        
        
    #     # Update the stats and total reward
    #     reward = 0
    #     for name, value in rewards.items():
    #         self.stats[name] += value
    #         reward += value
    #     self.stats["rew/r_total"] += reward

    #     # Print the final game stats sometimes
    #     if is_game_finished and random.random() <= 0.15:
    #         stats_string = []
    #         for key, value in self.stats.items():
    #             stats_string.append("%s=%.2f" % (key, value))
    #         print(",".join(stats_string))


    #     return reward
        
    

    def process_turn(self, game, team):
        """
        Decides on a set of actions for the current turn. Not used in training, only inference.
        Returns: Array of actions to perform.
        """
        start_time = time.time()
        actions = []
        new_turn = True

        # Inference the model per-unit
        units = game.state["teamStates"][team]["units"].values()
        for unit in units:
            if unit.can_act():
                obs = self.get_observation(game, unit, None, unit.team, new_turn)
                action_code, _states = self.model.predict(obs, deterministic=False)
                if action_code is not None:
                    actions.append(
                        self.action_code_to_action(action_code, game=game, unit=unit, city_tile=None, team=unit.team))
                new_turn = False

        # Inference the model per-city
        cities = game.cities.values()
        for city in cities:
            if city.team == team:
                for cell in city.city_cells:
                    city_tile = cell.city_tile
                    if city_tile.can_act():
                        obs = self.get_observation(game, None, city_tile, city.team, new_turn)
                        action_code, _states = self.model.predict(obs, deterministic=False)
                        if action_code is not None:
                            actions.append(
                                self.action_code_to_action(action_code, game=game, unit=None, city_tile=city_tile,
                                                           team=city.team))
                        new_turn = False

        time_taken = time.time() - start_time
        if time_taken > 0.5:  # Warn if larger than 0.5 seconds.
            print("WARNING: Inference took %.3f seconds for computing actions. Limit is 1 second." % time_taken,
                  file=sys.stderr)

        return actions



Overwriting agent_policy2.py


# Build the environment for training

Notes on metrics:
* An Episode is a single game between your RL agent and it's opponent. This is generally 360 turns, spanning more than 360 unit + city decision steps.
* Mean episode length (ep_len_mean) is the number of decision made per game. The larger this gets, means that it is making more unit + city decision per game, meaning that more units and cities were alive for longer during the game.
* Episode reward mean (ep_rew_mean), is set up as micro-reward funciton for faster learning. Per turn it gets a small reward based on the number of cities and units alive. It gets a really big reward based on the number of cities and units alive at the end of the game.

In [63]:
import argparse
import glob
import os
import random
from typing import Callable

from stable_baselines3 import PPO  # pip install stable-baselines3
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import SubprocVecEnv

from importlib import reload
import agent_policy
reload(agent_policy) # Reload the file from disk incase the above agent-writing cell block was edited
from agent_policy import AgentPolicy
from agent_policy2 import AgentPolicy2

from luxai2021.env.agent import Agent
from luxai2021.env.lux_env import LuxEnvironment
from luxai2021.game.constants import LuxMatchConfigs_Default


# Default Lux configs
configs = LuxMatchConfigs_Default

# # Create a default opponent agent
mdl = PPO.load("./models/rl_model_1_9000000_steps.zip")
mdl2 = PPO.load("./models/rl_model_1_5000000_steps.zip")
mdl3 = PPO.load("./models/rl_model_1_6000000_steps.zip")
mdl4 = PPO.load("./models/rl_model_1_7000000_steps.zip")
mdl5 = PPO.load("./models/rl_model_1_8000000_steps.zip")
mdl_list = [mdl,mdl2,mdl3,mdl4,mdl5]
opponent = AgentPolicy2(mode="inference", model=mdl, md_list = mdl_list)
# opponent = Agent()

# # Create a RL agent in training mode
ply_model = PPO.load("./models/rl_model_1_9000000_steps.zip")
player = AgentPolicy(mode="train", model = ply_model)

# # Create the Lux environment
env = LuxEnvironment(configs=configs,
                     learning_agent=player,
                     opponent_agent=opponent)

# Create the Lux environment
# env = LuxEnvironment(configs=configs,
#                      learning_agent=player,
#                      opponent_agent=player)

# # Define the model, you can pick other RL algos from Stable Baselines3 instead if you like
# model = PPO("MlpPolicy",
#                 env,
#                 verbose=1,
#                 tensorboard_log="./lux_tensorboard/",
#                 learning_rate=0.001,
#                 gamma=0.999,
#                 gae_lambda=0.95,
#                 batch_size=2048 * 8,
#                 n_steps=2048 * 8
#             )

# Define a learning rate schedule
# (number of steps, learning_rate)
schedule = [
    #(2000000, 0.01),
    #(6000000, 0.001),
    # (1000000, 0.01),
    (1000000, 0.0004),
    # (3000000, 0.0001),
]

Running in inference-only mode.


In [64]:
import inspect 
a = inspect.getsourcefile(AgentPolicy)
print(a)

# print(configs)

f:\rl\ProjectModel\ppo_with_revised_policy\agent_policy.py


# Define Own Model (YCHuang)

In [65]:
import torch
import numpy as np
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class CustomFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=64):
        super(CustomFeatureExtractor, self).__init__(observation_space, features_dim=features_dim)
        # CNN for the map
        self.cnn = nn.Sequential(
            nn.Conv2d(18, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(512, 64),  # Adjust based on your image size after pooling
            nn.ReLU()
        )

        # MLP for the vector
        self.vector_processor = nn.Sequential(
            nn.Linear(85+3, 128),  # Adjust size according to your vector size
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )

        # Combine CNN and MLP features
        self.combiner = nn.Linear(64 + 64, features_dim)

    # def _get_conv_output(self, observation_space):
    #     # Dummy pass to get output size
    #     return self.cnn(torch.zeros(1, *observation_space.shape)).data.view(1, -1).size(1)

    def _get_conv_output(self, observation_space):
        # Create a dummy input to pass through the CNN to determine the output size
        # Specifically target the 'map' part of the observation space
        dummy_input = torch.zeros(1, *observation_space.spaces['map'].shape)
        dummy_output = self.cnn(dummy_input)
        return int(np.prod(dummy_output.size()))
    

    def forward(self, observations):
        # print("observations:", observations)
        # print(f"observations['map'].shape: {observations['map'].shape}")
        map_features = self.cnn(observations['map'].float())
        vector_features = self.vector_processor(observations['vector'].float())
        combined_features = torch.cat([map_features, vector_features], dim=1)

        return self.combiner(combined_features)
import torch.nn.functional as F

# class CustomFeatureExtractor(BaseFeaturesExtractor):
#     def __init__(self, observation_space, features_dim=64):
#         super(CustomFeatureExtractor, self).__init__(observation_space, features_dim=features_dim)
#         # CNN for the map
#         self.cnn = nn.Sequential(
#             nn.Conv2d(18, 32, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2),
#             nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2),
#             nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2),
#             nn.Flatten(),
#             nn.Linear(512, 64),  # Adjust based on your image size after pooling
#             nn.ReLU()
#         )

#         # MLP for the vector
#         self.vector_processor = nn.Sequential(
#             nn.Linear(85+3, 128),  # Adjust size according to your vector size
#             nn.ReLU(),
#             nn.Linear(128, 64),
#             nn.ReLU()
#         )

#         # Combine CNN and MLP features
#         self.query = nn.Linear(128, 128)
#         self.key = nn.Linear(128, 128)
#         self.value = nn.Linear(128, 128)
#         self.final_layer = nn.Linear(128, features_dim)

#     # def _get_conv_output(self, observation_space):
#     #     # Dummy pass to get output size
#     #     return self.cnn(torch.zeros(1, *observation_space.shape)).data.view(1, -1).size(1)

#     def _get_conv_output(self, observation_space):
#         # Create a dummy input to pass through the CNN to determine the output size
#         # Specifically target the 'map' part of the observation space
#         dummy_input = torch.zeros(1, *observation_space.spaces['map'].shape)
#         dummy_output = self.cnn(dummy_input)
#         return int(np.prod(dummy_output.size()))
    

#     def forward(self, observations):
#         # print("observations:", observations)
#         # print(f"observations['map'].shape: {observations['map'].shape}")
#         map_features = self.cnn(observations['map'].float())
#         vector_features = self.vector_processor(observations['vector'].float())
#         combined_features = torch.cat([map_features, vector_features], dim=1)
#         Q = self.query(combined_features)
#         K = self.key(combined_features)
#         V = self.value(combined_features)

#         attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / 256 ** 0.5
#         attention_weights = F.softmax(attention_scores, dim=-1)
        
#         return self.final_layer(torch.matmul(attention_weights, V))
    


In [66]:
# import torchsummary
import torchinfo

# Create an example input in the dictionary format expected by your model
example_input = {'map': torch.zeros(2, 18, 32, 32), 'vector': torch.zeros(2, 85+3)}
# Use the 'input_data' parameter of torchinfo.summary to pass the example input
torchinfo.summary(CustomFeatureExtractor(env.observation_space), 
                        input_data=[example_input], 
                        device='cpu')   

Layer (type:depth-idx)                   Output Shape              Param #
CustomFeatureExtractor                   [2, 64]                   --
├─Sequential: 1-1                        [2, 64]                   --
│    └─Conv2d: 2-1                       [2, 32, 32, 32]           5,216
│    └─ReLU: 2-2                         [2, 32, 32, 32]           --
│    └─MaxPool2d: 2-3                    [2, 32, 16, 16]           --
│    └─Conv2d: 2-4                       [2, 32, 16, 16]           9,248
│    └─ReLU: 2-5                         [2, 32, 16, 16]           --
│    └─MaxPool2d: 2-6                    [2, 32, 8, 8]             --
│    └─Conv2d: 2-7                       [2, 32, 8, 8]             9,248
│    └─ReLU: 2-8                         [2, 32, 8, 8]             --
│    └─MaxPool2d: 2-9                    [2, 32, 4, 4]             --
│    └─Flatten: 2-10                     [2, 512]                  --
│    └─Linear: 2-11                      [2, 64]                   32,832
│ 

In [67]:
from stable_baselines3.common.policies import ActorCriticPolicy

            # self.observation_space,
            # self.action_space,
            # self.lr_schedule,
            # use_sde=self.use_sde,
class CustomActorCriticPolicy(ActorCriticPolicy):
    def __init__(self, observation_space, action_space, lr_schedule, **kwargs):
        super(CustomActorCriticPolicy, self).__init__(observation_space, action_space,
                                                      lr_schedule,
                                                      **kwargs)
        # print(kwargs)
        features_dim = kwargs["features_extractor_kwargs"]["features_dim"]
        # print(features_dim)
        # # Define additional network components here, e.g., RNN, MLP
        # # Example: An RNN followed by an MLP for action and value heads
        # self.rnn = nn.LSTM(input_size=kwargs['features_extractor_kwargs']['features_dim'], hidden_size=128, batch_first=True)
        self.action_head = torch.nn.Sequential(
            nn.Linear(features_dim, 64),
            #nn.BatchNorm1d(64),
            nn.Tanh(),
            nn.Linear(64, 64),
            #nn.BatchNorm1d(64),
            nn.Tanh()
        )        
        self.mlp_extractor.policy_net = self.action_head
        # nn.Linear(128, action_space.n)
        
        self.value_head = torch.nn.Sequential(
            nn.Linear(features_dim, 64),
            #nn.BatchNorm1d(64),
            nn.Tanh(),
            nn.Linear(64, 64),
            #nn.BatchNorm1d(64),
            nn.Tanh()
        )    
        self.mlp_extractor.value_net = self.value_head    
        # nn.Linear(128, 1)    

        # self.value_net = torch.nn.Linear(in_features=64, out_features=1, bias=True)
        # self.action_net = torch.nn.Linear(in_features=64, out_features=action_space.n, bias=True)


        self.value_net = torch.nn.Linear(in_features=64, out_features=1, bias=True)
        self.action_net = torch.nn.Linear(in_features=64, out_features=action_space.n, bias=True)


    def _build(self, lr_schedule):
        # This method is called to build the network
        # Make sure to call super()._build(lr_schedule) if you override it
        super()._build(lr_schedule)

    def forward(self, obs, deterministic=False):
        """
        Forward pass in all the networks (actor and critic)

        :param obs: Observation
        :param deterministic: Whether to sample or use deterministic actions
        :return: action, value and log probability of the action
        """
        # Preprocess the observation if needed
        features = self.extract_features(obs)
        if self.share_features_extractor:
            latent_pi, latent_vf = self.mlp_extractor(features)
        else:
            pi_features, vf_features = features
            latent_pi = self.mlp_extractor.forward_actor(pi_features)
            latent_vf = self.mlp_extractor.forward_critic(vf_features)
        # Evaluate the values for the given observations
        values = self.value_net(latent_vf)
        distribution = self._get_action_dist_from_latent(latent_pi)
        actions = distribution.get_actions(deterministic=deterministic)
        log_prob = distribution.log_prob(actions)
        actions = actions.reshape((-1, *self.action_space.shape))
        return actions, values, log_prob
    
        # # Custom forward pass
        # # print(f"obs: {obs}")
        # features = self.extract_features(obs)
        # rnn_out, _ = self.rnn(features.unsqueeze(0))
        # actions = self.action_head(rnn_out.squeeze(0))
        # values = self.value_head(rnn_out.squeeze(0))

        # log_probs = None
        # # pick actions based on the output of the network
        # if deterministic:
        #     actions = torch.argmax(actions, dim=1)
        # else:
        #     m = torch.distributions.Categorical(logits=actions)
        #     actions = m.sample()
        #     log_probs = m.log_prob(actions)
        
        # # log_probs = self.log_prob_layer(actions)  # Your custom log probability calculation
        # # log_probs = torch.log_softmax(actions, dim=1)

        # return actions, values, log_probs
        # # return actions, values


In [68]:
# import inspect
# a = inspect.getsource(PPO)
# print(a)

In [69]:
from stable_baselines3 import PPO


import warnings
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union

import numpy as np
import torch as th
from gym import spaces
from torch.nn import functional as F

from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn

class PPO_self(PPO):
    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 3e-4,
        n_steps: int = 2048,
        batch_size: int = 64,
        n_epochs: int = 10,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_range: Union[float, Schedule] = 0.2,
        clip_range_vf: Union[None, float, Schedule] = None,
        normalize_advantage: bool = True,
        ent_coef: float = 0.0,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        use_sde: bool = False,
        sde_sample_freq: int = -1,
        rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
        rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
        target_kl: Optional[float] = None,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
    ):
        super().__init__(
            policy,
            env,
            learning_rate=learning_rate,
            n_steps=n_steps,
            batch_size=batch_size,
            n_epochs=n_epochs,
            gamma=gamma,
            gae_lambda=gae_lambda,
            clip_range=clip_range,
            clip_range_vf=clip_range_vf,
            normalize_advantage=normalize_advantage,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            use_sde=use_sde,
            sde_sample_freq=sde_sample_freq,
            rollout_buffer_class=rollout_buffer_class,
            rollout_buffer_kwargs=rollout_buffer_kwargs,
            target_kl=target_kl,
            stats_window_size=stats_window_size,
            tensorboard_log=tensorboard_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            seed=seed,
            device=device,
            _init_setup_model=_init_setup_model
        )
        pass

    
    # def train(self) -> None:
    #     """
    #     Update policy using the currently gathered rollout buffer.
    #     """
    #     # Switch to train mode (this affects batch norm / dropout)
    #     self.policy.set_training_mode(True)
    #     # Update optimizer learning rate
    #     self._update_learning_rate(self.policy.optimizer)
    #     # Compute current clip range
    #     clip_range = self.clip_range(self._current_progress_remaining)  # type: ignore[operator]
    #     # Optional: clip range for the value function
    #     if self.clip_range_vf is not None:
    #         clip_range_vf = self.clip_range_vf(self._current_progress_remaining)  # type: ignore[operator]

    #     entropy_losses = []
    #     pg_losses, value_losses = [], []
    #     clip_fractions = []

    #     continue_training = True
    #     # train for n_epochs epochs
    #     for epoch in range(self.n_epochs):
    #         approx_kl_divs = []
    #         # Do a complete pass on the rollout buffer
    #         for rollout_data in self.rollout_buffer.get(self.batch_size):
    #             actions = rollout_data.actions
    #             if isinstance(self.action_space, spaces.Discrete):
    #                 # Convert discrete action from float to long
    #                 actions = rollout_data.actions.long().flatten()

    #             # Re-sample the noise matrix because the log_std has changed
    #             if self.use_sde:
    #                 self.policy.reset_noise(self.batch_size)

    #             values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
    #             values = values.flatten()
    #             # Normalize advantage
    #             advantages = rollout_data.advantages
    #             # Normalization does not make sense if mini batchsize == 1, see GH issue #325
    #             if self.normalize_advantage and len(advantages) > 1:
    #                 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    #             # ratio between old and new policy, should be one at the first iteration
    #             ratio = th.exp(log_prob - rollout_data.old_log_prob)

    #             # clipped surrogate loss
    #             policy_loss_1 = advantages * ratio
    #             policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
    #             policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

    #             # Logging
    #             pg_losses.append(policy_loss.item())
    #             clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
    #             clip_fractions.append(clip_fraction)

    #             if self.clip_range_vf is None:
    #                 # No clipping
    #                 values_pred = values
    #             else:
    #                 # Clip the difference between old and new value
    #                 # NOTE: this depends on the reward scaling
    #                 values_pred = rollout_data.old_values + th.clamp(
    #                     values - rollout_data.old_values, -clip_range_vf, clip_range_vf
    #                 )
    #             # Value loss using the TD(gae_lambda) target
    #             value_loss = F.mse_loss(rollout_data.returns, values_pred)
    #             value_losses.append(value_loss.item())

    #             # Entropy loss favor exploration
    #             if entropy is None:
    #                 # Approximate entropy when no analytical form
    #                 entropy_loss = -th.mean(-log_prob)
    #             else:
    #                 entropy_loss = -th.mean(entropy)

    #             entropy_losses.append(entropy_loss.item())

    #             loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

    #             # Calculate approximate form of reverse KL Divergence for early stopping
    #             # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
    #             # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
    #             # and Schulman blog: http://joschu.net/blog/kl-approx.html
    #             with th.no_grad():
    #                 log_ratio = log_prob - rollout_data.old_log_prob
    #                 approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
    #                 approx_kl_divs.append(approx_kl_div)

    #             if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
    #                 continue_training = False
    #                 if self.verbose >= 1:
    #                     print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
    #                 break

    #             # Optimization step
    #             self.policy.optimizer.zero_grad()
    #             loss.backward()
    #             # Clip grad norm
    #             th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
    #             self.policy.optimizer.step()

    #         self._n_updates += 1
    #         if not continue_training:
    #             break

    #     explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

    #     # Logs
    #     self.logger.record("train/entropy_loss", np.mean(entropy_losses))
    #     self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
    #     self.logger.record("train/value_loss", np.mean(value_losses))
    #     self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
    #     self.logger.record("train/clip_fraction", np.mean(clip_fractions))
    #     self.logger.record("train/loss", loss.item())
    #     self.logger.record("train/explained_variance", explained_var)
    #     if hasattr(self.policy, "log_std"):
    #         self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

    #     self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
    #     self.logger.record("train/clip_range", clip_range)
    #     if self.clip_range_vf is not None:
    #         self.logger.record("train/clip_range_vf", clip_range_vf)


In [70]:
def compute_kl_loss(self, predicted_features, actual_features):
    # Assuming both predicted_features and actual_features are normalized
    kl_loss = torch.nn.functional.kl_div(predicted_features.log(), actual_features, reduction='batchmean')
    return kl_loss

# Incorporate this KL loss calculation into your training step


### DEFINE MODEL

In [71]:
from stable_baselines3 import PPO

# model = PPO(CustomActorCriticPolicy, env, verbose=1, policy_kwargs={'features_extractor_class': CustomCNNExtractor, 'features_dim': 256})
# model_self = PPO(CustomActorCriticPolicy,
# # model_self = PPO_self(CustomActorCriticPolicy,
#                 env,
#                 verbose=1,
#                 tensorboard_log="./lux_tensorboard/",
#                 learning_rate=0.001,
#                 gamma=0.999,
#                 gae_lambda=0.95,
#                 batch_size=2048 * 8,
#                 n_steps=2048 * 8,
#                 policy_kwargs={
#                                  'features_extractor_class': CustomFeatureExtractor,
#                                  'features_extractor_kwargs': {'features_dim': 128}  # Correct way
#                              }
#             )
model_self = PPO.load("./models/rl_model_1_9000000_steps.zip", env)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [72]:
# stop_sign
# print(model_self.policy.mlp_extractor)
# print(model_self.policy.value_net)
# print(model_self.policy.action_net)

### TRAIN MODEL

In [73]:
from stable_baselines3.common.utils import get_schedule_fn

print("Training model...")
run_id = 1

# Save a checkpoint every 1M steps
checkpoint_callback = CheckpointCallback(save_freq=1000000,
                                         save_path='./models/',
                                         name_prefix=f'rl_model_{run_id}')

# Train the policy 
for steps, learning_rate in schedule:
    model_self.lr_schedule = get_schedule_fn(learning_rate)
    model_self.learn(total_timesteps=steps,
                callback=checkpoint_callback,
                reset_num_timesteps = False)
7
# Save final model
model_self.save(path=f'models/model.zip')

print("Done training model.")

Training model...
Logging to ./lux_tensorboard/PPO_0
rew/result=1.00,rew/r_total=3.45,rew/r_wood=0.12,rew/r_coal=0.00,rew/r_uranium=0.00,rew/r_research=0.02,rew/r_city_tiles_end=3.00,rew/r_fuel_collected=0.01,rew/r_units=0.00,rew/r_city_tiles=0.30,game/turns=37.00,game/research=7.00,game/unit_count=0.00,game/cart_count=0.00,game/city_count=2.00,game/city_tiles=3.00,game/wood_rate_mined=0.12,game/coal_rate_mined=0.00,game/uranium_rate_mined=0.00
rew/result=1.00,rew/r_total=3.99,rew/r_wood=0.36,rew/r_coal=0.00,rew/r_uranium=0.00,rew/r_research=0.07,rew/r_city_tiles_end=3.00,rew/r_fuel_collected=0.07,rew/r_units=0.20,rew/r_city_tiles=0.30,game/turns=116.00,game/research=26.00,game/unit_count=4.00,game/cart_count=0.00,game/city_count=1.00,game/city_tiles=3.00,game/wood_rate_mined=0.36,game/coal_rate_mined=0.00,game/uranium_rate_mined=0.00
rew/result=1.00,rew/r_total=6.64,rew/r_wood=0.62,rew/r_coal=0.00,rew/r_uranium=0.00,rew/r_research=0.16,rew/r_city_tiles_end=5.00,rew/r_fuel_collected=0.

In [74]:
# print(sb3.__version__)

In [75]:
# !pip install stable-baselines3==2.3.0a4

### SAVE MODEL

In [76]:
model = model_self
# # Save final model
model.save(path=f'models/model.zip')

print("Done training model.")

Done training model.


In [77]:
# from kaggle_environments import make
# import json
# from pathlib import Path
# # run another match but with our empty agent
# env = make("lux_ai_2021", configuration={"seed": 5621242, "loglevel": 2, "annotations": True}, debug=True)


# # Play the environment where the RL agent plays against itself
# # steps = env.run(["./kaggle_submissions/main.py", "./kaggle_submissions/main.py"])
# steps = env.run([model, model])


In [78]:
# # Render the match
# env.render(mode="ipython", width=1200, height=800)

### LOAD MODEL

In [79]:
# model_loaded = PPO(CustomActorCriticPolicy,
# # model_self = PPO_self(CustomActorCriticPolicy,
#                 env,
#                 verbose=1,
#                 tensorboard_log="./lux_tensorboard/",
#                 learning_rate=0.001,
#                 gamma=0.999,
#                 gae_lambda=0.95,
#                 batch_size=2048 * 8,
#                 n_steps=2048 * 8,
#                 policy_kwargs={
#                                  'features_extractor_class': CustomFeatureExtractor,
#                                  'features_extractor_kwargs': {'features_dim': 128}  # Correct way
#                              }
#             )
# model_loaded.load("models_saved01/model.zip")

In [80]:
# from kaggle_environments import make
# import json
# from pathlib import Path
# # run another match but with our empty agent
# env = make("lux_ai_2021", configuration={"seed": 5621242, "loglevel": 2, "annotations": True}, debug=True)


# # Play the environment where the RL agent plays against itself
# # steps = env.run(["./kaggle_submissions/main.py", "./kaggle_submissions/main.py"])
# steps = env.run([model, model])


In [81]:
# # Render the match
# env.render(mode="ipython", width=1200, height=800)

# Train the agent against a dummy opponent

In [82]:
# from stable_baselines3.common.utils import get_schedule_fn

# print("Training model...")
# run_id = 1

# # Save a checkpoint every 1M steps
# checkpoint_callback = CheckpointCallback(save_freq=1000000,
#                                          save_path='./models/',
#                                          name_prefix=f'rl_model_{run_id}')

# # Train the policy 
# for steps, learning_rate in schedule:
#     model.lr_schedule = get_schedule_fn(learning_rate)
#     model.learn(total_timesteps=steps,
#                 callback=checkpoint_callback,
#                 reset_num_timesteps = False)

# # Save final model
# model.save(path=f'models/model.zip')

# print("Done training model.")

In [83]:
# # Save final model
# model.save(path=f'models/model.zip')

# print("Done training model.")

# Set up a Kaggle Submission and lux replay environment for the agent

In [84]:
"""
This downloads two required python package dependencies that are not pre-installed
by Kaggle yet.

This places the following two packages in the current working directory:
    luxai2021
    stable_baselines3
"""

import os
import shutil
import subprocess
import tempfile

def localize_package(git, branch, folder):
    if os.path.exists(folder):
        print("Already localized %s" % folder)
    else:
        # https://stackoverflow.com/questions/51239168/how-to-download-single-file-from-a-git-repository-using-python
        # Create temporary dir
        t = tempfile.mkdtemp()

        args = ['git', 'clone', '--depth=1', git, t, '-b', branch]
        res = subprocess.Popen(args, stdout=subprocess.PIPE)
        output, _error = res.communicate()

        if not _error:
            print(output)
        else:
            print(_error)
        
        # Copy desired file from temporary dir
        shutil.move(os.path.join(t, folder), '.')
        # Remove temporary dir
        shutil.rmtree(t, ignore_errors=True)

localize_package('https://github.com/glmcdona/LuxPythonEnvGym.git', 'main', 'luxai2021')
localize_package('https://github.com/glmcdona/LuxPythonEnvGym.git', 'main', 'kaggle_submissions')
localize_package('https://github.com/DLR-RM/stable-baselines3.git', 'master', 'stable_baselines3')

Already localized luxai2021
Already localized kaggle_submissions
Already localized stable_baselines3


In [85]:
# Move the dependent packages into kaggle submissions
!cp luxai2021 kaggle_submissions
!mv stable_baselines3 kaggle_submissions
!rm ./kaggle_submissions/agent_policy.py
!cp agent_policy.py kaggle_submissions

# Copy the agent and model to the submission 
!cp ./agent_policy.py kaggle_submissions
!cp ./models/rl_model_1_9000000_steps.zip kaggle_submissions

!ls kaggle_submissions

'cp' is not recognized as an internal or external command,
operable program or batch file.
'mv' is not recognized as an internal or external command,
operable program or batch file.
'rm' is not recognized as an internal or external command,
operable program or batch file.
'cp' is not recognized as an internal or external command,
operable program or batch file.
'cp' is not recognized as an internal or external command,
operable program or batch file.


'cp' is not recognized as an internal or external command,
operable program or batch file.
'ls' is not recognized as an internal or external command,
operable program or batch file.


In [86]:
import inspect
from kaggle_environments import make
print(inspect.getsourcefile(make))

d:\mnconda\envs\rlProject\lib\site-packages\kaggle_environments\core.py


In [87]:

# # # Create the Lux environment
# env = LuxEnvironment(configs=configs,
#                      learning_agent=player,
#                      opponent_agent=opponent)

### MAKE TESTING ENVIRONMENT AND RUN

In [88]:
from kaggle_environments import make
import json
from pathlib import Path
# run another match but with our empty agent
env = make("lux_ai_2021", configuration={"seed": 5621242, "loglevel": 2, "annotations": True}, debug=True)


# Play the environment where the RL agent plays against itself
steps = env.run(["./kaggle_submissions/main.py", "./kaggle_submissions/main.py"])


# # play with the trained model
# agent = model_self
# steps = env.run([agent, "./kaggle_submissions/main.py"])

# # store replay to a file
# import json
# replay = env.toJSON()
# with open("replay.json", "w") as f:
#     json.dump(replay, f)

Traceback (most recent call last):
  File "./main_lux-ai-2021.py", line 1, in <module>
    from stable_baselines3 import PPO  # pip install stable-baselines3
  File "f:\rl\ProjectModel\ppo_with_revised_policy\kaggle_submissions\stable_baselines3\__init__.py", line 3, in <module>
    from stable_baselines3.a2c import A2C
  File "f:\rl\ProjectModel\ppo_with_revised_policy\kaggle_submissions\stable_baselines3\a2c\__init__.py", line 1, in <module>
    from stable_baselines3.a2c.a2c import A2C
  File "f:\rl\ProjectModel\ppo_with_revised_policy\kaggle_submissions\stable_baselines3\a2c\a2c.py", line 7, in <module>
    from stable_baselines3.common.buffers import RolloutBuffer
  File "f:\rl\ProjectModel\ppo_with_revised_policy\kaggle_submissions\stable_baselines3\common\buffers.py", line 10, in <module>
    from stable_baselines3.common.type_aliases import (
  File "f:\rl\ProjectModel\ppo_with_revised_policy\kaggle_submissions\stable_baselines3\common\type_aliases.py", line 4, in <module>
    

In [89]:
# Render the match
env.render(mode="ipython", width=1200, height=800)

# Prepare and submit the kaggle submission

In [90]:
!tar -czf submission.tar.gz -C kaggle_submissions .
!ls

'ls' is not recognized as an internal or external command,
operable program or batch file.
