# MCTS Training Notebook

In [1]:
import os, pickle, random
from tqdm import tqdm 
import wandb
import numpy as np
import gym
import chess
from torch.multiprocessing import Pool, set_start_method, Lock, Process

import adversarial_gym
from OBM_ChessNetwork import ChessNetworkSimple
from search import MonteCarloTreeSearch
from parallel import run_games_continuously, torch_safesave, ReplayBufferManager, ReplayBuffer, ChessReplayDataset

import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast

import sys
sys.path.append('../../chess_utils')
from chess_dataset import ChessDataset


  from .autonotebook import tqdm as notebook_tqdm


### Load Model

In [2]:
# MODEL_PATH = 'best_baseSwinChessNet.pt'
MODEL_PATH = None
BESTMODEL_SAVEPATH = 'mcts_baseSwinChessNet_best.pt'
CURRMODEL_SAVEPATH = 'currentMCTS' #+ MODEL_PATH
DEVICE = 'cuda'

model = ChessNetworkSimple(hidden_dim=512, device=DEVICE)
best_model = ChessNetworkSimple(hidden_dim=512, device=DEVICE)

if MODEL_PATH is not None:
    model.load_state_dict(torch.load(MODEL_PATH))
    best_model.load_state_dict(torch.load(MODEL_PATH))

model = torch.compile(model)
x = torch.randn((1,1,8,8), device='cuda', dtype=torch.float32)
out = model(x)
best_model = torch.compile(best_model)
best_model.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]




OptimizedModule(
  (_orig_mod): ChessNetworkSimple(
    (swin_transformer): SwinTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))
        (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (layers): Sequential(
        (0): SwinTransformerStage(
          (downsample): Identity()
          (blocks): Sequential(
            (0): SwinTransformerBlock(
              (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attn): WindowAttention(
                (qkv): Linear(in_features=128, out_features=384, bias=True)
                (attn_drop): Dropout(p=0.0, inplace=False)
                (proj): Linear(in_features=128, out_features=128, bias=True)
                (proj_drop): Dropout(p=0.0, inplace=False)
                (softmax): Softmax(dim=-1)
              )
              (drop_path1): Identity()
              (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)

### Helper Code

In [3]:
# class ReplayBuffer:
#     """ Replay buffer to store past experiences for training policy/value network"""
#     def __init__(self, capacity=None):
#         self.actions = []
#         self.states = []
#         self.values = []

#         self.capacity = capacity
#         self.curr_length = 0
#         self.position = 0
    
#     def get_state(self):
#         return {
#             'actions': list(self.actions),
#             'states': list(self.states),
#             'values': list(self.values),
#             'capacity': self.capacity,
#             'curr_length': self.curr_length,
#             'position': self.position
#         }

#     def from_dict(self, buffer_state_dict):
#         for key, value in buffer_state_dict.items():
#             setattr(self, key, value)
    
#     def push(self, state, action, value):    
#         if len(self.actions) < self.capacity:
#             self.states.append(None)
#             self.actions.append(None)
#             self.values.append(None)
        
#         self.states[self.position] = state
#         self.actions[self.position] = action
#         self.values[self.position] = value

#         self.curr_length = len(self.states)
#         self.position = (self.position + 1) % self.capacity

#     def update(self, states, actions, winner):
#         # Create value targets based on who won
#         if winner == 1:
#             values = [(-1)**(i) for i in range(len(actions))]
#         elif winner == -1:
#             values = [(-1)**(i+1) for i in range(len(actions))]
#         else:
#             values = [0] * len(actions)

#         for state, action, value in zip(states, actions, values):
#             self.push(state, action, value)

# class ChessReplayDataset(Dataset):
#     def __init__(self, replay_buffer_proxy):
#         # Initialize the dataset with replay buffer data
#         self.replay_buffer = ReplayBuffer().from_dict(replay_buffer_proxy.get_state())

#     def __len__(self):
#         # Return the current size of the replay buffer
#         return self.replay_buffer.curr_length

#     def __getitem__(self, idx):
#         # Fetch a single experience at the specified index
#         if idx >= len(self):
#             raise IndexError('Index out of range in ChessReplayDataset')
#         state = self.replay_buffer.states[idx]
#         action = self.replay_buffer.actions[idx]
#         value = self.replay_buffer.values[idx]
#         return state, action, value


In [4]:
def play_game(env, white, black, perspective: int = None, sample_n: int = 1):
    """ 
    Play a game and returns whether white or black white. 
    
    Perspective - Chess.WHITE (1) or Chess.BLACK (0).
    sample_n - Set number of top moves to sample from

    """
    step = 0
    done = False
    obs, info = env.reset()
    
    while not done:
        if step % 2 == 0:
            action, log_prob = white.get_action(obs[0], env.board.legal_moves, sample_n=sample_n)
        else:
            action, log_prob = black.get_action(obs[0], env.board.legal_moves, sample_n=sample_n)

        obs, reward, done, _, _ = env.step(action)
        step += 1

    # return reward based on winning or losing from white/black perspective
    if perspective == chess.BLACK and reward == -1:
        reward = 1
    elif perspective == chess.WHITE and reward == 1:
        reward = 1
    else:
        reward = 0
        
    return reward


def duel(env, new_model, old_model, num_rounds):
    """ Duel against the previous best model and return the win ratio. """
    new_model.eval()
    old_model.eval()

    with torch.inference_mode():
        wins = 0
        for i in range(num_rounds):
            reward_w = play_game(env, new_model, old_model, perspective=chess.WHITE, sample_n=2)
            reward_b = play_game(env, old_model, new_model, perspective=chess.BLACK, sample_n=2)

            wins += reward_w + reward_b
    new_model.train()    
    return wins / (2 * num_rounds)


def update_model(model, selfplay_buffer_proxy, expert_dataset, dataset_size):
    """ 
    Train on selfplay and expert data. Builds a dataset of size dataset_size, where the
    proportion of data comes from,
    
        expert_size + selfplay_size = dataset_size 
    
    If the replay buffer has more data than dataset_size, will sample from selfplay data only
    
    """
    # Initialize buffer dataset
    print("updating model")
    selfplay_dataset = ChessReplayDataset(selfplay_buffer_proxy)
    
    expert_size = dataset_size - len(selfplay_dataset)

    if expert_size > 0: # combine data
        indices = random.sample(range(1, len(expert_dataset)), expert_size)
        expert_subset = torch.utils.data.Subset(expert_dataset, indices)
        train_dataset = ConcatDataset([expert_subset, selfplay_dataset])

    else: # selfplay data
        indices = random.sample(range(1, len(selfplay_dataset)), dataset_size)
        train_dataset = torch.utils.data.Subset(selfplay_dataset, indices)

    # Create dataloader
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    print("dataset and trainloader created")
    # Initialize losses
    total_policy_loss = 0
    total_value_loss = 0
    total_loss = 0
    model.train()  # Set the model to training mode

    for (states_batch, actions_batch, values_batch) in train_loader:
        print("doing batch")
        states_batch = states_batch.to(model.device, dtype=torch.float32).unsqueeze(1)
        actions_batch = actions_batch.to(model.device, dtype=torch.long)
        values_batch = values_batch.to(model.device, dtype=torch.float32)

        # Forward pass and calculate loss
        with autocast():
            policy_output, value_output = model(states_batch)
            policy_loss = model.policy_loss(policy_output.squeeze(), actions_batch)
            value_loss = model.value_loss(value_output.squeeze(), values_batch)
            loss = policy_loss + value_loss

        # Backward pass and optimization
        model.optimizer.zero_grad()
        model.grad_scaler.scale(loss).backward()
        model.grad_scaler.unscale_(model.optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        model.grad_scaler.step(model.optimizer)
        model.grad_scaler.update()

        # Record the losses
        total_policy_loss += policy_loss.item()
        total_value_loss += value_loss.item()
        total_loss += loss.item()
    
    # Average the losses over the iterations
    avg_policy_loss = total_policy_loss / len(train_loader)
    avg_value_loss = total_value_loss / len(train_loader)
    avg_loss = total_loss / len(train_loader)

    return avg_loss, avg_policy_loss, avg_value_loss

##### Parallel

In [5]:
def pickle_bufferproxy(buffer_proxy):
    shared_buffer_state = buffer_proxy.get_state()
    with open('replay_buffer_state.pkl', 'wb') as f:
        pickle.dump(shared_buffer_state, f)
        
def run_training(num_games, expert_dataset, games_in_parallel, train_every, duel_every=10, duel_winrate=0.55, buffer_capacity=1_000_000):
    # Multiprocessing stuff
    manager = ReplayBufferManager()
    manager.start()
    shared_replay_buffer = manager.ReplayBuffer(capacity=buffer_capacity)
    shutdown_event = manager.Event()
    buffer_lock = manager.Lock()
    global_game_counter = manager.GameCounter()  # Initialize a shared counter
    file_lock = Lock()
    
    model_state = model.state_dict()
    model_state = {k: v.cpu() for k, v in model.state_dict().items()} # can't share cuda tensors

    # Start the continuous game running process in a separate process
    process = Process(target=run_games_continuously, args=(model_state, CURRMODEL_SAVEPATH, shared_replay_buffer, games_in_parallel, buffer_lock, file_lock, global_game_counter, shutdown_event))
    process.start()

    train_flag = False
    duel_flag = False

    # Main training loop
    env = gym.make("Chess-v0")
    training = True
    while training:
        game_count = global_game_counter.count
        if game_count > 0 and game_count % train_every == 0:
            train_flag = True
        if game_count > 0 and game_count % duel_every == 0:
            duel_flag = True
        if game_count >= num_games:
            shutdown_event.set()
            training = False # don't break so we can train and duel one last time
            
            # train_flag = True
            # duel_flag = True

        if train_flag:
            loss, policy_loss, value_loss = update_model(model, shared_replay_buffer, expert_dataset, 100)
            print(loss, policy_loss, value_loss)
            # wandb.log({"policy_loss": policy_loss.item(), "value_loss": value_loss.item(), "total_loss": loss.item()})
            torch_safesave(model.state_dict(), CURRMODEL_SAVEPATH, file_lock)
            pickle_bufferproxy(shared_replay_buffer)
            train_flag = False
        
        if duel_flag:
            winlose = duel(env, model, best_model, 7)
            # wandb.log({"win/loss:": winlose})
            print(winlose)
            if winlose > duel_winrate: 
                torch.save(model.state_dict(), BESTMODEL_SAVEPATH)
            duel_flag = False

    process.join()

    # Done. Save model and buffer
    torch.save(model.state_dict(), CURRMODEL_SAVEPATH)
    shared_buffer_state = shared_replay_buffer.get_state()
    with open('replay_buffer.pkl', 'wb') as f:
        pickle.dump(shared_buffer_state, f)    

In [6]:
PGN_FILE = "/home/kage/chess_workspace/PGN-data/tcec+alphastock/TCEC_Cup_1_Final_5.pgn"

# Load the datasets
expert_dataset = ChessDataset(PGN_FILE)
print(len(expert_dataset))

1322


In [7]:
# wandb.init(project="Chess")
set_start_method('spawn', force=True)

NUM_GAMES = 25
GAMES_IN_PARALLEL = 4


run_training(NUM_GAMES, expert_dataset, GAMES_IN_PARALLEL, train_every=5, duel_every=5)

  logger.warn(
  logger.warn(
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  logger.warn(
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  logger.warn(
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  logger.warn(
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):


gameover
gameover
gameover
gameover
gameoverNOW TRAINING
updating model

dataset and trainloader created
doing batch
gameover
gameover
gameover
gameover
gameover
gameover
gameover


  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
gameover
doing batch
8.83162260055542 8.57621955871582 0.2554032653570175


TypeError: pickle_safesave() takes 1 positional argument but 2 were given

gameover
gameover
gameover
gameover
