# Curriculum Learning

In [1]:
import sys
sys.path.insert(0, '../../src/')

import numpy as np
import matplotlib.pyplot as plt
import pickle
import config
import torch
from tqdm.notebook import tqdm
from copy import copy, deepcopy
import cmath
import chess

from basis_gates import *
from agents import *
from environments import *
from models import *
%matplotlib inline

np.set_printoptions(precision = 3)

print("PyTorch version:", torch.__version__)  
print("CUDA toolkit version PyTorch was built with:", torch.version.cuda)  
print("cuDNN version:", torch.backends.cudnn.version()) 
print("cuda available:", torch.cuda.is_available())

torch.set_float32_matmul_precision('high')

PyTorch version: 2.8.0+cu126
CUDA toolkit version PyTorch was built with: 12.6
cuDNN version: 91002
cuda available: True


In [23]:
torch.manual_seed(42)
random.seed(42)
logger = Logger(sample_freq = 1000)

agent = Agent(board_logic = BoardLogic())


opt_list = [torch.optim.Adam(agent.online_net1.parameters(), lr=1e-5), 
            torch.optim.Adam(agent.online_net2.parameters(), lr=1e-5)]

model = Model(agent = agent,
               environment = Environment(max_num_moves=100,),
               mem_capacity = 100000,
               batch_size = 1024,
               policy_update = 10,
               target_update = 5000,
               temp_constants = (0.5, 0.1, 0, 10000),
               opt_list=opt_list,
               scaler=torch.amp.GradScaler("cuda")
             )

In [540]:
model.train(num_episodes = 3000, logger = logger)

  0%|          | 0/3000 [00:00<?, ?it/s]

tensor(0.0056, device='cuda:0')
2 checkmate! 23 0.09811788258374475
10 checkmate! 35 0.09010676482243644
20 checkmate! 83 0.04902411062806558
25 checkmate! 41 0.05884949186217802
26 checkmate! 45 0.00503352553249522
31 checkmate! 74 0.09753862663359746
38 checkmate! 24 0.08789215552479687
39 checkmate! 80 0.05498695091334488
40 checkmate! 64 0.01937189490248419
42 checkmate! 79 0.07567316450330473
43 checkmate! 88 0.030985601500966853
44 checkmate! 22 0.06101689756936271
47 checkmate! 88 0.017650724512246797
50 checkmate! 76 0.018904276756862973
53 checkmate! 29 0.04953726894942109
54 checkmate! 72 0.02790332383855837
60 checkmate! 18 0.07839058967354462
68 checkmate! 71 0.003675422254480554
70 checkmate! 94 0.09969964322071467
71 checkmate! 35 0.04996718397642284
72 checkmate! 79 0.094337809458044
82 checkmate! 29 0.007418938206176695
83 checkmate! 48 0.09772775118998303
89 checkmate! 91 0.010440992556349971
90 checkmate! 83 0.09324646852369133
tensor(0.0056, device='cuda:0')
101 chec

In [541]:
saver(model, 'dqn_model.pth')

In [578]:
board = chess.Board()

mirror = False
#random.seed(42)
#np.random.seed(42)
#torch.manual_seed(42)

In [592]:
state = agent.board_to_state([board])
action = agent.select_action(board, eps=0.1, greedy=False)

Q1 = agent.online_net1(state).detach()
Q2 = agent.online_net2(state).detach()
mask_legal = agent.get_mask_legal([board])

Q1_legal = Q1[mask_legal]
Q2_legal = Q2[mask_legal]

diff = torch.abs(Q1_legal - Q2_legal)/torch.max(torch.abs(Q1_legal), torch.abs(Q2_legal))

#print(Q1_legal)
#print(Q2_legal)

print(f"{np.mean(diff.cpu().numpy()):.4f}")

Q_legal = Q1.masked_fill(~mask_legal, -1e9)
action_star = torch.argmax(Q_legal, dim=1).to(config.device)
score = Q2[0,action_star[0]]

if not mirror:
    print("White:")
else:
    print("Black:")
print(f"score: {score.item():.4f}")

move = agent.action_to_move(action)[0]

board.push(move)

board = board.mirror()
mirror = not mirror
if mirror:
    print(board.mirror())
else:
    print(board)

if board.is_checkmate():
    print("checkmate!")

if board.can_claim_threefold_repetition():
    print("draw!")

0.1631
Black:
score: 0.2512
. n b q k . n r
. p p p b p p p
. . . . p . . .
P . . . . . . .
. . P . . . P .
B . . . . . . .
P . . P P P . P
R N . . K B N R


In [276]:
state, action, next_state, mask_legal, reward, done =model.memory.sample()

In [295]:
model.memory.rewards[:1000]

tensor([[-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-