# Curriculum Learning

In [1]:
import sys
sys.path.insert(0, '../../src/')

import numpy as np
import matplotlib.pyplot as plt
import pickle
import config
import torch
from tqdm.notebook import tqdm
from copy import copy, deepcopy
import cmath
import chess
from utils import saver, loader

from agents import *
from environments import *
from models import *
%matplotlib inline

np.set_printoptions(precision = 3)

print("PyTorch version:", torch.__version__)  
print("CUDA toolkit version PyTorch was built with:", torch.version.cuda)  
print("cuDNN version:", torch.backends.cudnn.version()) 
print("cuda available:", torch.cuda.is_available())

torch.set_float32_matmul_precision('high')

PyTorch version: 2.7.1+cu128
CUDA toolkit version PyTorch was built with: 12.8
cuDNN version: 90701
cuda available: True


In [2]:
torch.manual_seed(42)
random.seed(42)
logger = Logger(sample_freq = 1000)

agent = Agent(board_logic = BoardLogic())


opt_list = [torch.optim.Adam(agent.online_net1.parameters(), lr=1e-4), 
            torch.optim.Adam(agent.online_net2.parameters(), lr=1e-4)]

model = Model(agent = agent,
               environment = Environment(max_num_moves=200,),
               mem_capacity = 100000,
               batch_size = 1024,
               policy_update = 10,
               target_update = 10000,
               temp_constants = (0.5, 0.15, 0.05, 2000),
               opt_list=opt_list,
               scaler=torch.amp.GradScaler("cuda")
             )

In [3]:
model.train(num_episodes = 10000, logger = logger)

  0%|          | 0/10000 [00:00<?, ?it/s]

0 checkmate! 135 0.33774205930604767
0
1 checkmate! 74 0.17526603526285234
2 checkmate! 62 0.32930425923246914
3 checkmate! 160 0.09858628167746916
4 checkmate! 43 0.0877175947493867
5 checkmate! 175 0.15906972803373948
6 draw! 199 0.49361612147626854
7 checkmate! 43 0.280328364426065
8 draw! 199 0.0730787114756079
9 checkmate! 54 0.35329804338239285
10 draw! 199 0.3720874949329221
11 checkmate! 116 0.4061735756067923
12 draw! 199 0.3891869438403617
13 draw! 199 0.41499186308210445
14 checkmate! 62 0.47599686494088966
15 checkmate! 116 0.36079303008611024
16 draw! 199 0.4001518318103631
17 checkmate! 68 0.2000884048484619
18 checkmate! 56 0.3135192607433652
19 checkmate! 115 0.43335202180085675
20 checkmate! 91 0.11732257841403473
21 checkmate! 95 0.26453233425845696
22 checkmate! 174 0.41821249846454617
23 checkmate! 43 0.33605850936214926
24 checkmate! 16 0.16175653868450693
25 draw! 199 0.40010978743875314
26 draw! 199 0.05207198910013093
27 checkmate! 161 0.24674934756559141
28 dra

KeyboardInterrupt: 

In [10]:
saver(model,"model_10000_episodes.pth")

In [9]:
board = chess.Board()

mirror = False
#random.seed(42)
#np.random.seed(42)
#torch.manual_seed(42)

In [61]:
state = agent.board_to_state([board])
action = agent.select_action(board, eps=0., greedy=False)

Q1 = agent.online_net1(state).detach()
Q2 = agent.online_net2(state).detach()
mask_legal = agent.get_mask_legal([board])

Q1_legal = Q1[mask_legal]
Q2_legal = Q2[mask_legal]

diff = torch.abs(Q1_legal - Q2_legal)/torch.max(torch.abs(Q1_legal), torch.abs(Q2_legal))

#print(Q1_legal)
#print(Q2_legal)

print(f"{np.mean(diff.cpu().numpy()):.4f}")

Q_legal = Q1.masked_fill(~mask_legal, -1e9)
action_star = torch.argmax(Q_legal, dim=1).to(config.device)
score = Q2[0,action_star[0]]

if not mirror:
    print("White:")
else:
    print("Black:")
print(f"score: {score.item():.4f}")

move = agent.action_to_move(action)[0]

board.push(move)

board = board.mirror()
mirror = not mirror
if mirror:
    print(board.mirror())
else:
    print(board)

if board.is_checkmate():
    print("checkmate!")

if board.can_claim_threefold_repetition():
    print("draw!")

0.6977
Black:
score: 0.0188
r . . q k . n r
. . p p p . . .
n p . . . p . b
p . B . . . N p
B . . . P P P .
P P N P . . . P
. P Q . . . . .
R . . . K . R .


In [373]:
state, action, next_state, mask_legal, reward, done =model.memory.sample()

In [None]:
model.memory.rewards[:1000]

tensor([[-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-