In [1]:
import time
from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode
import flax.serialization
from typing import Dict
import sys
from argparse import Namespace
import jax
import numpy as np
from lux.utils import direction_to
from luxai_s3.params import EnvParams


# from lux.config import EnvConfig
from lux.kit import from_json
env = LuxAIS3GymEnv()




In [244]:
class Agent():
    def __init__(self, player: str, env_cfg) -> None:
        self.player = player
        self.opp_player = "player_1" if self.player == "player_0" else "player_0"
        self.team_id = 0 if self.player == "player_0" else 1
        self.opp_team_id = 1 if self.team_id == 0 else 0
        np.random.seed(0)
        self.env_cfg = env_cfg
        
        self.relic_node_positions = []
        self.discovered_relic_nodes_ids = set()
        self.unit_explore_locations = dict()

    def act(self, step: int, obs, remainingOverageTime: int = 60):
        """implement this function to decide what actions to send to each available unit. 
        
        step is the current timestep number of the game starting from 0 going up to max_steps_in_match * match_count_per_episode - 1.
        """
        unit_mask = np.array(obs["units_mask"][self.team_id]) # shape (max_units, )
        unit_positions = np.array(obs["units"]["position"][self.team_id]) # shape (max_units, 2)
        unit_energys = np.array(obs["units"]["energy"][self.team_id]) # shape (max_units, 1)
        observed_relic_node_positions = np.array(obs["relic_nodes"]) # shape (max_relic_nodes, 2)
        observed_relic_nodes_mask = np.array(obs["relic_nodes_mask"]) # shape (max_relic_nodes, )
        team_points = np.array(obs["team_points"]) # points of each team, team_points[self.team_id] is the points of the your team
        
        # ids of units you can control at this timestep
        available_unit_ids = np.where(unit_mask)[0]
        # visible relic nodes
        visible_relic_node_ids = set(np.where(observed_relic_nodes_mask)[0])
        
        actions = np.zeros((self.env_cfg["max_units"], 3), dtype=int)


        # basic strategy here is simply to have some units randomly explore and some units collecting as much energy as possible
        # and once a relic node is found, we send all units to move randomly around the first relic node to gain points
        # and information about where relic nodes are found are saved for the next match
        
        # save any new relic nodes that we discover for the rest of the game.
        for id in visible_relic_node_ids:
            if id not in self.discovered_relic_nodes_ids:
                self.discovered_relic_nodes_ids.add(id)
                self.relic_node_positions.append(observed_relic_node_positions[id])
            

        # unit ids range from 0 to max_units - 1
        for unit_id in available_unit_ids:
            unit_pos = unit_positions[unit_id]
            unit_energy = unit_energys[unit_id]
            if len(self.relic_node_positions) > 0:
                nearest_relic_node_position = self.relic_node_positions[0]
                manhattan_distance = abs(unit_pos[0] - nearest_relic_node_position[0]) + abs(unit_pos[1] - nearest_relic_node_position[1])
                
                # if close to the relic node we want to hover around it and hope to gain points
                if manhattan_distance <= 4:
                    random_direction = np.random.randint(0, 5)
                    actions[unit_id] = [random_direction, 0, 0]
                else:
                    # otherwise we want to move towards the relic node
                    actions[unit_id] = [direction_to(unit_pos, nearest_relic_node_position), 0, 0]
            else:
                # randomly explore by picking a random location on the map and moving there for about 20 steps
                if step % 20 == 0 or unit_id not in self.unit_explore_locations:
                    rand_loc = (np.random.randint(0, self.env_cfg["map_width"]), np.random.randint(0, self.env_cfg["map_height"]))
                    self.unit_explore_locations[unit_id] = rand_loc
                actions[unit_id] = [direction_to(unit_pos, self.unit_explore_locations[unit_id]), 0, 0]
        return actions


In [2]:
# training setup
np.random.seed(2)
env = LuxAIS3GymEnv()
env = RecordEpisode(env, save_dir="episodes")
env_params = EnvParams(map_type=1, max_steps_in_match=100)
obs, info = env.reset(seed=1, options=dict(params=env_params))

N = env_params.max_steps_in_match * env_params.match_count_per_episode
for _ in range(N):
    obs, reward, terminated, truncated, info=  env.step(env.action_space.sample())

env.close()

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.spaces import Box, Discrete
from collections import namedtuple, defaultdict

In [3]:
## QMIX Neural Network
class QMixNet(nn.Module):
    def __init__(self, num_agents, state_dim, action_dim=6, hidden_dim=32):
        super(QMixNet, self).__init__()
        self.num_agents = num_agents
        self.state_dim = state_dim
        self.action_dim = action_dim

        # State-dependent weights for mixing
        self.hyper_w1 = nn.LazyLinear(action_dim * hidden_dim)
        self.hyper_w2 = nn.LazyLinear(hidden_dim)
        self.hyper_w3 = nn.LazyLinear(1)
        self.hyper_b1 = nn.LazyLinear(num_agents * hidden_dim)
        self.hyper_b2 = nn.LazyLinear(num_agents)

        self.elu = nn.ELU()

    def forward(self, agent_qs, state_inputs):
        batch_size = agent_qs.size(0)
        
        # Flatten state inputs for the hypernetworks
        state_inputs = state_inputs.view(batch_size, -1)

        # Compute weights and biases
        w1 = self.hyper_w1(state_inputs).view(batch_size, self.action_dim, -1)
        w2 = self.hyper_w2(state_inputs).view(batch_size, -1, 1)
        b1 = self.hyper_b1(state_inputs).view(batch_size, self.num_agents, -1)
        b2 = self.hyper_b2(state_inputs).view(batch_size, self.num_agents, 1)

        # Mixing process
        hidden = self.elu(torch.bmm(agent_qs, w1) + b1)
        
        q_total = torch.bmm(hidden, w2) + b2
        q_total = self.hyper_w3(q_total.squeeze())

        return q_total

# Replay Buffer
Transition = namedtuple("Transition", ("state", "actions", "rewards", "next_state", "dones"))

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, *args):
        if len(self.memory) >= self.capacity:
            self.memory.pop(0)
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        indices = np.random.choice(len(self.memory), batch_size, replace=False)
        batch = [self.memory[idx] for idx in indices]
        return Transition(*zip(*batch))

    def __len__(self):
        return len(self.memory)

In [4]:
# Training Loop
num_agents = 16
np.random.seed(16)
env = LuxAIS3GymEnv()
env = RecordEpisode(env, save_dir="episodes")
env_params = EnvParams(map_type=1, max_steps_in_match=100)
state_dim = 2
action_dim = 6
mixing_dim = 32
N = env_params.max_steps_in_match * env_params.match_count_per_episode

q_networks = [nn.Sequential(nn.Linear(state_dim, 64), nn.ReLU(), nn.Linear(64, action_dim)) for _ in range(num_agents)]
optimizers = [optim.Adam(q.parameters(), lr=0.001) for q in q_networks]
mixing_network = QMixNet(num_agents, state_dim, action_dim, mixing_dim)
mixing_optimizer = optim.Adam(mixing_network.parameters(), lr=0.001)

buffer = ReplayBuffer(capacity=10000)
batch_size = 32
epsilon = 0.1
num_episodes = 1000

def select_action(q_values, epsilon):
    if np.random.rand() < epsilon:
        return np.random.randint(q_values.size(-1))
    return [torch.argmax(q_values).item(),0,0]

In [6]:
for episode in range(num_episodes):
    state, info = env.reset(seed=1, options=dict(params=env_params))
    done = False
    episode_transitions = []
    i =0
    while not done or i < N:
        actions = np.zeros((16, 3), dtype=int)
        q_values = {}
        state = state['player_0'].units.position[0]
        
        for agent_id, q_network in zip(range(num_agents), q_networks):
            obs = torch.tensor(np.array(state[agent_id]), dtype=torch.float32).unsqueeze(0)
            q_values[agent_id] = q_network(obs)
            actions[agent_id] = select_action(q_values[agent_id], epsilon)

        opp_actions = env.action_space.sample()['player_1'] # random action opponents
        act = {'player_0': actions, 'player_1': opp_actions}

        next_state, rewards, terminated, truncated, info = env.step(act)
        dones = terminated['player_0'] + truncated['player_0']
        buffer.push(state, actions, rewards['player_0'], next_state['player_0'].units.position[0], dones)
        state = next_state
        done = dones
        i = i + 1

    if len(buffer) >= batch_size:
        batch = buffer.sample(batch_size)
        state_batch = torch.tensor(np.array([s for s in batch.state]), dtype=torch.float32)
        action_batch = torch.tensor(np.array([a for a in batch.actions]), dtype=torch.int64)
        reward_batch = torch.tensor(np.array([r for r in batch.rewards]), dtype=torch.float32)
        next_state_batch = torch.tensor(np.array([ns for ns in batch.next_state]), dtype=torch.float32)
        done_batch = torch.tensor(np.array([d for d in batch.dones]), dtype=torch.float32)

        # Calculate individual Q-values
        agent_qs = []
        for i, q_network in enumerate(q_networks):
            agent_qs.append(q_network(state_batch[:, i, :]))
        agent_qs = torch.stack(agent_qs, dim=1)

        # Calculate total Q-value using mixing network
        state_inputs = state_batch.view(batch_size, -1)
        q_total = mixing_network(agent_qs, state_inputs)

        # Compute loss and update networks
        loss = torch.mean((reward_batch - q_total) ** 2)

        for optimizer in optimizers:
            optimizer.zero_grad()
        mixing_optimizer.zero_grad()
        loss.backward()

        for optimizer in optimizers:
            optimizer.step()
        mixing_optimizer.step()

    if episode % 1 == 0:
        print(f"Episode {episode}, Loss: {loss.item():.4f}")


Episode 0, Loss: 914.7701
Episode 1, Loss: 4320.4917
Episode 2, Loss: 260.4322
Episode 3, Loss: 135.6003
Episode 4, Loss: 5157.8965
Episode 5, Loss: 685.6636
Episode 6, Loss: 308.9046
Episode 7, Loss: 79.1759
Episode 8, Loss: 12.3900
Episode 9, Loss: 875.3238
Episode 10, Loss: 59.8169
Episode 11, Loss: 343.3664
Episode 12, Loss: 101.9702
Episode 13, Loss: 136.9545
Episode 14, Loss: 52.1247
Episode 15, Loss: 121.8430
Episode 16, Loss: 364.3985
Episode 17, Loss: 49.6874
Episode 18, Loss: 198.9029
Episode 19, Loss: 472.8383
Episode 20, Loss: 191.5628
Episode 21, Loss: 87.7426
Episode 22, Loss: 1579.8009
Episode 23, Loss: 435.2163
Episode 24, Loss: 677.3134
Episode 25, Loss: 91.5042
Episode 26, Loss: 410.9431
Episode 27, Loss: 82.1918
Episode 28, Loss: 927.1182
Episode 29, Loss: 821.4628
Episode 30, Loss: 94.2729
Episode 31, Loss: 94.2055


KeyboardInterrupt: 

In [13]:
rewards

{'player_0': Array(0, dtype=int32), 'player_1': Array(5, dtype=int32)}