In [11]:
import torch
import chess
import torch.nn as nn 
import torch.optim as optim 
import numpy as np
from torch.utils.data import DataLoader
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
from lmdb_dataset import LMDBDataset

dataset = LMDBDataset("../../data/lmdb/")
loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0, pin_memory=True)

In [13]:
torch.set_printoptions(profile="full")

x, y = dataset[5]
print(x)

tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 0., 1., 1., 1., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
       

In [14]:
print(y)

torch.set_printoptions(profile="default")

tensor(4505)


In [15]:
print("Ilość batchy:", len(loader))
print("ilość pozycji: ", len(loader)*loader.batch_size)

Ilość batchy: 18360
ilość pozycji:  2350080


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [17]:
import torch
from torch import nn, optim
from model import ChessPolicyNet

model = ChessPolicyNet().to(device)

optimizer = optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()

scaler = torch.amp.GradScaler()

epochs = 12

In [None]:
def load_checkpoint(model, optimizer, scaler, path="../../models/policy_network/DeltaChess.pt"):
    checkpoint = torch.load(path, map_location=device)

    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])

    if "scaler_state" in checkpoint:
        scaler.load_state_dict(checkpoint["scaler_state"])
        print("Wczytano scaler")

    start_epoch = checkpoint.get("epoch", 0) + 1
    print(f"Wczytano checkpoint z epoki {start_epoch-1}")

    return start_epoch

In [45]:
# start_epoch = 0
epochs = 10
start_epoch = load_checkpoint(model, optimizer, scaler)


Wczytano scaler
Wczytano checkpoint z epoki 6


In [None]:
for epoch in range(start_epoch, epochs):
    
    model.train()
    running_loss = 0.0

    for batch_idx, (batch_x, batch_y) in enumerate(loader):
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        optimizer.zero_grad()

        with torch.amp.autocast("cuda"):
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

        if (batch_idx + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(loader)}], Loss: {running_loss/100:.4f}")
            running_loss = 0.0

    torch.save({
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scaler_state": scaler.state_dict(),
        "epoch": epoch
    }, f"chess_policy_epoch{epoch+1}_v4.pt")

Epoch [7/10], Step [100/18360], Loss: 1.2670
Epoch [7/10], Step [200/18360], Loss: 1.2547
Epoch [7/10], Step [300/18360], Loss: 1.2534
Epoch [7/10], Step [400/18360], Loss: 1.2707
Epoch [7/10], Step [500/18360], Loss: 1.2497
Epoch [7/10], Step [600/18360], Loss: 1.2776
Epoch [7/10], Step [700/18360], Loss: 1.2504
Epoch [7/10], Step [800/18360], Loss: 1.2570
Epoch [7/10], Step [900/18360], Loss: 1.2700
Epoch [7/10], Step [1000/18360], Loss: 1.2531
Epoch [7/10], Step [1100/18360], Loss: 1.2758
Epoch [7/10], Step [1200/18360], Loss: 1.2650
Epoch [7/10], Step [1300/18360], Loss: 1.2625
Epoch [7/10], Step [1400/18360], Loss: 1.2619
Epoch [7/10], Step [1500/18360], Loss: 1.2771
Epoch [7/10], Step [1600/18360], Loss: 1.2755
Epoch [7/10], Step [1700/18360], Loss: 1.2718
Epoch [7/10], Step [1800/18360], Loss: 1.2853
Epoch [7/10], Step [1900/18360], Loss: 1.2739
Epoch [7/10], Step [2000/18360], Loss: 1.2942
Epoch [7/10], Step [2100/18360], Loss: 1.2893
Epoch [7/10], Step [2200/18360], Loss: 1.27

<table>
  <tr>
    <th>Named (.pt)</th>
    <th style=text-align:center>Model</th>
    <th>Epochs</th>
    <th>Precision</th>
    <th>Learning rate</th>
    <th>Positions</th>
    <th>Loss</th>
    <th>Best accuracy</th>
  </tr>
  <tr>
    <td align=center>AlphaChess</td>
    <td style=text-align:left>ChessPolicyNet</td>
    <td align=center>5</td>
    <td align=center>float32</td>
    <td align=center>0,0003</td>
    <td align=center>780 032</td>
    <td align=center>1.5345</td>
    <td align=center>79.17%</td>
  </tr>
  <tr>
    <td align=center>BetaChess</td>
    <td style=text-align:left>ChessPolicyNet</td>
    <td align=center>6</td>
    <td align=center>float32</td>
    <td align=center>0,0003</td>
    <td align=center>1 560 064</td>
    <td align=center>1.4802</td>
    <td align=center>75.00%</td>
  </tr>
  <tr>
    <td align=center>GammaChess</td>
    <td style=text-align:left>ChessPolicyNet</td>
    <td align=center>12</td>
    <td align=center>float16</td>
    <td align=center>0,0003</td>
    <td align=center>2 350 080</td>
    <td align=center>1.0552</td>
    <td align=center>66.67%</td>
  </tr>
  <tr>
    <td align=center>DeltaChess</td>
    <td style=text-align:left>ChessPolicyNet</td>
    <td align=center>10</td>
    <td align=center>float32</td>
    <td align=center>0,0003</td>
    <td align=center>2 350 080</td>
    <td align=center>1.2194</td>
    <td align=center>75.00%</td>
  </tr>
</table>
