In [1]:
from torch.utils.data import DataLoader
from dataset import HDF5ChessDataset
import torch.nn as nn
import torch.optim as optim
from model import ChessBotCNN
import torch
from processing import preprocess_and_save_hdf5
import h5py


In [None]:
hdf5_file = 'preprocessed_data.h5'
parquet_file = 'chess_data.parquet'

preprocess_and_save_hdf5(parquet_file, hdf5_file, chunk_size=10000)

In [None]:
# Initialize the dataset
dataset = HDF5ChessDataset(hdf5_file)
dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0,  # Use 0 to avoid HDF5 multi-threading issues
    pin_memory=True
)
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [None]:
# Initialize the model
model = ChessBotCNN()
model = model.to(device) 
model.train()

# Loss functions
policy_loss_fn = nn.KLDivLoss(reduction='batchmean')  # For policy output
value_loss_fn = nn.MSELoss()                          # For value output

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10  
log_interval = 100  # Logging interval

In [None]:
for epoch in range(1, num_epochs + 1):
    model.train()
    for batch_idx, (board_tensor, policy_tensor, value_tensor) in enumerate(dataloader):
        board_tensor = board_tensor.to(device, non_blocking=True)
        policy_tensor = policy_tensor.to(device, non_blocking=True)
        value_tensor = value_tensor.to(device, non_blocking=True)

        optimizer.zero_grad()

        # Forward pass
        policy_pred, value_pred = model(board_tensor)

        # Compute losses
        policy_loss = policy_loss_fn(policy_pred, policy_tensor)
        value_loss = value_loss_fn(value_pred.squeeze(), value_tensor)

        # Total loss
        loss = policy_loss + value_loss

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], "
                  f"Policy Loss: {policy_loss.item():.4f}, Value Loss: {value_loss.item():.4f}")


In [None]:
dataset.close()

model_name = "mark5.1-MCTS"

# Save the updated model
torch.save(model.state_dict(), f"{model_name}.pth")

model.eval()
# Create an example input
example_input = torch.randn(1, 13, 8, 8)
# Trace the model
traced_script_module = torch.jit.trace(model, example_input)
# Save the traced model
traced_script_module.save(f"{model_name}.pt")