In [None]:
# Advanced Chess Transformer with Novel Architectures and Training Techniques
# Publication-Grade Implementation with Multiple Innovations

import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import chess
import chess.engine
from collections import deque, namedtuple
import json
import time
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict
import threading
import queue

# Device configuration
device = torch.device("cpu")

# Constants
BOARD_SIZE = 8
NUM_SQUARES = BOARD_SIZE * BOARD_SIZE  # 64
ACTION_SPACE_SIZE = NUM_SQUARES * NUM_SQUARES  # 4096
NUM_PIECE_TYPES = 13  # 0 = empty, 1-6 white pawn..king, 7-12 black pawn..king
NUM_TOKENS = 16  # Extended for special tokens

# Special tokens
EMPTY_TOKEN = 0
CLS_TOKEN = 13
MASK_TOKEN = 14
ENDGAME_TOKEN = 15

# Game phase detection thresholds
OPENING_MOVES = 15
ENDGAME_MATERIAL = 13  # Total piece value threshold

Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done', 'value', 'policy'])

@dataclass
class TrainingConfig:
    """Configuration for training hyperparameters"""
    d_model: int = 256
    nhead: int = 16
    num_layers: int = 8
    dropout: float = 0.1
    learning_rate: float = 3e-4
    batch_size: int = 32
    memory_size: int = 100000
    target_update_freq: int = 1000
    exploration_noise: float = 0.3
    temperature: float = 1.0
    lambda_value: float = 0.95
    entropy_coeff: float = 0.01

class PositionalEncoding(nn.Module):
    """Learnable positional encoding with chess-specific geometry"""

    def __init__(self, d_model: int, max_len: int = 65):
        super().__init__()
        # Standard positional encoding
        self.pos_embed = nn.Embedding(max_len, d_model)

        # Chess-specific encodings
        self.rank_embed = nn.Embedding(8, d_model // 4)  # Rank (1-8)
        self.file_embed = nn.Embedding(8, d_model // 4)  # File (a-h)
        self.diagonal_embed = nn.Embedding(15, d_model // 4)  # Diagonals
        self.color_embed = nn.Embedding(2, d_model // 4)  # Square color

    def forward(self, positions):
        batch_size, seq_len = positions.shape

        # Standard positional encoding
        pos_enc = self.pos_embed(positions)

        # Chess-specific encodings for board squares (skip CLS token)
        board_positions = positions[:, 1:]  # Skip CLS token

        # Compute rank, file, diagonal, and color for each square
        ranks = board_positions // 8
        files = board_positions % 8
        diagonals = ranks + files  # Main diagonal encoding
        colors = (ranks + files) % 2  # Square color (0=dark, 1=light)

        # Get embeddings
        rank_enc = self.rank_embed(ranks)
        file_enc = self.file_embed(files)
        diag_enc = self.diagonal_embed(diagonals)
        color_enc = self.color_embed(colors)

        # Combine chess-specific encodings
        chess_enc = torch.cat([rank_enc, file_enc, diag_enc, color_enc], dim=-1)

        # Add zero encoding for CLS token
        cls_enc = torch.zeros(batch_size, 1, self.d_model, device=positions.device)
        full_chess_enc = torch.cat([cls_enc, chess_enc], dim=1)

        return pos_enc + full_chess_enc

class MultiScaleAttention(nn.Module):
    """Multi-scale attention for capturing both local and global patterns"""

    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.local_attention = nn.MultiheadAttention(d_model, num_heads // 2, batch_first=True)
        self.global_attention = nn.MultiheadAttention(d_model, num_heads // 2, batch_first=True)
        self.combine = nn.Linear(2 * d_model, d_model)

    def create_local_mask(self, seq_len: int, window_size: int = 5):
        """Create attention mask for local patterns (nearby squares)"""
        mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
        for i in range(1, seq_len):  # Skip CLS token
            square = i - 1
            rank, file = square // 8, square % 8

            # Allow attention to nearby squares
            for j in range(1, seq_len):
                other_square = j - 1
                other_rank, other_file = other_square // 8, other_square % 8

                distance = max(abs(rank - other_rank), abs(file - other_file))
                if distance <= window_size:
                    mask[i, j] = True

        # CLS token can attend to all
        mask[0, :] = True
        mask[:, 0] = True

        return ~mask  # Invert for attention mask

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        # Local attention with restricted window
        local_mask = self.create_local_mask(seq_len).to(x.device)
        local_out, _ = self.local_attention(x, x, x, attn_mask=local_mask)

        # Global attention (unrestricted)
        global_out, _ = self.global_attention(x, x, x)

        # Combine and project
        combined = torch.cat([local_out, global_out], dim=-1)
        return self.combine(combined)

class ChessTransformerBlock(nn.Module):
    """Enhanced transformer block with chess-specific components"""

    def __init__(self, d_model: int, nhead: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MultiScaleAttention(d_model, nhead)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Enhanced feed-forward with gating
        self.ff1 = nn.Linear(d_model, 4 * d_model)
        self.ff2 = nn.Linear(4 * d_model, d_model)
        self.gate = nn.Linear(d_model, 4 * d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Multi-scale attention
        attn_out = self.attention(x)
        x = self.norm1(x + self.dropout(attn_out))

        # Gated feed-forward
        ff_out = self.ff1(x)
        gate_out = torch.sigmoid(self.gate(x))
        ff_out = ff_out * gate_out
        ff_out = self.dropout(F.gelu(ff_out))
        ff_out = self.ff2(ff_out)

        return self.norm2(x + self.dropout(ff_out))

class GamePhaseEncoder(nn.Module):
    """Encode game phase information (opening/middlegame/endgame)"""

    def __init__(self, d_model: int):
        super().__init__()
        self.phase_embed = nn.Embedding(3, d_model)  # 0=opening, 1=middle, 2=endgame
        self.move_embed = nn.Embedding(200, d_model)  # Move number encoding
        self.material_proj = nn.Linear(1, d_model)

    def detect_phase(self, board_tensor, move_count):
        """Detect game phase based on material and move count"""
        # Count material (simple heuristic)
        material_count = torch.sum(board_tensor > 0, dim=1).float()

        phase = torch.zeros(board_tensor.shape[0], dtype=torch.long, device=board_tensor.device)

        # Opening phase
        opening_mask = move_count < OPENING_MOVES
        phase[opening_mask] = 0

        # Endgame phase
        endgame_mask = material_count < ENDGAME_MATERIAL
        phase[endgame_mask] = 2

        # Middle game (default)
        middle_mask = ~(opening_mask | endgame_mask)
        phase[middle_mask] = 1

        return phase

    def forward(self, board_tensor, move_count):
        batch_size = board_tensor.shape[0]

        # Detect game phase
        phase = self.detect_phase(board_tensor, move_count)
        phase_enc = self.phase_embed(phase)

        # Move count encoding
        move_count_clamped = torch.clamp(move_count, 0, 199).long()
        move_enc = self.move_embed(move_count_clamped)

        # Material encoding
        material = torch.sum(board_tensor > 0, dim=1, keepdim=True).float()
        material_enc = self.material_proj(material / 32.0)  # Normalize

        return phase_enc + move_enc + material_enc

class AdvancedChessTransformer(nn.Module):
    """Advanced Chess Transformer with multiple innovations"""

    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.config = config
        self.d_model = config.d_model

        # Token and positional embeddings
        self.token_embed = nn.Embedding(NUM_TOKENS, config.d_model)
        self.pos_encoding = PositionalEncoding(config.d_model)
        self.phase_encoder = GamePhaseEncoder(config.d_model)

        # Transformer layers
        self.layers = nn.ModuleList([
            ChessTransformerBlock(config.d_model, config.nhead, config.dropout)
            for _ in range(config.num_layers)
        ])

        # Output heads with separate processing
        self.policy_head = nn.ModuleList([
            nn.Linear(config.d_model, config.d_model),
            nn.LayerNorm(config.d_model),
            nn.Linear(config.d_model, ACTION_SPACE_SIZE)
        ])

        self.value_head = nn.ModuleList([
            nn.Linear(config.d_model, config.d_model // 2),
            nn.LayerNorm(config.d_model // 2),
            nn.Linear(config.d_model // 2, 1)
        ])

        # Auxiliary heads for self-supervised learning
        self.piece_prediction_head = nn.Linear(config.d_model, NUM_PIECE_TYPES)
        self.phase_prediction_head = nn.Linear(config.d_model, 3)

        self.dropout = nn.Dropout(config.dropout)
        self._init_weights()

    def _init_weights(self):
        """Initialize weights with better initialization"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, 0, 0.02)

    def forward(self, board_tensor, move_count=None, return_aux=False):
        batch_size, seq_len = board_tensor.shape

        # Default move count if not provided
        if move_count is None:
            move_count = torch.zeros(batch_size, device=board_tensor.device)

        # Add CLS token
        cls_tokens = torch.full((batch_size, 1), CLS_TOKEN, dtype=torch.long, device=board_tensor.device)
        x = torch.cat([cls_tokens, board_tensor], dim=1)

        # Create position indices
        pos_ids = torch.arange(seq_len + 1, device=board_tensor.device).unsqueeze(0).repeat(batch_size, 1)

        # Embeddings
        x = self.token_embed(x) + self.pos_encoding(pos_ids)

        # Add game phase information
        phase_info = self.phase_encoder(board_tensor, move_count)
        x[:, 0] += phase_info  # Add to CLS token

        # Apply transformer layers
        for layer in self.layers:
            x = layer(x)

        # Extract CLS token representation
        cls_repr = x[:, 0]

        # Policy head
        policy = cls_repr
        for layer in self.policy_head[:-1]:
            if isinstance(layer, nn.LayerNorm):
                policy = layer(policy)
            else:
                policy = F.gelu(layer(policy))
        policy_logits = self.policy_head[-1](self.dropout(policy))

        # Value head
        value = cls_repr
        for layer in self.value_head[:-1]:
            if isinstance(layer, nn.LayerNorm):
                value = layer(value)
            else:
                value = F.gelu(layer(value))
        value = self.value_head[-1](self.dropout(value)).squeeze(-1)

        if return_aux:
            # Auxiliary predictions for self-supervised learning
            piece_logits = self.piece_prediction_head(x[:, 1:])  # Skip CLS
            phase_logits = self.phase_prediction_head(cls_repr)
            return policy_logits, value, piece_logits, phase_logits

        return policy_logits, value

class PrioritizedReplayBuffer:
    """Prioritized Experience Replay with improved sampling"""

    def __init__(self, capacity: int, alpha: float = 0.6, beta: float = 0.4):
        self.capacity = capacity
        self.alpha = alpha
        self.beta = beta
        self.buffer = []
        self.priorities = np.zeros(capacity, dtype=np.float32)
        self.position = 0
        self.max_priority = 1.0

    def add(self, experience: Experience):
        if len(self.buffer) < self.capacity:
            self.buffer.append(experience)
        else:
            self.buffer[self.position] = experience

        # Assign max priority to new experience
        self.priorities[self.position] = self.max_priority
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size: int):
        if len(self.buffer) == 0:
            return None

        # Calculate sampling probabilities
        priorities = self.priorities[:len(self.buffer)]
        probabilities = priorities ** self.alpha
        probabilities /= probabilities.sum()

        # Sample indices
        indices = np.random.choice(len(self.buffer), batch_size, p=probabilities, replace=True)

        # Calculate importance sampling weights
        weights = (len(self.buffer) * probabilities[indices]) ** (-self.beta)
        weights /= weights.max()

        # Get experiences
        experiences = [self.buffer[i] for i in indices]

        return experiences, indices, weights

    def update_priorities(self, indices, priorities):
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority
            self.max_priority = max(self.max_priority, priority)

class ChessEnvironment:
    """Enhanced Chess Environment with game analysis"""

    def __init__(self, use_engine=False, engine_depth=5):
        self.board = chess.Board()
        self.move_history = []
        self.use_engine = use_engine
        self.engine_depth = engine_depth
        self.engine = None

        if use_engine:
            try:
                # Try to use Stockfish if available
                self.engine = chess.engine.SimpleEngine.popen_uci("/usr/local/bin/stockfish")
            except:
                print("Stockfish not found, continuing without engine evaluation")

    def reset(self):
        self.board.reset()
        self.move_history = []
        return self.board_to_tensor(), self.get_move_count()

    def board_to_tensor(self):
        """Convert board to tensor with improved encoding"""
        tensor = [EMPTY_TOKEN] * NUM_SQUARES
        for square in chess.SQUARES:
            piece = self.board.piece_at(square)
            if piece is not None:
                piece_type = piece.piece_type
                color = piece.color
                token = piece_type if color else piece_type + 6
                tensor[square] = token
        return torch.tensor(tensor, dtype=torch.long, device=device)

    def get_move_count(self):
        return len(self.move_history)

    def step(self, action_idx):
        from_sq = action_idx // NUM_SQUARES
        to_sq = action_idx % NUM_SQUARES

        # Handle promotion
        promotion = None
        piece = self.board.piece_type_at(from_sq)
        if piece == chess.PAWN:
            if self.board.turn == chess.WHITE and to_sq // 8 == 7:
                promotion = chess.QUEEN
            elif self.board.turn == chess.BLACK and to_sq // 8 == 0:
                promotion = chess.QUEEN

        move = chess.Move(from_sq, to_sq, promotion=promotion)

        if move in self.board.legal_moves:
            self.board.push(move)
            self.move_history.append(move)
            reward = self.calculate_reward()
        else:
            # Penalty for illegal moves
            reward = -1.0
            # Make random legal move as fallback
            legal_moves = list(self.board.legal_moves)
            if legal_moves:
                move = random.choice(legal_moves)
                self.board.push(move)
                self.move_history.append(move)

        done = self.board.is_game_over()
        return self.board_to_tensor(), self.get_move_count(), reward, done

    def calculate_reward(self):
        """Calculate reward based on multiple factors"""
        if self.board.is_checkmate():
            return 10.0 if self.board.turn == chess.BLACK else -10.0
        elif self.board.is_stalemate() or self.board.is_insufficient_material():
            return 0.0
        elif self.board.is_check():
            return 0.5 if self.board.turn == chess.BLACK else -0.5

        # Small reward for making legal moves
        return 0.01

    def get_engine_evaluation(self):
        """Get engine evaluation if available"""
        if self.engine is None:
            return None
        try:
            info = self.engine.analyse(self.board, chess.engine.Limit(depth=self.engine_depth))
            return info["score"].relative.score(mate_score=10000) / 100.0
        except:
            return None

    def __del__(self):
        if self.engine:
            self.engine.quit()

class AdvancedTrainer:
    """Advanced training system with multiple techniques"""

    def __init__(self, config: TrainingConfig):
        self.config = config
        self.model = AdvancedChessTransformer(config).to(device)
        self.target_model = AdvancedChessTransformer(config).to(device)
        self.target_model.load_state_dict(self.model.state_dict())

        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=config.learning_rate,
            weight_decay=1e-4
        )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=1000, eta_min=1e-6
        )

        self.replay_buffer = PrioritizedReplayBuffer(config.memory_size)
        self.env = ChessEnvironment()

        # Training statistics
        self.stats = {
            'games_played': 0,
            'total_rewards': [],
            'policy_losses': [],
            'value_losses': [],
            'aux_losses': []
        }

    def select_move(self, board_tensor, move_count, training=True, temperature=1.0):
        """Select move with improved exploration"""
        self.model.eval() if not training else self.model.train()

        with torch.no_grad() if not training else torch.enable_grad():
            batch_tensor = board_tensor.unsqueeze(0)
            batch_count = torch.tensor([move_count], device=device)

            policy_logits, value = self.model(batch_tensor, batch_count)

            # Apply temperature
            if temperature != 1.0:
                policy_logits = policy_logits / temperature

            # Mask illegal moves
            legal_moves = list(self.env.board.legal_moves)
            if not legal_moves:
                return None, None, None, None

            legal_indices = []
            for move in legal_moves:
                if move.promotion and move.promotion != chess.QUEEN:
                    continue
                idx = move.from_square * NUM_SQUARES + move.to_square
                legal_indices.append(idx)

            if not legal_indices:
                return None, None, None, None

            # Create mask
            mask = torch.full((ACTION_SPACE_SIZE,), -1e9, device=device)
            mask[legal_indices] = 0

            masked_logits = policy_logits.squeeze(0) + mask

            # Sample action
            if training and random.random() < self.config.exploration_noise:
                # Random exploration
                action = random.choice(legal_indices)
                probs = F.softmax(masked_logits, dim=-1)
                log_prob = torch.log(probs[action] + 1e-8)
            else:
                # Policy sampling
                probs = F.softmax(masked_logits, dim=-1)
                m = Categorical(probs)
                action = m.sample()
                log_prob = m.log_prob(action)
                action = action.item()

            return action, log_prob, value.squeeze(0), masked_logits

    def play_game(self, training=True):
        """Play a single game with experience collection"""
        state, move_count = self.env.reset()
        done = False
        experiences = []
        total_reward = 0

        while not done and len(experiences) < 200:  # Max game length
            action, log_prob, value, policy_logits = self.select_move(
                state, move_count, training, self.config.temperature
            )

            if action is None:
                break

            next_state, next_move_count, reward, done = self.env.step(action)
            total_reward += reward

            if training:
                experience = Experience(
                    state=state.cpu(),
                    action=action,
                    reward=reward,
                    next_state=next_state.cpu(),
                    done=done,
                    value=value.cpu() if value is not None else torch.tensor(0.0),
                    policy=policy_logits.cpu() if policy_logits is not None else torch.zeros(ACTION_SPACE_SIZE)
                )
                experiences.append(experience)

            state = next_state
            move_count = next_move_count

        # Add experiences to replay buffer with TD-lambda returns
        if training and experiences:
            returns = self.calculate_td_lambda_returns(experiences)
            for exp, ret in zip(experiences, returns):
                updated_exp = Experience(
                    exp.state, exp.action, ret, exp.next_state,
                    exp.done, exp.value, exp.policy
                )
                self.replay_buffer.add(updated_exp)

        return total_reward, len(experiences)

    def calculate_td_lambda_returns(self, experiences):
        """Calculate TD-lambda returns for better value estimation"""
        returns = []
        g = 0

        for i in reversed(range(len(experiences))):
            exp = experiences[i]
            if exp.done:
                g = exp.reward
            else:
                g = exp.reward + 0.99 * (self.config.lambda_value * g + (1 - self.config.lambda_value) * exp.value.item())
            returns.insert(0, g)

        return returns

    def train_batch(self):
        """Train on a batch of experiences"""
        if len(self.replay_buffer.buffer) < self.config.batch_size:
            return

        # Sample batch
        batch_data = self.replay_buffer.sample(self.config.batch_size)
        if batch_data is None:
            return

        experiences, indices, weights = batch_data

        # Prepare batch tensors
        states = torch.stack([exp.state for exp in experiences]).to(device)
        actions = torch.tensor([exp.action for exp in experiences], device=device)
        rewards = torch.tensor([exp.reward for exp in experiences], device=device, dtype=torch.float)
        values = torch.stack([exp.value for exp in experiences]).to(device)

        # Forward pass
        policy_logits, pred_values, piece_logits, phase_logits = self.model(
            states,
            torch.zeros(len(experiences), device=device),
            return_aux=True
        )

        # Policy loss
        policy_dist = Categorical(logits=policy_logits)
        log_probs = policy_dist.log_prob(actions)
        advantages = rewards - pred_values.detach()
        policy_loss = -(log_probs * advantages).mean()

        # Value loss
        value_loss = F.mse_loss(pred_values, rewards)

        # Auxiliary losses for better representation learning
        # Piece prediction loss (self-supervised)
        piece_targets = states  # Predict the pieces on the board
        aux_loss = F.cross_entropy(
            piece_logits.reshape(-1, NUM_PIECE_TYPES),
            piece_targets.reshape(-1),
            ignore_index=EMPTY_TOKEN
        )

        # Entropy regularization
        entropy_loss = -policy_dist.entropy().mean()

        # Combined loss
        total_loss = (policy_loss +
                     value_loss +
                     0.1 * aux_loss +
                     self.config.entropy_coeff * entropy_loss)

        # Apply importance sampling weights
        weights_tensor = torch.tensor(weights, device=device, dtype=torch.float)
        total_loss = (total_loss * weights_tensor).mean()

        # Backward pass
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        self.scheduler.step()

        # Update priorities
        td_errors = torch.abs(rewards - pred_values).detach().cpu().numpy()
        new_priorities = td_errors + 1e-6
        self.replay_buffer.update_priorities(indices, new_priorities)

        # Update statistics
        self.stats['policy_losses'].append(policy_loss.item())
        self.stats['value_losses'].append(value_loss.item())
        self.stats['aux_losses'].append(aux_loss.item())

    def train(self, num_games=1000, save_interval=100):
        """Main training loop"""
        print(f"Starting training with {sum(p.numel() for p in self.model.parameters())} parameters")

        for game in range(1, num_games + 1):
            # Play game and collect experience
            reward, moves = self.play_game(training=True)
            self.stats['games_played'] += 1
            self.stats['total_rewards'].append(reward)

            # Train on batch
            if game % 4 == 0:  # Train every 4 games
                for _ in range(2):  # Multiple training steps
                    self.train_batch()

            # Update target network
            if game % self.config.target_update_freq == 0:
                self.target_model.load_state_dict(self.model.state_dict())

            # Logging
            if game % 10 == 0:
                avg_reward = np.mean(self.stats['total_rewards'][-10:])
                avg_policy_loss = np.mean(self.stats['policy_losses'][-10:]) if self.stats['policy_losses'] else 0
                avg_value_loss = np.mean(self.stats['value_losses'][-10:]) if self.stats['value_losses'] else 0

                print(f"Game {game}: Reward={avg_reward:.3f}, "
                      f"Policy Loss={avg_policy_loss:.4f}, "
                      f"Value Loss={avg_value_loss:.4f}, "
                      f"Moves={moves}")

            # Save checkpoint
            if game % save_interval == 0:
                self.save_checkpoint(f"advanced_chess_model_game_{game}.pth")
                print(f"Checkpoint saved at game {game}")

        return self.model

    def save_checkpoint(self, path):
        """Save training checkpoint"""
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'config': self.config,
            'stats': self.stats
        }, path)

    def load_checkpoint(self, path):
        """Load training checkpoint"""
        checkpoint = torch.load(path, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.stats = checkpoint['stats']
        print(f"Loaded checkpoint from {path}")

def play_against_human(model_path, config=None):
    """Play against the trained model"""
    if config is None:
        config = TrainingConfig()

    # Load model
    model = AdvancedChessTransformer(config).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    env = ChessEnvironment()
    board_state, move_count = env.reset()

    print("üöÄ Advanced Chess AI - Ready to play!")
    print("You are White. Enter moves in UCI format (e.g., e2e4)")
    print("Type 'quit' to exit, 'hint' for AI suggestion")

    while not env.board.is_game_over():
        print(f"\n{env.board}")
        print(f"Move {move_count + 1}")

        if env.board.turn == chess.WHITE:  # Human turn
            move_input = input("Your move: ").strip().lower()

            if move_input == 'quit':
                break
            elif move_input == 'hint':
                # Get AI suggestion
                with torch.no_grad():
                    policy_logits, value = model(board_state.unsqueeze(0), torch.tensor([move_count]))

                    # Find best legal move
                    legal_moves = list(env.board.legal_moves)
                    best_score = -float('inf')
                    best_move = None

                    for move in legal_moves:
                        if move.promotion and move.promotion != chess.QUEEN:
                            continue
                        idx = move.from_square * NUM_SQUARES + move.to_square
                        score = policy_logits[0, idx].item()
                        if score > best_score:
                            best_score = score
                            best_move = move

                    print(f"üí° AI suggests: {best_move} (confidence: {torch.softmax(policy_logits, dim=-1)[0, idx]:.3f})")
                    print(f"üìä Position evaluation: {value.item():.3f}")
                continue

            try:
                move = chess.Move.from_uci(move_input)
                if move not in env.board.legal_moves:
                    print("‚ùå Illegal move! Try again.")
                    continue

                env.board.push(move)
                env.move_history.append(move)
                board_state = env.board_to_tensor()
                move_count += 1

            except ValueError:
                print("‚ùå Invalid move format! Use UCI notation (e.g., e2e4)")
                continue

        else:  # AI turn
            print("ü§ñ AI is thinking...")
            start_time = time.time()

            with torch.no_grad():
                policy_logits, value = model(board_state.unsqueeze(0), torch.tensor([move_count]))

                # Apply softmax with temperature for more interesting play
                temperature = 0.7
                probs = F.softmax(policy_logits / temperature, dim=-1)

                # Get legal moves and their probabilities
                legal_moves = list(env.board.legal_moves)
                move_probs = []

                for move in legal_moves:
                    if move.promotion and move.promotion != chess.QUEEN:
                        continue
                    idx = move.from_square * NUM_SQUARES + move.to_square
                    move_probs.append((move, probs[0, idx].item()))

                # Select move (top-3 sampling for more variety)
                move_probs.sort(key=lambda x: x[1], reverse=True)
                top_moves = move_probs[:3]
                selected_move = random.choices(
                    [m[0] for m in top_moves],
                    weights=[m[1] for m in top_moves]
                )[0]

                env.board.push(selected_move)
                env.move_history.append(selected_move)
                board_state = env.board_to_tensor()
                move_count += 1

                think_time = time.time() - start_time
                print(f"üéØ AI plays: {selected_move} (thought for {think_time:.2f}s)")
                print(f"üìä Position evaluation: {value.item():.3f}")

    # Game over
    print(f"\n{env.board}")
    result = env.board.result()
    if result == "1-0":
        print("üéâ You won! Congratulations!")
    elif result == "0-1":
        print("ü§ñ AI won! Better luck next time!")
    else:
        print("ü§ù It's a draw!")

    print(f"Final result: {result}")

class ChessAnalyzer:
    """Advanced chess position analyzer using the trained model"""

    def __init__(self, model_path, config=None):
        if config is None:
            config = TrainingConfig()

        self.model = AdvancedChessTransformer(config).to(device)
        checkpoint = torch.load(model_path, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()

        self.env = ChessEnvironment()

    def analyze_position(self, fen_string=None):
        """Analyze a chess position from FEN string or current board"""
        if fen_string:
            self.env.board.set_fen(fen_string)

        board_state = self.env.board_to_tensor()
        move_count = len(self.env.move_history)

        with torch.no_grad():
            policy_logits, value, piece_logits, phase_logits = self.model(
                board_state.unsqueeze(0),
                torch.tensor([move_count]),
                return_aux=True
            )

            # Get top moves
            legal_moves = list(self.env.board.legal_moves)
            move_scores = []

            for move in legal_moves:
                if move.promotion and move.promotion != chess.QUEEN:
                    continue
                idx = move.from_square * NUM_SQUARES + move.to_square
                score = policy_logits[0, idx].item()
                move_scores.append((move, score))

            move_scores.sort(key=lambda x: x[1], reverse=True)

            # Game phase prediction
            phase_probs = F.softmax(phase_logits, dim=-1)[0]
            phases = ['Opening', 'Middlegame', 'Endgame']
            predicted_phase = phases[torch.argmax(phase_probs).item()]

            analysis = {
                'position_value': value.item(),
                'predicted_phase': predicted_phase,
                'phase_probabilities': {phase: prob.item() for phase, prob in zip(phases, phase_probs)},
                'top_moves': [(str(move), score) for move, score in move_scores[:5]],
                'turn': 'White' if self.env.board.turn else 'Black'
            }

            return analysis

    def compare_positions(self, fen1, fen2):
        """Compare two chess positions"""
        analysis1 = self.analyze_position(fen1)
        analysis2 = self.analyze_position(fen2)

        value_diff = analysis2['position_value'] - analysis1['position_value']

        return {
            'position1': analysis1,
            'position2': analysis2,
            'value_difference': value_diff,
            'better_for': 'Position 2' if value_diff > 0 else 'Position 1' if value_diff < 0 else 'Equal'
        }

def benchmark_model(model_path, num_games=100):
    """Benchmark the model against random and basic players"""
    config = TrainingConfig()
    model = AdvancedChessTransformer(config).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    def random_player(board):
        moves = list(board.legal_moves)
        return random.choice(moves) if moves else None

    def greedy_player(board):
        """Simple greedy player that captures pieces when possible"""
        moves = list(board.legal_moves)
        if not moves:
            return None

        # Prioritize captures
        captures = [m for m in moves if board.is_capture(m)]
        if captures:
            return random.choice(captures)

        # Then checks
        checks = [m for m in moves if board.gives_check(m)]
        if checks:
            return random.choice(checks)

        return random.choice(moves)

    def model_player(board, move_count):
        board_state = ChessEnvironment().board_to_tensor()
        # Update board_state based on current board
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece is not None:
                piece_type = piece.piece_type
                color = piece.color
                token = piece_type if color else piece_type + 6
                board_state[square] = token

        with torch.no_grad():
            policy_logits, _ = model(board_state.unsqueeze(0), torch.tensor([move_count]))

            legal_moves = list(board.legal_moves)
            if not legal_moves:
                return None

            best_score = -float('inf')
            best_move = None

            for move in legal_moves:
                if move.promotion and move.promotion != chess.QUEEN:
                    continue
                idx = move.from_square * NUM_SQUARES + move.to_square
                score = policy_logits[0, idx].item()
                if score > best_score:
                    best_score = score
                    best_move = move

            return best_move

    results = {'vs_random': {'wins': 0, 'losses': 0, 'draws': 0},
               'vs_greedy': {'wins': 0, 'losses': 0, 'draws': 0}}

    print(f"üéØ Benchmarking model against {num_games} games each...")

    # Test against random player
    for i in range(num_games):
        board = chess.Board()
        move_count = 0

        while not board.is_game_over() and move_count < 200:
            if board.turn == chess.WHITE:  # Model plays white
                move = model_player(board, move_count)
            else:  # Random plays black
                move = random_player(board)

            if move:
                board.push(move)
                move_count += 1
            else:
                break

        result = board.result()
        if result == "1-0":
            results['vs_random']['wins'] += 1
        elif result == "0-1":
            results['vs_random']['losses'] += 1
        else:
            results['vs_random']['draws'] += 1

        if (i + 1) % 20 == 0:
            print(f"Progress: {i + 1}/{num_games} games vs random")

    # Test against greedy player
    for i in range(num_games):
        board = chess.Board()
        move_count = 0

        while not board.is_game_over() and move_count < 200:
            if board.turn == chess.WHITE:  # Model plays white
                move = model_player(board, move_count)
            else:  # Greedy plays black
                move = greedy_player(board)

            if move:
                board.push(move)
                move_count += 1
            else:
                break

        result = board.result()
        if result == "1-0":
            results['vs_greedy']['wins'] += 1
        elif result == "0-1":
            results['vs_greedy']['losses'] += 1
        else:
            results['vs_greedy']['draws'] += 1

        if (i + 1) % 20 == 0:
            print(f"Progress: {i + 1}/{num_games} games vs greedy")

    # Print results
    print("\nüìä BENCHMARK RESULTS:")
    print("-" * 50)

    for opponent, stats in results.items():
        total = sum(stats.values())
        win_rate = stats['wins'] / total * 100 if total > 0 else 0
        print(f"\n{opponent}:")
        print(f"  Wins: {stats['wins']} ({stats['wins']/total*100:.1f}%)")
        print(f"  Losses: {stats['losses']} ({stats['losses']/total*100:.1f}%)")
        print(f"  Draws: {stats['draws']} ({stats['draws']/total*100:.1f}%)")
        print(f"  Win Rate: {win_rate:.1f}%")

    return results

def create_training_visualization():
    """Create a simple training progress visualization"""
    import matplotlib.pyplot as plt

    # This would be called after training to visualize progress
    def plot_training_stats(stats):
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))

        # Rewards over time
        axes[0, 0].plot(stats['total_rewards'])
        axes[0, 0].set_title('Game Rewards Over Time')
        axes[0, 0].set_xlabel('Game')
        axes[0, 0].set_ylabel('Total Reward')

        # Policy loss
        axes[0, 1].plot(stats['policy_losses'])
        axes[0, 1].set_title('Policy Loss Over Time')
        axes[0, 1].set_xlabel('Training Step')
        axes[0, 1].set_ylabel('Policy Loss')

        # Value loss
        axes[1, 0].plot(stats['value_losses'])
        axes[1, 0].set_title('Value Loss Over Time')
        axes[1, 0].set_xlabel('Training Step')
        axes[1, 0].set_ylabel('Value Loss')

        # Auxiliary loss
        axes[1, 1].plot(stats['aux_losses'])
        axes[1, 1].set_title('Auxiliary Loss Over Time')
        axes[1, 1].set_xlabel('Training Step')
        axes[1, 1].set_ylabel('Aux Loss')

        plt.tight_layout()
        plt.savefig('training_progress.png', dpi=300, bbox_inches='tight')
        plt.show()

    return plot_training_stats

# üöÄ MAIN EXECUTION EXAMPLES
def main():
    """Main execution with different modes"""

    # Configuration for publication-grade model
    config = TrainingConfig(
        d_model=384,      # Larger model
        nhead=24,         # More attention heads
        num_layers=12,    # Deeper network
        learning_rate=1e-4,
        batch_size=64,
        memory_size=200000,
        exploration_noise=0.2,
        temperature=0.8
    )

    print("üß† Advanced Chess Transformer - Publication Grade")
    print("=" * 60)
    print(f"Model parameters: ~{sum(p.numel() for p in AdvancedChessTransformer(config).parameters()):,}")
    print("Key innovations:")
    print("‚Ä¢ Multi-scale attention (local + global patterns)")
    print("‚Ä¢ Chess-specific positional encoding")
    print("‚Ä¢ Game phase awareness")
    print("‚Ä¢ Prioritized experience replay")
    print("‚Ä¢ Self-supervised auxiliary tasks")
    print("‚Ä¢ TD-lambda returns")
    print("‚Ä¢ Advanced exploration strategies")

    mode = input("\nSelect mode (train/play/analyze/benchmark): ").lower()

    if mode == 'train':
        trainer = AdvancedTrainer(config)

        # Optional: Load from checkpoint
        checkpoint_path = input("Load from checkpoint? (path or Enter to skip): ").strip()
        if checkpoint_path:
            try:
                trainer.load_checkpoint(checkpoint_path)
                print("‚úÖ Checkpoint loaded successfully!")
            except:
                print("‚ùå Could not load checkpoint, starting fresh.")

        num_games = int(input("Number of training games (default 2000): ") or "2000")

        print(f"\nüöÄ Starting training for {num_games} games...")
        model = trainer.train(num_games=num_games, save_interval=200)

        print("‚úÖ Training completed!")

        # Optionally plot results (if matplotlib available)
        try:
            plot_fn = create_training_visualization()
            plot_fn(trainer.stats)
        except ImportError:
            print("Install matplotlib for training visualizations")

    elif mode == 'play':
        model_path = input("Model path: ").strip()
        if not model_path:
            print("‚ùå Model path required!")
            return

        try:
            play_against_human(model_path, config)
        except Exception as e:
            print(f"‚ùå Error loading model: {e}")

    elif mode == 'analyze':
        model_path = input("Model path: ").strip()
        if not model_path:
            print("‚ùå Model path required!")
            return

        try:
            analyzer = ChessAnalyzer(model_path, config)

            while True:
                fen = input("\nEnter FEN string (or 'quit'): ").strip()
                if fen.lower() == 'quit':
                    break

                if not fen:
                    # Use starting position
                    fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"

                analysis = analyzer.analyze_position(fen)

                print(f"\nüìä POSITION ANALYSIS:")
                print(f"Turn: {analysis['turn']}")
                print(f"Evaluation: {analysis['position_value']:.3f}")
                print(f"Game Phase: {analysis['predicted_phase']}")
                print(f"\nTop 5 moves:")
                for i, (move, score) in enumerate(analysis['top_moves'], 1):
                    print(f"  {i}. {move} ({score:.3f})")

        except Exception as e:
            print(f"‚ùå Error: {e}")

    elif mode == 'benchmark':
        model_path = input("Model path: ").strip()
        if not model_path:
            print("‚ùå Model path required!")
            return

        num_games = int(input("Games per opponent (default 50): ") or "50")

        try:
            results = benchmark_model(model_path, num_games)
            print("\nüèÜ Benchmarking completed!")
        except Exception as e:
            print(f"‚ùå Error: {e}")

    else:
        print("‚ùå Invalid mode! Choose: train/play/analyze/benchmark")

if __name__ == "__main__":
    # Uncomment to run
    # main()

    # Quick training example (commented out)
    """
    # Quick test with smaller model
    config = TrainingConfig(d_model=128, num_layers=4, nhead=8)
    trainer = AdvancedTrainer(config)
    model = trainer.train(num_games=100)
    trainer.save_checkpoint("quick_test_model.pth")
    """

    pass