# Curriculum Learning

In [1]:
import sys
sys.path.insert(0, '../../src/')

import numpy as np
import matplotlib.pyplot as plt
import pickle
import config
import torch
from tqdm.notebook import tqdm
from copy import copy, deepcopy
import cmath
import chess

from basis_gates import *
from agents import *
from environments import *
from models import *
%matplotlib inline

np.set_printoptions(precision = 3)

print("PyTorch version:", torch.__version__)  
print("CUDA toolkit version PyTorch was built with:", torch.version.cuda)  
print("cuDNN version:", torch.backends.cudnn.version()) 
print("cuda available:", torch.cuda.is_available())

torch.set_float32_matmul_precision('high')

PyTorch version: 2.7.1+cu128
CUDA toolkit version PyTorch was built with: 12.8
cuDNN version: 90701
cuda available: True


In [2]:
torch.manual_seed(42)
random.seed(42)
logger = Logger(sample_freq = 1000)

agent = Agent(board_logic = BoardLogic())


opt_list = [torch.optim.Adam(agent.online_net1.parameters(), lr=1e-6), 
            torch.optim.Adam(agent.online_net2.parameters(), lr=1e-6)]

model = Model(agent = agent,
               environment = Environment(max_num_moves=200,),
               mem_capacity = 100000,
               batch_size = 128,
               batch_size_min = 128,
               policy_update = 1,
               target_update = 25000,
               temp_constants = (2, 1, 1e-5, 100000),
               opt_list=opt_list,
               scaler=torch.amp.GradScaler("cuda")
             )

In [3]:
board = chess.Board()
board1 = deepcopy(board)
board.push(board.legal_moves.__iter__().__next__())
board2 = deepcopy(board)
board.push(board.legal_moves.__iter__().__next__())
board3 = deepcopy(board)

In [4]:
board_list = [board1, board2, board3]

states = agent.board_logic.board_to_state(board_list)
Q = agent.forward(states)
mask = agent.get_mask_legal(board_list)

Q_masked = Q.masked_fill(~mask, -1e9)
action = Q_masked.argmax(dim=1)
moves = agent.action_to_move(action)

In [5]:
with torch.no_grad():
    board = chess.Board()
    board2 = chess.Board()
    board2.push(chess.Move.from_uci("e2e4"))
    board2.push(chess.Move.from_uci("e7e5"))

    board_logic = BoardLogic()
    state_dim = 64*76

    state1 = agent.board_logic.board_to_state([board]).to(config.device)
    state2 = agent.board_logic.board_to_state([board2]).to(config.device)

    state = torch.concat([state1, state2], dim=0)
    print(state.shape)
    Q = agent.forward(state)
    print(Q.shape)

    moves1 = torch.tensor([board_logic.move_to_action(m) for m in board.legal_moves], dtype=torch.long).to(config.device)
    moves2 = torch.tensor([board_logic.move_to_action(m) for m in board2.legal_moves], dtype=torch.long).to(config.device)

    mask = torch.zeros((2,state_dim), dtype=torch.bool).to(config.device)
    mask[0, moves1] = True
    mask[1, moves2] = True

    Q[~mask] = -1e9

    print(Q)

torch.Size([2, 12, 8, 8])
torch.Size([2, 4864])
tensor([[-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
         -1.0000e+09, -1.0000e+09],
        [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
         -1.0000e+09, -1.0000e+09]])


In [6]:
model.train(num_episodes = 1000, logger = logger)

  0%|          | 0/1000 [00:00<?, ?it/s]

torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])


W0916 20:50:32.369000 525614 site-packages/torch/_dynamo/convert_frame.py:964] [1/8] torch._dynamo hit config.recompile_limit (8)
W0916 20:50:32.369000 525614 site-packages/torch/_dynamo/convert_frame.py:964] [1/8]    function: 'sample' (/home/kristian/Documents/chess-agent/notebooks/Experiments/../../src/models.py:64)
W0916 20:50:32.369000 525614 site-packages/torch/_dynamo/convert_frame.py:964] [1/8]    last reason: 1/7: len(self.buffer) == 135                                
W0916 20:50:32.369000 525614 site-packages/torch/_dynamo/convert_frame.py:964] [1/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0916 20:50:32.369000 525614 site-packages/torch/_dynamo/convert_frame.py:964] [1/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.


torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128]) torch.Size([128, 1]) torch.Size([128, 1])
torch.Size([128, 1]) t

KeyboardInterrupt: 