In [1]:
!pip install kaggle-environments -U > /dev/null 2>&1s
!cp -r ../input/lux-ai-2021/* .

In [2]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split
import math
from itertools import count

In [3]:
from lux.constants import Constants
from kaggle_environments import make

In [4]:
def agent(observation, configuration):
    global game_state
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = []
    
    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                if unit_count < player.city_tile_count: 
                    actions.append(city_tile.build_worker())
                    unit_count += 1
                elif not player.researched_uranium():
                    actions.append(city_tile.research())
                    player.research_points += 1
    
    # Worker Actions
    dest = []
    for unit in player.units:
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p = model(torch.from_numpy(state).unsqueeze(0))

            policy = p.squeeze(0).numpy()

            action, pos = get_action(policy, unit, dest)
            actions.append(action)
            dest.append(pos)

    return actions

### 0) Set Inputs

In [5]:
# Global Variables
# < SYSTEM >
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_PATH = './policy_network'

# < TRAINING >
EPS_START = 0.5
EPS_END = 0.01
EPS_DECAY = 200
BATCH_SIZE = 32
NUM_EPOCHS = 100
GAMMA = 0.999
STEPS_DONE = 0
REPLAY_CAPACITY = 20
TARGET_UPDATE = 10
LEARNING_RATE = 1e-3

# < MAP >
N_ACTIONS = 5

step[0] = [{'action': [], 'reward': 0, 'info': {}, 'observation': {'remainingOverageTime': 60, 'step': 0, 'width': 24, 'height': 24, 'reward': 0, 'globalUnitIDCount': 2, 'globalCityIDCount': 2, 'player': 0, 'updates': ['0', '24 24', 'rp 0 0', 'rp 1 0', 'r uranium 0 0 310', 'r uranium 0 14 344', 'r uranium 0 23 304', 'r wood 3 10 333', 'r coal 3 23 407', 'r wood 4 7 323', 'r wood 4 8 324', 'r wood 4 10 318', 'r wood 4 11 360', 'r wood 5 8 345', 'r wood 5 9 303', 'r wood 5 10 351', 'r wood 5 21 800', 'r wood 5 22 800', 'r wood 6 8 387', 'r wood 6 9 342', 'r wood 6 10 326', 'r wood 6 11 398', 'r wood 6 12 393', 'r wood 6 19 350', 'r wood 6 20 331', 'r wood 6 21 800', 'r wood 7 10 355', 'r wood 7 11 376', 'r wood 7 12 310', 'r wood 7 19 364', 'r wood 7 20 396', 'r wood 8 11 346', 'r wood 8 12 394', 'r wood 8 20 385', 'r wood 9 10 376', 'r uranium 10 0 324', 'r uranium 10 1 322', 'r uranium 11 0 349', 'r coal 11 5 391', 'r coal 11 6 418', 'r uranium 12 0 349', 'r coal 12 5 391', 'r coal 12 6 418', 'r uranium 13 0 324', 'r uranium 13 1 322', 'r wood 14 10 376', 'r wood 15 11 346', 'r wood 15 12 394', 'r wood 15 20 385', 'r wood 16 10 355', 'r wood 16 11 376', 'r wood 16 12 310', 'r wood 16 19 364', 'r wood 16 20 396', 'r wood 17 8 387', 'r wood 17 9 342', 'r wood 17 10 326', 'r wood 17 11 398', 'r wood 17 12 393', 'r wood 17 19 350', 'r wood 17 20 331', 'r wood 17 21 800', 'r wood 18 8 345', 'r wood 18 9 303', 'r wood 18 10 351', 'r wood 18 21 800', 'r wood 18 22 800', 'r wood 19 7 323', 'r wood 19 8 324', 'r wood 19 10 318', 'r wood 19 11 360', 'r wood 20 10 333', 'r coal 20 23 407', 'r uranium 23 0 310', 'r uranium 23 14 344', 'r uranium 23 23 304', 'u 0 0 u_1 6 22 0 0 0 0', 'u 0 1 u_2 17 22 0 0 0 0', 'c 0 c_1 0 23', 'c 1 c_2 0 23', 'ct 0 c_1 6 22 0', 'ct 1 c_2 17 22 0', 'ccd 6 22 6', 'ccd 17 22 6', 'D_DONE']}, 'status': 'ACTIVE'}, {'action': [], 'reward': 0, 'info': {}, 'observation': {'remainingOverageTime': 60, 'reward': 0, 'player': 1}, 'status': 'ACTIVE'}]

step[1] = [{'action': None, 'reward': None, 'info': {}, 'observation': {'remainingOverageTime': 60, 'step': 1, 'width': 24, 'height': 24, 'reward': 10001, 'globalUnitIDCount': 2, 'globalCityIDCount': 2, 'player': 0, 'updates': ['rp 0 0', 'rp 1 0', 'r uranium 0 0 310', 'r uranium 0 14 344', 'r uranium 0 23 304', 'r wood 3 10 342', 'r coal 3 23 407', 'r wood 4 7 332', 'r wood 4 8 333', 'r wood 4 10 326', 'r wood 4 11 369', 'r wood 5 8 354', 'r wood 5 9 311', 'r wood 5 10 360', 'r wood 5 21 800', 'r wood 5 22 780', 'r wood 6 8 397', 'r wood 6 9 351', 'r wood 6 10 335', 'r wood 6 11 408', 'r wood 6 12 403', 'r wood 6 19 359', 'r wood 6 20 340', 'r wood 6 21 780', 'r wood 7 10 364', 'r wood 7 11 386', 'r wood 7 12 318', 'r wood 7 19 374', 'r wood 7 20 406', 'r wood 8 11 355', 'r wood 8 12 404', 'r wood 8 20 395', 'r wood 9 10 386', 'r uranium 10 0 324', 'r uranium 10 1 322', 'r uranium 11 0 349', 'r coal 11 5 391', 'r coal 11 6 418', 'r uranium 12 0 349', 'r coal 12 5 391', 'r coal 12 6 418', 'r uranium 13 0 324', 'r uranium 13 1 322', 'r wood 14 10 386', 'r wood 15 11 355', 'r wood 15 12 404', 'r wood 15 20 395', 'r wood 16 10 364', 'r wood 16 11 386', 'r wood 16 12 318', 'r wood 16 19 374', 'r wood 16 20 406', 'r wood 17 8 397', 'r wood 17 9 351', 'r wood 17 10 335', 'r wood 17 11 408', 'r wood 17 12 403', 'r wood 17 19 359', 'r wood 17 20 340', 'r wood 17 21 780', 'r wood 18 8 354', 'r wood 18 9 311', 'r wood 18 10 360', 'r wood 18 21 800', 'r wood 18 22 780', 'r wood 19 7 332', 'r wood 19 8 333', 'r wood 19 10 326', 'r wood 19 11 369', 'r wood 20 10 342', 'r coal 20 23 407', 'r uranium 23 0 310', 'r uranium 23 14 344', 'r uranium 23 23 304', 'u 0 0 u_1 6 22 0 0 0 0', 'u 0 1 u_2 17 22 0 0 0 0', 'c 0 c_1 40 23', 'c 1 c_2 40 23', 'ct 0 c_1 6 22 0', 'ct 1 c_2 17 22 0', 'ccd 6 22 6', 'ccd 17 22 6', 'D_DONE']}, 'status': 'ERROR'}, {'action': None, 'reward': None, 'info': {}, 'observation': {'remainingOverageTime': 60, 'reward': 10001, 'player': 1}, 'status': 'ERROR'}]

step[100] = {'action': ['m u_1 w'],
 'reward': 10001,
 'info': {},
 'observation': {'remainingOverageTime': 60,
  'step': 100,
  'width': 24,
  'height': 24,
  'reward': 10001,
  'globalUnitIDCount': 2,
  'globalCityIDCount': 2,
  'player': 0,
  'updates': ['rp 0 0',
   'rp 1 0',
   'r coal 0 3 367',
   'r coal 0 4 389',
   'r coal 0 8 410',
   'r uranium 0 23 302',
   'r wood 1 21 500',
   'r wood 1 22 500',
   'r wood 1 23 500',
   'r wood 2 21 500',
   'r wood 2 22 500',
   'r wood 3 22 500',
   'r wood 6 11 500',
   'r wood 7 10 500',
   'r wood 8 1 500',
   'r wood 8 2 500',
   'r wood 8 3 500',
   'r wood 8 9 500',
   'r wood 8 10 500',
   'r uranium 8 16 328',
   'r uranium 8 17 310',
   'r coal 8 22 354',
   'r coal 8 23 396',
   'r wood 9 9 500',
   'r wood 9 10 500',
   'r uranium 9 17 342',
   'r coal 9 23 394',
   'r wood 10 10 500',
   'r wood 13 10 500',
   'r wood 14 9 500',
   'r wood 14 10 500',
   'r uranium 14 17 342',
   'r coal 14 23 394',
   'r wood 15 1 500',
   'r wood 15 2 500',
   'r wood 15 3 500',
   'r wood 15 9 500',
   'r wood 15 10 500',
   'r uranium 15 16 328',
   'r uranium 15 17 310',
   'r coal 15 22 354',
   'r coal 15 23 396',
   'r wood 16 10 500',
   'r wood 17 11 500',
   'r wood 20 22 500',
   'r wood 21 21 500',
   'r wood 21 22 500',
   'r wood 22 21 500',
   'r wood 22 22 500',
   'r wood 22 23 500',
   'r coal 23 3 367',
   'r coal 23 4 389',
   'r coal 23 8 410',
   'r uranium 23 23 302',
   'u 0 0 u_1 7 2 0 0 0 0',
   'u 0 1 u_2 16 2 0 0 0 0',
   'c 0 c_1 2240 23',
   'c 1 c_2 2240 23',
   'ct 0 c_1 7 2 0',
   'ct 1 c_2 16 2 0',
   'ccd 7 2 6',
   'ccd 16 2 6',
   'D_DONE']},
 'status': 'ACTIVE'}

In [6]:
INPUT_CONSTANTS = Constants.INPUT_CONSTANTS
RESOURCE_TYPES = Constants.RESOURCE_TYPES

def updateMap(nStep: int, \
              nXShift: int, \
              nYShift: int, \
              nTeam: int, \
              sUId: str, \
              updateList: list) -> list:

    # indexing
    # rp  - gameMap[0:2]                  #resource points
    # r   - gameMap[2:5]                  #resource
    # u   - gameMap[5:13]                 #unit
    # c   - ...it only consumes fuels     #city
    # ct  - gameMap[8:12]                 #citytile
    # ccd - gameMap[]                     #roads (city cool down)

    rpStart = 0
    rStart = 2
    uStart = 5
    ctStart = 8

    gameMap = np.zeros((20, 32, 32))
    cityDict: dict = {}

    for update in updateList:
        cmdList: list[str] = update.split(' ')

        sIdentifier: str = cmdList[0]
        if INPUT_CONSTANTS.RESEARCH_POINTS == sIdentifier:
            team = int(cmdList[1])
            rp = int(cmdList[2])
            idx = rpStart + (team - nTeam) % 2
            value = min(rp, 200) / 200
            gameMap[idx, :] = value

        elif INPUT_CONSTANTS.RESOURCES == sIdentifier:
            rtype = cmdList[1]
            x = int(cmdList[2]) + nXShift
            y = int(cmdList[3]) + nYShift
            amt = int(float(cmdList[4]))
            idx = rStart + {'wood':0, 'coal':1, 'uranium':2}[rtype]
            value = amt / 800
            gameMap[idx, x, y] = value

        elif INPUT_CONSTANTS.UNITS == sIdentifier:
            utype = int(cmdList[1])
            team = int(cmdList[2])
            uid = cmdList[3]
            x = int(cmdList[4])
            y = int(cmdList[5])
            cooldown = float(cmdList[6]) / 6.0
            wood = int(cmdList[7])
            coal = int(cmdList[8])
            uranium = int(cmdList[9])
            resources = (wood + coal + uranium) / 100

            if sUId == uid:
                idx = uStart
                value = (1, resources)
                gameMap[idx:idx+2, x, y] = value
            else:
                idx = uStart + 2
                value = (1, cooldown, resources)
                gameMap[idx:idx+3, x, y] = value

        elif INPUT_CONSTANTS.CITY == sIdentifier:
            team = int(cmdList[1])
            cid: str = cmdList[2]
            fuel = float(cmdList[3])
            lightupkeep = float(cmdList[4])
            cityDict[cid] = min(fuel / lightupkeep, 10) / 10

        elif INPUT_CONSTANTS.CITY_TILES == sIdentifier:
            team = int(cmdList[1])
            cid: str = cmdList[2]
            x = int(cmdList[3]) + nXShift
            y = int(cmdList[4]) + nYShift
            cooldown = float(cmdList[5])
            idx = ctStart + (team - nTeam) % 2 * 2
            value = (1, cityDict[cid])
            gameMap[idx:idx+2, x, y] = value

        elif INPUT_CONSTANTS.ROADS == sIdentifier:
            x = int(cmdList[1])
            y = int(cmdList[2])
            road = float(cmdList[3])


    # Day/Night Cycle
    gameMap[17, :] = nStep % 40 / 40
    # Turns
    gameMap[18, :] = nStep / 360
    # Map Size
    gameMap[19, nXShift:32-nXShift, nYShift:32-nYShift] = 1

    return gameMap
    

In [7]:
def toLabel(player, action):
    if action is None or len(action) < 1:
        return f'u_{player}', None
    strs = action[0].split(' ')
    unit_id = strs[1]
    if strs[0] == 'm':
        label = {'c': None, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
    elif strs[0] == 'bcity':
        label = 4
    else:
        label = None
    return unit_id, label

def depletedResources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True

### 1) Network

In [8]:
class CBasicConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.conv = nn.Conv2d(
            input_dim, output_dim, 
            kernel_size=kernel_size, 
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h) if self.bn is not None else h
        return h

class CLuxNet(nn.Module):
    def __init__( self, nActions ):
        super().__init__()
        layers, filters = 12, 32
        self.conv = CBasicConv2d(20, filters, (3, 3), True)
        self.blocks = nn.ModuleList([
            CBasicConv2d(filters, filters, (3, 3), True) for _ in range( layers )
        ])
        self.head = nn.Linear( filters, nActions, bias=False )

    def forward(self, x):
        h = F.relu_( self.conv(x) )
        for b in self.blocks:
            h = F.relu_( h + b( h ) )
        h = ( h * x[:, :1] ).view( h.size(0), h.size(1), -1 ).sum(-1)
        ret = self.head( h )
        return ret

### 2) Replay Memory

In [9]:
# Input for ReplayMemory
from collections import namedtuple, deque
Data = namedtuple('Data',
                  ('state', 'action', 'next_state', 'reward'))

# state: list(str) = state
# action: list(str) = step[0]['action']
# next_state: list(str) = step[0]['observation']['updates']
class CReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a Data"""
        self.memory.append(Data(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
    
memory = CReplayMemory(REPLAY_CAPACITY)

### 3) Select Action w/ noise

In [10]:
# NO NEED....
def select_action(state, model: CLuxNet):
    global debug
    sample = random.random() # check - range?
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * STEPS_DONE / EPS_DECAY)
    STEPS_DONE += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return model(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

### 4) Optimize Model

In [11]:
def optimizeModel(memory: CReplayMemory, \
                  policyNet: CLuxNet, \
                  targetNet: CLuxNet, \
                  optimizer) -> None:
    global bestAccuracy
    # -1) return exceptions
    if len(memory) < BATCH_SIZE:
        return
    
    # 1) fetch memory in batch size
    datas = memory.sample(BATCH_SIZE)
    
    # 2) make in a bulk list of Data type
    datas = Data(*zip(*datas))
    
    # 3) concatenate state, action, reward
    states = torch.cat(datas.state)
    actions = torch.cat(datas.action)
    rewards = torch.cat(datas.reward)
    
    # 4) next state mask
    nextStateMask = torch.tensor(
        tuple(map(lambda n_s: n_s is not None, datas.next_state)), \
        device = DEVICE, \
        dtype = torch.bool
    )
    
    # 5) concatenate next state
    nextStates = torch.cat([
        n_s for n_s in datas.next_state if n_s is not None
    ])
    
    # 6) Compute Q-Value( Q(s_t, a) ), and select the columns of actions taken for each batch size
    qValue = policyNet(states).gather(1, actions)
    
    # 7) Compute V(s_{t+1}) for all next states
    vValue = torch.zeros(BATCH_SIZE, device=device)
    vValue[nextStateMask] = targetNet(nextStates).max(1)[0].detach() #select action 가능
    
    # 8) Compute expected Q-Values with discount rate
    expcQValue = (vValue * GAMMA + rewards).unsqueeze(1)
    
    # 9) Compute Huber Loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(qValue, expcQValue)
    acc = torch.sum(vValue[nextStateMask] == actions.data[nextStateMask]) / len(vValue[nextStateMask])
    
    # 10) Save the best model
    if bestAccuracy < acc:
        bestAccuracy = acc
        torch.save(targetNet, SAVE_PATH)
        
    # 11) Log
    print(f':: {STEPS_DONE} :: Acc({acc}/{bestAccuracy}), loss({loss}), QValue({expcQValue})')
    
    # 12) Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policyNet.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

### 5) Trainig

In [12]:
def trainModel(policyNet: CLuxNet, \
               targetNet: CLuxNet, \
               optimizer, \
               width: int, \
               height: int) -> None:
    global STEPS_DONE
        
    targetNet.cuda()
        
    for epoch in range(NUM_EPOCHS):
        env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=False)
        steps = env.run(['agent.py', 'agent.py'])
        
        xShift, yShift = 0, 0
        nextState = np.zeros((20, 32, 32), dtype=np.float32)
        for s in steps:
            if len(s) > 1:
                step = s[0]
            else:
                step = s
            
            observation = step['observation']
            
            depletedResources(observation)
            
            nStep: int = observation['step']
            nTeam: int = observation['player']
#             print( step['action'])
            sUId, action = toLabel(nTeam, step['action'])
            
            # Debugging
#             if action is not None and len(action) != 1:
#                 print( "action length:", len(action), action )
            
            if nStep == 0:
                width, height = observation['width'], observation['height']
                xShift, yShift = (32 - width) // 2, (32 - height) // 2
                    
            state = nextState
            reward = torch.tensor([step['reward']], device=DEVICE)
            nextState = updateMap(nStep, xShift, yShift, nTeam, sUId, observation['updates'])
            
            memory.push(state, action, nextState, reward)
            optimizeModel(memory, policyNet, targetNet, optimizer)
            
            STEPS_DONE += 1
        
        if epoch % TARGET_UPDATE == 0:
            targetNet.load_state_dict(policyNet.state_dict())
        

### 6) Run

In [13]:
policyNet = CLuxNet(N_ACTIONS)
targetNet = CLuxNet(N_ACTIONS)
targetNet.load_state_dict(policyNet.state_dict())
optimizer = torch.optim.AdamW(targetNet.parameters(), lr=LEARNING_RATE)

In [None]:
trainModel(policyNet, targetNet, optimizer, 24, 24)

# Preprocessing

In [None]:
n_actions = 5
policy_net = LuxNet(n_actions)
target_net = LuxNet(n_actions)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

train, val = train_test_split(samples, test_size=0.1, random_state=42, stratify=labels)
batch_size = 32
train_loader = DataLoader(
    LuxDataset(obses, train), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2
)
val_loader = DataLoader(
    LuxDataset(obses, val), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)
dataloaders_dict = {"train": train_loader, "val": val_loader}
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(target_net.parameters(), lr=1e-3)
# optimizer = torch.optim.RMSprop(target_net.parameters(), lr=1e-3)

In [None]:
train_model(target_net, dataloaders_dict, criterion, optimizer, num_epochs=15)

# Submission

In [None]:
%%writefile agent.py
import os
import numpy as np
import torch
from lux.game import Game


path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.'
model = torch.jit.load(f'{path}/model.pth')
model.eval()


def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b


game_state = None
def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_action(policy, unit, dest):
    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos 
            
    return unit.move('c'), unit.pos


def agent(observation, configuration):
    global game_state
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = []
    
    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                if unit_count < player.city_tile_count: 
                    actions.append(city_tile.build_worker())
                    unit_count += 1
                elif not player.researched_uranium():
                    actions.append(city_tile.research())
                    player.research_points += 1
    
    # Worker Actions
    dest = []
    for unit in player.units:
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p = model(torch.from_numpy(state).unsqueeze(0))

            policy = p.squeeze(0).numpy()

            action, pos = get_action(policy, unit, dest)
            actions.append(action)
            dest.append(pos)

    return actions

In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=False)
steps = env.run(['agent.py', 'agent.py'])
env.render(mode="ipython", width=1200, height=800)

In [None]:
type(steps)
print( steps[2] )

In [None]:
!tar -czf submission.tar.gz *