In [1]:
import os
KAGGLE = "KAGGLE_KERNEL_RUN_TYPE" in os.environ
if KAGGLE:
    #!pip install kaggle-environments -U
    pass

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [3]:
import os
import gc
import sys
from time import time, sleep
import json
from pathlib import Path
from datetime import datetime
from itertools import count
from collections import defaultdict

from tqdm.notebook import tqdm
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import matplotlib.pyplot as plt

# Envirionment

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device={device}")

n_cpu = !lscpu | grep ^CPU\(s\):
n_cpu = int(n_cpu[0].split()[-1])
print(f"n_cpu={n_cpu}")

torch.backends.cudnn.benchmark = True

# Directory settings

if KAGGLE:
    TEMP_DIR = Path("../temp")
    STORAGE_DIR = Path()
    PROJECTS_DIR = Path()
else:
    TEMP_DIR = STORAGE_DIR = Path("./015")
    PROJECTS_DIR = Path("../../")

if not TEMP_DIR.exists():
    print(f"mkdir {TEMP_DIR}")
    TEMP_DIR.mkdir()
if not STORAGE_DIR.exists():
    print(f"mkdir {STORAGE_DIR}")
    STORAGE_DIR.mkdir()

sys.path.append(str(STORAGE_DIR))
sys.path.append(str(PROJECTS_DIR / "kore2022"))


device=cpu
n_cpu=16


In [4]:
N_GLOBAL_FEATURES = 9
N_SHIPYARD_FEATURES = 32800

class NNUE(nn.Module):
    def __init__(self):
        super().__init__()
        self.global_feature_encoder = nn.Linear(N_GLOBAL_FEATURES, 256)
        self.embedding = nn.EmbeddingBag(N_SHIPYARD_FEATURES + 1, 256, mode="sum", padding_idx=N_SHIPYARD_FEATURES)
        self.fc1 = nn.Linear(256, 256)
        self.fc2 = nn.Linear(256, 256)
        self.value_decoder = nn.Linear(256, 1)
        self.type_decoder = nn.Linear(256, 4)
        self.n_ships_decoder = nn.ModuleList([
            nn.Linear(256, 12),  # [1, 10]
            nn.Linear(256, 32),
            nn.Linear(256, 32),
            nn.Linear(256, 32),
        ])
        self.relative_position_decoder = nn.ModuleList([
            None,
            nn.Linear(256, 448),  # [0, 441)
            nn.Linear(256, 448),
            nn.Linear(256, 448),
        ])
        self.n_steps_decoder = nn.ModuleList([
            None,
            nn.Linear(256, 24),  # [1, 21]
            None,
            None,
        ])
        self.direction_decoder = nn.ModuleList([
            None,
            None,
            nn.Linear(256, 4),
            nn.Linear(256, 4),
        ])
    
    def dump(self, file):
        # file: filname or file pointer
        if isinstance(file, str):
            f = open(file, "wb")
        elif hasattr(file, "write"):
            f = file
        else:
            raise ValueError

        def write(params):
            f.write(params.detach().cpu().numpy().ravel().tobytes())
        
        write(self.global_feature_encoder.weight)
        write(self.global_feature_encoder.bias)
        write(self.embedding.weight)
        write(self.fc1.weight)
        write(self.fc1.bias)
        write(self.fc2.weight)
        write(self.fc2.bias)
        write(self.value_decoder.weight)
        write(self.value_decoder.bias)
        write(self.type_decoder.weight)
        write(self.type_decoder.bias)
        
        # Spawn
        write(self.n_ships_decoder[0].weight)
        write(self.n_ships_decoder[0].bias)
        
        # Move
        write(self.n_ships_decoder[1].weight)
        write(self.n_ships_decoder[1].bias)
        write(self.relative_position_decoder[1].weight)
        write(self.relative_position_decoder[1].bias)
        write(self.n_steps_decoder[1].weight)
        write(self.n_steps_decoder[1].bias)
        
        # Attack
        write(self.n_ships_decoder[2].weight)
        write(self.n_ships_decoder[2].bias)
        write(self.relative_position_decoder[2].weight)
        write(self.relative_position_decoder[2].bias)
        write(self.direction_decoder[2].weight)
        write(self.direction_decoder[2].bias)
        
        # Convert
        write(self.n_ships_decoder[3].weight)
        write(self.n_ships_decoder[3].bias)
        write(self.relative_position_decoder[3].weight)
        write(self.relative_position_decoder[3].bias)
        write(self.direction_decoder[3].weight)
        write(self.direction_decoder[3].bias)
        
        
        if isinstance(file, str):
            f.close()
    
    def forward(
        self,
        shipyard_features,
        global_features,
        target_values,
        target_action_types,
        target_action_n_ships,  # quantized
        target_action_relative_position,
        target_action_n_steps,
        target_action_direction,
    ):
        batch_size = shipyard_features.size(0)
        shipyard_features[shipyard_features == -100] = N_SHIPYARD_FEATURES
        
        # [batch_size, N_GLOBAL_FEATURES], [batch_size, 512] -> [batch_size, 256]
        x = self.global_feature_encoder(global_features) + self.embedding(shipyard_features)
        x = F.leaky_relu(x, 1.0 / 64.0)
        
        # [batch_size, 256]
        x = self.fc1(x)
        x = F.leaky_relu(x, 1.0 / 64.0)
        x = self.fc2(x)
        x = F.leaky_relu(x, 1.0 / 64.0)
        
        # [batch_size, 256] -> [batch_size]
        value = self.value_decoder(x).squeeze(1)
        # [batch_size, 256] -> [batch_size, 4]
        action_type = self.type_decoder(x)
        
        specific_predictions = []
        for i in range(4):
            # [batch_size, 256] -> [n_action_data, 256]
            xi = x[target_action_types == i]
            n_action_data = len(xi)
#             if n_action_data == 0:
#                 specific_predictions.append([
#                     0, None, None, None, None
#                 ])
            # [n_action_data, 256] -> [n_action_data, ??]
            n_ships = self.n_ships_decoder[i](xi)
            # [n_action_data, 256] -> [n_action_data, ??]
            relative_position = None if self.relative_position_decoder[i] is None else self.relative_position_decoder[i](xi)
            # [n_action_data, 256] -> [n_action_data, ??]
            n_steps = None if self.n_steps_decoder[i] is None else self.n_steps_decoder[i](xi)
            # [n_action_data, 256] -> [n_action_data, ??]
            direction = None if self.direction_decoder[i] is None else self.direction_decoder[i](xi)
            
            specific_predictions.append([
                n_action_data, n_ships, relative_position, n_steps, direction
            ])
        
        # === loss computation ===
        
        value_loss = F.binary_cross_entropy_with_logits(value, target_values, reduction="sum")
        type_loss = F.cross_entropy(action_type, target_action_types, reduction="sum")
        loss = value_loss * 10.0 + type_loss
        ACTION_LOSS_WEIGHTS = [1.0, 1.0, 5.0, 25.0]
        
        specific_losses = []
        for i in range(4):
            n_action_data, n_ships, relative_position, n_steps, direction = specific_predictions[i]
            indices = target_action_types == i
            
            n_ships_loss = F.cross_entropy(n_ships, target_action_n_ships[indices], reduction="sum")
            action_loss = n_ships_loss.clone()
            
            if relative_position is None:
                relative_position_loss = None
            else:
                relative_position_loss = F.cross_entropy(relative_position, target_action_relative_position[indices], reduction="sum")
                action_loss += relative_position_loss
            
            if n_steps is None:
                n_steps_loss = None
            else:
                n_steps_loss = F.cross_entropy(n_steps, target_action_n_steps[indices], reduction="sum")
                action_loss += n_steps_loss
            
            if direction is None:
                direction_loss = None
            else:
                direction_loss = F.cross_entropy(direction, target_action_direction[indices], reduction="sum")
                action_loss += direction_loss
            
            loss += ACTION_LOSS_WEIGHTS[i] * action_loss
            
            specific_losses.append([
                n_action_data, n_ships_loss, relative_position_loss, n_steps_loss, direction_loss
            ])
        
        loss *= 1 / batch_size
        
        return (value, action_type, specific_predictions), (value_loss, type_loss, specific_losses), loss

model = NNUE()

In [5]:
# #checkpoint_name = "01340000"
# checkpoint_name = "02180000"
# dict_checkpoint = torch.load(f"010/checkpoint_{checkpoint_name}.pt", map_location="cpu")
# model.load_state_dict(dict_checkpoint["state_dict"], strict=False)

In [6]:
#checkpoint_name = "02720000"
checkpoint_name = "03000000"
dict_checkpoint = torch.load(f"024/checkpoint_{checkpoint_name}.pt", map_location="cpu")
model.load_state_dict(dict_checkpoint["state_dict"], strict=False)

<All keys matched successfully>

In [7]:
model.dump(str(STORAGE_DIR / f"parameters_{checkpoint_name}.bin"))