# Curriculum Learning

In [1]:
import sys
sys.path.insert(0, '../../src/')

import numpy as np
import matplotlib.pyplot as plt
import pickle
import config
import torch
from tqdm.notebook import tqdm
from copy import copy, deepcopy
import cmath
import chess

from basis_gates import *
from agents import *
from environments import *
from models import *
%matplotlib inline

np.set_printoptions(precision = 3)

print("PyTorch version:", torch.__version__)  
print("CUDA toolkit version PyTorch was built with:", torch.version.cuda)  
print("cuDNN version:", torch.backends.cudnn.version()) 
print("cuda available:", torch.cuda.is_available())

torch.set_float32_matmul_precision('high')

PyTorch version: 2.8.0+cu126
CUDA toolkit version PyTorch was built with: 12.6
cuDNN version: 91002
cuda available: True


In [None]:
torch.manual_seed(42)
random.seed(42)
logger = Logger(sample_freq = 1000)

agent = Agent(board_logic = BoardLogic())


opt_list = [torch.optim.Adam(agent.online_net1.parameters(), lr=1e-5), 
            torch.optim.Adam(agent.online_net2.parameters(), lr=1e-5)]

model = Model(agent = agent,
               environment = Environment(max_num_moves=100,),
               mem_capacity = 100000,
               batch_size = 1024,
               policy_update = 10,
               target_update = 5000,
               temp_constants = (0.5, 0.1, 0, 10000),
               opt_list=opt_list,
               scaler=torch.amp.GradScaler("cuda")
             )

In [None]:
model.train(num_episodes = 5000, logger = logger)

  0%|          | 0/10000 [00:00<?, ?it/s]

0
4 checkmate! 90 0.3847747188337613
5 checkmate! 52 0.24419120785502393
6 checkmate! 70 0.1900977223469995
7 checkmate! 90 0.2019743594844126
9 checkmate! 47 0.13976936359246342
10 checkmate! 85 0.3214130058290622
11 checkmate! 17 0.15057705648993028
12 checkmate! 64 0.0575314031130205
19 checkmate! 99 0.26430646600922564
25 checkmate! 10 0.04996703431517383
26 checkmate! 90 0.42173142144467946
29 checkmate! 65 0.14668446152405698
34 checkmate! 74 0.2504006873464691
38 checkmate! 84 0.4530800126571689
39 checkmate! 54 0.06271743799865545
40 checkmate! 56 0.18259084361499722
42 checkmate! 72 0.1876706357422696
43 checkmate! 54 0.4455271428067112
44 checkmate! 59 0.06741519734860625
45 checkmate! 68 0.08037721030589551
51 checkmate! 53 0.2878813551757399
52 checkmate! 28 0.45310924852790835
54 checkmate! 70 0.24417327495539592
56 checkmate! 77 0.3690552717537619
60 checkmate! 43 0.04605338299328076
62 checkmate! 58 0.32463189346574656
67 checkmate! 92 0.34475244765490193
73 checkmate! 4

KeyboardInterrupt: 

In [285]:
board = chess.Board()

mirror = False
#random.seed(42)
#np.random.seed(42)
#torch.manual_seed(42)

In [293]:
state = agent.board_to_state([board])
#print(state)
action = agent.select_action(board, eps=0.1, greedy=False)

Q = agent.forward(state)
mask_legal = agent.get_mask_legal([board])
Q = Q.masked_fill(~mask_legal, -1e9)
print(Q[Q>-1])
print("Q max:", Q.max().item())

move = agent.action_to_move(action)[0]

board.push(move)

board = board.mirror()
mirror = not mirror
if mirror:
    print(board.mirror())
else:
    print(board)

if board.is_checkmate():
    print("checkmate!")

if board.can_claim_threefold_repetition():
    print("draw!")

tensor([-0.0931,  0.0613,  0.0190,  0.0528,  0.0229, -0.0144,  0.0538, -0.1166,
         0.0484, -0.0319, -0.0742, -0.0382,  0.0341, -0.0807, -0.0239,  0.0563,
         0.0262, -0.0308, -0.0383, -0.0311,  0.0375, -0.0989,  0.0211, -0.0266,
         0.0529, -0.0343,  0.0781, -0.0886], device='cuda:0',
       grad_fn=<IndexBackward0>)
Q max: 0.07812061905860901
r n b q k b n r
p p . p . p p p
. . . . . . . .
. . p . p P . .
. . P . . . . .
. . . . P . . .
P P . P . . P P
R N B Q K B N R


In [276]:
state, action, next_state, mask_legal, reward, done =model.memory.sample()

In [277]:
reward

tensor([[-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-