In [None]:
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import chess
import chess.engine
import pickle
from torch.utils.data import DataLoader, TensorDataset

# ChessNet Model 
class ChessNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(12, 128, 3, padding=1)
        self.conv2 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv3 = nn.Conv2d(256, 256, 3, padding=1)
        self.fc1 = nn.Linear(256*8*8, 1024)
        self.policy_head = nn.Linear(1024, 4672)
        self.value_head = nn.Linear(1024, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 256*8*8)
        x = F.relu(self.fc1(x))
        policy = self.policy_head(x)
        value = torch.tanh(self.value_head(x))
        return policy, value

In [None]:
# Board to Tensor 
def board_to_tensor(board):
    piece_map = {
        'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
        'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
    }
    tensor = np.zeros((12, 8, 8), dtype=np.float32)
    for square, piece in board.piece_map().items():
        idx = piece_map[piece.symbol()]
        row, col = divmod(square, 8)
        tensor[idx, row, col] = 1
    return torch.tensor(tensor)

In [None]:
# Game Generation 
def generate_stockfish_games(num_games=100, max_moves=40, depth=4, stockfish_path="/usr/games/stockfish", save_dir=Path("games")):
    save_dir.mkdir(parents=True, exist_ok=True)
    game_file = save_dir / f"games_{num_games}_d{depth}.pkl"

    if game_file.exists():
        print("Loading existing game data...")
        with open(game_file, "rb") as f:
            return pickle.load(f)

    print("Generating new games...")
    engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
    data = []
    for _ in range(num_games):
        board = chess.Board()
        for _ in range(max_moves):
            if board.is_game_over():
                break
            result = engine.play(board, chess.engine.Limit(depth=depth))
            move = result.move
            state = board_to_tensor(board)
            move_idx = move.from_square * 64 + move.to_square
            data.append((state, move_idx))
            board.push(move)
    engine.quit()

    with open(game_file, "wb") as f:
        pickle.dump(data, f)

    return data

In [None]:
# Pretraining
def supervised_pretrain(model, data, epochs=5, batch_size=32, pretrained_path=Path("chess_pretrained.pth")):
    if pretrained_path.exists():
        print("Pretrained model found. Skipping training.")
        return

    print("Starting supervised pretraining...")
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    states = torch.stack([d[0] for d in data])
    moves = torch.tensor([d[1] for d in data])
    dataset = TensorDataset(states, moves)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for epoch in range(epochs):
        total_loss = 0
        for x, y in loader:
            optimizer.zero_grad()
            policy, _ = model(x)
            loss = F.cross_entropy(policy, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")
    torch.save(model.state_dict(), pretrained_path)
    print(f"Saved pretrained model to {pretrained_path}")

In [None]:
# Checkpointing 
def save_checkpoint(model, optimizer, episode, path):
    torch.save({
        'episode': episode,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)

def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['episode']

# RL Fine-tuning 
def rl_finetune(model, episodes=1000, max_moves=40, depth=4, checkpoint_interval=100, checkpoint_dir=Path("checkpoints"), stockfish_path="/usr/games/stockfish"):
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    ckpts = sorted(checkpoint_dir.glob("*.pth"))
    last_episode = 0
    if ckpts:
        latest_ckpt = ckpts[-1]
        last_episode = load_checkpoint(model, optimizer, latest_ckpt)
        print(f"Resuming from {latest_ckpt.name} (Episode {last_episode})")
    else:
        print("Starting fresh RL training...")

    engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)

    for ep in range(last_episode, episodes):
        board = chess.Board()
        states, actions, rewards = [], [], []

        for _ in range(max_moves):
            if board.is_game_over():
                break
            state = board_to_tensor(board).unsqueeze(0)
            with torch.no_grad():
                policy_logits, _ = model(state)
            legal_moves = list(board.legal_moves)
            move_indices = [m.from_square * 64 + m.to_square for m in legal_moves]
            policy = torch.softmax(policy_logits[0][move_indices], dim=0).numpy()
            move = np.random.choice(legal_moves, p=policy)
            move_idx = move.from_square * 64 + move.to_square
            states.append(state.squeeze(0))
            actions.append(move_idx)
            board.push(move)
            if board.is_game_over():
                break
            result = engine.play(board, chess.engine.Limit(depth=depth))
            board.push(result.move)

        outcome = board.result()
        reward = 1 if outcome == "1-0" else -1 if outcome == "0-1" else 0
        rewards = [reward] * len(actions)

        for s, a, r in zip(states, actions, rewards):
            optimizer.zero_grad()
            policy_logits, value = model(s.unsqueeze(0))
            loss = F.cross_entropy(policy_logits, torch.tensor([a])) - r * value
            loss.backward()
            optimizer.step()

        if (ep + 1) % checkpoint_interval == 0:
            ckpt_path = checkpoint_dir / f"checkpoint_ep{ep+1}.pth"
            save_checkpoint(model, optimizer, ep + 1, ckpt_path)
            print(f"Saved checkpoint at episode {ep+1}")

    engine.quit()

In [None]:
def save_model(model, path="chess_rl_model.pth"):
    torch.save(model.state_dict(), path)

def load_model(model, path="chess_rl_model.pth"):
    model.load_state_dict(torch.load(path))


In [None]:
# SET PATHS
stockfish_path = "/usr/games/stockfish"
data_dir = Path("games")
checkpoint_dir = Path("checkpoints")
pretrained_path = Path("chess_pretrained.pth")
final_model_path = Path("chess_rl_model_final.pth")

# Step 1: Load or generate game data
data = generate_stockfish_games(num_games=200, max_moves=40, depth=4,
                                stockfish_path=stockfish_path, save_dir=data_dir)

# Step 2: Pretrain if not already done
model = ChessNet()
supervised_pretrain(model, data, epochs=3, batch_size=32, pretrained_path=pretrained_path)

# Step 3: RL fine-tuning (resumes if checkpoints exist)
load_model(model, pretrained_path)
rl_finetune(model, episodes=300, max_moves=40, depth=4,
            checkpoint_interval=50, checkpoint_dir=checkpoint_dir,
            stockfish_path=stockfish_path)

# Step 4: Save final model
save_model(model, final_model_path)
print(f"Final model saved to {final_model_path}")


Generating new games...
Starting supervised pretraining...
Epoch 1, Loss: 5.9100
Epoch 2, Loss: 4.9378
Epoch 3, Loss: 4.5588
Saved pretrained model to chess_pretrained.pth
Starting fresh RL training...
Saved checkpoint at episode 50
Saved checkpoint at episode 100
Saved checkpoint at episode 150
Saved checkpoint at episode 200
Saved checkpoint at episode 250
Saved checkpoint at episode 300
Final model saved to chess_rl_model_final.pth


In [None]:
from google.colab import files

# Download pretrained model
files.download("chess_pretrained.pth")

# Download final RL-trained model
files.download("chess_rl_model_final.pth")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
import shutil

# Zip the games folder
shutil.make_archive("games", 'zip', "games")

# Zip the checkpoints folder
shutil.make_archive("checkpoints", 'zip', "checkpoints")

# Download the zipped folders
files.download("games.zip")
files.download("checkpoints.zip")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>