In [5]:
import pandas as pd
pd.options.mode.chained_assignment = None

from tqdm.notebook import tqdm
tqdm.pandas()

import numpy as np
import torch
import torch.nn as nn
import os
from torch.utils.data import Dataset, DataLoader, random_split

from torch.utils.tensorboard import SummaryWriter


In [6]:
DATASET_NAME = "datasets/hce_full_train_small.csv"
EVAL_COLUMN = "eval_8"
FLAG_COLUMN = None
FIX_MATE_ERROR = True
EVAL_MAX = 100000
ACCEPTABLE_MATE_DEPTH = 0


MICRO_BATCH_SIZE = 512
TRAIN_SPLIT = 0.80
NUM_WORKERS = 5
TRAIN_EPOCHS = 40
REPORT_FREQ = 500



INPUT_SIZE = 768

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
print("device = ", DEVICE)

In [7]:
class FenDataset(Dataset):
    """
    Assumes a small dataset => loads everything to memory
    FAR from clean but works :) 
    """
    def __init__(self, csv_file, eval_column):
        # Load data
        df = pd.read_csv(csv_file, dtype={"best_move": "string"})
        
        # Remove mate positions
        if FIX_MATE_ERROR:
            for i in tqdm(range(2, 9), "Fixing mate errors in dataset"):
                prev = f"eval_{i - 1}"
                now = f"eval_{i}"
                df.loc[(df[prev] == EVAL_MAX) | (df[prev] == -EVAL_MAX), now] = df[prev]

        # Remove captures
        if FLAG_COLUMN:
            raw_n = len(df)
            df = df[df[FLAG_COLUMN] % 2 == 0]
            print("Removed", raw_n - len(df), "captures")

        # Remove mate positions
        print("Removing mates")
        mate_in = np.full(len(df.index), 1000)
        for i in range(1, 9):
            is_mate = (df[f"eval_{i}"].abs() == EVAL_MAX)
            mate_in[is_mate] = np.minimum(mate_in[is_mate], i)
            
        df["mate_in"] = mate_in
        df = df[(df["mate_in"] <= ACCEPTABLE_MATE_DEPTH) | (df["mate_in"] == 1000)]

        # Clip eval values
        df[eval_column] = df[eval_column].clip(lower=-10000, upper=10000)

        # centipawns to pawns
        df[eval_column] = df[eval_column] / 100

        # Safety check
        df = df[~df[eval_column].isna()]
        
        # Parse fen
        print("Parsing fens to stm")
        df["stm"] = df["fen"].progress_apply(FenDataset.fen_to_side_to_move)

        print("Parsing fens to features")
        df["pieces"] = df["fen"].progress_apply(FenDataset.fen_to_features)

        # Transfrom to list for faster access (not sure if faster tbh)
        self._dataset_py = list(df[["pieces", eval_column, "stm"]].to_records(index=False))

    def __len__(self):
        return len(self._dataset_py)
    
    def __getitem__(self, idx):
        features = self._dataset_py[idx][0]

        # Multi hot encode
        white_features = np.zeros(INPUT_SIZE)
        black_features = np.zeros(INPUT_SIZE)
        white_features[features[0]] = 1
        black_features[features[1]] = 1

        evaluation = self._dataset_py[idx][1]
        side_to_move = self._dataset_py[idx][2]

        return (
            torch.from_numpy(white_features).to(torch.float32), 
            torch.from_numpy(black_features).to(torch.float32), 
            torch.tensor([side_to_move]).to(torch.float32),
            torch.tensor([evaluation]).to(torch.float32)   
        )
    
    @staticmethod
    def fen_to_features(fen_str):
        features_white = []
        features_black = []
        fen_parts = fen_str.split()
        rank = 7
        file = 0
        PIECES = "pnbrqk"

        for char in fen_parts[0]:
            if char == '/':
                rank -= 1
                file = 0
            elif char.isdigit():
                file += int(char)
            else:
                square = rank * 8 + file
                piece_id = PIECES.find(char.lower())
                color = 1 if char.islower() else 0

                feature_white = square * 12 + piece_id * 2 + color
                feature_black = (square ^ 56) * 12 + piece_id * 2 + (1 - color)
                features_white.append(feature_white)
                features_black.append(feature_black)

                file += 1
                
        return np.array(features_white), np.array(features_black)

    @staticmethod
    def fen_to_side_to_move(fen_str):
        return 0 if fen_str.split()[1] == "w" else 1

In [8]:
class NNUE(nn.Module):
    def __init__(self):
        super(NNUE, self).__init__()

        self.ft = nn.Linear(INPUT_SIZE, 128)
        self.l1 = nn.Linear(2 * 128, 32)
        self.l2 = nn.Linear(32, 1)

    def forward(self, white_features, black_features, stm):      
        w = self.ft(white_features) # white's perspective
        b = self.ft(black_features) # black's perspective
        
        # stm (side to move): 0 - white, 1 - black
        accumulator = ((1 - stm) * torch.cat([w, b], dim=1)) + (stm * torch.cat([b, w], dim=1))
        accumulator = torch.clamp(accumulator, 0.0, 1.0)

        l1_out = self.l1(accumulator)
        l1_out = torch.clamp(l1_out, 0.0, 1.0)

        l2_out = self.l2(l1_out)
        return l2_out

In [9]:
class Trainer:
    def __init__(self, name, lr, scaling_factor, optimizer, loss_fn, model_init = lambda: NNUE()):
        self.name = name
        self.model = model_init()
        self.model.to(DEVICE)
        self.scaling_factor = scaling_factor
        self.optimizer = optimizer(self.model.parameters(), lr=lr)
        self.loss_fn = loss_fn
    
    def start_run(self, run_name, num_batches):
        self.model.train()
        self.run_name = run_name
        self.writer = SummaryWriter(f'runs/{run_name}_{self.name}')
        self.running_loss = 0
        self.best_validation_loss = 999999
        self.num_batches = num_batches

    def micro_batch(self, epoch, i, white_f, black_f, stm, true_eval):
        self.optimizer.zero_grad()

        pred = self.model(white_f, black_f, stm)
        pred = torch.tanh(pred / self.scaling_factor)
        need = torch.tanh(true_eval / self.scaling_factor)

        loss = self.loss_fn(pred, need)
        loss.backward()

        self.optimizer.step()

        self.running_loss += loss.item()

        if self.running_loss != self.running_loss:
            print(white_f, black_f, stm, true_eval, pred)
        
        if i > 0 and i % REPORT_FREQ == 0 or i == self.num_batches - 1:
            self.writer.add_scalar('training loss', self.running_loss / ((i - 1) % 500 + 1), epoch * self.num_batches + i)
            self.running_loss = 0

    # Validation run
    def start_validation(self, epoch):
        self.model.eval()
        self.val_loss_scaled = 0
        self.val_epoc = epoch
        self.num_val_samples = 0
    
    def validation_batch(self, white_f, black_f, stm, true_eval):
        pred = self.model(white_f, black_f, stm)
        pred_scaled = torch.tanh(pred / self.scaling_factor)
        need_scaled = torch.tanh(true_eval / self.scaling_factor)
        self.val_loss_scaled += self.loss_fn(pred_scaled, need_scaled).item()
        self.num_val_samples += 1
    
    def end_validation(self):
        self.writer.add_scalar('validation loss', self.val_loss_scaled / self.num_val_samples, (self.val_epoc + 1) * self.num_batches)
        self.model.train()

        if self.val_loss_scaled < self.best_validation_loss:
            save_dir = f"checkpoints/{self.run_name}/{self.name}"
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            
            torch.save(self.model, f"{save_dir}/epoch={str(self.val_epoc).zfill(3)},loss={round(self.val_loss_scaled / self.num_val_samples, 5)}")

            self.best_validation_loss = self.val_loss_scaled


class MultiTrainer:
    def __init__(self, trainers):
        self.trainers = trainers
    
    def __getattr__(self, name):
        def forward_calls(*args, **kwargs):
            for trainer in self.trainers:
                getattr(trainer, name)(*args, **kwargs)
        return forward_calls

In [10]:
def train(trainer, epochs, run_name, train_dl, validation_dl, epoch_start=0):
    trainer.start_run(run_name, len(train_dl))

    for epoch in tqdm(range(epoch_start, epochs), "Training"):
        # Train
        for i, batch in enumerate(tqdm((train_dl), f"Epoch {epoch}")):
            white_f, black_f, stm, true_eval = (b.to(DEVICE) for b in batch)
            trainer.micro_batch(epoch, i, white_f, black_f, stm, true_eval)
      
        
        # Validate
        trainer.start_validation(epoch)
        for batch in tqdm(validation_dl, "Computing validation loss"):
            white_f, black_f, stm, true_eval = (b.to(DEVICE) for b in batch)
            trainer.validation_batch(white_f, black_f, stm, true_eval)
            
        trainer.end_validation()

In [11]:
def main(run_name, override_train_size = None):
    print("Loading dataset")
    dataset = FenDataset(DATASET_NAME, EVAL_COLUMN)

    print("Preparing dataloaders")
    total_len = len(dataset) if override_train_size is None else int(override_train_size / TRAIN_SPLIT)
    train_size = int(total_len * TRAIN_SPLIT)
    validation_size = total_len - train_size
    train_ds, validation_ds, _ = random_split(dataset, [train_size, validation_size, len(dataset) - train_size - validation_size])
    train_dl = DataLoader(train_ds, batch_size=MICRO_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    validation_dl = DataLoader(validation_ds, batch_size=MICRO_BATCH_SIZE*10, shuffle=True, num_workers=NUM_WORKERS)

    print("Train size: ", train_size, "validation size: ", validation_size)

    print("Preparing trainers")
    trainers = [
        Trainer("lr00085_sf35_adam", 0.00085, 3.5, torch.optim.Adam, nn.MSELoss()),
        Trainer("lr00085_sf55_adam", 0.00085, 5.5, torch.optim.Adam, nn.MSELoss()),
    ]
    trainer = MultiTrainer(trainers)

    print("Start training")
    train(trainer, TRAIN_EPOCHS, run_name, train_dl, validation_dl)

    # del train_dl, validation_dl, dataset, train_ds, train_dl


In [None]:
params = [
    # ("full_d8_nocap", "datasets/hce_full_train.csv", "eval_8", None),
    # ("full_d8_cap", "datasets/hce_full_train.csv", "eval_8", "flags_8"),
    # ("full_d5_nocap", "datasets/hce_full_train.csv", "eval_5", None),
    # ("full_d5_cap", "datasets/hce_full_train.csv", "eval_5", "flags_5"),
    # ("simple_d8_nocap", "datasets/train_simple_hce.csv", "eval_8", None),
    # ("simple_d8_cap", "datasets/train_simple_hce.csv", "eval_8", "flags_8"),
    # ("simple_d5_nocap", "datasets/train_simple_hce.csv", "eval_5", None),
    # ("simplie_d5_cap", "datasets/train_simple_hce.csv", "eval_5", "flags_5"),

    #("full_d8_m2_norm", "datasets/hce_full_train.csv", "eval_8", None, 2, 2300000),
    #("full_d8_m4_norm", "datasets/hce_full_train.csv", "eval_8", None, 4, 2300000),
    # ("full_d8_m6_norm", "datasets/hce_full_train.csv", "eval_8", None, 6, 2300000),
    # ("full_d8_ma_norm", "datasets/hce_full_train.csv", "eval_8", None, 100, 2300000),
    # ("full_d8_m0_norm", "datasets/hce_full_train.csv", "eval_8", None, 0, 2300000),

    # ("full_d8_250k", "datasets/hce_full_train.csv", "eval_8", None, 0, 250000),
    # ("full_d8_500k", "datasets/hce_full_train.csv", "eval_8", None, 0, 500000),
    # ("full_d8_1000k", "datasets/hce_full_train.csv", "eval_8", None, 0, 1000000),
    # ("full_d8_2000k", "datasets/hce_full_train.csv", "eval_8", None, 0, 2000000),

    # ("full_d8_final", "datasets/hce_full_train.csv", "eval_8", None, 4),
    # ("full_d5_final", "datasets/hce_full_train.csv", "eval_5", None, 4),
    ("simple_d8_final", "datasets/train_simple_hce.csv", "eval_8", None, 4),
    # ("simple_d5_final", "datasets/train_simple_hce.csv", "eval_5", None, 4),
]

for p in params:
    DATASET_NAME = p[1]
    EVAL_COLUMN = p[2]
    FLAG_COLUMN = p[3]
    ACCEPTABLE_MATE_DEPTH = p[4]

    if len(p) >= 6:
        main(p[0], p[5])
    else:
        main(p[0])

In [None]:
assert False

In [12]:
def infer(model_path, fen):
    model = torch.load(model_path)

    fw, fb =  FenDataset.fen_to_features(fen)
    stm = FenDataset.fen_to_side_to_move(fen)

    white_features = np.zeros(INPUT_SIZE)
    black_features = np.zeros(INPUT_SIZE)
    white_features[fw] = 1
    black_features[fb] = 1
        
    white_features = torch.tensor(white_features).unsqueeze(0).to(torch.float32).to(DEVICE)
    black_features = torch.tensor(black_features).unsqueeze(0).to(torch.float32).to(DEVICE)
    stm = torch.tensor(stm).unsqueeze(0).to(torch.float32).to(DEVICE)

    return model(white_features, black_features, stm)

# infer("/home/eff/bakalauras/models_es/d8_lr00085_sc35_simple_epoch_8_0_0815059666832288.torch", "rnbqkbnr/pppppppp/8/4N3/8/8/PPPPPPPP/RNBQKB1R b KQkq - 0 1")

In [13]:
def export_nnue(model_path, nnue_name):
    def torch_get_weights(tensor_row):
	    return "\n".join([str(tensor.item()) for tensor in tensor_row]) + "\n"
    
    def torch_save_nn(model, filename):
        print("Saving weights...")
        with open(filename, "w") as file:
            for parameter in model.parameters():
                if parameter.dim() == 1:
                    print("B:", len(parameter.data), "x", 1)
                    row = parameter.data
                    file.write(torch_get_weights(row))
                elif parameter.dim() == 2:
                    print("W:", len(parameter.data), "x", len(parameter.data[0]))
                    for row in parameter.data:
                        file.write(torch_get_weights(row))
                else:
                    assert(0)
        print("Weights saved")

    print("Saving", model_path, "as", nnue_name)
    model = torch.load(model_path).to("cpu")
    torch_save_nn(model, nnue_name)

In [14]:
import re
def get_best_network(checkpoint_name, run_name):
    models = os.listdir(f"./checkpoints/{checkpoint_name}/{run_name}")
    models = sorted(models, key=lambda s: int(re.search(r"epoch=(\d+)", s).group(1)), reverse=True)
    return f"./checkpoints/{checkpoint_name}/{run_name}/{models[0]}"

In [16]:
# export_nnue(get_best_network("full_d8_nocap", "lr00085_sf35_adam"), "nets/D8_FULL.nnue")
# export_nnue(get_best_network("full_d8_nocap", "lr00085_sf55_adam"), "nets/D8_FULL_55.nnue")
# export_nnue(get_best_network("full_d8_cap", "lr00085_sf35_adam"), "nets/D8_FULL_NC.nnue")
# export_nnue(get_best_network("full_d8_cap", "lr00085_sf55_adam"), "nets/D8_FULL_NC_55.nnue")

# export_nnue(get_best_network("full_d5_nocap", "lr00085_sf35_adam"), "nets/D5_FULL.nnue")
# export_nnue(get_best_network("full_d5_nocap", "lr00085_sf55_adam"), "nets/D5_FULL_55.nnue")
# export_nnue(get_best_network("full_d5_cap", "lr00085_sf35_adam"), "nets/D5_FULL_NC.nnue")
# export_nnue(get_best_network("full_d5_cap", "lr00085_sf55_adam"), "nets/D5_FULL_NC_55.nnue")

# export_nnue(get_best_network("simple_d8_nocap", "lr00085_sf35_adam"), "nets/D8_SIMPLE.nnue")
# export_nnue(get_best_network("simple_d8_nocap", "lr00085_sf55_adam"), "nets/D8_SIMPLE_55.nnue")
# export_nnue(get_best_network("simple_d8_cap", "lr00085_sf35_adam"), "nets/D8_SIMPLE_NC.nnue")
# export_nnue(get_best_network("simple_d8_cap", "lr00085_sf55_adam"), "nets/D8_SIMPLE_NC_55.nnue")

# export_nnue(get_best_network("simple_d5_nocap", "lr00085_sf35_adam"), "nets/D5_SIMPLE.nnue")
# export_nnue(get_best_network("simple_d5_nocap", "lr00085_sf55_adam"), "nets/D5_SIMPLE_55.nnue")
# export_nnue(get_best_network("simplie_d5_cap", "lr00085_sf35_adam"), "nets/D5_SIMPLE_NC.nnue")
# export_nnue(get_best_network("simplie_d5_cap", "lr00085_sf55_adam"), "nets/D5_SIMPLE_NC_55.nnue")

# export_nnue(get_best_network("full_d8_m2_norm", "lr00085_sf55_adam"), "nets/D8_FULL_M2.nnue")
# export_nnue(get_best_network("full_d8_m4_norm", "lr00085_sf55_adam"), "nets/D8_FULL_M4.nnue")
# export_nnue(get_best_network("full_d8_m6_norm", "lr00085_sf55_adam"), "nets/D8_FULL_M6.nnue")
# export_nnue(get_best_network("full_d8_ma_norm", "lr00085_sf55_adam"), "nets/D8_FULL_MA.nnue")
# export_nnue(get_best_network("full_d8_m0_norm", "lr00085_sf55_adam"), "nets/D8_FULL_M0.nnue")

# export_nnue(get_best_network("full_d8_250k", "lr00085_sf55_adam"), "nets/D8_FULL_250k.nnue")
# export_nnue(get_best_network("full_d8_500k", "lr00085_sf55_adam"), "nets/D8_FULL_500k.nnue")
# export_nnue(get_best_network("full_d8_1000k", "lr00085_sf55_adam"), "nets/D8_FULL_1000k.nnue")
# export_nnue(get_best_network("full_d8_2000k", "lr00085_sf55_adam"), "nets/D8_FULL_2000k.nnue")

export_nnue(get_best_network("full_d8_final", "lr00085_sf55_adam"), "nets/D8_FULL.nnue")
export_nnue(get_best_network("full_d5_final", "lr00085_sf55_adam"), "nets/D5_FULL.nnue")
export_nnue(get_best_network("simple_d8_final", "lr00085_sf55_adam"), "nets/D8_SIMPLE.nnue")
export_nnue(get_best_network("simple_d5_final", "lr00085_sf55_adam"), "nets/D5_SIMPLE.nnue")

Saving ./checkpoints/full_d8_final/lr00085_sf55_adam/epoch=038,loss=0.04115 as nets/D8_FULL.nnue
Saving weights...
W: 128 x 768
B: 128 x 1
W: 32 x 256
B: 32 x 1
W: 1 x 32
B: 1 x 1
Weights saved
Saving ./checkpoints/full_d5_final/lr00085_sf55_adam/epoch=029,loss=0.03419 as nets/D5_FULL.nnue
Saving weights...
W: 128 x 768
B: 128 x 1
W: 32 x 256
B: 32 x 1
W: 1 x 32
B: 1 x 1
Weights saved
Saving ./checkpoints/simple_d8_final/lr00085_sf55_adam/epoch=028,loss=0.03943 as nets/D8_SIMPLE.nnue
Saving weights...
W: 128 x 768
B: 128 x 1
W: 32 x 256
B: 32 x 1
W: 1 x 32
B: 1 x 1
Weights saved
Saving ./checkpoints/simple_d5_final/lr00085_sf55_adam/epoch=030,loss=0.03598 as nets/D5_SIMPLE.nnue
Saving weights...
W: 128 x 768
B: 128 x 1
W: 32 x 256
B: 32 x 1
W: 1 x 32
B: 1 x 1
Weights saved


In [None]:
nnues  = os.listdir(f"./nets")
for nnue in nnues:
    if not nnue.endswith('.nnue'):
        continue
    
    name = nnue.split(".")[0]

    pattern = """{
		"command" : "BoomChess.exe",
		"name" : "LatestBuild_NNUE_{NAME}",
		"options" : [
			{
				"alias" : "",
				"default" : 256,
				"max" : 1024,
				"min" : 1,
				"name" : "Hash",
				"type" : "spin",
				"value" : 256
			},
			{
				"alias" : "",
				"choices" : [
					"FULL",
					"SIMPLE"
				],
				"default" : "FULL",
				"name" : "EvalType",
				"type" : "combo",
				"value" : "SIMPLE"
			},
			{
				"alias" : "",
				"default" : "<empty>",
				"name" : "NNUEPath",
				"type" : "folder",
				"value" : "C:/Users/marty/Desktop/Kursinis/nets/{NAME}.nnue"
			}
		],
		"protocol" : "uci",
		"stderrFile" : "",
		"variants" : [
			"standard",
			"atomic"
		],
		"workingDirectory" : "C:\\\\Users\\\\marty\\\\Desktop\\\\Kursinis\\\\BoomChess\\\\cmake-build-release-mingw"
	}
    """

    print(pattern.replace("{NAME}", name), ",")
    

In [1]:
test

NameError: name 'test' is not defined