In [3]:
# Install necessary packages
!pip install python-chess stockfish torch

Collecting stockfish
  Downloading stockfish-3.28.0-py3-none-any.whl.metadata (12 kB)
Downloading stockfish-3.28.0-py3-none-any.whl (13 kB)
Installing collected packages: stockfish
Successfully installed stockfish-3.28.0


In [4]:
# Download Stockfish binary if not available
!apt-get install -y stockfish

# Confirm stockfish is installed
!stockfish

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
Suggested packages:
  polyglot xboard | scid
The following NEW packages will be installed:
  stockfish
0 upgraded, 1 newly installed, 0 to remove and 34 not upgraded.
Need to get 24.8 MB of archives.
After this operation, 47.4 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 stockfish amd64 14.1-1 [24.8 MB]
Fetched 24.8 MB in 1s (19.2 MB/s)
Selecting previously unselected package stockfish.
(Reading database ... 126333 files and directories currently installed.)
Preparing to unpack .../stockfish_14.1-1_amd64.deb ...
Unpacking stockfish (14.1-1) ...
Setting up stockfish (14.1-1) ...
Processing triggers for man-db (2.10.2-1) ...
/bin/bash: line 1: stockfish: command not found


In [5]:
import chess
import chess.pgn
from stockfish import Stockfish
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import pickle
import os

# Paths
STOCKFISH_PATH = "/usr/games/stockfish"

# Constants
FEATURES_PER_PIECE = 81  # 9x9 king-centered
NUM_PIECE_TYPES = 6      # pawn, knight, bishop, rook, queen, king
COLOR_MULTIPLIER = 2     # white/black
INPUT_SIZE = FEATURES_PER_PIECE * NUM_PIECE_TYPES * COLOR_MULTIPLIER * 2  # White King + Black King

# Stockfish setup
stockfish = Stockfish(path=STOCKFISH_PATH, parameters={"Threads": 2, "Hash": 128})

In [6]:
def improved_relative_square(king_square, square):
    king_file = chess.square_file(king_square)
    king_rank = chess.square_rank(king_square)
    sq_file = chess.square_file(square)
    sq_rank = chess.square_rank(square)

    df = sq_file - king_file + 4
    dr = sq_rank - king_rank + 4
    if 0 <= df <= 8 and 0 <= dr <= 8:
        return dr * 9 + df
    else:
        return None

def improved_extract_features(board):
    features = torch.zeros(INPUT_SIZE)
    white_king = board.king(chess.WHITE)
    black_king = board.king(chess.BLACK)

    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            piece_type = piece.piece_type - 1
            color_offset = 0 if piece.color == chess.WHITE else NUM_PIECE_TYPES

            if piece.color == chess.WHITE and white_king is not None:
                rel = improved_relative_square(white_king, square)
                if rel is not None:
                    idx = piece_type * 81 + rel + color_offset * 81
                    features[idx] = 1
            elif piece.color == chess.BLACK and black_king is not None:
                rel = improved_relative_square(black_king, square)
                if rel is not None:
                    idx = (NUM_PIECE_TYPES * COLOR_MULTIPLIER) * 81 + piece_type * 81 + rel + color_offset * 81
                    features[idx] = 1
    return features

In [7]:
class SimpleNNUE(nn.Module):
    def __init__(self):
        super(SimpleNNUE, self).__init__()
        self.fc1 = nn.Linear(INPUT_SIZE, 512)
        self.fc2 = nn.Linear(512, 32)
        self.fc3 = nn.Linear(32, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [None]:
def random_board():
    board = chess.Board()
    moves = random.randint(5, 30)
    for _ in range(moves):
        if board.is_game_over():
            break
        move = random.choice(list(board.legal_moves))
        board.push(move)
    return board

def generate_dataset(num_positions=500):
    data = []
    for i in range(num_positions):
        board = random_board()
        stockfish.set_fen_position(board.fen())
        eval_cp = stockfish.get_evaluation()

        if eval_cp['type'] == 'cp':
            score = eval_cp['value'] / 100.0  # centipawn to pawn units
        else:
            continue  # Skip mate scores

        data.append((board.fen(), score))

        if i % 100 == 0:
            print(f"Generated {i} positions")

    with open("training_data.pkl", "wb") as f:
        pickle.dump(data, f)

# Generate 5000 positions (adjust if needed to fit memory)
generate_dataset(5000)


Generated 0 positions
Generated 100 positions
Generated 200 positions
Generated 300 positions
Generated 400 positions
Generated 500 positions
Generated 600 positions


In [None]:
# Load dataset
with open("training_data.pkl", "rb") as f:
    dataset = pickle.load(f)

model = SimpleNNUE()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training loop
for epoch in range(10):
    total_loss = 0
    random.shuffle(dataset)

    for fen, score in dataset:
        board = chess.Board(fen)
        features = improved_extract_features(board).to(device)
        pred = model(features)
        target = torch.tensor([score], dtype=torch.float32).to(device)

        loss = loss_fn(pred, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss/len(dataset)}")

# Save trained model
torch.save(model.state_dict(), "nnue_trained.pth")

In [None]:
# Load trained model
model = SimpleNNUE()
model.load_state_dict(torch.load("nnue_trained.pth", map_location=device))
model.eval()

def evaluate_board(board):
    features = improved_extract_features(board).to(device)
    with torch.no_grad():
        eval_score = model(features)
    return eval_score.item()

# Example usage
board = chess.Board()
print("Starting position evaluation:", evaluate_board(board))

board.push_san("e4")
print("After 1. e4 evaluation:", evaluate_board(board))

board.push_san("e5")
print("After 1... e5 evaluation:", evaluate_board(board))
