# Curriculum Learning

In [3]:
import sys
sys.path.insert(0, '../../src/')

import numpy as np
import matplotlib.pyplot as plt
import pickle
import config
import torch
from tqdm.notebook import tqdm
from copy import copy, deepcopy
import cmath
import chess
from utils import saver, loader

from agents import *
from environments import *
from models import *
%matplotlib inline

np.set_printoptions(precision = 3)

print("PyTorch version:", torch.__version__)  
print("CUDA toolkit version PyTorch was built with:", torch.version.cuda)  
print("cuDNN version:", torch.backends.cudnn.version()) 
print("cuda available:", torch.cuda.is_available())

torch.set_float32_matmul_precision('high')

PyTorch version: 2.7.1+cu128
CUDA toolkit version PyTorch was built with: 12.8
cuDNN version: 90701
cuda available: True


In [None]:
torch.manual_seed(42)
random.seed(42)
logger = Logger(sample_freq = 1000)

agent = Agent(board_logic = BoardLogic())


opt_list = [torch.optim.Adam(agent.online_net1.parameters(), lr=1e-4), 
            torch.optim.Adam(agent.online_net2.parameters(), lr=1e-4)]

model = Model(agent = agent,
               environment = Environment(max_num_moves=200,),
               mem_capacity = 100000,
               batch_size = 1024,
               policy_update = 10,
               target_update = 10000,
               temp_constants = (0.5, 0.1, 0, 1000),
               opt_list=opt_list,
               scaler=torch.amp.GradScaler("cuda")
             )

In [9]:
model.train(num_episodes = 10000, logger = logger)

  0%|          | 0/10000 [00:00<?, ?it/s]

0 draw! 99 0.3197133992289419
0
1 draw! 99 0.00028571930552057524
2 checkmate! 88 0.2916214551539184
3 draw! 99 0.11094993831177742
4 draw! 99 0.22057859515358444
5 draw! 99 0.21333225059090274
6 draw! 99 0.10339526193779915
7 checkmate! 31 0.13735020868311876
8 checkmate! 22 0.2612924999259965
9 draw! 99 0.049891794096246576
10 checkmate! 73 0.3220227653571572
11 draw! 99 0.32259429911517495
12 checkmate! 58 0.03782306772464007
13 checkmate! 28 0.13997574269841398
14 draw! 99 0.2093321075096481
15 draw! 99 0.36318237049728175
16 checkmate! 64 0.3170583277211178
17 draw! 99 0.011430550280304328
18 draw! 99 0.1466644577684602
19 draw! 99 0.2354567674296243
20 draw! 99 0.18521841274550896
21 checkmate! 97 0.14746753319110617
22 checkmate! 83 0.16442243310450527
23 draw! 99 0.16244972294549542
24 checkmate! 56 0.2522691102868576
25 draw! 99 0.11473048232736553
26 draw! 99 0.13158709254596668
27 checkmate! 32 0.10518817021052827
28 checkmate! 62 0.18707814957167557
29 draw! 99 0.3938679398

In [10]:
saver(model,"model_10000_episodes.pth")

In [11]:
board = chess.Board()

mirror = False
#random.seed(42)
#np.random.seed(42)
#torch.manual_seed(42)

In [144]:
state = agent.board_to_state([board])
action = agent.select_action(board, eps=0.2, greedy=False)

Q1 = agent.online_net1(state).detach()
Q2 = agent.online_net2(state).detach()
mask_legal = agent.get_mask_legal([board])

Q1_legal = Q1[mask_legal]
Q2_legal = Q2[mask_legal]

diff = torch.abs(Q1_legal - Q2_legal)/torch.max(torch.abs(Q1_legal), torch.abs(Q2_legal))

#print(Q1_legal)
#print(Q2_legal)

print(f"{np.mean(diff.cpu().numpy()):.4f}")

Q_legal = Q1.masked_fill(~mask_legal, -1e9)
action_star = torch.argmax(Q_legal, dim=1).to(config.device)
score = Q2[0,action_star[0]]

if not mirror:
    print("White:")
else:
    print("Black:")
print(f"score: {score.item():.4f}")

move = agent.action_to_move(action)[0]

board.push(move)

board = board.mirror()
mirror = not mirror
if mirror:
    print(board.mirror())
else:
    print(board)

if board.is_checkmate():
    print("checkmate!")

if board.can_claim_threefold_repetition():
    print("draw!")

0.2632
White:
score: -0.1257
. . . . k b R .
. . . . . p p .
. . . p . . . .
p . . . . . n .
. . . B . P r .
P P . P . . . .
. . . . K . P .
R N . . . B . .


In [None]:
state, action, next_state, mask_legal, reward, done =model.memory.sample()

In [None]:
model.memory.rewards[:1000]

tensor([[-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-0.],
        [-