In [1]:
import time, os, random
from collections import OrderedDict
import re

import wandb

import chess
import chess.pgn

import torch
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast
from torch.utils.data import random_split

from chess_transformer.superChessNet import SuperChessNetwork

import sys
sys.path.append('../../chess-utils')
from chess_dataset import ChessDataset
from utils import RunningAverage
from adversarial_gym.chess_env import ChessEnv


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
PGN_DIR_TRAIN = '/home/kage/chess_workspace/ALL_PGN_FILES'
PGN_DIR_TEST = '/home/kage/chess_workspace/ccrl/test'
# PGN_DIR_TRAIN = '/home/kage/chess_workspace/tmp'
# PGN_DIR_TEST = '/home/kage/chess_workspace/tmp'

In [3]:
def align_state_dict_keys(state_dict):
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        # Remove the unexpected prefix from each key
        new_key = re.sub(r'^od\.|^_orig_mod\.', '', key)
        new_state_dict[new_key] = value
    return new_state_dict

def get_backbone_dict(state_dict):
    backbone_dict = OrderedDict()
    for key, value in state_dict.items():
        if key.startswith('swin_transformer'):
            backbone_dict[key] = value
    return backbone_dict

def load_backbone(model, pretrained_path):
    pretrained_dict = align_state_dict_keys(torch.load(pretrained_path))
    backbone_dict = get_backbone_dict(pretrained_dict)
    model.load_state_dict(backbone_dict, strict=False)
    return model


# Initialize model
MODEL_PATH = 'super-baseSwinChessNet.pt'
MODEL_PRETRAIN_PATH = '/home/kage/chess_workspace/chess-rl/monte-carlo-tree-search-NN/best_1024-baseSwinChessNet.pt'
MODEL_SAVEPATH = "super-baseSwinChessNet.pt"

model = SuperChessNetwork(memory_size=8500, topk=750, base_lr=0.2, device='cuda')
# if os.path.exists(MODEL_PATH):
#     print(f"Loading model at: {MODEL_PATH}")
#     model.load_state_dict(align_state_dict_keys(torch.load(MODEL_PATH)))

# # Load backbone
# model = load_backbone(model, MODEL_PRETRAIN_PATH)

model = model.to('cuda' if torch.cuda.is_available() else 'cpu')
# model = torch.compile(model)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
def run_validation(model, val_loader, stats):
    model.eval()
    stats.reset(["val_loss", "val_ploss", "val_vloss", "val_hvloss", "val_mvloss"])
    t1 = time.perf_counter()

    with torch.no_grad():
        for i, (state, action, result) in enumerate(val_loader):
            state = state.float().to('cuda' if torch.cuda.is_available() else 'cpu')
            action = action.to('cuda' if torch.cuda.is_available() else 'cpu')
            result = result.float().to('cuda' if torch.cuda.is_available() else 'cpu')
            
            policy_output, hvalue_output, value_output, features = model(state.unsqueeze(1))
            
            # mem_policy_loss, mem_value_loss = model.memory_loss(features[1], features[2], action, result)
            mem_value_loss = model.value_loss(features[2].squeeze(), result)
            head_value_loss = model.value_loss(hvalue_output.squeeze(), result)
            
            policy_loss = model.policy_loss(policy_output.squeeze(), action)
            value_loss = model.value_loss(value_output.squeeze(), result)
            
            loss = policy_loss + value_loss + head_value_loss + mem_value_loss

            stats.update({
                "val_loss": loss.detach().item(),
                "val_ploss": policy_loss.detach().item(),
                "val_vloss": value_loss.detach().item(),
                "val_hvloss": head_value_loss.detach().item(),
                "val_mvloss": mem_value_loss.detach().item()
            })
        
    return stats.get_average('val_loss'), stats.get_average('val_ploss'), stats.get_average('val_vloss'), stats.get_average('val_mem_vloss'), stats.get_average('val_hvloss')

def training_round(model, train_loader, val_loader, num_epochs=10, log_every=1000, validation_every=20_000):
    stats = RunningAverage()
    stats.add(["train_loss", "train_vloss",
               "train_hvloss", "train_ploss",
               "train_mvloss", "val_loss",
               "val_vloss", "val_hvloss",
               "val_ploss", "val_mvloss"])

    best_val_loss = 1000
    
    for epoch in range(num_epochs): 
        model.train()
        t1 = time.perf_counter()
        
        for i, (state, action, result) in enumerate(train_loader):
            state = state.float().to('cuda' if torch.cuda.is_available() else 'cpu')
            action = action.to('cuda' if torch.cuda.is_available() else 'cpu')
            result = result.float().to('cuda' if torch.cuda.is_available() else 'cpu')
            
            model.optimizer.zero_grad()
            with autocast():
                policy_output, valueh_output, value_output, features = model(state.unsqueeze(1)) # features: [states, mem_actions, mem_result]
           
                # mem_policy_loss, mem_value_loss = model.memory_loss(features[1], features[2], action, result)
                mem_value_loss = model.value_loss(features[2].squeeze(), result)
                head_value_loss = model.value_loss(valueh_output.squeeze(), result)     

                policy_loss = model.policy_loss(policy_output, action)
                value_loss = model.value_loss(value_output.squeeze(), result)

                loss = policy_loss + value_loss + mem_value_loss + head_value_loss
            
            # AMP with gradient clipping
            model.grad_scaler.scale(loss).backward()
            model.grad_scaler.unscale_(model.optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            model.grad_scaler.step(model.optimizer)
            scale = model.grad_scaler.get_scale()
            model.grad_scaler.update()

            skip_lr_sched = scale > model.grad_scaler.get_scale()
            
            if not skip_lr_sched: model.scheduler.step()

            # Write data to memory
            model.write_to_memory(features[0], action, result)

            stats.update({
                "train_loss": loss.item(),
                "train_vloss": value_loss.item(),
                "train_ploss": policy_loss.item(),
                "train_hvloss": head_value_loss.item(),
                "train_mvloss": mem_value_loss.item(),
            })
            
            if i % log_every == 0:
                wandb.log({"lr": model.scheduler.get_last_lr()[0],
                            "train_loss": stats.get_average('train_loss'),
                            "train_vloss": stats.get_average('train_vloss'),
                            "train_ploss": stats.get_average('train_ploss'),
                            "train_hvloss": stats.get_average('train_hvloss'),
                            "train_mvloss": stats.get_average('train_mvloss'),
                            "iter": i})
            
            if i % validation_every == 0 and i > 0 :
                val_loss, val_ploss, val_vloss, val_mem_vloss, val_hvloss = run_validation(model, val_loader, stats)
                
                wandb.log({"val_loss": val_loss,
                           "val_vloss": val_vloss,
                           "val_ploss": val_ploss,
                           "val_hvloss": val_hvloss,
                           "val_mvloss": val_mem_vloss,
                           "iter": i})

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    torch.save(model.state_dict(), "best_" + MODEL_SAVEPATH)

        print(f"Epoch took {time.perf_counter()-t1} seconds ")
        torch.save(model.state_dict(), MODEL_SAVEPATH)

def run_training(num_rounds):
    train_data = [pgn.path for pgn in os.scandir(PGN_DIR_TRAIN) if pgn.name.endswith(".pgn")]
    test_data = [pgn.path for pgn in os.scandir(PGN_DIR_TEST) if pgn.name.endswith(".pgn")]

    init_dataset = random.sample(train_data, 2)
    init_dataset = ChessDataset(init_dataset)
    init_loader = DataLoader(init_dataset, batch_size=10, shuffle=True, num_workers=0)
    model.initialize_memory(init_loader)

    print(f"Successfully initialized memory")
    for round in range(num_rounds):
        print(f"Starting round {round}")
        # build dataset 
        # randomly sample dataset_size pgn files 
        sampled_train_data = random.sample(train_data, DATASET_SIZE_TRAIN)
        sampled_test_data = random.sample(test_data, DATASET_SIZE_TEST)

        train_dataset = ChessDataset(sampled_train_data)
        test_dataset = ChessDataset(sampled_test_data)
        
        print(f"Successfully loaded dataset with {len(train_dataset)} / {len(test_dataset)} images")
        
        train_loader = DataLoader(train_dataset, batch_size=48, shuffle=True, num_workers=0)
        val_loader = DataLoader(test_dataset, batch_size=48, shuffle=False, num_workers=0)   

        training_round(model, train_loader, val_loader, num_epochs=3, log_every=1000, validation_every=20_000)


In [5]:
wandb.init(project='Chess')

NUM_ROUNDS = 50
DATASET_SIZE_TRAIN = 10
DATASET_SIZE_TEST = 1
# DATASET_SIZE_TRAIN = 1
# DATASET_SIZE_TEST = 1
run_training(NUM_ROUNDS)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkeithg33[0m ([33mopen_sim2real[0m). Use [1m`wandb login --relogin`[0m to force relogin


Successfully initialized memory
Starting round 0
Successfully loaded dataset with 4446043 / 292714 images
Variable val_mem_ploss is not being tracked.
Variable val_mem_vloss is not being tracked.
Variable val_mem_ploss is not being tracked.
Variable val_mem_vloss is not being tracked.
Variable val_mem_ploss is not being tracked.
Variable val_mem_vloss is not being tracked.
Variable val_mem_ploss is not being tracked.
Variable val_mem_vloss is not being tracked.


KeyboardInterrupt: 

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7fbe814b1dd0>> (for post_run_cell), with arguments args (<ExecutionResult object at 7fbe802b7b50, execution_count=5 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7fbe31d94ad0, raw_cell="wandb.init(project='Chess')

NUM_ROUNDS = 50
DATAS.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/home/kage/chess_workspace/chess-rl/monte-carlo-tree-search-NN/expert_super-pretraining.ipynb#X12sZmlsZQ%3D%3D> result=None>,),kwargs {}:


TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given