In [16]:
!pip install kaggle-environments -U > /dev/null 2>&1
!cp -r ../input/lux-ai-2021/* .

In [15]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split
import math

### Run Episode

In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=False)
steps = env.run(['agent.py', 'agent.py'])
# env.render(mode="ipython", width=1200, height=800)

### Make Inputs

In [5]:
# Global Variables
# < MAP >
width, height = 0, 0

In [None]:
# Input for ReplayMemory
from collections import namedtuple
Data = namedtuple('Data',
                  ('state', 'action', 'next_state', 'reward'))

# state: list(str) = state
# action: list(str) = step[0]['action']
# next_state: list(str) = step[0]['observation']['updates']

step[0] = [{'action': [], 'reward': 0, 'info': {}, 'observation': {'remainingOverageTime': 60, 'step': 0, 'width': 24, 'height': 24, 'reward': 0, 'globalUnitIDCount': 2, 'globalCityIDCount': 2, 'player': 0, 'updates': ['0', '24 24', 'rp 0 0', 'rp 1 0', 'r uranium 0 0 310', 'r uranium 0 14 344', 'r uranium 0 23 304', 'r wood 3 10 333', 'r coal 3 23 407', 'r wood 4 7 323', 'r wood 4 8 324', 'r wood 4 10 318', 'r wood 4 11 360', 'r wood 5 8 345', 'r wood 5 9 303', 'r wood 5 10 351', 'r wood 5 21 800', 'r wood 5 22 800', 'r wood 6 8 387', 'r wood 6 9 342', 'r wood 6 10 326', 'r wood 6 11 398', 'r wood 6 12 393', 'r wood 6 19 350', 'r wood 6 20 331', 'r wood 6 21 800', 'r wood 7 10 355', 'r wood 7 11 376', 'r wood 7 12 310', 'r wood 7 19 364', 'r wood 7 20 396', 'r wood 8 11 346', 'r wood 8 12 394', 'r wood 8 20 385', 'r wood 9 10 376', 'r uranium 10 0 324', 'r uranium 10 1 322', 'r uranium 11 0 349', 'r coal 11 5 391', 'r coal 11 6 418', 'r uranium 12 0 349', 'r coal 12 5 391', 'r coal 12 6 418', 'r uranium 13 0 324', 'r uranium 13 1 322', 'r wood 14 10 376', 'r wood 15 11 346', 'r wood 15 12 394', 'r wood 15 20 385', 'r wood 16 10 355', 'r wood 16 11 376', 'r wood 16 12 310', 'r wood 16 19 364', 'r wood 16 20 396', 'r wood 17 8 387', 'r wood 17 9 342', 'r wood 17 10 326', 'r wood 17 11 398', 'r wood 17 12 393', 'r wood 17 19 350', 'r wood 17 20 331', 'r wood 17 21 800', 'r wood 18 8 345', 'r wood 18 9 303', 'r wood 18 10 351', 'r wood 18 21 800', 'r wood 18 22 800', 'r wood 19 7 323', 'r wood 19 8 324', 'r wood 19 10 318', 'r wood 19 11 360', 'r wood 20 10 333', 'r coal 20 23 407', 'r uranium 23 0 310', 'r uranium 23 14 344', 'r uranium 23 23 304', 'u 0 0 u_1 6 22 0 0 0 0', 'u 0 1 u_2 17 22 0 0 0 0', 'c 0 c_1 0 23', 'c 1 c_2 0 23', 'ct 0 c_1 6 22 0', 'ct 1 c_2 17 22 0', 'ccd 6 22 6', 'ccd 17 22 6', 'D_DONE']}, 'status': 'ACTIVE'}, {'action': [], 'reward': 0, 'info': {}, 'observation': {'remainingOverageTime': 60, 'reward': 0, 'player': 1}, 'status': 'ACTIVE'}]

step[1] = [{'action': None, 'reward': None, 'info': {}, 'observation': {'remainingOverageTime': 60, 'step': 1, 'width': 24, 'height': 24, 'reward': 10001, 'globalUnitIDCount': 2, 'globalCityIDCount': 2, 'player': 0, 'updates': ['rp 0 0', 'rp 1 0', 'r uranium 0 0 310', 'r uranium 0 14 344', 'r uranium 0 23 304', 'r wood 3 10 342', 'r coal 3 23 407', 'r wood 4 7 332', 'r wood 4 8 333', 'r wood 4 10 326', 'r wood 4 11 369', 'r wood 5 8 354', 'r wood 5 9 311', 'r wood 5 10 360', 'r wood 5 21 800', 'r wood 5 22 780', 'r wood 6 8 397', 'r wood 6 9 351', 'r wood 6 10 335', 'r wood 6 11 408', 'r wood 6 12 403', 'r wood 6 19 359', 'r wood 6 20 340', 'r wood 6 21 780', 'r wood 7 10 364', 'r wood 7 11 386', 'r wood 7 12 318', 'r wood 7 19 374', 'r wood 7 20 406', 'r wood 8 11 355', 'r wood 8 12 404', 'r wood 8 20 395', 'r wood 9 10 386', 'r uranium 10 0 324', 'r uranium 10 1 322', 'r uranium 11 0 349', 'r coal 11 5 391', 'r coal 11 6 418', 'r uranium 12 0 349', 'r coal 12 5 391', 'r coal 12 6 418', 'r uranium 13 0 324', 'r uranium 13 1 322', 'r wood 14 10 386', 'r wood 15 11 355', 'r wood 15 12 404', 'r wood 15 20 395', 'r wood 16 10 364', 'r wood 16 11 386', 'r wood 16 12 318', 'r wood 16 19 374', 'r wood 16 20 406', 'r wood 17 8 397', 'r wood 17 9 351', 'r wood 17 10 335', 'r wood 17 11 408', 'r wood 17 12 403', 'r wood 17 19 359', 'r wood 17 20 340', 'r wood 17 21 780', 'r wood 18 8 354', 'r wood 18 9 311', 'r wood 18 10 360', 'r wood 18 21 800', 'r wood 18 22 780', 'r wood 19 7 332', 'r wood 19 8 333', 'r wood 19 10 326', 'r wood 19 11 369', 'r wood 20 10 342', 'r coal 20 23 407', 'r uranium 23 0 310', 'r uranium 23 14 344', 'r uranium 23 23 304', 'u 0 0 u_1 6 22 0 0 0 0', 'u 0 1 u_2 17 22 0 0 0 0', 'c 0 c_1 40 23', 'c 1 c_2 40 23', 'ct 0 c_1 6 22 0', 'ct 1 c_2 17 22 0', 'ccd 6 22 6', 'ccd 17 22 6', 'D_DONE']}, 'status': 'ERROR'}, {'action': None, 'reward': None, 'info': {}, 'observation': {'remainingOverageTime': 60, 'reward': 10001, 'player': 1}, 'status': 'ERROR'}]

step[100] = {'action': ['m u_1 w'],
 'reward': 10001,
 'info': {},
 'observation': {'remainingOverageTime': 60,
  'step': 100,
  'width': 24,
  'height': 24,
  'reward': 10001,
  'globalUnitIDCount': 2,
  'globalCityIDCount': 2,
  'player': 0,
  'updates': ['rp 0 0',
   'rp 1 0',
   'r coal 0 3 367',
   'r coal 0 4 389',
   'r coal 0 8 410',
   'r uranium 0 23 302',
   'r wood 1 21 500',
   'r wood 1 22 500',
   'r wood 1 23 500',
   'r wood 2 21 500',
   'r wood 2 22 500',
   'r wood 3 22 500',
   'r wood 6 11 500',
   'r wood 7 10 500',
   'r wood 8 1 500',
   'r wood 8 2 500',
   'r wood 8 3 500',
   'r wood 8 9 500',
   'r wood 8 10 500',
   'r uranium 8 16 328',
   'r uranium 8 17 310',
   'r coal 8 22 354',
   'r coal 8 23 396',
   'r wood 9 9 500',
   'r wood 9 10 500',
   'r uranium 9 17 342',
   'r coal 9 23 394',
   'r wood 10 10 500',
   'r wood 13 10 500',
   'r wood 14 9 500',
   'r wood 14 10 500',
   'r uranium 14 17 342',
   'r coal 14 23 394',
   'r wood 15 1 500',
   'r wood 15 2 500',
   'r wood 15 3 500',
   'r wood 15 9 500',
   'r wood 15 10 500',
   'r uranium 15 16 328',
   'r uranium 15 17 310',
   'r coal 15 22 354',
   'r coal 15 23 396',
   'r wood 16 10 500',
   'r wood 17 11 500',
   'r wood 20 22 500',
   'r wood 21 21 500',
   'r wood 21 22 500',
   'r wood 22 21 500',
   'r wood 22 22 500',
   'r wood 22 23 500',
   'r coal 23 3 367',
   'r coal 23 4 389',
   'r coal 23 8 410',
   'r uranium 23 23 302',
   'u 0 0 u_1 7 2 0 0 0 0',
   'u 0 1 u_2 16 2 0 0 0 0',
   'c 0 c_1 2240 23',
   'c 1 c_2 2240 23',
   'ct 0 c_1 7 2 0',
   'ct 1 c_2 16 2 0',
   'ccd 7 2 6',
   'ccd 16 2 6',
   'D_DONE']},
 'status': 'ACTIVE'}

In [None]:
# input for optimize_model

# return map states = [][]
def makeInputMap(updateList:list) -> list:
    global width, height
    nXShift:int = (32 - width) // 2
    nYShift:int = (32 - height) // 2
        
    # 20 = Position&Cargo(:2) + 
    gameMapList: np.ndarray(np.float32) = np.zeros()
        
    dRes = dict()
    player = lStep[0]
    
    sAction = player['action']
    if sAction[3] == 'u':
        
    dRes['action']: list(str) = player['action']
    dRes['reward']: int = player['reward']
    dRes['']
    

### Replay Memory

### Select Action

### Optimize Model

### Trainig

In [17]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed = 42
seed_everything(seed)

# Preprocessing

In [18]:
def to_label(action):
    strs = action.split(' ')
    unit_id = strs[1]
    if strs[0] == 'm':
        label = {'c': None, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
    elif strs[0] == 'bcity':
        label = 4
    else:
        label = None
    return unit_id, label


def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True


def create_dataset_from_json(episode_dir, team_name='Toad Brigade'): 
    obses = {}
    samples = []
    append = samples.append
    
    episodes = [path for path in Path(episode_dir).glob('*.json') if 'output' not in path.name]
    for filepath in tqdm(episodes): 
        with open(filepath) as f:
            json_load = json.load(f)

        ep_id = json_load['info']['EpisodeId']
        index = np.argmax([r or 0 for r in json_load['rewards']])
        if json_load['info']['TeamNames'][index] != team_name:
            continue

        for i in range(len(json_load['steps'])-1):
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                actions = json_load['steps'][i+1][index]['action']
                obs = json_load['steps'][i][0]['observation']
                
                #Debugging
                #print("obs:", obs)
                
                if depleted_resources(obs):
                    break
                
                obs['player'] = index
                obs = dict([
                    (k,v) for k,v in obs.items() 
                    if k in ['step', 'updates', 'player', 'width', 'height']
                ])
                obs_id = f'{ep_id}_{i}'
                obses[obs_id] = obs
                                
                for action in actions:
                    unit_id, label = to_label(action)
                    if label is not None:
                        append((obs_id, unit_id, label))

    return obses, samples

In [19]:
episode_dir = '../input/lux-ai-episodes'
obses, samples = create_dataset_from_json(episode_dir)
# print('obses:', len(obses), list(obses.items())[32574], '\nsamples:', len(samples), samples[0])

'26814974_0', { 'height': 12,  
                'width': 12,  
                'player': 0,   
                'step': 0,   
                'updates': ['0',  
                            '12 12',  
                            'rp 0 0',   
                            'rp 1 0',   
                            'r uranium 0 0 326',   
                            'r wood 0 5 800',   
                            'r wood 0 6 800',   
                            'r coal 0 10 386',   
                            'r coal 0 11 366',   
                            'r wood 1 5 800',   
                            'r wood 4 2 371',    
                            'r wood 4 3 340',   
                            'r wood 5 1 397',   
                            'r wood 5 2 344',   
                            'r wood 5 3 326',   
                            'r wood 5 10 377',   
                            'r wood 5 11 390',   
                            'r wood 6 1 397',   
                            'r wood 6 2 344',   
                            'r wood 6 3 326',   
                            'r wood 6 10 377',   
                            'r wood 6 11 390',   
                            'r wood 7 2 371',   
                            'r wood 7 3 340',   
                            'r wood 10 5 800',   
                            'r uranium 11 0 326',   
                            'r wood 11 5 800',   
                            'r wood 11 6 800',   
                            'r coal 11 10 386',   
                            'r coal 11 11 366',   
                            'u 0 0 u_1 1 6 0 0 0 0',   
                            'u 0 1 u_2 10 6 0 0 0 0',   
                            'c 0 c_1 0 23',   
                            'c 1 c_2 0 23',   
                            'ct 0 c_1 1 6 0',   
                            'ct 1 c_2 10 6 0',   
                            'ccd 1 6 6',   
                            'ccd 10 6 6',   
                            'D_DONE']}
samples[0] ==> ('26814974_0', 'u_1', 0)
                            
'26691692_313', {'height': 12, 'player': 0, 'step': 313, 'updates': ['rp 0 190', 'rp 1 131', 'r uranium 0 0 322', 'r uranium 0 1 311', 'r uranium 0 11 304', 'r uranium 11 0 322', 'r uranium 11 1 311', 'r uranium 11 11 304', 'u 0 0 u_3 6 10 0 0 0 0', 'u 0 0 u_7 7 1 0 0 0 0', 'u 0 0 u_15 5 10 0 0 0 0', 'u 0 0 u_26 5 1 0 0 0 0', 'u 0 0 u_28 6 10 0 0 0 0', 'u 0 0 u_29 4 10 0 0 0 0', 'u 0 0 u_32 4 9 0 0 0 0', 'u 0 0 u_35 4 11 0 0 0 0', 'u 0 0 u_37 6 1 0 0 0 0', 'u 0 0 u_41 7 10 0 0 0 0', 'u 0 1 u_2 5 9 0 0 0 0', 'u 0 1 u_8 5 9 0 0 0 0', 'u 0 1 u_14 5 9 0 0 0 0', 'u 0 1 u_18 5 9 0 0 0 0', 'u 0 1 u_24 5 9 0 0 0 0', 'c 0 c_19 364 62', 'c 1 c_30 17 36', 'c 0 c_31 3616 88', 'ct 0 c_19 5 1 0', 'ct 0 c_19 6 1 0', 'ct 0 c_19 5 0 9', 'ct 0 c_19 7 1 5', 'ct 1 c_30 5 9 9', 'ct 1 c_30 6 9 9', 'ct 0 c_31 7 10 8', 'ct 0 c_31 6 10 0', 'ct 0 c_31 4 10 0', 'ct 0 c_31 4 9 0', 'ct 0 c_31 4 11 6', 'ct 0 c_31 5 10 0', 'ccd 5 0 6', 'ccd 5 1 6', 'ccd 6 1 6', 'ccd 7 1 6', 'ccd 4 9 6', 'ccd 5 9 6', 'ccd 6 9 6', 'ccd 4 10 6', 'ccd 5 10 6', 'ccd 6 10 6', 'ccd 7 10 6', 'ccd 4 11 6', 'D_DONE'], 'width': 12}


In [6]:
pwd

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
labels = [sample[-1] for sample in samples]
actions = ['north', 'south', 'west', 'east', 'bcity']
n_actions = 5
for value, count in zip(*np.unique(labels, return_counts=True)):
    print(f'{actions[value]:^5}: {count:>3}')

# Training

In [20]:
# Input for Neural Network
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b


class LuxDataset(Dataset):
    def __init__(self, obses, samples):
        self.obses = obses
        self.samples = samples
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        obs_id, unit_id, action = self.samples[idx]
        obs = self.obses[obs_id]
        state = make_input(obs, unit_id)
        
        return state, action

In [21]:
class CBasicConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.conv = nn.Conv2d(
            input_dim, output_dim, 
            kernel_size=kernel_size, 
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h) if self.bn is not None else h
        return h

class LuxNet(nn.Module):
    def __init__( self, n_actions ):
        super().__init__()
        layers, filters = 12, 32
        self.conv = CBasicConv2d(20, filters, (3, 3), True)
        self.blocks = nn.ModuleList([
            CBasicConv2d(filters, filters, (3, 3), True) for _ in range( layers )
        ])
        self.head = nn.Linear( filters, n_actions, bias=False )

    def forward(self, x):
        h = F.relu_( self.conv(x) )
        for b in self.blocks:
            h = F.relu_( h + b( h ) )
        h = ( h * x[:, :1] ).view( h.size(0), h.size(1), -1 ).sum(-1)
        ret = self.head( h )
        return ret 

In [22]:
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
steps_done = 0
def select_action(state, model):
    global steps_done, debug
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

In [23]:
from collections import namedtuple
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [24]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [25]:
def train_model(model, dataloaders_dict, criterion, optimizer, num_epochs):
    global policy_net
    best_acc = 0.0

    for epoch in range(num_epochs):
        model.cuda()
        
        dataloader = dataloaders_dict[phase]
        for item in tqdm(dataloader, leave=False):
            states = item[0].cuda().float()
            actions = item[1].cuda().long()
            
            optimizer.zero_grad()
            
            counts = len(states)
            for i in range(counts):
                action = select_action(states[i], model)
                _, reward, done, _ = env.step(action.item())
                reward = torch.tensor([reward], device=device)
                
                
                
                
            with torch.set_grad_enabled(phase == 'train'):
                policy = select_action(states)
            
        for t in count():
            # Select and perform an action
            action = select_action(state)
            _, reward, done, _ = env.step(action.item())
            reward = torch.tensor([reward], device=device)

            # Observe new state
            last_screen = current_screen
            current_screen = get_screen()
            if not done:
                next_state = current_screen - last_screen
            else:
                next_state = None

            # Store the transition in memory
            memory.push(state, action, next_state, reward)

            # Move to the next state
            state = next_state

            # Perform one step of the optimization (on the policy network)
            optimize_model()
            if done:
                episode_durations.append(t + 1)
                plot_durations()
                break
        # Update the target network, copying all weights and biases in DQN
        if i_episode % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())

In [27]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=True)
n_actions = 5
policy_net = LuxNet(n_actions)
target_net = LuxNet(n_actions)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

train, val = train_test_split(samples, test_size=0.1, random_state=42, stratify=labels)
batch_size = 32
train_loader = DataLoader(
    LuxDataset(obses, train), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2
)
val_loader = DataLoader(
    LuxDataset(obses, val), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)
dataloaders_dict = {"train": train_loader, "val": val_loader}
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(target_net.parameters(), lr=1e-3)
# optimizer = torch.optim.RMSprop(target_net.parameters(), lr=1e-3)

In [15]:
train_model(target_net, dataloaders_dict, criterion, optimizer, num_epochs=15)

In [None]:
num_episodes = 50
for i_episode in range(num_episodes):
    # Initialize the environment and state
    env.reset()
    last_screen = get_screen()
    current_screen = get_screen()
    state = current_screen - last_screen
    for t in count():
        # Select and perform an action
        action = select_action(state)
        _, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)

        # Observe new state
        last_screen = current_screen
        current_screen = get_screen()
        if not done:
            next_state = current_screen - last_screen
        else:
            next_state = None

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()
        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break
    # Update the target network, copying all weights and biases in DQN
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

# Submission

In [28]:
%%writefile agent.py
import os
import numpy as np
import torch
from lux.game import Game


path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.'
model = torch.jit.load(f'{path}/model.pth')
model.eval()


def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b


game_state = None
def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_action(policy, unit, dest):
    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos 
            
    return unit.move('c'), unit.pos


def agent(observation, configuration):
    global game_state
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = []
    
    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                if unit_count < player.city_tile_count: 
                    actions.append(city_tile.build_worker())
                    unit_count += 1
                elif not player.researched_uranium():
                    actions.append(city_tile.research())
                    player.research_points += 1
    
    # Worker Actions
    dest = []
    for unit in player.units:
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p = model(torch.from_numpy(state).unsqueeze(0))

            policy = p.squeeze(0).numpy()

            action, pos = get_action(policy, unit, dest)
            actions.append(action)
            dest.append(pos)

    return actions

In [29]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=False)
steps = env.run(['agent.py', 'agent.py'])
env.render(mode="ipython", width=1200, height=800)

In [30]:
type(steps)
print( steps[2] )

In [None]:
!tar -czf submission.tar.gz *