In [1]:
import chess
import numpy as np
import time
from typing import Tuple, Dict, List, Optional, Union
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from collections import deque
from typing import List, Tuple, Dict
import logging
import math

In [2]:
from environment import GameState, BulletChessEnv
from agent_dqn import ChessQNetwork, BulletExperienceReplay, BulletChessDQNAgent
from agent_policy_value import ChessPolicyValueNetwork, BulletChessAlphaZeroAgent, MCTSNode, AdaptiveMCTS
from utils import get_time_pressure_level

In [3]:
def train_self_play(agent, env, episodes, max_steps_per_episode, target_update_freq):
    for episode in range(episodes):
        state = env.reset()
        total_reward = 0

        for step in range(max_steps_per_episode):
            action = agent.select_action(state)
            next_state, reward, done, info = env.step(action)

            # Zeitdruck aus info extrahieren, falls vorhanden
            time_pressure = info.get("time_pressure", 1.0)

            agent.store_experience(state, action, reward, next_state, done, time_pressure)
            agent.train_step()

            state = next_state
            total_reward += reward

            if done:
                break

        if episode % target_update_freq == 0:
            agent.update_target_network()

        print(f"Episode {episode+1}: Total Reward = {total_reward:.2f}")



In [4]:
# Initialize environment and agent
env = BulletChessEnv()
agent = BulletChessDQNAgent()

num_episodes = 50

for episode in range(num_episodes):
    state = env.reset()
    done = False
    total_reward = 0
    legal_actions = env.get_legal_actions()  
    time_pressure_level = "moderate"


    while not done:
        action = agent.select_action(state, legal_actions,time_pressure_level)
        next_state, reward, done, _ = env.step(action)
        agent.store_experience(state, action, reward, next_state, done, time_pressure_level)
        agent.train_step()

        state = next_state
        total_reward += reward

    print(f"Episode {episode + 1}: Total Reward = {total_reward}")

    if (episode + 1) % 100 == 0:
        agent.save_model(f"checkpoint_{episode + 1}.pth")

torch.save(agent.q_network.state_dict(), "chess_dqn_model.pth")


Episode 1: Total Reward = -0.999
Episode 2: Total Reward = -0.999
Episode 3: Total Reward = -0.999
Episode 4: Total Reward = -0.999
Episode 5: Total Reward = -0.999
Episode 6: Total Reward = -0.9990000334103902
Episode 7: Total Reward = -0.999
Episode 8: Total Reward = -0.999
Episode 9: Total Reward = -0.999
Episode 10: Total Reward = -0.999
Episode 11: Total Reward = -0.999
Episode 12: Total Reward = -0.999
Episode 13: Total Reward = -0.999
Episode 14: Total Reward = -0.999
Episode 15: Total Reward = -0.999
Episode 16: Total Reward = -0.999
Episode 17: Total Reward = -0.9990000333825747
Episode 18: Total Reward = -0.999
Episode 19: Total Reward = -0.999
Episode 20: Total Reward = -0.999
Episode 21: Total Reward = -0.999
Episode 22: Total Reward = -0.999
Episode 23: Total Reward = -0.9990002008120219
Episode 24: Total Reward = -0.999
Episode 25: Total Reward = -0.9990000334938367
Episode 26: Total Reward = -0.999
Episode 27: Total Reward = -0.999
Episode 28: Total Reward = -0.999
Episo

In [None]:
env = BulletChessEnv()
agent = BulletChessAlphaZeroAgent()
num_episodes = 50

for episode in range(num_episodes):
    state = env.reset()
    done = False
    total_reward = 0
    legal_actions = env.get_legal_actions()  
    time_pressure_level = "moderate"

    while not done:
        time_remaining = env.game_state.white_time if env.game_state.board.turn else env.game_state.black_time
        action = agent.select_action(env, time_remaining)
        value = None
        uci_action = agent.action_index_to_uci[action]
        next_state, reward, done, _ = env.step(uci_action)
        agent.train_network(epochs=10, batch_size=32)

        state = next_state
        total_reward += reward

    print(f"Episode {episode + 1}: Total Reward = {total_reward}")

    if (episode + 1) % 100 == 0:
        agent.save_model(f"checkpoint_{episode + 1}.pth")

torch.save(agent.model.state_dict(), "chess_alpha_zero_model.pth")


Episode 1: Total Reward = -1
Episode 2: Total Reward = -1
Episode 3: Total Reward = -1
Episode 4: Total Reward = -1
Episode 5: Total Reward = -1
Episode 6: Total Reward = -1
Episode 7: Total Reward = -1


In [None]:
env = BulletChessEnv()
agent = BulletChessAlphaZeroAgent()
num_episodes = 50

for episode in range(num_episodes):
    state = env.reset()
    done = False
    total_reward = 0
    episode_data = []  # store (state_tensor, MCTS_probs, current_player)

    while not done:
        time_remaining = env.game_state.white_time if env.game_state.board.turn else env.game_state.black_time

        # Run MCTS to get move and policy distribution
        action_index, policy_probs = agent.select_action(env, time_remaining, return_probs=True)
        uci_action = agent.action_index_to_uci[action_index]

        # Store training data (you'll need a function to convert env state to tensor)
        board_tensor = agent.board_to_tensor(env.game_state.board)
        episode_data.append((board_tensor, policy_probs, 1 if env.game_state.board.turn else -1))

        next_state, reward, done, _ = env.step(uci_action)
        total_reward += reward
        state = next_state

    # Assign outcome to all positions (+1 win, -1 loss, 0 draw)
    for i in range(len(episode_data)):
        board_tensor, policy_probs, current_player = episode_data[i]
        final_value = reward * current_player  # from that player's perspective
        agent.memory.append((board_tensor, policy_probs, final_value))

    # Train the network (on a batch of experiences)
    agent.train_network(epochs=5, batch_size=32)

    print(f"Episode {episode + 1}: Total Reward = {total_reward}")

    if (episode + 1) % 10 == 0:
        agent.save_model(f"checkpoint_{episode + 1}.pth")

# Save final model
torch.save(agent.model.state_dict(), "chess_alpha_zero_model.pth")
