Chess RL + Expert Learning + Weirdness 

In [1]:
import wandb
import chess
import gym
import chess
import os, sys, copy
import torch
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast
from tqdm import tqdm
import numpy as np
import time
import adversarial_gym
from adversarial_gym.chess_env import ChessEnv

from OBM_ChessNetwork import Chess42069NetworkSimple

sys.path.append('../../chess_utils')
from chess_dataset import ChessDataset
from utils import RunningAverage

  from .autonotebook import tqdm as notebook_tqdm


Model Gameplay Functions

In [2]:
def play_game(env, white, black, perspective=None, sample_n=1, duel=False):
    step = 0
    actions = []
    log_probs = []
    observations = []
    values = []
    done_mask = []
    done = False
    obs = env.reset()[0]
    while not done:
        # Note: Modify the get_action function to return the action, log_prob, and state value
        if step % 2 == 0:
            action_logits, value_estimate = white(obs[0])
            action, log_prob = white.to_action(action_logits, env.board.legal_moves, sample_n) # same for black/white
        else:
            action_logits, value_estimate = black(obs[0])
            action, log_prob = black.to_action(action_logits, env.board.legal_moves, sample_n) # same for black/white

        if perspective is None or perspective == chess.WHITE and step % 2 == 0 or perspective == chess.BLACK and step % 2 == 1:
            observations.append(obs[0])
            actions.append(action)
            log_probs.append(log_prob)
            values.append(value_estimate)
            done_mask.append(0 if not done else 1)

        obs, reward, done, _, info = env.step(action)
        step += 1
        
    if reward in [-1,1]:
        print(f"PERSPECTIVE: {perspective} - GAME OUTCOME: {reward}")
    else:
        print(f"DRAW")

    # Reward is 1 if chosen perspective won and -1 if chosen perspective lost
    # If dueling set reward to 1 for a win and zero otherwise
    if perspective is not None:
        if perspective == chess.BLACK: reward *= -1 
        if duel and reward != 1: reward = 0

    rewards = prepare_game_rewards(reward, perspective, len(actions)) if not duel else reward

    return observations, actions, log_probs, values, done_mask, rewards


def prepare_game_rewards(reward, perspective, game_len):
    if perspective is not None:
        # Reward in [-1, 0, 1] 
        rewards = [reward for _ in range(game_len)]
        return rewards
    
    # Self play alternates +/- reward
    # Reward in [-1,0,1].
    rewards = [reward if i % 2 == 0 else -reward for i in range(game_len)]

    rewards = compute_discounted_rewards(rewards)
    return rewards
    
    
def compute_discounted_rewards(reward, gamma=0.99):
    """Compute discounted rewards for a sequence of rewards."""
    n = len(reward)
    discounted_rewards = [0] * n
    running_add = 0
    for t in reversed(range(n)):
        running_add = running_add * gamma + reward[t]
        discounted_rewards[t] = running_add
    return discounted_rewards


def duel(env, old_model, new_model, num_rounds):
    """ Duel against the previous best model and return the win ratio. """
    new_model.eval()
    with torch.no_grad():
        wins = 0
        for i in range(num_rounds):
            _, _, _, _, _, r_w = play_game(env, new_model, old_model, perspective=chess.WHITE, sample_n = 2, duel=True)
            _, _, _, _, _, r_b = play_game(env, old_model, new_model, perspective=chess.BLACK, sample_n = 2, duel=True)

            wins += r_w + r_b
    new_model.train()    
    return wins / (2 * num_rounds)


def self_play(env, model, num_games):
    """ Plays num_games against itself to gather obs, actions, log_probs, rewards data """
    # TODO: check if numpy array of shape (num_games, 4) is faster, each row could be output of play_game
    actions = []
    log_probs = []
    rewards = []
    observations = []
    for _ in range(num_games):
        g_obs, g_actions, g_log_probs, g_reward = play_game(env, model, model, perspective=None)
        actions.append(g_actions)
        log_probs.append(g_log_probs)
        rewards.append(g_reward)
        observations.append(g_obs)
    return observations, actions, log_probs, rewards

Expert Learning

In [3]:
def run_validation(model, val_loader, stats):
    model.eval()
    stats.reset("val_loss")
    t1 = time.perf_counter()
    with torch.no_grad():
        for i, (state, action, result) in enumerate(val_loader):
            state = state.float().to('cuda' if torch.cuda.is_available() else 'cpu')
            action = action.to('cuda' if torch.cuda.is_available() else 'cpu')
            result = result.float().to('cuda' if torch.cuda.is_available() else 'cpu')
            
            policy_output, value_output = model(state.unsqueeze(1))
            policy_loss = model.policy_loss(policy_output.squeeze(), action)
            value_loss = model.val_loss(value_output.squeeze(), result)
            
            loss = policy_loss + value_loss
            stats.update("val_loss", loss.item())
    
    print(f"Mean Validation Loss: {stats.get_average('val_loss')}, time elapsed: {time.perf_counter()-t1} seconds")
    return stats.get_average('val_loss')


def expert_study(model, dataset, percent_dataset=0.1):
    """ Trains on TCEC data in a supervised fashion (behaviour cloning)"""

    # Load random subset of dataset and split
    study_size = int(percent_dataset * len(dataset))
    random_indices = np.random.randint(0, study_size, study_size)
    study_dataset = Subset(dataset, random_indices)
    
    train_ratio = 0.9
    train_size = int(train_ratio * study_size)
    val_size = study_size - train_size
    train_dataset, val_dataset = random_split(study_dataset, [train_size, val_size])

    # Create data loaders for the training and validation sets
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, 
                            pin_memory=False, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, 
                            pin_memory=False, num_workers=2)

    stats = RunningAverage()
    stats.add(["train_loss", "val_loss", "train_p_loss", "train_v_loss"])

    model.train()
    t1 = time.perf_counter()
    for i, (state, action, result) in enumerate(train_loader):
        state = state.float().to(model.device)
        action = action.to(model.device)
        result = result.float().to(model.device)

        with autocast():
            policy_output, value_output = model(state.unsqueeze(1))
            policy_loss = model.policy_loss(policy_output.squeeze(), action)
            value_loss = model.val_loss(value_output.squeeze(), result)
            loss = policy_loss + value_loss
        
        # AMP with gradient clipping
        model.optimizer.zero_grad()
        model.grad_scaler.scale(loss).backward()
        model.grad_scaler.unscale_(model.optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        model.grad_scaler.step(model.optimizer)
        model.grad_scaler.update()

        stats.update({
            "train_loss": loss.item(),
            "train_p_loss": policy_loss.item(),
            "train_v_loss": value_loss.item()
            })
        
    print(f"Study Train Loss: {stats.get_average('train_loss')}")
    # wandb.log({"study_train_loss": stats.get_average('train_loss')})
    t2 = time.perf_counter()
    valid_loss = run_validation(model, val_loader, stats)
    # wandb.log({"val_loss": valid_loss, "iter": i})


Load Models

In [4]:
# Load Model
MODEL_PATH = '/home/kage/chess_workspace/simpler_SwinChessNet42069.pt'

model = Chess42069NetworkSimple(hidden_dim=256, device='cuda')
best_model = Chess42069NetworkSimple(hidden_dim=256, device='cuda')

if os.path.exists(MODEL_PATH):
    print("Loading model at: {MODEL_PATH}")
    model.load_state_dict(torch.load(MODEL_PATH))
    best_model.load_state_dict(torch.load(MODEL_PATH))

best_model.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Loading model at: {MODEL_PATH}


Chess42069NetworkSimple(
  (swin_transformer): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(1, 96, kernel_size=(1, 1), stride=(1, 1))
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    )
    (layers): Sequential(
      (0): SwinTransformerStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): SwinTransformerBlock(
            (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=96, out_features=288, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=96, out_features=96, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=96, out_fe

Train - VPG with Self-Play and Dueling

In [5]:
# Train Params
PGN_FILE = '/home/kage/chess_workspace/PGN-data/alphazero_stockfish_all/alphazero_vs_stockfish_all.pgn'
MODEL_SAVEPATH = '/home/kage/chess_workspace/WACKY_RL_MODEL.pt'

NUM_EPOCHS = 100
STUDY_EVERY = 1 
DUEL_EVERY = 50

chess_dataset = ChessDataset(PGN_FILE)
env = gym.make("Chess-v0")

  logger.warn(


In [9]:
for i in tqdm(range(NUM_EPOCHS)):
    # Play games as white or black against the previous best model
    if i % 2 == 0:
        observations, actions, log_probs, values, done_mask, rewards = play_game(env, model, best_model, perspective=chess.WHITE, sample_n=3)
    else:
        observations, actions, log_probs, values, done_mask, rewards = play_game(env, best_model, model, perspective=chess.BLACK, sample_n=3)
    
    next_values = values[1:] + [0] # Value of next state, 0 for final action

    # Convert data to PyTorch tensors
    # observations = torch.as_tensor(observations, dtype=torch.float32, device=model.device)
    actions = torch.as_tensor(actions, dtype=torch.int64, device= model.device)
    log_probs = torch.stack(log_probs).to(model.device)
    values = torch.as_tensor(values, dtype=torch.float32, device= model.device)
    rewards = torch.as_tensor(rewards, dtype=torch.float32, device= model.device)
    next_values = torch.as_tensor(next_values, dtype=torch.float32, device= model.device)
    done_mask = torch.as_tensor(done_mask, dtype=torch.float32, device= model.device)

    model.update_network(log_probs, rewards, values, next_values, done_mask, gamma=0.99)

    # # # Expert Study
    # # # if i % STUDY_EVERY == 0:
    # # if True:
    # #     expert_study(model, chess_dataset, percent_dataset=0.05)

    # Darwinian duel to the death
    if i % DUEL_EVERY == 0:
        win_ratio = duel(env, best_model, model, num_rounds=10)
        print(f"Model win ratio: {win_ratio}")
        if win_ratio > 0.6:
            print("Best model was deafeted!")
            best_model = copy.deepcopy(model)
            torch.save(model.state_dict(), MODEL_SAVEPATH)
            best_model.eval()

    # Self play
    observations, actions, log_probs, values, done_mask, rewards = play_game(env, model, model, perspective=None, sample_n=3)
    next_values = values[1:] + [0] # Value of next state, 0 for final action

    # Convert data to PyTorch tensors
    # observations = torch.as_tensor(observations, dtype=torch.float32, device=model.device)
    actions = torch.as_tensor(actions, dtype=torch.int64, device= model.device)
    log_probs = torch.stack(log_probs).to(model.device)
    values = torch.as_tensor(values, dtype=torch.float32, device= model.device)
    rewards = torch.as_tensor(rewards, dtype=torch.float32, device= model.device)
    next_values = torch.as_tensor(next_values, dtype=torch.float32, device= model.device)
    done_mask = torch.as_tensor(done_mask, dtype=torch.float32, device= model.device)

    model.update_network(log_probs, rewards, values, next_values, done_mask, gamma=0.99, selfplay=True)


  0%|          | 0/100 [00:00<?, ?it/s]

DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
PERSPECTIVE: False - GAME OUTCOME: 1
DRAW
DRAW
DRAW
DRAW
DRAW
PERSPECTIVE: False - GAME OUTCOME: 1
PERSPECTIVE: True - GAME OUTCOME: 1
DRAW
DRAW
DRAW
Model win ratio: 0.05
DRAW


  1%|          | 1/100 [01:19<2:11:17, 79.57s/it]

PERSPECTIVE: None - GAME OUTCOME: 1


  2%|▏         | 2/100 [01:31<1:04:31, 39.50s/it]

DRAW


  3%|▎         | 3/100 [01:45<45:07, 27.91s/it]  

DRAW


  4%|▍         | 4/100 [01:58<35:35, 22.24s/it]

PERSPECTIVE: None - GAME OUTCOME: 1


  5%|▌         | 5/100 [02:04<25:42, 16.23s/it]

DRAW


  6%|▌         | 6/100 [02:24<27:38, 17.64s/it]

DRAW


  7%|▋         | 7/100 [02:44<28:31, 18.40s/it]

PERSPECTIVE: None - GAME OUTCOME: -1


  8%|▊         | 8/100 [02:56<25:12, 16.44s/it]

DRAW


  9%|▉         | 9/100 [03:26<31:02, 20.47s/it]

DRAW


 10%|█         | 10/100 [03:40<27:38, 18.43s/it]

DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
PERSPECTIVE: False - GAME OUTCOME: 1
DRAW
DRAW
DRAW
DRAW
PERSPECTIVE: True - GAME OUTCOME: 1
DRAW
DRAW
DRAW
DRAW
PERSPECTIVE: False - GAME OUTCOME: 1
DRAW
PERSPECTIVE: False - GAME OUTCOME: -1
Model win ratio: 0.1
DRAW


 11%|█         | 11/100 [04:51<51:27, 34.69s/it]

DRAW


 12%|█▏        | 12/100 [05:05<41:47, 28.50s/it]

DRAW


 13%|█▎        | 13/100 [05:25<37:27, 25.83s/it]

DRAW


 14%|█▍        | 14/100 [05:42<33:14, 23.19s/it]

DRAW


 15%|█▌        | 15/100 [05:58<29:49, 21.05s/it]

PERSPECTIVE: None - GAME OUTCOME: 1


 16%|█▌        | 16/100 [06:07<24:12, 17.30s/it]

DRAW


 17%|█▋        | 17/100 [06:22<22:48, 16.49s/it]

DRAW


 18%|█▊        | 18/100 [06:36<21:41, 15.87s/it]

DRAW


 19%|█▉        | 19/100 [06:51<21:15, 15.74s/it]

PERSPECTIVE: None - GAME OUTCOME: -1


 20%|██        | 20/100 [06:59<17:39, 13.24s/it]

DRAW
PERSPECTIVE: False - GAME OUTCOME: 1
DRAW
DRAW
PERSPECTIVE: True - GAME OUTCOME: -1
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
DRAW
Model win ratio: 0.0
DRAW


 21%|██        | 21/100 [08:07<39:08, 29.73s/it]

DRAW


 22%|██▏       | 22/100 [08:15<30:04, 23.13s/it]

DRAW


 23%|██▎       | 23/100 [08:22<23:32, 18.34s/it]

DRAW


 24%|██▍       | 24/100 [08:38<22:22, 17.67s/it]

PERSPECTIVE: None - GAME OUTCOME: -1


 25%|██▌       | 25/100 [08:43<17:31, 14.02s/it]

PERSPECTIVE: None - GAME OUTCOME: -1


 26%|██▌       | 26/100 [08:50<14:35, 11.83s/it]

PERSPECTIVE: None - GAME OUTCOME: -1


 27%|██▋       | 27/100 [08:54<11:33,  9.50s/it]

PERSPECTIVE: None - GAME OUTCOME: -1


 28%|██▊       | 28/100 [08:58<09:27,  7.88s/it]

DRAW


 29%|██▉       | 29/100 [09:08<10:02,  8.49s/it]

DRAW


 30%|███       | 30/100 [09:19<10:47,  9.24s/it]

DRAW
DRAW
DRAW
DRAW
PERSPECTIVE: True - GAME OUTCOME: 1
DRAW
DRAW
PERSPECTIVE: False - GAME OUTCOME: 1
DRAW
DRAW
DRAW
PERSPECTIVE: False - GAME OUTCOME: -1
DRAW
DRAW
PERSPECTIVE: True - GAME OUTCOME: 1
DRAW
DRAW
DRAW
DRAW
DRAW
Model win ratio: 0.15
DRAW


 31%|███       | 31/100 [10:24<29:54, 26.00s/it]

DRAW


 32%|███▏      | 32/100 [10:37<24:58, 22.04s/it]

DRAW


 33%|███▎      | 33/100 [10:52<22:08, 19.82s/it]

DRAW


 34%|███▍      | 34/100 [11:08<20:43, 18.84s/it]

DRAW


 35%|███▌      | 35/100 [11:16<16:49, 15.53s/it]

DRAW


 36%|███▌      | 36/100 [11:40<19:09, 17.96s/it]

DRAW


 37%|███▋      | 37/100 [11:59<19:12, 18.30s/it]

DRAW


 38%|███▊      | 38/100 [12:17<18:46, 18.17s/it]

DRAW
