In [None]:
import chess

import pandas as pd
import matplotlib.pyplot as plt
import encoding_tools as EncodingTools

from model import ChessNet

import torch
from torch.utils.data import DataLoader, Dataset


MODE = "RELEASE"  # If in release mode, please comment this line

In [None]:
from tqdm import tqdm
from FEN_to_chessboard import FenToChessBoard
from encoder_decoder import *

# Pulling in training data using Pandas
df = pd.concat([
    pd.read_csv('stockfish_data/chess_games_1.csv'),
    pd.read_csv('stockfish_data/chess_games_2.csv'),
    pd.read_csv('stockfish_data/chess_games_2.csv')]
)

non_zero_winners = df.copy()
# non_zero_winners = df[df['Winner'] != 0].copy()
non_zero_winners.reset_index(drop=True, inplace=True)
print("Game with an existed winner:", non_zero_winners.shape) # (18830, 4)

train_df = non_zero_winners[:5000] if MODE == "DEBUG" else non_zero_winners[:270000]
train_df.reset_index(drop=True, inplace=True)
# We'll also grab the last 1000 examples as a validation set
val_df = non_zero_winners[-1000:] if MODE == "DEBUG" else non_zero_winners[-30000:]
val_df.reset_index(drop=True, inplace=True)

##### Package the training data

X_train = np.stack(train_df['FEN'].apply(FenToChessBoard.fen_to_board).apply(encode_board)).reshape(-1, 22, 8, 8) # Size(5000, 22, 8, 8)
print("Training set size:", X_train.shape)

train_best_move_embedding = []
for idx, row in tqdm(train_df.iterrows(), total=len(train_df), desc="[TrainSet] BestMove to embedding"):
    fen_str = row['FEN']
    move_str = row['BestMove']
    # Check if fen_str is a valid string before processing
    if not isinstance(fen_str, str):
        raise ValueError(f"Invalid FEN string at index {idx}: {fen_str}")
    board = FenToChessBoard.fen_to_board(fen_str)
    train_best_move_embedding.append(move_on_board(board, move_str))
train_best_move_embedding = np.array(train_best_move_embedding)
print("Embedded Best move shape:", train_best_move_embedding.shape)
y_train = {'best_move' : train_best_move_embedding, 'winner' : train_df['Winner']}


##### Package the validation data

X_val = np.stack(val_df['FEN'].apply(FenToChessBoard.fen_to_board).apply(encode_board)).reshape(-1, 22, 8, 8)
print("Validation set size:", X_val.shape)

val_best_move_embedding = []
for idx, row in tqdm(val_df.iterrows(), total=len(val_df), desc="[ValSet] BestMove to embedding"):
    fen_str = row['FEN']
    move_str = row['BestMove']
    # Check if fen_str is a valid string before processing
    if not isinstance(fen_str, str):
        raise ValueError(f"Invalid FEN string at index {idx}: {fen_str}")
    board = FenToChessBoard.fen_to_board(fen_str)
    val_best_move_embedding.append(move_on_board(board, move_str))
val_best_move_embedding = np.array(val_best_move_embedding)
X_val_cur_board = np.stack(val_df['FEN'].apply(FenToChessBoard.fen_to_board))
y_val = {'best_move' : val_best_move_embedding, 'winner' : val_df['Winner']}
print("Datasets are prepared, you are all set!")

In [None]:
# Instantiate the model
model = ChessNet()

# Move tensors to device if CUDA or MPS is available
if torch.cuda.is_available():
    device = "cuda"
# elif torch.mps.is_available(): # For M series chips of Mac
#     device = "mps"
else:
    device = "cpu"

model = model.to(device)

if device == "cuda":
    # 获取当前GPU名字
    gpu_name = torch.cuda.get_device_name(torch.cuda.current_device())
    # 获取当前GPU总显存
    props = torch.cuda.get_device_properties(device)
    total_memory = props.total_memory / 1e9

    print("当前 GPU 型号是：{}，可用总显存为：{} GB".format(gpu_name, total_memory))

In [None]:
from model import ChessDataset

# Create Dataset objects for training and validation
train_dataset = ChessDataset(X_train, y_train)
val_dataset = ChessDataset(X_val, y_val)

# Create DataLoaders for batching
batch_size = 3000
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

for X_batch, y_batch in train_loader:
    # 获取一个批次的数据
    best_move_batch, winner_batch = y_batch  # 拆分 y_batch
    print("X_batch shape:", X_batch.shape)  # 打印 X 的维度
    print("best_move shape:", best_move_batch.shape)  # 打印 best_move 的维度
    print("winner shape:", winner_batch.shape)  # 打印 winner 的维度
    break  # 打印一个批次后停止

In [None]:
import torch.optim as optim
from train import AlphaLoss

n_epochs = 40
learning_rate = 0.0001 # 0.003 —> 0.0001
train_losses, val_losses = [], []
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = AlphaLoss().to(device)

for epoch in range(n_epochs):
    model.train()  # Set model to training mode
    epoch_train_loss = 0

    grads = {}
    for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{n_epochs} [Training]"):
        X_batch = X_batch.to(device)
        y_p = y_batch[0].to(device)
        y_v = y_batch[1].to(device).reshape(-1, 1)

        optimizer.zero_grad()
        predictions = model(X_batch)
        loss = loss_fn(y_v, predictions['v'], y_p, predictions['p'])
        loss.backward()

        # for name, param in model.named_parameters():
        #     if param.requires_grad and param.grad is not None:
        #         grads[name] = param.grad
        # print(grads)
    
        optimizer.step()
        epoch_train_loss += loss.item()

    # Average training loss for the epoch
    train_losses.append(epoch_train_loss / len(train_loader))
    print(f"Epoch {epoch + 1}: Training Loss = {train_losses[-1]}")

    # Validation loop
    model.eval()  # Set model to evaluation mode
    epoch_val_loss = 0
    with torch.no_grad():
        for X_batch, y_batch in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{n_epochs} [Validating]"):
            X_batch = X_batch.to(device)
            y_p = y_batch[0].to(device)
            y_v = y_batch[1].to(device).reshape(-1, 1)
            predictions = model(X_batch)
            loss = loss_fn(y_v, predictions['v'], y_p, predictions['p'])
            epoch_val_loss += loss.item()

    # Average validation loss for the epoch
    val_losses.append(epoch_val_loss / len(val_loader))
    print(f"Epoch {epoch + 1}: Validation Loss = {val_losses[-1]}")

In [None]:
# Plotting results
plt.style.use('ggplot')
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.legend()
plt.title('Loss During Training')
plt.show()