In [1]:
# Install dependencies: run this in your environment (e.g. a notebook cell)
!pip install python-chess

import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import chess

# Device configuration (use CPU)
device = torch.device("cpu")

# Constants
BOARD_SIZE = 8
NUM_SQUARES = BOARD_SIZE * BOARD_SIZE  # 64
ACTION_SPACE_SIZE = NUM_SQUARES * NUM_SQUARES  # 4096 possible moves (from square * to square)
NUM_PIECE_TYPES = 13  # 0 = empty, 1-6 white pawn..king, 7-12 black pawn..king
NUM_TOKENS = 14  # 0-12 piece tokens + 13 as class token

def board_to_tensor(board):
    """
    Convert a python-chess board to a 1D array of token ids (length 64).
    Token mapping: 0=empty, 1-6 white pawn..king, 7-12 black pawn..king.
    """
    tensor = [0] * NUM_SQUARES
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece is not None:
            piece_type = piece.piece_type  # 1 to 6
            color = piece.color  # True for white, False for black
            token = piece_type
            if not color:
                token += 6  # black pieces 7-12
            tensor[square] = token
    return tensor

class ChessEnv:
    """
    Chess environment using python-chess. State is the board position.
    """
    def __init__(self):
        self.board = chess.Board()
    def reset(self):
        self.board.reset()
        return self.board
    def step(self, action_idx):
        """
        Apply an action (encoded as index) to the board.
        Returns (new_board, done_flag).
        """
        from_sq = action_idx // NUM_SQUARES
        to_sq = action_idx % NUM_SQUARES
        # Determine promotion if needed
        promotion = None
        piece = self.board.piece_type_at(from_sq)
        if piece == chess.PAWN:
            # if white pawn moving to last rank or black pawn to first rank
            if self.board.turn == chess.WHITE and to_sq // 8 == 7:
                promotion = chess.QUEEN
            elif self.board.turn == chess.BLACK and to_sq // 8 == 0:
                promotion = chess.QUEEN
        move = chess.Move(from_sq, to_sq, promotion=promotion) if promotion else chess.Move(from_sq, to_sq)
        if move in self.board.legal_moves:
            self.board.push(move)
        else:
            # Illegal move chosen: pick a random legal move instead
            legal_moves = list(self.board.legal_moves)
            if len(legal_moves) > 0:
                move = random.choice(legal_moves)
                self.board.push(move)
            # else game is over
        done = self.board.is_game_over()
        return self.board, done

class TransformerChessAgent(nn.Module):
    """
    Transformer-based policy and value network for chess.
    Input: board state (with CLS token)
    Outputs: policy logits (4096 actions) and value.
    """
    def __init__(self, d_model=128, nhead=8, num_layers=4, dropout=0.1):
        super(TransformerChessAgent, self).__init__()
        self.d_model = d_model
        # Token embedding: 14 tokens (0-12 pieces + 13 CLS)
        self.token_embed = nn.Embedding(NUM_TOKENS, d_model)
        # Positional embedding: 65 positions (0 for CLS + 1-64 squares)
        self.pos_embed = nn.Embedding(NUM_SQUARES+1, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                   dim_feedforward=256, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.policy_head = nn.Linear(d_model, ACTION_SPACE_SIZE)
        self.value_head = nn.Linear(d_model, 1)
    def forward(self, x):
        """
        x: Tensor of shape (batch_size, 64) containing board tokens 0-12
        We prepend a CLS token to each sequence, so input sequence length = 65.
        """
        batch_size = x.size(0)
        # Create CLS tokens and concatenate (token id = 13)
        cls_token_id = NUM_PIECE_TYPES  # 13
        cls_tokens = torch.full((batch_size, 1), cls_token_id, dtype=torch.long, device=x.device)
        x = torch.cat([cls_tokens, x], dim=1)  # now (batch, 65)
        # Create position ids (0 for CLS, 1-64 for board positions)
        pos_ids = torch.arange(0, NUM_SQUARES+1, device=x.device).unsqueeze(0).repeat(batch_size, 1)
        # Embeddings
        x_emb = self.token_embed(x) + self.pos_embed(pos_ids)  # (batch, 65, d_model)
        # Transformer expects (seq_len, batch, d_model)
        x_emb = x_emb.permute(1, 0, 2)  # (65, batch, d_model)
        # Transformer encode
        x_trans = self.transformer(x_emb)  # (65, batch, d_model)
        cls_out = x_trans[0]  # (batch, d_model) output for CLS token
        # Compute policy and value
        policy_logits = self.policy_head(cls_out)  # (batch, ACTION_SPACE_SIZE)
        value = self.value_head(cls_out).squeeze(-1)  # (batch,)
        return policy_logits, value

def select_move(model, board):
    """
    Given a model and current board (python-chess), select an action index and return it along with log_prob and value.
    """
    # Prepare state tensor
    state = board_to_tensor(board)
    state_tensor = torch.tensor([state], dtype=torch.long, device=device)  # (1, 64)
    # Model forward (training or eval)
    logits, value = model(state_tensor)  # (1, ACTION_SPACE_SIZE), (1,)
    value = value.squeeze(0)
    # Mask illegal moves
    legal_moves = list(board.legal_moves)
    legal_indices = []
    for move in legal_moves:
        # Only consider promotions to queen
        if move.promotion is not None and move.promotion != chess.QUEEN:
            continue
        idx = move.from_square * NUM_SQUARES + move.to_square
        legal_indices.append(idx)
    if len(legal_indices) == 0:
        return None, None, None  # no moves available (game over)
    legal_mask = torch.zeros(ACTION_SPACE_SIZE, dtype=torch.bool, device=device)
    legal_mask[legal_indices] = True
    masked_logits = logits.clone()
    masked_logits[0, ~legal_mask] = -1e9  # mask out illegal moves
    # Compute probabilities and select action
    probs = F.softmax(masked_logits, dim=-1)
    m = Categorical(probs)
    action = m.sample()
    log_prob = m.log_prob(action)
    return action.item(), log_prob, value

def play_game(model, env):
    """
    Play a single self-play game using the model for both players.
    Returns lists of log_probs, values, and rewards for each move.
    """
    state = env.reset()
    done = False
    log_probs = []
    values = []
    rewards = []
    colors = []
    while not done:
        action, log_prob, value = select_move(model, state)
        if action is None:
            break
        log_probs.append(log_prob)
        values.append(value)
        colors.append(state.turn)
        state, done = env.step(action)
    # Determine game result
    if env.board.is_checkmate():
        # If checkmate, the winner is opposite of turn (because turn failed to move)
        winner = not state.turn
    else:
        # Stalemate or draw
        winner = None
    for color in colors:
        if winner is None:
            rewards.append(0.0)
        else:
            rewards.append(1.0 if color == winner else -1.0)
    return log_probs, values, rewards

def train_model(num_games=1000, learning_rate=1e-4, save_interval=100, checkpoint_path='chess_agent.pth'):
    """
    Training loop for self-play training.
    """
    model = TransformerChessAgent().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    # Load from checkpoint if exists
    try:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        print(f"Loaded checkpoint from {checkpoint_path}")
    except FileNotFoundError:
        print("No checkpoint found, starting fresh training.")
    env = ChessEnv()
    for game in range(1, num_games+1):
        log_probs, values, rewards = play_game(model, env)
        if not log_probs:
            continue  # skip if no moves
        # Compute losses
        policy_loss = 0.0
        value_loss = 0.0
        for log_prob, value, reward in zip(log_probs, values, rewards):
            advantage = reward - value.item()
            policy_loss = policy_loss - log_prob * advantage
            value_loss = value_loss + 0.5 * (value - reward) ** 2
        loss = policy_loss + value_loss
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Logging
        if game % 10 == 0:
            print(f"Game {game}: Loss={loss.item():.4f}")
        # Save checkpoint
        if game % save_interval == 0:
            torch.save({'model_state': model.state_dict(),
                        'optimizer_state': optimizer.state_dict()},
                       checkpoint_path)
            print(f"Checkpoint saved at game {game}")
    return model

def play_against_human(model):
    """
    Let a human play against the trained model in terminal.
    Human plays White, model plays Black.
    """
    board = chess.Board()
    print("Starting a new game. You are White.")
    while not board.is_game_over():
        print(board)
        human_move = input("Your move (in UCI, e.g. e2e4): ")
        try:
            move = chess.Move.from_uci(human_move.strip())
        except:
            print("Invalid move format. Try again.")
            continue
        if move not in board.legal_moves:
            print("Illegal move. Try again.")
            continue
        board.push(move)
        if board.is_game_over():
            break
        action, _, _ = select_move(model, board)
        if action is None:
            print("Model has no moves. Game over.")
            break
        from_sq = action // NUM_SQUARES
        to_sq = action % NUM_SQUARES
        move = chess.Move(from_sq, to_sq)
        if move in board.legal_moves:
            board.push(move)
            print(f"Model plays: {move}")
        else:
            legal_moves = list(board.legal_moves)
            if legal_moves:
                move = random.choice(legal_moves)
                board.push(move)
                print(f"Model plays random: {move}")
    print(board)
    result = board.result()
    print("Game over. Result:", result)

# Example usage:
# model = train_model(num_games=1000)
# play_against_human(model)


Collecting python-chess
  Downloading python_chess-1.999-py3-none-any.whl.metadata (776 bytes)
Collecting chess<2,>=1 (from python-chess)
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading python_chess-1.999-py3-none-any.whl (1.4 kB)
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=2b1a0f1b11951dfacc27ce5f07ace0ac25b7e0412312869d8f9dd90a63d77b31
  Stored in directory: /root/.cache/pip/wheels/83/1f/4e/8f4300f7dd554eb8de70ddfed96e94d3d030ace10c5b53d447
Successfully built chess
Installing collected packages: chess, python-chess
Successfully installed chess-1.11.2 python-chess-1.999


In [None]:
model = train_model(num_games=10000)



No checkpoint found, starting fresh training.
Game 10: Loss=-140.7212
Game 20: Loss=58.5411
Game 30: Loss=77.6480
Game 40: Loss=-40.2849
Game 50: Loss=187.5981
Game 60: Loss=107.1049
Game 70: Loss=0.8397
Game 80: Loss=79.8426
Game 90: Loss=13.1001
Game 100: Loss=6.4393
Checkpoint saved at game 100
Game 110: Loss=11.6769
Game 120: Loss=-1.2644
Game 130: Loss=23.2603
Game 140: Loss=-22.3618
Game 150: Loss=-10.7768
Game 160: Loss=-10.1002
Game 170: Loss=56.0475
Game 180: Loss=3.3601
Game 190: Loss=-16.4753
Game 200: Loss=116.8346
Checkpoint saved at game 200
Game 210: Loss=82.9838
Game 220: Loss=-2.2993
Game 230: Loss=-17.5820
Game 240: Loss=152.9500
Game 250: Loss=-6.7169
Game 260: Loss=-10.7660
Game 270: Loss=-1.8296
Game 280: Loss=-6.4584
Game 290: Loss=1.0208
Game 300: Loss=-11.6201
Checkpoint saved at game 300
Game 310: Loss=-3.7800
Game 320: Loss=-5.8576
