# LOADING DATASET

In [1]:
import pathlib
import urllib
import zstandard
import chess
import torch
import numpy as np
from torch import nn

In [2]:
def __download(url: str, name: str) -> str:
    path, _ = urllib.request.urlretrieve(url, name)
    return path


def __unpack(path: str, name: str):
    input_file = pathlib.Path(path)
    with open(input_file, 'rb') as compressed:
        decomp = zstandard.ZstdDecompressor()
        output_path = name
        with open(output_path, 'wb') as destination:
            decomp.copy_stream(compressed, destination)
            destination.close()
        compressed.close()


def __remove(path: str):
    pathlib.Path.unlink(pathlib.Path(path))

In [3]:
path = __download("https://database.lichess.org/lichess_db_puzzle.csv.zst", "lichess_db_puzzle.csv.zst")

In [4]:
__unpack(path, "lichess_db_puzzle.csv")

In [5]:
__remove("lichess_db_puzzle.csv.zst")

In [6]:
class Puzzle:
    def __init__(self, row: str):
        fields = row.split(',')
        self.fen = fields[1]
        self.moves = fields[2].split(" ")
        self.tags = fields[7].split(" ")

    def __str__(self):
        return "{fen: " + self.fen + " ,tags: [" + ", ".join(self.tags) + "],moves: [" + ",".join(self.moves) + "]}"

In [7]:
def load(k: int) -> [Puzzle]:
    f = open("lichess_db_puzzle.csv")
    f.readline()
    result = []
    for i in range(k):
        result.append(Puzzle(f.readline()))
    f.close()
    return result

In [8]:
load(10)[0].__str__()

'{fen: r6k/pp2r2p/4Rp1Q/3p4/8/1N1P2R1/PqP2bPP/7K b - - 0 24 ,tags: [crushing, hangingPiece, long, middlegame],moves: [f2g3,e6e7,b2b1,b3c1,b1c1,h6c1]}'

# FILTER DATASET

In [9]:
expected_tags = {
    'attraction',
    'discoveredAttack',
    'doubleCheck',
    'fork',
    'pin',
    'sacrifice',
    'skewer',
    'xRayAttack',
    'zugzwang',
    'deflection',
    'clearance'
}

In [10]:
expected_tags_list = list(expected_tags)

In [11]:
def filter_data(data: [Puzzle]) -> [Puzzle]:
    return list(filter(lambda p: len(set(p.tags) & expected_tags) == 1, data))

In [12]:
len(filter_data(load(100)))

37

# CONVERSION TO TENSOR

In [13]:
def bitboard_to_tensor(bitboard: int) -> torch.Tensor:
    li = [1 if digit == '1' else 0 for digit in bin(bitboard)[2:]]
    li = [0 for _ in range(64 - len(li))] + li
    return torch.tensor(li).reshape((8, 8))

In [14]:
def fen_to_tensors_list(fen: str) -> [torch.Tensor]:
    board = chess.Board(fen)
    return [
        bitboard_to_tensor(board.occupied_co[chess.WHITE]),
        bitboard_to_tensor(board.occupied_co[chess.BLACK]),
        bitboard_to_tensor(board.pawns),
        bitboard_to_tensor(board.kings),
        bitboard_to_tensor(board.queens),
        bitboard_to_tensor(board.knights),
        bitboard_to_tensor(board.bishops),
        bitboard_to_tensor(board.rooks)
    ]

In [15]:
fen_to_tensors_list(load(1)[0].fen)

[tensor([[0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 1, 0, 1, 0],
         [1, 1, 0, 0, 0, 1, 0, 1],
         [1, 0, 0, 0, 0, 0, 0, 0]]),
 tensor([[1, 0, 0, 0, 0, 0, 0, 1],
         [1, 0, 0, 1, 0, 0, 1, 1],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0]]),
 tensor([[0, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 1, 1],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [1, 1, 0, 0, 0, 1, 0, 1],
         [0, 0, 0, 0, 0, 0, 0, 0]]),
 tensor([[1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0

In [16]:
def move_to_tensor(move: str) -> torch.Tensor:
    x1 = 7 - ord(move[0]) + ord('a')
    y1 = 8 - int(move[1])
    x2 = 7 - ord(move[2]) + ord('a')
    y2 = 8 - int(move[3])
    tensor = torch.zeros(8, 8)
    tensor[y1][x1] = 1
    tensor[y2][x2] = 1
    return tensor

In [17]:
print(move_to_tensor('e2e4'))

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])


In [18]:
def puzzle_to_tensor(puzzle: Puzzle) -> torch.Tensor:
    fen_tensors = fen_to_tensors_list(puzzle.fen)
    move_tensors = [move_to_tensor(puzzle.moves[0]), move_to_tensor(puzzle.moves[1])]  # FIRST TWO MOVES
    return torch.stack(fen_tensors + move_tensors)

In [19]:
puzzle_to_tensor(load(1)[0])

tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 1., 0., 1., 0.],
         [1., 1., 0., 0., 0., 1., 0., 1.],
         [1., 0., 0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 1.],
         [1., 0., 0., 1., 0., 0., 1., 1.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 1., 1.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0.],
         [1., 1., 0., 0., 0., 1., 0., 1.],
       

# CONVERT AND BATCH DATASET

In [45]:
def puzzle_to_truth(puzzle: Puzzle) -> torch.Tensor:
    tensor = torch.zeros(len(expected_tags_list))
    [tag] = set(puzzle.tags) & expected_tags
    index = expected_tags_list.index(tag)
    tensor[index] = 1
    return torch.zeros(1) + index

In [47]:
puzzle_to_truth(filter_data(load(100))[0])

tensor([9.])

In [48]:
BATCH_SIZE = 64

In [49]:
def convert_dataset(puzzles: [Puzzle]) -> list[tuple[torch.Tensor, torch.Tensor]]:
    return [(puzzle_to_tensor(puzzle), puzzle_to_truth(puzzle)) for puzzle in puzzles]

In [62]:
def dataset_to_batches(dataset: list[tuple[torch.Tensor, torch.Tensor]]) -> list[tuple[torch.Tensor, torch.Tensor]]:
    batches = []
    index = 0
    while index + BATCH_SIZE <= len(dataset):
        batch = []
        truth = []
        max_index = index + BATCH_SIZE
        while index < max_index:
            batch.append(dataset[index][0])
            truth.append(dataset[index][1])
            index += 1
        batches.append((torch.stack(batch).cuda(), torch.tensor(truth).cuda().type(torch.long)))

    return batches

In [63]:
batched_dataset=dataset_to_batches(convert_dataset(filter_data(load(10000))))
print(len(batched_dataset))
print(batched_dataset[0][0].shape,batched_dataset[0][1].shape)

47
torch.Size([64, 10, 8, 8]) torch.Size([64])


# TRAIN

In [70]:
def accuracy(out,truth):
    return torch.argmax(out,dim=1) == truth

In [73]:
class Model(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Model, self).__init__()
        self.classifier = nn.Sequential(*args, **kwargs)

    def forward(self, X):
        return self.classifier.forward(X)


model = Model(nn.Conv2d(10, 8 * 8, kernel_size=4, padding=2),
              nn.ReLU(),
              nn.Conv2d(8 * 8, 4 * 4, kernel_size=2, padding=2),
              nn.ReLU(),
              nn.MaxPool2d(kernel_size=4, stride=1),
              nn.Conv2d(4*4, 8 * 8, kernel_size=2, padding=2),
              nn.ReLU(),
              nn.Conv2d(8 * 8, 1, kernel_size=4, padding=2),
              nn.ReLU(),
              nn.Flatten(),
              nn.Linear(169, 256),
              nn.ReLU(),
              nn.Linear(256, 64),
              nn.ReLU(),
              nn.Linear(64, 11),
              nn.LogSoftmax(),
              )
criterion = (
    nn.NLLLoss()
)


In [66]:
size_to_load=2000000
test_batches_count=400

In [67]:
all_batches=dataset_to_batches(convert_dataset(filter_data(load(size_to_load))))
train_batches=all_batches[test_batches_count:]
test_batches=all_batches[:test_batches_count]
print(len(all_batches),len(train_batches),len(test_batches))

9421 9021 400


In [74]:
def train(model, criterion, optimizer, epoch):
    model.cuda()
    criterion.cuda()
    batches = train_batches
    print("Dataset size:", len(batches))
    batch_index = 0
    for i in range(epoch):

        batch = batches[batch_index][0]
        truth = batches[batch_index][1]

        if batch_index == len(batches):
            batch_index = 0

        optimizer.zero_grad()
        out = model.forward(batch)
        loss = criterion(out, truth)
        loss.backward()
        optimizer.step()

        if i % 1000 == 0:
            print(loss.item(),(accuracy(out,truth).sum()/BATCH_SIZE).item())

In [75]:
train(model,
      criterion,
      torch.optim.SGD(model.classifier.parameters(), lr=0.001),
      200000)

Dataset size: 9021
2.3968663215637207 0.125
2.2532296180725098 0.28125
1.9107728004455566 0.28125
1.771907091140747 0.28125
1.7604730129241943 0.28125
1.7562994956970215 0.28125
1.7538591623306274 0.28125
1.7520688772201538 0.28125
1.7505725622177124 0.28125
1.7491289377212524 0.28125
1.7476325035095215 0.28125
1.7459619045257568 0.28125
1.7439974546432495 0.28125
1.7415149211883545 0.28125
1.7381435632705688 0.28125
1.7334425449371338 0.28125
1.7254749536514282 0.34375
1.7112091779708862 0.3125
1.6790845394134521 0.34375
1.563393235206604 0.4375
1.1087039709091187 0.640625
0.5687057971954346 0.859375
0.2833808660507202 0.921875
0.1510951966047287 0.9375
0.07308834791183472 1.0
0.03386355936527252 1.0
0.016486238688230515 1.0
0.009407673962414265 1.0
0.006124743260443211 1.0
0.004372854717075825 1.0
0.003319905139505863 1.0
0.0026326931547373533 1.0
0.0021554080303758383 1.0
0.00181114231236279 1.0
0.001551235793158412 1.0
0.0013495150487869978 1.0
0.0011893026530742645 1.0
0.001059889

In [87]:
torch.save(model,'model.pt')

In [89]:
def test(model, criterion):
    model.cuda()
    criterion.cuda()
    batches = test_batches
    print("Dataset size:", len(batches))
    batch_index = 0
    
    total_loss = 0
    total_accuracy = 0
    for i in range(len(batches)):

        batch = batches[batch_index][0]
        truth = batches[batch_index][1]

        if batch_index == len(batches):
            batch_index = 0

        out = model.forward(batch)
        loss = criterion(out, truth)
        print(loss.item(),(accuracy(out,truth).sum()/BATCH_SIZE).item())
        total_loss += loss.item()
        total_accuracy+=(accuracy(out,truth).sum()/BATCH_SIZE).item()

    return (total_loss / len(batches)),total_accuracy / len(batches)

In [90]:
test(model, criterion)

Dataset size: 400
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.54104995727539 0.109375
25.5410499

(25.54104995727539, 0.109375)

In [88]:
def eye_test(model,puzzle):
    tensor=puzzle_to_tensor(puzzle).cuda()
    out=model.forward(tensor)
    return sorted(zip(expected_tags_list,out.squeeze().tolist()),key=lambda x:-x[1]),puzzle.tags

eye_test(model,filter_data(load(100))[0])

([('discoveredAttack', -0.0025150116998702288),
  ('pin', -6.047675132751465),
  ('zugzwang', -8.906311988830566),
  ('clearance', -11.255400657653809),
  ('fork', -17.663063049316406),
  ('deflection', -18.95039939880371),
  ('attraction', -19.286771774291992),
  ('sacrifice', -20.158817291259766),
  ('skewer', -53.08643341064453),
  ('doubleCheck', -65.40817260742188),
  ('xRayAttack', -68.17584991455078)],
 ['crushing', 'endgame', 'exposedKing', 'long', 'skewer'])