In [1]:
!pip install kaggle-environments -U > /dev/null 2>&1
!cp -r ../input/lux-ai-2021/* 

cp: target '../input/lux-ai-2021/main.py' is not a directory


In [2]:
!npm install -g @lux-ai/2021-challenge@latest &> /dev/null

In [3]:
!pip uninstall nvidia_cublas_cu11 -y
!pip install nvidia_cublas_cu12 



[0mCollecting nvidia_cublas_cu12
  Downloading nvidia_cublas_cu12-12.8.3.14-py3-none-manylinux_2_27_x86_64.whl.metadata (1.7 kB)
Downloading nvidia_cublas_cu12-12.8.3.14-py3-none-manylinux_2_27_x86_64.whl (609.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m609.6/609.6 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: nvidia_cublas_cu12
Successfully installed nvidia_cublas_cu12-12.8.3.14


In [4]:
%%bash

AGENT_ST='/kaggle/input/luxai-replay-dataset/data/23281649'

export AGENT_A_DIR=$AGENT_ST
rm -rf agentA/
mkdir -p agentA/
cp -r $AGENT_A_DIR/* agentA/

echo $AGENT_A_DIR > .AGENT_A_DIR

In [5]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split

In [6]:
from enum import IntEnum


MAP_SIZE = 32
UNIT_ACTIONS = 10
CITYTILE_ACTIONS = 3
STATE_CHANNELS = 33


class UnitAction(IntEnum):
    MOVE_NORTH = 0
    MOVE_WEST = 1
    MOVE_SOUTH = 2
    MOVE_EAST = 3
    MOVE_CENTER = 4
    BUILD_CITY = 5
    TRANSFER_NORTH = 6
    TRANSFER_WEST = 7
    TRANSFER_SOUTH = 8
    TRANSFER_EAST = 9


class CitytileAction(IntEnum):
    RESEARCH = 0
    BUILD_WORKER = 1
    DO_NOTHING = 2

In [7]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed = 42
seed_everything(seed)

In [8]:
def to_label(action, units):
    strs = action.split(' ')
    unit_id = strs[1]
    if strs[0] == 'm':
        label = {'c': 9, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
    elif strs[0] == 'bcity':
        label = 4
    elif strs[0] == 't':
        from_pos = units[unit_id]
        if strs[2] in units:
            to_pos = units[strs[2]]
            if from_pos[1]  - 1 == to_pos[1]:
                label = 5 #n
            elif from_pos[1]  + 1 == to_pos[1]:
                label = 6 #s
            if from_pos[0] - 1 == to_pos[0]:
                label = 7  # w
            elif from_pos[0] + 1 == to_pos[0]:
                label = 8  # e
        else:
            label = None
    elif strs[0] == 'p':
        #pillage
        label = None
    else:
        if strs[0] not in ['r','bw','bc']:
            print("Unexpected no acton from",strs)
        label = None
    return unit_id, label





In [9]:
def turns_it_will_live(autonomy, steps_until_night,_next_night_number_turn=-1) ->int:
    autonomy=max(0,autonomy)
    if _next_night_number_turn == -1:
        next_night_number_turn = min(10, 10 + steps_until_night)
    else:
        next_night_number_turn = _next_night_number_turn

    turn_to_night= max(0,steps_until_night)
    # print('turn_to_night',turn_to_night, 'next_night_number_turn',next_night_number_turn, 'aut',autonomy)
    if autonomy>=next_night_number_turn:
       return turns_it_will_live(autonomy-next_night_number_turn,turn_to_night+40,10)
    else:
       return autonomy + turn_to_night

def get_shift_when_map_bigger_array(size_array, unit_coordinate, size_map, shift):
    if unit_coordinate - (size_array // 2) <= 0:
        shift = 0
    elif unit_coordinate - (size_map // 2) <= 0:
        shift = (size_array // 2) - unit_coordinate
    elif unit_coordinate + (size_array // 2) > size_map:
        shift = size_array - size_map
    elif unit_coordinate + (size_map // 2) >= size_map:
        shift = (size_array // 2) - unit_coordinate + 1

    # print("XY5", unit_coordinate, shift)

    return shift



def invalid_size(x,y,size):
    return (x>=size) or (y>=size) or (x<0) or (y<0)


In [10]:
def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True


def virtually_won_game_resources(obs, index, turn):

    our_units = 0
    enemy_units = 0
    our_city = 0
    enemy_city = 0
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]

        if input_identifier == 'u':
            team = int(strs[2])
            if team == index:
                #our team
                our_units += 1
            else:
                # enemy team
                enemy_units += 1

        elif input_identifier == 'ct':
            team = int(strs[1])
            if team == index:
                # our team
                our_city += 1
            else:
                # enemy team
                enemy_city += 1

    if min(our_city,our_units) > 8 * min(enemy_city,enemy_units):
        return True

    return False

In [11]:
def create_dataset_from_json(episode_dir, team_name='', set_sizes=[], exclude_turns_on_after=350):
    if team_name=='':
        print('Need to specify a team name')
    samples = {}
    num_samples = 0
    num_episodes = 0
    num_actions=0
    num_non_actions=0
    num_cannot_work = 0

    episodes = [path for path in Path(episode_dir).glob('*.json') if
                ('output' not in path.name and '_info' not in path.name)]

    print('create_dataset_from_json,"',team_name,'"',set_sizes, ':', episode_dir)
    for filepath in episodes:
        episode_samples = []
        episode_obses = {}
        with open(filepath) as f:
            json_load = json.load(f)

        ep_id = json_load['info']['EpisodeId']
        index = np.argmax([r or 0 for r in json_load['rewards']])
        if json_load['info']['TeamNames'][index] != team_name:
            continue

        if len(set_sizes) != 0:
            this_size = json_load['steps'][0][0]['observation']['height']
            if this_size not in set_sizes:
                continue
        num_episodes += 1
        for i in range(len(json_load['steps']) - 1):
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                actions = json_load['steps'][i + 1][index]['action']

                obs = json_load['steps'][i][0]['observation']

                # do not add samples after the cutoff turn. For example 350 if we do not want the last night
                if i >= exclude_turns_on_after:
                    break

                #do not add samples in which there are no resources, because they have no value for training (noise)
                if depleted_resources(obs):
                    break

                if virtually_won_game_resources(obs,index,i):
                    break

                obs['player'] = index
                obs = dict([
                    (k, v) for k, v in obs.items()
                    if k in ['step', 'updates', 'player', 'width', 'height']
                ])
                obs_id = f'{ep_id}_{i}'
                episode_obses[obs_id] = obs

                units = {}
                unit_can_work = []
                for update in obs['updates']:
                    strs = update.split(' ')
                    input_identifier = strs[0]

                    if input_identifier == 'u':
                        x = int(strs[4])
                        y = int(strs[5])
                        tyep = int(strs[1])
                        team = int(strs[2])
                        if team == index:
                            #our unit!
                            unit_id = strs[3]

                            if type == 1:
                                print('turn',obs['step'], unit_id , 'is a cart...')


                            units[unit_id] = (x, y)
                            cooldown = float(strs[6])
                            if cooldown == 0:
                                unit_can_work.append(unit_id)
                            else:
                                num_cannot_work += 1


                unit_that_worked = []
                for action in actions:
                    unit_id, label = to_label(action, units)
                    unit_that_worked.append(unit_id) # if we move this below "if label", then we consider pillage a stay
                    if label is not None:
                        num_actions += 1
                        episode_samples.append((obs_id, unit_id, label))

                #those units could have worked but it didn't, it is an important to record those
                lazy_units = [u for u in unit_can_work if u not in unit_that_worked]
                # if obs['step'] <=4:
                #     print('turn',obs['step'],len(unit_can_work), '-', len(unit_that_worked),'=',len(lazy_units))
                #     print('turn', obs['step'], unit_can_work, '-', unit_that_worked, '=', lazy_units)
                for unit_id in lazy_units:
                    num_non_actions += 1
                    episode_samples.append((obs_id, unit_id, 9))

        samples[ep_id] = (episode_obses, episode_samples)
        num_samples += len(episode_samples)

    print(episode_dir,'num_episodes=',num_episodes)
    print("non_actions",num_non_actions,";actions",num_actions,";cannotwork",num_cannot_work)
    return samples, num_samples


In [12]:
!mkdir train
!mkdir val

In [13]:
import os
import shutil
from pathlib import Path
from random import shuffle

# Define the source folder and destination folders
source_folder = Path("/kaggle/working/agentA")
destination_folder_1 = Path("/kaggle/working/train")
destination_folder_2 = Path("/kaggle/working/val")

# Ensure destination folders exist
destination_folder_1.mkdir(parents=True, exist_ok=True)
destination_folder_2.mkdir(parents=True, exist_ok=True)

# Get all subfolders in the source folder
all_folders = []
che = 0
maxi = 430
split = maxi // 4
for filename in os.listdir(source_folder):
    che += 1
    if che >= maxi:
        break
    file_path = os.path.join(source_folder, filename)
    if not os.path.isfile(file_path):
        for file in os.listdir(file_path):
            file_path2 = os.path.join(file_path, file)
            if not 'info' in file_path2:
                if che < split:
                    shutil.move(file_path2, destination_folder_1)
                    print(f"Copied {file_path2} to {destination_folder_1}")
                else:
                    shutil.move(file_path2, destination_folder_2)
                    print(f"Copied {file_path2} to {destination_folder_2}")
                    
                    
        if os.path.isfile(file_path) or os.path.islink(file_path):
            os.unlink(file_path)
        elif os.path.isdir(file_path):
            shutil.rmtree(file_path)





Copied /kaggle/working/agentA/30415692/30415692.json to /kaggle/working/train
Copied /kaggle/working/agentA/32161139/32161139.json to /kaggle/working/train
Copied /kaggle/working/agentA/31181451/31181451.json to /kaggle/working/train
Copied /kaggle/working/agentA/31477346/31477346.json to /kaggle/working/train
Copied /kaggle/working/agentA/30910350/30910350.json to /kaggle/working/train
Copied /kaggle/working/agentA/31452123/31452123.json to /kaggle/working/train
Copied /kaggle/working/agentA/30469243/30469243.json to /kaggle/working/train
Copied /kaggle/working/agentA/30696492/30696492.json to /kaggle/working/train
Copied /kaggle/working/agentA/30084251/30084251.json to /kaggle/working/train
Copied /kaggle/working/agentA/29365157/29365157.json to /kaggle/working/train
Copied /kaggle/working/agentA/30663943/30663943.json to /kaggle/working/train
Copied /kaggle/working/agentA/31829845/31829845.json to /kaggle/working/train
Copied /kaggle/working/agentA/29132272/29132272.json to /kaggle/

In [14]:
CONSTANT = {
  "UNIT_TYPES": {
    "WORKER": 0,
    "CART": 1
  },
  "RESOURCE_TYPES": {
    "WOOD": "wood",
    "COAL": "coal",
    "URANIUM": "uranium"
  },
  "DIRECTIONS": {
    "NORTH": "n",
    "WEST": "w",
    "EAST": "e",
    "SOUTH": "s",
    "CENTER": "c"
  },
  "PARAMETERS": {
    "DAY_LENGTH": 30,
    "NIGHT_LENGTH": 10,
    "MAX_DAYS": 360,
    "LIGHT_UPKEEP": {
      "CITY": 23,
      "WORKER": 4,
      "CART": 10
    },
    "WOOD_GROWTH_RATE": 1.025,
    "MAX_WOOD_AMOUNT": 500,
    "CITY_BUILD_COST": 100,
    "CITY_ADJACENCY_BONUS": 5,
    "RESOURCE_CAPACITY": {
      "WORKER": 100,
      "CART": 2000
    },
    "WORKER_COLLECTION_RATE": {
      "WOOD": 20,
      "COAL": 5,
      "URANIUM": 2
    },
    "RESOURCE_TO_FUEL_RATE": {
      "WOOD": 1,
      "COAL": 10,
      "URANIUM": 40
    },
    "RESEARCH_REQUIREMENTS": {
      "COAL": 50,
      "URANIUM": 200
    },
    "CITY_ACTION_COOLDOWN": 10,
    "UNIT_ACTION_COOLDOWN": {
      "CART": 3,
      "WORKER": 2
    },
    "MAX_ROAD": 6,
    "MIN_ROAD": 0,
    "CART_ROAD_DEVELOPMENT_RATE": 0.75,
    "PILLAGE_RATE": 0.5
  }
}



In [15]:
CHANNELS = 34


def make_input(obs, unit_id, size = 32):
    width, height = size, size
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    
    cities = {}
    
    b = np.zeros((34, 32, 32), dtype=np.float32)

    day_length = 30
    night_length = 10
    cycle_length = 40
    max_turn = 360
    

    for update in obs['updates']:
        turn = obs['step']
        strs = update.split(' ')
        input_id = strs[0]
        #0:9
        if input_id == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uran = int(strs[9])
            u_CL = float(strs[6])
            team = int(strs[2])
            if (team == obs['player']):
                b[0, x, y] = 1
                b[1, x, y] = wood / 100.
                b[2, x, y] = coal / 100.
                b[3, x, y] = uran / 100.
                if ( int(strs[1])) == 0 :
                    #worker
                    b[4, x, y] = float(strs[6]) / 2
                else:
                    #then cart
                    b[4, x, y] = float(strs[6]) / 3
                #u_can_ACt 29
                if u_CL == 0.0:
                    b[29, x, y] = 1    
                #31 units' weighted resource
                b[31, x, y] = (wood * 1 + coal * 10 + uran * 40) / (100. * 40.)


            
            else:
                b[5, x, y] = 1
                b[6, x, y] = wood / 100.
                b[7, x, y] = coal / 100.
                b[8, x, y] = uran / 100.
                if ( int(strs[1])) == 0 :
                    #worker
                    b[9, x, y] = float(strs[6]) / 2
                else:
                    #then cart
                    b[9, x, y] = float(strs[6]) / 3
                #u_can_ACt 30
                if u_CL == 0.0:
                    b[30, x, y] = 1
                #32 Enms' weighted resource
                b[32, x, y] = (wood * 1 + coal * 10 + uran * 40) / (100. * 40.)
                
        #10:12
        elif input_id == 'r':
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = float(strs[4])
            b[{'wood': 10, 'coal': 11, 'uranium': 12}[r_type], x, y] = amt / 800.
        #13, 14
        elif input_id == 'rp':
            team = int(strs[1])
            rp = float(strs[2])
            b[13 + (team - obs['player']) % 2, x_shift:32 - x_shift, y_shift:32 - y_shift ] = rp / 200.
        #15 : 24
        elif input_id == 'c':
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = (fuel / 1000., lightupkeep / 23. )
        elif input_id == 'ct':
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            c_CL = float(strs[5]) / 10.
            idx = 15 + (team - obs['player']) % 2 * 5
            b[idx:idx + 5, x, y] = (
                1,
                cities[city_id][0],
                cities[city_id][1],
                c_CL,
                int(c_CL == 0.0)
            )
            
        #25 turn
        b[25, x_shift:32 - x_shift, y_shift:32 - y_shift] = turn / max_turn
        #26 D/N
        b[26, x_shift:32 - x_shift, y_shift:32 - y_shift] = turn % cycle_length / cycle_length
        #27 is day?
        dayinturn = turn % cycle_length
        if 0 <= dayinturn < day_length:
            b[27, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
            b[33, x_shift:32 - x_shift, y_shift:32 - y_shift] = dayinturn / 30.
        #28 MAP
        b[28, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1
        #29, 30 unit can act?
        #31,32 weighted r
        #33 how much night?

    return b


In [16]:
class LuxDataset(Dataset):
    def __init__(self, obses, samples, make_input_size=32):
        self.obses = obses
        self.samples = samples
        self.make_input_size = make_input_size

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        obs_id, unit_id, action = self.samples[idx]
        obs = self.obses[obs_id]
        state = make_input(obs, unit_id, self.make_input_size)

        return state, action


In [17]:
def samples_to_obs_sample_list(input_samples):
    obses  = {}
    samples = []
    for obs, sample in input_samples.values():
        obses.update(obs)
        samples.extend(sample)
    return obses, samples

In [18]:
episode_train = '/kaggle/working/val'
episode_eval = '/kaggle/working/train'
team_name = 'Toad Brigade'
dataset_sizes = []

ep_samples_eval, num_samples_eval = create_dataset_from_json(episode_eval, team_name=team_name,set_sizes=dataset_sizes)
ep_samples_train, num_samples_train = create_dataset_from_json(episode_train, team_name=team_name, set_sizes=dataset_sizes)

obses_eval, samples_eval  = samples_to_obs_sample_list(ep_samples_eval)
obses_train, samples_train= samples_to_obs_sample_list(ep_samples_train)


create_dataset_from_json," Toad Brigade " [] : /kaggle/working/train
/kaggle/working/train num_episodes= 88
non_actions 211077 ;actions 140235 ;cannotwork 101764
create_dataset_from_json," Toad Brigade " [] : /kaggle/working/val
/kaggle/working/val num_episodes= 291
non_actions 845969 ;actions 543913 ;cannotwork 404822


In [19]:
obs_id, unit_id, action = samples_eval[1]
obs = obses_eval[obs_id]
state = make_input(obs, unit_id, 32)

In [20]:
print(obs_id, unit_id, action, obs)

30521655_2 u_2 0 {'height': 12, 'player': 1, 'step': 2, 'updates': ['rp 0 1', 'rp 1 0', 'r wood 2 2 373', 'r wood 2 3 364', 'r wood 2 8 364', 'r wood 2 9 373', 'r wood 3 2 373', 'r wood 3 3 335', 'r wood 3 8 335', 'r wood 3 9 373', 'r wood 4 2 418', 'r wood 4 3 415', 'r wood 4 8 415', 'r wood 4 9 418', 'r coal 8 0 417', 'r coal 8 11 417', 'r coal 9 0 396', 'r coal 9 11 396', 'r wood 10 1 800', 'r wood 10 2 760', 'r wood 10 9 760', 'r wood 10 10 800', 'r uranium 11 0 311', 'r wood 11 2 760', 'r wood 11 5 394', 'r wood 11 6 394', 'r wood 11 9 760', 'r uranium 11 11 311', 'u 0 0 u_1 11 2 0 80 0 0', 'u 0 1 u_2 11 9 0 80 0 0', 'c 0 c_1 0 23', 'c 1 c_2 0 23', 'ct 0 c_1 11 1 8', 'ct 1 c_2 11 10 0', 'ccd 11 1 6', 'ccd 11 10 6', 'D_DONE'], 'width': 12}


In [21]:
print((state[:, :1].shape))

(34, 1, 32)


In [22]:
print('Train observations:', len(obses_train), 'samples:', len(samples_train))
print('Eval  observations:', len(obses_eval), 'samples:', len(samples_eval))
print(f'Ratio  observations: {len(obses_eval) / len(obses_train) :.4f}'
      f' samples:{len(samples_eval) / len(samples_train) :.4f}')

Train observations: 72278 samples: 1389882
Eval  observations: 20390 samples: 351312
Ratio  observations: 0.2821 samples:0.2528


In [23]:
train_loader = DataLoader(
    LuxDataset(obses_train, samples_train, make_input_size=32),
    batch_size=64,
    shuffle=True,
    num_workers=2
)
val_loader = DataLoader(
    LuxDataset(obses_eval, samples_eval, make_input_size=32),
    batch_size=64,
    shuffle=False,
    num_workers=4
)
dataloaders_dict = {"train": train_loader, "val": val_loader}
criterion = nn.CrossEntropyLoss()


In [24]:
number_train_cycle = 0
def train_model(model, dataloaders_dict, criterion, optimizer, scheduler, num_epochs, map_size=32,
                skip_first_train=True, Save=True):
    try:
        number_train_cycle += 1
    except NameError:
        number_train_cycle = 1

    best_acc = 0.0

    num_train = len(dataloaders_dict['train'])
    num_val = len(dataloaders_dict['val'])

    print(get_time(),f' {number_train_cycle} LR: {optimizer.param_groups[0]["lr"]} Epochs {num_epochs} | #train{num_train} #val{num_val}')

    for epoch in range(num_epochs):
        model.cuda()

        for phase in ['train', 'val']:
            if phase == 'train':
                if epoch == 0 and skip_first_train:
                    continue

                model.train()
            else:
                model.eval()

            epoch_loss = 0.0
            epoch_acc = 0

            dataloader = dataloaders_dict[phase]

            for item in tqdm(dataloader, leave=False):

                states = item[0].cuda().float()
                actions = item[1].cuda().long()

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    policy = model(states)
                    loss = criterion(policy, actions)
                    _, preds = torch.max(policy, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    epoch_loss += loss.item() * len(policy)
                    epoch_acc += torch.sum(preds == actions.data)

            data_size = len(dataloader.dataset)
            epoch_loss = epoch_loss / data_size
            epoch_acc = epoch_acc.double() / data_size

            print(get_time(),
                  f'LR: {optimizer.param_groups[0]["lr"]} Epoch {epoch + 1}/{num_epochs} of  {number_train_cycle} | {phase:^5}'
                  f' | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}')

            if not Save:
                r = (epoch_loss, float (f'{epoch_acc:.6f}'))
                return r

        if epoch_acc > best_acc and Save:
            traced = torch.jit.trace(model.cpu(), torch.rand(1, CHANNELS, map_size, map_size))
            traced.save('model.pth')
            print(
                f'Saved model.pth from epoch {epoch + 1} as it is the most accurate so far: Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            best_acc = epoch_acc

        if Save:
            traced = torch.jit.trace(model.cpu(), torch.rand(1, CHANNELS, map_size, map_size))
            suffix = datetime.now().strftime('%H%M')
            traced.save(f'model_{number_train_cycle}_{epoch + 1}_{suffix}.pth')

        scheduler.step(epoch_loss)

In [25]:
def do_train(criterion, dataloaders_dict, map_size, model, num_epochs, lr,
             scheduler_factor=.5, scheduler_patience=-1, skip_first=False):
    if scheduler_patience == -1:
        scheduler_patience = num_epochs
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=scheduler_factor, patience=scheduler_patience)
    train_model(model, dataloaders_dict, criterion, optimizer, scheduler, num_epochs=num_epochs, map_size=map_size,
                skip_first_train=skip_first)

In [26]:
from datetime import datetime
import time

def get_time():
    now = datetime.now()

    return now.strftime("%H:%M:%S")

In [27]:

class BasicConv2d(nn.Module):
    """
    This class refers to https://www.kaggle.com/shoheiazuma/lux-ai-with-imitation-learning
    """
    def __init__(self, input_dim: int, output_dim: int, kernel_size: int, bn: bool):
        super().__init__()
        self.conv = nn.Conv2d(
            input_dim, output_dim,
            kernel_size=kernel_size,
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        self.bn = nn.BatchNorm2d(output_dim) if bn else None


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.conv(x)
        h = self.bn(h) if self.bn is not None else h

        return h



class LuxNet(nn.Module):
    def __init__(self, filt=32):
        super().__init__()
        layers, filters = 16, filt
        self.conv0 = BasicConv2d(CHANNELS, filters, (3, 3), True)
        self.blocks = nn.ModuleList([BasicConv2d(filters, filters, (3, 3), True) for _ in range(layers)])

        encoder_layer1 = nn.TransformerEncoderLayer(d_model=32, nhead=2) 
        self.transformer_encoder1 = nn.TransformerEncoder(encoder_layer1, num_layers=2)

        encoder_layer2 = nn.TransformerEncoderLayer(d_model=32, nhead=2, batch_first=True) 
        self.transformer_encoder2 = nn.TransformerEncoder(encoder_layer2, num_layers=2)

    
        self.head_p = nn.Linear(filters , len(actions), bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))

        B, C, H, W = h.shape

        h_trans1 = h.reshape(B, C, H * W).permute(2, 0, 1)
        h_trans1 = self.transformer_encoder1(h_trans1)


        h_trans1 = h_trans1.reshape(H, W, B, C).permute(2, 3, 0, 1)

        
        h_head1 = (h_trans1 * x[:, :1]).view(h_trans1.size(0), h_trans1.size(1), -1).sum(-1)

        
        p1 = self.head_p(h_head1)

        return p1


In [28]:
filters = 32
map_size = 32
 
actions = ['north', 'south', 'west', 'east', 'bcity', 't_north', 't_south', 't_west', 't_east', 'stay']

model = LuxNet(); print('Starting new model filters=',filters); skip_first = False

Starting new model filters= 32




In [29]:
#do_train(criterion, dataloaders_dict, map_size, model, num_epochs=15, lr=1e-03)


In [37]:
dataloader = dataloaders_dict['train']
for item in tqdm(dataloader, leave=False):

    states = item[0]
    actions = item[1]


    policy = model(states)
    loss = criterion(policy, actions)
    print(states.shape)
    print(actions[0])
    print(policy[0])
    _, preds = torch.max(policy, 1)
    print(preds[0])
    print(loss)
    break

  0%|          | 0/21717 [00:00<?, ?it/s]

torch.Size([64, 34, 32, 32])
tensor(9)
tensor([ 15.3159,  14.6482,  -9.3475,   1.6265, -10.8692, -10.5915,  -7.1157,
         -3.4912,  12.9663,  -4.5733], grad_fn=<SelectBackward0>)
tensor(0)
tensor(25.0800, grad_fn=<NllLossBackward0>)


In [None]:
# import torch
# import torch.nn as nn
# from torchview import draw_graph


# # Create an instance of the MLP model

# # Visualize the model
# batch_size = 2
# model_graph = draw_graph(model, input_size=(64, 34, 32, 32 ), device='meta',  expand_nested=False)
# model_graph.visual_graph


In [None]:
# model_graph.visual_graph.format = "png"  # You can choose 'pdf', 'svg', etc.
# model_graph.visual_graph.render('/kaggle/working/model_graph2')


In [None]:
#torch.save(model, 'model.pth')


1389882

Total parameters: 709216
Trainable parameters: 709216
