# MCTS Training Notebook

In [1]:
import copy
import os, pickle, random

import wandb

import gym
import chess
from torch.multiprocessing import Pool, set_start_method, Lock, Process

import adversarial_gym
from OBM_ChessNetwork import ChessNetworkSimple
from search import MonteCarloTreeSearch
from parallel import run_games_continuously, torch_safesave, ReplayBufferManager, ChessReplayDataset, duel, run_training_epoch

import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast

import sys
sys.path.append('/home/kage/chess_workspace/chess_utils')
from utils import RunningAverage
from chess_dataset import ChessDataset


  from .autonotebook import tqdm as notebook_tqdm


### Load Model

In [None]:
MODEL_PATH = 'best_baseSwinChessNet.pt'
BESTMODEL_PATH = 'bestMCTS' + MODEL_PATH
CURRMODEL_PATH = 'currentMCTS' + MODEL_PATH
DEVICE = 'cuda'

TRAIN_EPOCHS = 5
DUEL_ROUNDS = 11

NUM_GAMES = 500
GAMES_IN_PARALLEL = 9

# model = Chess42069NetworkSimple(hidden_dim=512, device=DEVICE, base_lr=0.1)
# best_model = Chess42069NetworkSimple(hidden_dim=512, device=DEVICE)

# if MODEL_PATH is not None:
#     model.load_state_dict(torch.load(MODEL_PATH))
#     best_model.load_state_dict(torch.load(MODEL_PATH))

# model = torch.compile(model)
# best_model = torch.compile(best_model)
# best_model.eval()

### Helper Code

In [None]:
def update_and_duel(selfplay_buffer_proxy, expert_dataset, file_lock, dataset_size = 25_000,
                epochs = 5, duel_rounds = 11, duel_winrate = 0.55):
    """ 
    Train on selfplay and expert data. Builds a dataset of size dataset_size, where the
    proportion of data comes from,
    
        expert_size + selfplay_size = dataset_size 
    
    If the replay buffer has more data than dataset_size, will sample from selfplay data only
    
    """

    # Initialize buffer dataset
    selfplay_dataset = ChessReplayDataset(selfplay_buffer_proxy)
    
    curr_best_wins = 0 
    curr_best_score = 0
    tmp_best_model_state = None

    bestmodel_path = BESTMODEL_PATH if os.path.exists(BESTMODEL_PATH) else MODEL_PATH

    for i in range(epochs):
        # Do training in a separate process
        with Pool(1) as pool:
            stats = pool.apply(run_training_epoch, (CURRMODEL_PATH, selfplay_dataset, expert_dataset, dataset_size))

        duel_score_dict = duel(CURRMODEL_PATH, bestmodel_path, duel_rounds, file_lock, num_sims=100, num_processes=5) # CURRMODEL_PATH exists after run_training_epoch
        
        print(f"Duel scoring: {duel_score_dict}")
        wandb.log(duel_score_dict)
        
        if duel_score_dict['score'] > (duel_winrate * 2 * duel_rounds): 
            print("MODEL WON!")
            if duel_score_dict['score'] > curr_best_score:
                curr_best_score = duel_score_dict['score']
                curr_best_wins = duel_score_dict['wins']
                tmp_best_model_state = torch.load(CURRMODEL_PATH)
            elif duel_score_dict['score'] == curr_best_score and duel_score_dict['wins'] > curr_best_wins:
                curr_best_wins = duel_score_dict['wins']
                tmp_best_model_state = torch.load(CURRMODEL_PATH)
                    
        print(f"Epoch - Loss: {stats.get_average('loss')} - Ploss: {stats.get_average('policy_loss')} - Vloss {stats.get_average('value_loss')}")
        wandb.log({"epoch_loss":  stats.get_average('loss'),
                   "epoch_ploss": stats.get_average('policy_loss'),
                   "epoch_vloss": stats.get_average('value_loss')})
    
    # Save model if new best and clear buffer
    if tmp_best_model_state is not None:
        torch.save(tmp_best_model_state, BESTMODEL_PATH)
        selfplay_buffer_proxy.clear()

##### Parallel

In [None]:
def pickle_bufferproxy(buffer_proxy):
    shared_buffer_state = buffer_proxy.get_state()
    with open('replay_buffer_state.pkl', 'wb') as f:
        pickle.dump(shared_buffer_state, f)

        
def run_training(num_games, expert_dataset, games_in_parallel, train_every, duel_winrate=0.55, buffer_capacity=50_000):
    # with open('replay_buffer_state.pkl', 'rb') as f:
    #     buffer_state = pickle.load(f)

    # Multiprocessing stuff
    manager = ReplayBufferManager()
    manager.start()
    shared_replay_buffer = manager.ReplayBuffer(capacity=buffer_capacity)
    # shared_replay_buffer.from_dict(buffer_state)
    shutdown_event = manager.Event()
    buffer_lock = manager.Lock()
    global_game_counter = manager.GameCounter()  # Initialize a shared counter
    file_lock = Lock()
    
    # Load initial model for self-play process
    model_state = torch.load(MODEL_PATH)
    model_state = {k: v.cpu() for k, v in model_state.items()} # can't share cuda tensors

    # Save current model so training and dueling processes can load/use it 
    torch.save(model_state, CURRMODEL_PATH)

    # Start the continuous game running in a separate process
    process = Process(target=run_games_continuously, args=(model_state, BESTMODEL_PATH, shared_replay_buffer, games_in_parallel, buffer_lock, file_lock, global_game_counter, shutdown_event))
    process.start()

    next_train = train_every

    # Main training loop
    training = True
    while training:
        game_count = global_game_counter.count
        
        if game_count >= num_games:
            shutdown_event.set()
            training = False
            
        if game_count >= next_train:
            # Finish current games before updating/dueling
            print("Waiting for games to finish...")
            shutdown_event.set()
            process.join()
            shutdown_event.clear()
            print("All done. Starting training...")

            update_and_duel(shared_replay_buffer, expert_dataset, file_lock, dataset_size=25_000,
                            epochs=5, duel_winrate=duel_winrate, duel_rounds=7)
            next_train += train_every

            # Restart background process
            print("Restarting self-play process...")
            process = Process(target=run_games_continuously, args=(model_state, BESTMODEL_PATH, shared_replay_buffer, games_in_parallel, buffer_lock, file_lock, global_game_counter, shutdown_event))
            process.start()

    process.join()

In [2]:
PGN_FILE = "/home/kage/chess_workspace/PGN-data/tcec+alphastock/COMBINED_tcec+alphazero.pgn"
# PGN_FILE = "/home/kage/chess_workspace/PGN-data/tcec+alphastock/TCEC_Cup_1_Final_5.pgn"

# Load the datasets
expert_dataset = ChessDataset(PGN_FILE)
print(len(expert_dataset))

FileNotFoundError: [Errno 2] No such file or directory: '/home/kage/chess_workspace/COMBINED_tcec+alphazero.pgn'

In [None]:
wandb.init(project="Chess")
set_start_method('spawn', force=True)


run_training(NUM_GAMES, expert_dataset, GAMES_IN_PARALLEL, train_every=25)