# MCTS Training Notebook

In [None]:
import os, pickle

import wandb

from torch.multiprocessing import Pool, set_start_method, Lock, Process

from parallel import run_games_continuously, torch_safesave, ReplayBufferManager, ChessReplayDataset, duel, run_training_epoch

import torch

import sys
sys.path.append('/home/kage/chess_workspace/chess-utils')
from utils import RunningAverage
from chess_dataset import ChessDataset


### Config 

In [None]:
MODEL_PATH = 'best_baseSwinChessNet.pt'
BESTMODEL_PATH = 'bestMCTS' + MODEL_PATH
CURRMODEL_PATH = 'currentMCTS' + MODEL_PATH
DEVICE = 'cuda'

DATASET_SIZE = 25_000
BUFFER_SIZE = 50_000

TRAIN_EVERY = 30
TRAIN_EPOCHS = 4
SELFPLAY_SIMS = 700
BATCH_SIZE = 96
DUEL_ROUNDS = 7
DUEL_WINRATE = 0.55
DUEL_SIMS = 100
DUEL_PROCESSES = 4
NUM_GAMES = 1000
GAMES_IN_PARALLEL = 8

### Helper Code

In [None]:
def update_and_duel(selfplay_buffer_proxy, expert_dataset, file_lock):
    """ 
    Train on selfplay and expert data. Builds a dataset of size dataset_size, where the
    proportion of data comes from,
    
        expert_size + selfplay_size = dataset_size 
    
    If the replay buffer has more data than dataset_size, will sample from selfplay data only
    
    """

    # Initialize buffer dataset
    selfplay_dataset = ChessReplayDataset(selfplay_buffer_proxy)
    
    curr_best_wins = 0 
    curr_best_score = 0
    tmp_best_model_state = None

    bestmodel_path = BESTMODEL_PATH if os.path.exists(BESTMODEL_PATH) else MODEL_PATH

    for i in range(TRAIN_EPOCHS):
        # Do training in a separate process
        with Pool(1) as pool:
            stats = pool.apply(run_training_epoch, (CURRMODEL_PATH, selfplay_dataset, expert_dataset, DATASET_SIZE))

        duel_score_dict = duel(CURRMODEL_PATH, bestmodel_path, DUEL_ROUNDS, file_lock, num_sims=DUEL_SIMS, num_processes=DUEL_PROCESSES) # CURRMODEL_PATH exists after run_training_epoch
        
        print(f"Duel scoring: {duel_score_dict}")
        wandb.log(duel_score_dict)
        
        if duel_score_dict['score'] > (DUEL_WINRATE * 2 * DUEL_ROUNDS): 
            print("MODEL WON!")
            if duel_score_dict['score'] > curr_best_score:
                curr_best_score = duel_score_dict['score']
                curr_best_wins = duel_score_dict['wins']
                tmp_best_model_state = torch.load(CURRMODEL_PATH)
            elif duel_score_dict['score'] == curr_best_score and duel_score_dict['wins'] > curr_best_wins:
                curr_best_wins = duel_score_dict['wins']
                tmp_best_model_state = torch.load(CURRMODEL_PATH)
                    
        print(f"Epoch - Loss: {stats.get_average('loss')} - Ploss: {stats.get_average('policy_loss')} - Vloss {stats.get_average('value_loss')}")
        wandb.log({"epoch_loss":  stats.get_average('loss'),
                   "epoch_ploss": stats.get_average('policy_loss'),
                   "epoch_vloss": stats.get_average('value_loss')})
    
    # Save model if new best and clear buffer
    if tmp_best_model_state is not None:
        torch.save(tmp_best_model_state, BESTMODEL_PATH)
        selfplay_buffer_proxy.clear()

##### Parallel

In [None]:
def pickle_bufferproxy(buffer_proxy):
    shared_buffer_state = buffer_proxy.get_state()
    with open('replay_buffer_state.pkl', 'wb') as f:
        pickle.dump(shared_buffer_state, f)

        
def run_training(num_games, expert_dataset):
    # with open('replay_buffer_state.pkl', 'rb') as f:
    #     buffer_state = pickle.load(f)

    # Multiprocessing stuff
    manager = ReplayBufferManager()
    manager.start()
    shared_replay_buffer = manager.ReplayBuffer(capacity=BUFFER_SIZE)
    # shared_replay_buffer.from_dict(buffer_state)
    shutdown_event = manager.Event()
    buffer_lock = manager.Lock()
    global_game_counter = manager.GameCounter()
    file_lock = Lock()
    
    # Load initial model for self-play process
    model_state = torch.load(MODEL_PATH)
    model_state = {k: v.cpu() for k, v in model_state.items()} # can't share cuda tensors

    # Save current model so training and dueling processes can load/use it 
    # torch.save(model_state, CURRMODEL_PATH)

    # Start the continuous game running in a separate process
    process = Process(target=run_games_continuously, args=(model_state, BESTMODEL_PATH, shared_replay_buffer, GAMES_IN_PARALLEL, buffer_lock, file_lock, global_game_counter, shutdown_event))
    process.start()

    next_train = TRAIN_EVERY

    # Main training loop
    training = True
    while training:
        game_count = global_game_counter.count
        
        if game_count >= num_games:
            shutdown_event.set()
            training = False
            
        if game_count >= next_train:
            # Finish current games before updating/dueling
            print("Waiting for games to finish...")
            shutdown_event.set()
            process.join()
            shutdown_event.clear()
            print("All done. Saving buffer and starting training...")

            # Save buffer, train, duel
            pickle_bufferproxy(shared_replay_buffer)
            update_and_duel(shared_replay_buffer, expert_dataset, file_lock)
            
            next_train += TRAIN_EVERY

            # Restart background process
            print("Restarting self-play process...")
            process = Process(target=run_games_continuously, args=(model_state, BESTMODEL_PATH, shared_replay_buffer, GAMES_IN_PARALLEL, buffer_lock, file_lock, global_game_counter, shutdown_event))
            process.start()

    process.join()

In [None]:
PGN_FILE = "/home/kage/chess_workspace/cclr/COMBINED_ccrltest.pgn"
# PGN_FILE = "/home/kage/chess_workspace/TCEC_Cup_1_Final_5.pgn"
 
# Load the datasets
expert_dataset = ChessDataset([PGN_FILE])
print(len(expert_dataset))

In [None]:
wandb.init(project="Chess")
set_start_method('spawn', force=True)

run_training(NUM_GAMES, expert_dataset)