# MCTS Training Notebook

In [1]:
import os, pickle
from tqdm import tqdm  # notebook compatible
import wandb
import numpy as np
import gym
import chess

import adversarial_gym
from OBM_ChessNetwork import Chess42069NetworkSimple
from search import MonteCarloTreeSearch

import torch
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast

  from .autonotebook import tqdm as notebook_tqdm


### Initialize Gym Chess Environment

In [2]:
env = gym.make("Chess-v0")

  logger.warn(


### Load Model

In [3]:
MODEL_PATH = '/home/kage/chess_workspace/chess-rl/monte-carlo-tree-search-NN/best_baseSwinChessNet.pt'
BESTMODEL_SAVEPATH = 'mcts_baseSwinChessNet_best.pt'
DEVICE = 'cuda'

model = Chess42069NetworkSimple(hidden_dim=512, device=DEVICE)
best_model = Chess42069NetworkSimple(hidden_dim=512, device=DEVICE)

if MODEL_PATH is not None:
    model.load_state_dict(torch.load(MODEL_PATH))
    best_model.load_state_dict(torch.load(MODEL_PATH))

best_model.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Chess42069NetworkSimple(
  (swin_transformer): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (layers): Sequential(
      (0): SwinTransformerStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): SwinTransformerBlock(
            (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=128, out_features=384, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=128, out_features=128, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=128

### Initialize MCTS Tree

In [4]:
tree = MonteCarloTreeSearch(env, model)

### Helper Code

In [5]:
# class ReplayBuffer:
#     """ Replay buffer to store past experiences for training policy/value network"""
#     def __init__(self, capacity):
#         self.actions = np.empty((capacity, 1), dtype=int)
#         self.states = np.empty((capacity, 8, 8), dtype=int)
#         self.values = np.empty((capacity, 1), dtype=float)

#         self.buffer = []
#         self.capacity = capacity
#         self.position = 0
    
#     def push(self, state, action, value):    
#         if len(self.buffer) < self.capacity:
#             self.buffer.append(None)
#         self.buffer[self.position] = (state, action, value)
#         self.position = (self.position + 1) % self.capacity

#     def update(self, states, actions, values):
#         for state, action, value in zip(states, actions, values):
#             self.push(state, action, value)

#     def sample(num_samples):
#         pass



In [6]:
def play_game(env, white, black, perspective=None, num_sims=1000):
    """ 
    Plays a game and returns 1 if chosen perspective has won, else 0.
    
    Perspective is either Chess.WHITE (1) or Chess.BLACK (0).
    """
    step = 0
    done = False
    obs, info = env.reset()

    while not done:
        state = env.get_string_representation()
        if step % 2 == 0:
            action, ucb = white.search(state, obs, simulations_number=num_sims)
        else:
            action, ucb = black.search(state, obs, simulations_number=num_sims)

        obs, reward, done, _, _ = env.step(action)
        step += 1

    # return reward based on winning or losing from white/black perspective
    if perspective == chess.BLACK and reward == -1:
        reward = 1
    elif perspective == chess.WHITE and reward == 1:
        reward = 1
    else:
        reward = 0

    return reward


def play_game(env, white, black, perspective: int = None, sample_n: int =1):
    """ 
    Play a game and returns whether white or black white. 
    
    Perspective - Chess.WHITE (1) or Chess.BLACK (0).
    sample_n - Set number of top moves to sample from

    """
    step = 0
    done = False
    obs, info = env.reset()
    
    while not done:
        if step % 2 == 0:
            action, log_prob = white.get_action(obs[0], env.board.legal_moves, sample_n=sample_n)
        else:
            action, log_prob = black.get_action(obs[0], env.board.legal_moves, sample_n=sample_n)

        obs, reward, done, _, _ = env.step(action)
        step += 1

    # return reward based on winning or losing from white/black perspective
    if perspective == chess.BLACK and reward == -1:
        reward = 1
    elif perspective == chess.WHITE and reward == 1:
        reward = 1
    else:
        reward = 0
        
    return reward


def duel(env, new_model, old_model, num_rounds):
    """ Duel against the previous best model and return the win ratio. """
    new_model.eval()
    old_model.eval()

    with torch.inference_mode():
        wins = 0
        for i in range(num_rounds):
            reward_w = play_game(env, new_model, old_model, perspective=chess.WHITE, sample_n=2)
            reward_b = play_game(env, old_model, new_model, perspective=chess.BLACK, sample_n=2)

            wins += reward_w + reward_b
    new_model.train()    
    return wins / (2 * num_rounds)


def run_training(num_games=100, duel_every=10, duel_winrate=0.55):
    observation, info = env.reset()
    env.render()

    terminal = False
    for g in range(num_games):    
        print(f"Starting game number: {g}")

        g_actions = []
        g_states = []

        gstep = 0
        pbar = tqdm()
        while not terminal:
            state = env.get_string_representation()

            model.eval()
            action, value = tree.search(state, observation, simulations_number=1000) # value = ucb
            model.train()

            if isinstance(value, float):
                value = torch.tensor(value, device=DEVICE)

            # Gather data
            g_actions.append(action)
            g_states.append(observation[0])
            wandb.log({'UCB': value})

            observation, reward, terminal, truncated, info = env.step(action)
            
            gstep += 1
            pbar.update()

        # Create training tensors
        if reward == 1: # white win
            g_values = [(-1)**(i) for i in range(len(g_actions))]
        elif reward == -1: # black win
            g_values = [(-1)**(i+1) for i in range(len(g_actions))]
        else:
            g_values = [0] * len(g_actions)

        g_values = torch.tensor(g_values, device=DEVICE, dtype=torch.float32)
        g_states = torch.tensor(g_states, device=DEVICE, dtype=torch.float32)

        # Update model with game data
        with autocast():   
            policy_output, value_output = model(g_states.unsqueeze(1)) 
            policy_loss = model.policy_loss(policy_output.squeeze(), torch.tensor(g_actions, device=DEVICE))
            value_loss = model.value_loss(value_output.squeeze(), g_values)
            loss = policy_loss + value_loss

        print(f"Game: {g} - TotalLoss: {loss.item():.6f} - PolicyLoss: {policy_loss.item():.6f} - ValueLoss: {value_loss.item():.6f}")

        # AMP with gradient clipping
        model.optimizer.zero_grad()
        model.grad_scaler.scale(loss).backward()
        model.grad_scaler.unscale_(model.optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        model.grad_scaler.step(model.optimizer)
        model.grad_scaler.update()

        # Duel models and save best 
        if (gstep % duel_every == 0) and (gstep > 0):
            winlose = duel(env, model, best_model, 7)
            
            wandb.log({"win/loss:": winlose})
            print(winlose)
            if winlose > duel_winrate:
                torch.save(model.state_dict(), BESTMODEL_SAVEPATH) 
            
        wandb.log({"policy_loss": policy_loss.item(), "value_loss": value_loss.item(), "total_loss": loss.item()})
        
        tree.reset()
        observation, info = env.reset()
        terminal = False

    env.close()


In [7]:
wandb.init(project="Chess")
run_training(100, duel_every=10)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkeithg33[0m ([33mopen_sim2real[0m). Use [1m`wandb login --relogin`[0m to force relogin


Starting game number: 0


83it [14:50,  9.96s/it]

TypeError: 'int' object is not iterable