# MCTS Training Notebook

In [None]:
import os, pickle, random
from tqdm.auto import tqdm 
import wandb
import numpy as np
import gym
import chess

import adversarial_gym
from OBM_ChessNetwork import ChessNetworkSimple
from search import MonteCarloTreeSearch

import torch
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast

### Initialize Gym Chess Environment

In [None]:
env = gym.make("Chess-v0")

### Load Model

In [None]:
MODEL_PATH = '/home/kage/chess_workspace/chess-rl/monte-carlo-tree-search-NN/best_baseSwinChessNet.pt'
BESTMODEL_SAVEPATH = 'mcts_baseSwinChessNet_best.pt'
DEVICE = 'cuda'

model = ChessNetworkSimple(hidden_dim=512, device=DEVICE)
best_model = ChessNetworkSimple(hidden_dim=512, device=DEVICE)

if MODEL_PATH is not None:
    model.load_state_dict(torch.load(MODEL_PATH))
    best_model.load_state_dict(torch.load(MODEL_PATH))

best_model.eval()

### Initialize MCTS Tree

In [None]:
tree = MonteCarloTreeSearch(env, model)

### Helper Code

In [None]:
class ReplayBuffer:
    """ Replay buffer to store past experiences for training policy/value network"""
    def __init__(self, capacity, batch_size):
        self.actions = []
        self.states = []
        self.values = []

        self.capacity = capacity
        self.curr_length = 0
        self.batch_size = batch_size
        self.position = 0
    
    def push(self, state, action, value):    
        if len(self.actions) < self.capacity:
            self.states.append(None)
            self.actions.append(None)
            self.values.append(None)
        
        self.states[self.position] = state
        self.actions[self.position] = action
        self.values[self.position] = value

        self.curr_length = len(self.states)
        self.position = (self.position + 1) % self.capacity

    def update(self, states, actions, winner):
        # Create value targets based on who won
        if winner == 1:
            values = [(-1)**(i) for i in range(len(actions))]
        elif winner == -1:
            values = [(-1)**(i+1) for i in range(len(actions))]
        else:
            values = [0] * len(actions)

        for state, action, value in zip(states, actions, values):
            self.push(state, action, value)

    def sample(self, ):
        indices = random.sample(range(len(self.states)), self.batch_size)
        states, actions, values = zip(*[(self.states[i], self.actions[i], self.values[i]) for i in indices])
        return states, actions, values



In [None]:
def play_game(env, white, black, perspective=None, num_sims=1000):
    """ 
    Plays a game and returns 1 if chosen perspective has won, else 0.
    
    Perspective is either Chess.WHITE (1) or Chess.BLACK (0).
    """
    step = 0
    done = False
    obs, info = env.reset()

    while not done:
        state = env.get_string_representation()
        if step % 2 == 0:
            action, ucb = white.search(state, obs, simulations_number=num_sims)
        else:
            action, ucb = black.search(state, obs, simulations_number=num_sims)

        obs, reward, done, _, _ = env.step(action)
        step += 1

    # return reward based on winning or losing from white/black perspective
    if perspective == chess.BLACK and reward == -1:
        reward = 1
    elif perspective == chess.WHITE and reward == 1:
        reward = 1
    else:
        reward = 0

    return reward


def play_game(env, white, black, perspective: int = None, sample_n: int =1):
    """ 
    Play a game and returns whether white or black white. 
    
    Perspective - Chess.WHITE (1) or Chess.BLACK (0).
    sample_n - Set number of top moves to sample from

    """
    step = 0
    done = False
    obs, info = env.reset()
    
    while not done:
        if step % 2 == 0:
            action, log_prob = white.get_action(obs[0], env.board.legal_moves, sample_n=sample_n)
        else:
            action, log_prob = black.get_action(obs[0], env.board.legal_moves, sample_n=sample_n)

        obs, reward, done, _, _ = env.step(action)
        step += 1

    # return reward based on winning or losing from white/black perspective
    if perspective == chess.BLACK and reward == -1:
        reward = 1
    elif perspective == chess.WHITE and reward == 1:
        reward = 1
    else:
        reward = 0
        
    return reward


def duel(env, new_model, old_model, num_rounds):
    """ Duel against the previous best model and return the win ratio. """
    new_model.eval()
    old_model.eval()

    with torch.inference_mode():
        wins = 0
        for i in range(num_rounds):
            reward_w = play_game(env, new_model, old_model, perspective=chess.WHITE, sample_n=2)
            reward_b = play_game(env, old_model, new_model, perspective=chess.BLACK, sample_n=2)

            wins += reward_w + reward_b
    new_model.train()    
    return wins / (2 * num_rounds)


def update_model(model, replay_buffer, num_iterations):
    """ Sample from replay buffer and train policy and value head"""
    for i in range(num_iterations):
        states_batch, actions_batch, values_batch = replay_buffer.sample()
        
        states_batch = torch.tensor(states_batch, device=DEVICE, dtype=torch.float32).unsqueeze(1)
        actions_batch = torch.tensor(actions_batch, device=DEVICE)
        values_batch = torch.tensor(values_batch, device=DEVICE, dtype=torch.float32)

        # Update model 
        with autocast():   
            policy_output, value_output = model(states_batch) 
            policy_loss = model.policy_loss(policy_output.squeeze(), actions_batch)
            value_loss = model.value_loss(value_output.squeeze(), values_batch)
            loss = policy_loss + value_loss
        
        # AMP with gradient clipping
        model.optimizer.zero_grad()
        model.grad_scaler.scale(loss).backward()
        model.grad_scaler.unscale_(model.optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        model.grad_scaler.step(model.optimizer)
        model.grad_scaler.update()
    
    return loss.item(), policy_loss.item(), value_loss.item()


def run_training(num_games=100, duel_every=10, duel_winrate=0.55, buffer_capacity=1_000_000, buffer_batch_size=64,
                 buffer_fillup_period=5):
    observation, info = env.reset()
    env.render()

    replay_buffer = ReplayBuffer(capacity=buffer_capacity, batch_size=buffer_batch_size)

    terminal = False
    for g in range(num_games):    
        print(f"Starting game number: {g}")

        g_actions = []
        g_states = []

        gstep = 0
        pbar = tqdm()
        while not terminal:
            state = env.get_string_representation()

            model.eval()
            action, value = tree.search(state, observation, simulations_number=5000) # value = ucb
            model.train()

            if isinstance(value, float):
                value = torch.tensor(value, device=DEVICE)

            # Gather data
            g_actions.append(action)
            g_states.append(observation[0])
            wandb.log({'UCB': value})

            observation, reward, terminal, truncated, info = env.step(action)
            
            gstep += 1
            pbar.update()

        replay_buffer.update(g_states, g_actions, reward)

        if g > buffer_fillup_period: 
            loss, policy_loss, value_loss = update_model(model, replay_buffer, 10)

            print(f"Game: {g} - TotalLoss: {loss:.6f} - PolicyLoss: {policy_loss:.6f} - ValueLoss: {value_loss:.6f}")

        # Duel models and save best 
        if (g % duel_every == 0) and (g > 0):
            winlose = duel(env, model, best_model, 7)
            
            wandb.log({"win/loss:": winlose})
            print(winlose)
            if winlose > duel_winrate:
                torch.save(model.state_dict(), BESTMODEL_SAVEPATH) 
            
        wandb.log({"policy_loss": policy_loss.item(), "value_loss": value_loss.item(), "total_loss": loss.item()})
        
        tree.reset()
        observation, info = env.reset()
        terminal = False

    env.close()


In [None]:
# wandb.init(project="Chess")
run_training(100, duel_every=10)