# Curriculum Learning

In [1]:
import sys
sys.path.insert(0, '../../src/')

import numpy as np
import matplotlib.pyplot as plt
import pickle
import config
import torch
from tqdm.notebook import tqdm
from copy import copy, deepcopy
import cmath
import chess

from basis_gates import *
from agents import *
from environments import *
from models import *
%matplotlib inline

np.set_printoptions(precision = 3)

print("PyTorch version:", torch.__version__)  
print("CUDA toolkit version PyTorch was built with:", torch.version.cuda)  
print("cuDNN version:", torch.backends.cudnn.version()) 
print("cuda available:", torch.cuda.is_available())

torch.set_float32_matmul_precision('high')

PyTorch version: 2.7.1+cu128
CUDA toolkit version PyTorch was built with: 12.8
cuDNN version: 90701
cuda available: True


In [None]:
torch.manual_seed(42)
random.seed(42)
logger = Logger(sample_freq = 1000)

agent = Agent(board_logic = BoardLogic())


opt_list = [torch.optim.Adam(agent.online_net1.parameters(), lr=1e-3), 
            torch.optim.Adam(agent.online_net2.parameters(), lr=1e-3)]

model = Model(agent = agent,
               environment = Environment(max_num_moves=100,),
               mem_capacity = 10000,
               batch_size = 256,
               policy_update = 10,
               target_update = 5000,
               temp_constants = (0.5, 0.05, 0, 1000),
               opt_list=opt_list,
               scaler=torch.amp.GradScaler("cuda")
             )

In [5]:
model.train(num_episodes = 1000, logger = logger)

  0%|          | 0/1000 [00:00<?, ?it/s]

0
2 checkmate! 71 0.10194576643228703
4 checkmate! 50 0.4128619823414247
6 checkmate! 15 0.21967222138617692
7 checkmate! 59 0.3275143483063791
10 checkmate! 99 0.017242887553177575
12 checkmate! 62 0.23348641593825867
15 checkmate! 20 0.417865283553795
17 checkmate! 62 0.18339642343537366
18 checkmate! 18 0.4661041417809405
20 checkmate! 60 0.01421354622621117
30 checkmate! 11 0.15203111448320392
35 checkmate! 21 0.20842430192568043
36 checkmate! 14 0.041094812208758304
37 checkmate! 72 0.40619790921851306
38 checkmate! 91 0.24450893677211352
44 checkmate! 58 0.07370900840464209
47 checkmate! 55 0.2883368397184388
48 checkmate! 74 0.3586411484977379
51 checkmate! 99 0.3664777530898692
53 checkmate! 64 0.09803511125488293
55 checkmate! 50 0.16578339735557984
56 checkmate! 15 0.18537445947463935
58 checkmate! 34 0.2420662919245842
60 checkmate! 74 0.3653507032055938
62 checkmate! 8 0.20725601262698612
63 checkmate! 92 0.18901379758025944
64 checkmate! 90 0.36850223113916714
66 checkmate

KeyboardInterrupt: 

In [58]:
board = chess.Board()

mirror = False
#random.seed(42)
#np.random.seed(42)
#torch.manual_seed(42)

In [64]:
state = agent.board_to_state([board])
action = agent.select_action(board, greedy=True)

Q = agent.forward(state)
mask_legal = agent.get_mask_legal([board])
Q = Q.masked_fill(~mask_legal, -1e9)
print(torch.max(Q))

move = agent.action_to_move(action)[0]

board.push(move)

board = board.mirror()
mirror = not mirror
if mirror:
    print(board.mirror())
else:
    print(board)

if board.is_checkmate():
    print("checkmate!")

if board.can_claim_threefold_repetition():
    print("draw!")

tensor(0.1202, device='cuda:0', grad_fn=<MaxBackward1>)
r n . q k b n r
p . p p p p p p
. p . B . . . .
. . . . . . . .
. . . . . . . .
. P . b . . . .
P . P P P P P P
R N . Q K B N R


In [None]:
state, action, next_state, mask_legal, reward, done =model.memory.sample()

In [None]:
action

tensor([[1672],
        [2185],
        [1425],
        [2331],
        [2184],
        [3897],
        [ 271],
        [1141],
        [ 422],
        [ 837],
        [1277],
        [ 989],
        [ 402],
        [2539],
        [ 880],
        [1624],
        [2720],
        [ 229],
        [ 988],
        [ 492],
        [2356],
        [2157],
        [ 175],
        [ 270],
        [ 782],
        [2280],
        [4533],
        [2946],
        [ 988],
        [ 304],
        [ 339],
        [1142],
        [ 263],
        [1106],
        [  21],
        [ 913],
        [1549],
        [1064],
        [ 700],
        [ 516],
        [2073],
        [  14],
        [ 424],
        [ 915],
        [1548],
        [1430],
        [ 401],
        [1748],
        [2681],
        [ 989],
        [ 912],
        [3635],
        [ 684],
        [ 518],
        [2793],
        [1881],
        [ 685],
        [1016],
        [1314],
        [ 947],
        [ 518],
        [1942],
        