In [1]:
!pip3 install chess
from copy import deepcopy
import torch
import torch.nn.functional as F
from torch import nn, Tensor
import math
import chess
from tqdm import trange
import linecache
from random import randint

device = "cuda" if torch.cuda.is_available() else "cpu"

Collecting chess
  Downloading chess-1.9.3-py3-none-any.whl (148 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m148.5/148.5 kB[0m [31m703.2 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: chess
Successfully installed chess-1.9.3
[0m

# VNN DECLARATION

In [2]:
class PosEncIndex(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        
    def forward(self, x: Tensor) -> Tensor:
        length = torch.max(x).item()+1
        
        pe = torch.zeros((length, self.d_model)).to(device)
        position = torch.arange(0, length).unsqueeze(1)
        div_term = torch.exp((torch.arange(0, self.d_model, 2, dtype=torch.float) *
                            -(math.log(10000.0) / self.d_model)))
        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)

        return pe[x]

class VNNBlock (nn.Module):
    def __init__(self, d_model, weight_nn, bias_nn) -> None:
        super().__init__()
        self.weight_nn = weight_nn 
        self.bias_nn = bias_nn 
        self.pos_enc = PosEncIndex(d_model)

    def weight_propagation (self, x, output_size, extra_out):
        input_size = x.size(1)
        batch_size = x.size(0)
        
        #* Weight Generation
        # Generate the weight vector
        argument_one = torch.arange(input_size)
        argument_two = torch.arange(output_size)

        # Generate the repeat
        argument_one = argument_one.repeat(batch_size, output_size)
        argument_two = argument_two.repeat(batch_size, input_size)

        x_concat = x.repeat(1, output_size).unsqueeze(2).to(device)

        # Positional Encoding + Concat
        argument_one = self.pos_enc(argument_one.detach())
        argument_two = self.pos_enc(argument_two.detach())

        if extra_out != None:  
            argument = torch.concat((argument_one, argument_two, x_concat, extra_out.repeat(1, input_size, 1)), dim=2)
        else:
            argument = torch.concat((argument_one, argument_two, x_concat), dim=2)
        
        weights = self.weight_nn(argument.detach()).view(batch_size, input_size, output_size)
        x = x.view(batch_size, 1, input_size)
        out = torch.bmm(x, weights).squeeze(1)

        #* Bias Generation

        # Create Bias Argument
        argument_one = torch.arange(output_size)
        argument_one = self.pos_enc(argument_one.detach()).squeeze(1)
        argument_one = argument_one.repeat(batch_size, 1, 1)
        argument_two = out.unsqueeze(2)
        if extra_out != None:
            bias_argument = torch.concat((argument_one, argument_two, extra_out), dim=2)
        else:
            bias_argument = torch.concat((argument_one, argument_two), dim=2)

        # Add bias
        bias = self.bias_nn(bias_argument.detach()).squeeze(2)
        out += bias

        return out

    def return_gpu_desc (self):
        t = torch.cuda.get_device_properties(0).total_memory
        r = torch.cuda.memory_reserved(0)
        a = torch.cuda.memory_allocated(0)
        f = r-a  # free inside reserved
        return f"Free: {f/1024**2} MB; Allocated: {a/1024**2} MB"

    def forward (self, x, output_size, extra_out=None, chunks=None):
        # Extra Out size: 
        # first dim is the batch size, second dim is the output space, third dim is the vector added during weight 
        if chunks != None:
            if extra_out != None:
                assert x.size(0) == extra_out.size(0), f"Batch size of x ({x.size(0)}) is the same as the batch size of extra_out ({extra_out.size(0)})"

            arr = [output_size // chunks for _ in range(chunks)]        
            if output_size % chunks > 0: 
                arr.append(output_size % chunks) 

            out = torch.tensor([])
            output_size = 5

            index = 0
            for i in range(len(arr)):
                indx_arr = torch.arange(start=index, end=index+arr[i]).to(device)
                partial_extra_out = torch.index_select(extra_out, dim=1, index=indx_arr)
                output = self.weight_propagation(x, arr[i], partial_extra_out)
                if out.size(0) == 0: 
                    out = output 
                else:
                    out = torch.concat((out, output), dim=1)
                index += arr[i]
            return out
        else:
            return self.weight_propagation(x, output_size, extra_out)
    
    
d_model = 16

weight_model = nn.Sequential(
    nn.Linear(65, 16),
    nn.Tanh(),
    nn.Linear(16, 1),
) 

bias_model = nn.Sequential(
    nn.Linear(49, 12),
    nn.Tanh(),
    nn.Linear(12, 1),
)

model = VNNBlock(d_model, weight_model, bias_model).to("cuda")
out = model(torch.randn(1, 60).to("cuda"), 10, extra_out=torch.randn(1, 10, 32).to("cuda"), chunks=7)
print(out.shape)

torch.Size([1, 10])


# TRAINING CHESS

In [3]:

input_database = "../input/35-million-chess-games/all_with_filtered_anotations_since1998.txt"
output_model = "./ChessEngine.pth"
# line cache warmup
linecache.getline(input_database, 0) 

def encode_move (move:chess.Move):
    char_to_num = {
        "a":0,
        "b":1,
        "c":2,
        "d":3,
        "e":4,
        "f":5,
        "g":6,
        "h":7,
    }

    uci_str = move.uci()
    first_num = char_to_num[uci_str[0]]
    second_num = int(uci_str[1])-1
    third_num = char_to_num[uci_str[2]]
    fourth_num = int(uci_str[3])-1
    
    return torch.concat((
        F.one_hot(torch.tensor([first_num]), num_classes=8),
        F.one_hot(torch.tensor([second_num]), num_classes=8),
        F.one_hot(torch.tensor([third_num]), num_classes=8),
        F.one_hot(torch.tensor([fourth_num]), num_classes=8),
    ), dim=1).to(torch.float)

def decode_move (x:torch.tensor):
    num_to_char = {
        0:"a",
        1:"b",
        2:"c",
        3:"d",
        4:"e",
        5:"f",
        6:"g",
        7:"h",
    }
    x = x.view(-1, 8)
    x = torch.argmax(x, dim=1).cpu().tolist()
    first_char = num_to_char[x[0]] 
    second_char = str(x[1]+1) 
    third_char = num_to_char[x[2]] 
    fourth_char = str(x[3]+1) 
    return chess.Move.from_uci(first_char+second_char+third_char+fourth_char)

def encode_board (board):
    x = 0
    y = 0
    return_tensor = torch.zeros(1,13,8,8)
    for char in board.__str__():
        if char == " ": continue
        if char == "r":   return_tensor[0][0][x][y] = 1
        elif char == "n": return_tensor[0][1][x][y] = 1
        elif char == "b": return_tensor[0][2][x][y] = 1
        elif char == "q": return_tensor[0][3][x][y] = 1
        elif char == "k": return_tensor[0][4][x][y] = 1
        elif char == "k": return_tensor[0][5][x][y] = 1
        elif char == "P": return_tensor[0][6][x][y] = 1
        elif char == "R": return_tensor[0][7][x][y] = 1
        elif char == "N": return_tensor[0][8][x][y] = 1
        elif char == "B": return_tensor[0][9][x][y] = 1
        elif char == "Q": return_tensor[0][10][x][y] = 1
        elif char == "K": return_tensor[0][11][x][y] = 1
        if char == "p":   return_tensor[0][12][x][y] = 1

        x += 1
        if char == "\n": 
            y += 1
            x = 0
    return return_tensor

class ChessClassificationDatabase(torch.utils.data.Dataset):
    def __init__(self, num_games):
        assert num_games > 0
        self.x = torch.tensor([]) 
        self.y = torch.tensor([])
        self.possible_moves = []

        len_lines = 3561469
        for i in range(num_games):
            try:
                line = linecache.getline(input_database, randint(0, len_lines)+6) 
                board = chess.Board()

                line = line.split("###")[1].strip()
                moves = line.split(" ")

                for move in moves:
                    move = move.split(".")[1]
                    tensor_board = encode_board(board)
                    actual_move = board.parse_san(move)
                    possible_move = torch.tensor([])

                    # Encode possible moves
                    legal_moves = board.legal_moves
                    for move in legal_moves:
                        move_enc = encode_move(move)
                        if possible_move.size(0) == 0: possible_move = move_enc
                        else: possible_move = torch.concat((possible_move, move_enc), dim=0)

                    # Append to variables
                    self.possible_moves.append(possible_move)
                    y_enc = list(legal_moves).index(actual_move)
                    y_enc = torch.tensor([[y_enc]])
                    if self.x.size(0) == 0:
                        self.x = tensor_board
                        self.y = y_enc 
                    else:
                        self.x = torch.vstack((self.x, tensor_board))
                        self.y = torch.vstack((self.y, y_enc))

                    board.push(actual_move)
            except:
                continue
        self.x = self.x.to(device)
        self.y = self.y.to(device)

    def __len__(self):
        return self.x.size(0)

    def __getitem__(self, index):
        return self.x[index].detach(), self.y[index].detach(), self.possible_moves[index].to(device).detach()

class PolicyNeuralNetwork (nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.neuralNet = nn.Sequential(
            nn.Conv2d(13, 16, 3),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3),
            nn.ReLU(),
            nn.Flatten(),
            nn.LazyLinear(128),
            nn.Tanh(),
            nn.LazyLinear(64),
        )

        d_model = 16

        weight_model = nn.Sequential(
            nn.Linear(65, 16),
            nn.Tanh(),
            nn.Linear(16, 1),
        ) 

        bias_model = nn.Sequential(
            nn.Linear(49, 12),
            nn.Tanh(),
            nn.Linear(12, 1),
        )

        self.model = VNNBlock(d_model, weight_model, bias_model)

    def forward (self, x, possible_moves):
        x = self.neuralNet(x)
        return self.model(x, possible_moves.size(1), possible_moves)

if __name__ == "__main__":
    # Optimizer and training parameters
    batch_size = 16
    validation_size = 64
    num_games_per_itr = 4
    lr = 0.001
    itr = 10_000

    criterion = nn.CrossEntropyLoss()
    model = PolicyNeuralNetwork().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    # Training Loop
    progress_bar = trange(itr)
    itr = 0
    mag = 1
    for i in progress_bar:
        # Get data
        dataset = ChessClassificationDatabase(num_games=num_games_per_itr),

        # Train on that data
        opt.zero_grad()
        losses = torch.tensor([])
        for data in dataset:
            x = data[0][0].unsqueeze(0)
            y = data[0][1]
            possible_moves = data[0][2].unsqueeze(0)
            out = model(x, possible_moves) 
            data_loss = criterion(out, y)
            data_loss = data_loss.unsqueeze(0)
            if losses.size(0) == 0: losses = data_loss
            else: losses = torch.concat((losses, data_loss), 0)
        loss = torch.mean(losses)
        loss.backward()
        opt.step()

        if i % 100 == 0 and i != 0:  
            # Save Model
            torch.save(model, output_model)
            print("\nSaved Model")
            
        progress_bar.set_description(f"Loss: {loss.item():.4f}")

Loss: 0.9315:   1%|          | 101/10000 [01:39<2:21:01,  1.17it/s]


Saved Model


Loss: 0.7540:   2%|▏         | 201/10000 [03:11<2:26:21,  1.12it/s]


Saved Model


Loss: 3.3411:   3%|▎         | 301/10000 [04:40<2:18:10,  1.17it/s]


Saved Model


Loss: 0.3665:   4%|▍         | 401/10000 [06:12<2:29:35,  1.07it/s]


Saved Model


Loss: 4.0071:   5%|▌         | 501/10000 [07:41<2:08:39,  1.23it/s]


Saved Model


Loss: 2.9564:   6%|▌         | 601/10000 [09:12<2:46:38,  1.06s/it]


Saved Model


Loss: 0.9245:   7%|▋         | 701/10000 [10:47<2:10:44,  1.19it/s]


Saved Model


Loss: 0.7778:   8%|▊         | 801/10000 [12:17<2:11:11,  1.17it/s]


Saved Model


Loss: 0.5498:   9%|▉         | 901/10000 [13:48<2:41:31,  1.07s/it]


Saved Model


Loss: 1.0063:  10%|█         | 1001/10000 [15:17<2:18:58,  1.08it/s]


Saved Model


Loss: 0.7832:  11%|█         | 1101/10000 [16:49<2:03:18,  1.20it/s]


Saved Model


Loss: 0.6470:  12%|█▏        | 1201/10000 [18:22<2:33:35,  1.05s/it]


Saved Model


Loss: 6.4815:  13%|█▎        | 1301/10000 [19:53<2:20:03,  1.04it/s]


Saved Model


Loss: 0.7968:  14%|█▍        | 1401/10000 [21:23<1:53:56,  1.26it/s]


Saved Model


Loss: 1.1010:  15%|█▌        | 1501/10000 [22:54<2:11:26,  1.08it/s]


Saved Model


Loss: 1.9307:  16%|█▌        | 1601/10000 [24:26<2:04:20,  1.13it/s]


Saved Model


Loss: 0.6431:  17%|█▋        | 1701/10000 [25:58<2:09:54,  1.06it/s]


Saved Model


Loss: 1.5876:  18%|█▊        | 1801/10000 [27:26<1:50:06,  1.24it/s]


Saved Model


Loss: 1.0383:  19%|█▉        | 1901/10000 [28:55<2:02:58,  1.10it/s]


Saved Model


Loss: 2.6342:  20%|██        | 2001/10000 [30:27<2:08:42,  1.04it/s]


Saved Model


Loss: 0.5535:  21%|██        | 2101/10000 [31:57<2:20:11,  1.06s/it]


Saved Model


Loss: 0.8529:  22%|██▏       | 2201/10000 [33:29<2:03:30,  1.05it/s]


Saved Model


Loss: 1.1879:  23%|██▎       | 2301/10000 [35:04<1:59:24,  1.07it/s]


Saved Model


Loss: 1.4640:  24%|██▍       | 2401/10000 [36:36<2:11:40,  1.04s/it]


Saved Model


Loss: 0.5413:  25%|██▌       | 2501/10000 [38:07<2:09:08,  1.03s/it]


Saved Model


Loss: 0.7380:  26%|██▌       | 2601/10000 [39:35<1:47:02,  1.15it/s]


Saved Model


Loss: 2.2787:  27%|██▋       | 2701/10000 [41:08<1:56:51,  1.04it/s]


Saved Model


Loss: 0.5960:  28%|██▊       | 2801/10000 [42:40<1:54:45,  1.05it/s]


Saved Model


Loss: 0.7318:  29%|██▉       | 2901/10000 [44:09<1:55:22,  1.03it/s]


Saved Model


Loss: 2.8989:  30%|███       | 3001/10000 [45:43<1:47:22,  1.09it/s]


Saved Model


Loss: 2.4744:  31%|███       | 3101/10000 [47:12<1:57:51,  1.03s/it]


Saved Model


Loss: 0.9035:  32%|███▏      | 3201/10000 [48:43<1:55:41,  1.02s/it]


Saved Model


Loss: 0.7651:  33%|███▎      | 3301/10000 [50:15<1:38:48,  1.13it/s]


Saved Model


Loss: 0.5597:  34%|███▍      | 3401/10000 [51:45<1:33:03,  1.18it/s]


Saved Model


Loss: 0.5590:  35%|███▌      | 3501/10000 [53:18<1:30:40,  1.19it/s]


Saved Model


Loss: 1.2809:  36%|███▌      | 3601/10000 [54:49<1:25:32,  1.25it/s]


Saved Model


Loss: 0.5946:  37%|███▋      | 3701/10000 [56:15<1:25:50,  1.22it/s]


Saved Model


Loss: 0.7210:  38%|███▊      | 3801/10000 [57:46<1:25:57,  1.20it/s]


Saved Model


Loss: 0.6097:  39%|███▉      | 3901/10000 [59:17<1:31:37,  1.11it/s]


Saved Model


Loss: 0.7838:  40%|████      | 4001/10000 [1:00:47<1:17:23,  1.29it/s]


Saved Model


Loss: 1.0776:  41%|████      | 4101/10000 [1:02:16<1:27:25,  1.12it/s]


Saved Model


Loss: 1.0488:  42%|████▏     | 4201/10000 [1:03:47<1:23:10,  1.16it/s]


Saved Model


Loss: 0.6943:  43%|████▎     | 4301/10000 [1:05:16<1:35:33,  1.01s/it]


Saved Model


Loss: 0.6323:  44%|████▍     | 4401/10000 [1:06:49<1:23:17,  1.12it/s]


Saved Model


Loss: 0.7138:  45%|████▌     | 4501/10000 [1:08:18<1:34:34,  1.03s/it]


Saved Model


Loss: 1.3607:  46%|████▌     | 4601/10000 [1:09:46<1:09:23,  1.30it/s]


Saved Model


Loss: 1.2788:  47%|████▋     | 4701/10000 [1:11:16<1:25:04,  1.04it/s]


Saved Model


Loss: 0.7292:  48%|████▊     | 4801/10000 [1:12:47<1:19:25,  1.09it/s]


Saved Model


Loss: 0.6798:  49%|████▉     | 4901/10000 [1:14:16<1:25:57,  1.01s/it]


Saved Model


Loss: 0.6550:  50%|█████     | 5001/10000 [1:15:45<1:33:44,  1.13s/it]


Saved Model


Loss: 0.7064:  51%|█████     | 5101/10000 [1:17:13<1:23:23,  1.02s/it]


Saved Model


Loss: 1.2253:  52%|█████▏    | 5201/10000 [1:18:44<1:09:08,  1.16it/s]


Saved Model


Loss: 0.9010:  53%|█████▎    | 5301/10000 [1:20:13<1:07:01,  1.17it/s]


Saved Model


Loss: 0.5731:  54%|█████▍    | 5401/10000 [1:21:45<1:19:02,  1.03s/it]


Saved Model


Loss: 0.8846:  55%|█████▌    | 5501/10000 [1:23:14<1:16:28,  1.02s/it]


Saved Model


Loss: 0.7041:  56%|█████▌    | 5601/10000 [1:24:46<1:19:56,  1.09s/it]


Saved Model


Loss: 5.7265:  57%|█████▋    | 5701/10000 [1:26:18<1:17:45,  1.09s/it]


Saved Model


Loss: 1.1278:  58%|█████▊    | 5801/10000 [1:27:47<49:48,  1.40it/s]


Saved Model


Loss: 0.6132:  59%|█████▉    | 5901/10000 [1:29:17<1:04:46,  1.05it/s]


Saved Model


Loss: 0.8232:  60%|██████    | 6001/10000 [1:30:48<1:05:37,  1.02it/s]


Saved Model


Loss: 2.5898:  61%|██████    | 6101/10000 [1:32:16<1:03:36,  1.02it/s]


Saved Model


Loss: 0.7578:  62%|██████▏   | 6201/10000 [1:33:48<55:02,  1.15it/s]


Saved Model


Loss: 1.1521:  63%|██████▎   | 6301/10000 [1:35:20<1:07:38,  1.10s/it]


Saved Model


Loss: 0.8892:  64%|██████▍   | 6401/10000 [1:36:50<54:53,  1.09it/s]


Saved Model


Loss: 2.2623:  65%|██████▌   | 6501/10000 [1:38:21<58:18,  1.00it/s]


Saved Model


Loss: 0.6194:  66%|██████▌   | 6601/10000 [1:39:56<56:43,  1.00s/it]  


Saved Model


Loss: 0.6884:  67%|██████▋   | 6701/10000 [1:41:28<48:38,  1.13it/s]


Saved Model


Loss: 0.6825:  68%|██████▊   | 6801/10000 [1:42:59<44:57,  1.19it/s]


Saved Model


Loss: 2.7184:  69%|██████▉   | 6901/10000 [1:44:29<43:04,  1.20it/s]


Saved Model


Loss: 0.7195:  70%|███████   | 7001/10000 [1:45:59<40:11,  1.24it/s]


Saved Model


Loss: 2.1470:  71%|███████   | 7101/10000 [1:47:31<45:16,  1.07it/s]


Saved Model


Loss: 0.7246:  72%|███████▏  | 7201/10000 [1:49:04<39:54,  1.17it/s]


Saved Model


Loss: 0.7800:  73%|███████▎  | 7301/10000 [1:50:35<41:09,  1.09it/s]


Saved Model


Loss: 2.7129:  74%|███████▍  | 7401/10000 [1:52:05<39:45,  1.09it/s]


Saved Model


Loss: 1.1188:  75%|███████▌  | 7501/10000 [1:53:37<41:42,  1.00s/it]


Saved Model


Loss: 1.4031:  76%|███████▌  | 7601/10000 [1:55:08<35:43,  1.12it/s]


Saved Model


Loss: 0.5504:  77%|███████▋  | 7701/10000 [1:56:39<32:18,  1.19it/s]


Saved Model


Loss: 0.5732:  78%|███████▊  | 7801/10000 [1:58:10<31:49,  1.15it/s]


Saved Model


Loss: 1.3473:  79%|███████▉  | 7901/10000 [1:59:43<36:27,  1.04s/it]


Saved Model


Loss: 10.3791:  80%|████████  | 8001/10000 [2:01:13<30:43,  1.08it/s]


Saved Model


Loss: 2.9449:  81%|████████  | 8101/10000 [2:02:43<27:22,  1.16it/s]


Saved Model


Loss: 0.6112:  82%|████████▏ | 8201/10000 [2:04:14<28:04,  1.07it/s]


Saved Model


Loss: 1.2212:  83%|████████▎ | 8301/10000 [2:05:44<23:52,  1.19it/s]


Saved Model


Loss: 1.2318:  84%|████████▍ | 8401/10000 [2:07:15<26:07,  1.02it/s]


Saved Model


Loss: 2.6495:  85%|████████▌ | 8501/10000 [2:08:43<20:30,  1.22it/s]


Saved Model


Loss: 1.0485:  86%|████████▌ | 8601/10000 [2:10:13<23:01,  1.01it/s]


Saved Model


Loss: 0.5985:  87%|████████▋ | 8701/10000 [2:11:44<18:29,  1.17it/s]


Saved Model


Loss: 1.0597:  88%|████████▊ | 8801/10000 [2:13:12<19:02,  1.05it/s]


Saved Model


Loss: 0.5983:  89%|████████▉ | 8901/10000 [2:14:47<15:38,  1.17it/s]


Saved Model


Loss: 1.1701:  90%|█████████ | 9001/10000 [2:16:18<14:14,  1.17it/s]


Saved Model


Loss: 2.3157:  91%|█████████ | 9101/10000 [2:17:51<15:32,  1.04s/it]


Saved Model


Loss: 1.3100:  92%|█████████▏| 9201/10000 [2:19:24<13:02,  1.02it/s]


Saved Model


Loss: 2.8601:  93%|█████████▎| 9301/10000 [2:20:55<11:26,  1.02it/s]


Saved Model


Loss: 1.0815:  94%|█████████▍| 9401/10000 [2:22:25<07:38,  1.31it/s]


Saved Model


Loss: 0.6583:  95%|█████████▌| 9501/10000 [2:23:55<09:11,  1.10s/it]


Saved Model


Loss: 0.5963:  96%|█████████▌| 9601/10000 [2:25:28<06:24,  1.04it/s]


Saved Model


Loss: 0.6126:  97%|█████████▋| 9701/10000 [2:27:01<04:15,  1.17it/s]


Saved Model


Loss: 0.6234:  98%|█████████▊| 9801/10000 [2:28:33<02:47,  1.19it/s]


Saved Model


Loss: 2.3190:  99%|█████████▉| 9901/10000 [2:30:05<01:20,  1.23it/s]


Saved Model


Loss: 1.0452: 100%|██████████| 10000/10000 [2:31:39<00:00,  1.10it/s]
