In [1]:
import chess
import numpy as np
import time
from typing import Tuple, Dict, List, Optional, Union
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from collections import deque
from typing import List, Tuple, Dict
import logging
import math

In [2]:
from environment import GameState, BulletChessEnv
from agent_dqn import ChessQNetwork, BulletExperienceReplay, BulletChessDQNAgent
from agent_policy_value import ChessPolicyValueNetwork, BulletChessAlphaZeroAgent, MCTSNode, AdaptiveMCTS
from utils import get_time_pressure_level

In [3]:
def train_self_play(agent, env, episodes, max_steps_per_episode, target_update_freq):
    for episode in range(episodes):
        state = env.reset()
        total_reward = 0

        for step in range(max_steps_per_episode):
            action = agent.select_action(state)
            next_state, reward, done, info = env.step(action)

            # Zeitdruck aus info extrahieren, falls vorhanden
            time_pressure = info.get("time_pressure", 1.0)

            agent.store_experience(state, action, reward, next_state, done, time_pressure)
            agent.train_step()

            state = next_state
            total_reward += reward

            if done:
                break

        if episode % target_update_freq == 0:
            agent.update_target_network()

        print(f"Episode {episode+1}: Total Reward = {total_reward:.2f}")



In [None]:
# Initialize environment and agent
env = BulletChessEnv()
agent = BulletChessDQNAgent()

num_episodes = 1000

for episode in range(num_episodes):
    state = env.reset()
    done = False
    total_reward = 0
    legal_actions = env.get_legal_actions()  
    time_pressure_level = "moderate"


    while not done:
        action = agent.select_action(state, legal_actions,time_pressure_level)
        next_state, reward, done, _ = env.step(action)
        agent.store_experience(state, action, reward, next_state, done, time_pressure_level)
        agent.train_step()

        state = next_state
        total_reward += reward

    print(f"Episode {episode + 1}: Total Reward = {total_reward}")

    if (episode + 1) % 100 == 0:
        agent.save(f"models/checkpoint_{episode + 1}.pth")

torch.save(agent.q_network.state_dict(), "models/chess_dqn_model_1000ep.pth")


Episode 1: Total Reward = -0.9990003026564916
Episode 2: Total Reward = -0.999000552535057
Episode 3: Total Reward = -0.999
Episode 4: Total Reward = -0.9990002816836039
Episode 5: Total Reward = -0.999
Episode 6: Total Reward = -0.999
Episode 7: Total Reward = -0.999
Episode 8: Total Reward = -0.999
Episode 9: Total Reward = -0.999
Episode 10: Total Reward = -0.999
Episode 11: Total Reward = -0.9990000334858894
Episode 12: Total Reward = -0.999
Episode 13: Total Reward = -0.999
Episode 14: Total Reward = -0.999000151360035
Episode 15: Total Reward = -0.999
Episode 16: Total Reward = -0.999
Episode 17: Total Reward = -0.999
Episode 18: Total Reward = -0.9990004037419955
Episode 19: Total Reward = -0.999
Episode 20: Total Reward = -0.9990000334183375
Episode 21: Total Reward = -0.999
Episode 22: Total Reward = -0.9990000334938367
Episode 23: Total Reward = -0.999000263941288
Episode 24: Total Reward = -0.999
Episode 25: Total Reward = -0.9990002634127935
Episode 26: Total Reward = -0.99