In [12]:
from eval_maia2 import parse_args
from utils import seed_everything, get_all_possible_moves, create_elo_dict
from net import MAIA2Model

cfg = parse_args(args="")
print("Configurations:", flush=True)
for arg in vars(cfg):
	print(f"\t{arg}: {getattr(cfg, arg)}", flush=True)
seed_everything(cfg.seed)

all_moves = get_all_possible_moves()
all_moves_dict = {move: i for i, move in enumerate(all_moves)}
# elo_dict = create_elo_dict()
elo_dict = create_elo_dict()

# Load MAIA2 model from checkpoint
model = MAIA2Model(len(all_moves), elo_dict, cfg)

Configurations:
	seed: 42
	num_workers: 0
	verbose: 1
	max_ply: 300
	clock_threshold: 30
	use_clock_filter: True
	chunk_size: 20000
	from_checkpoint: None
	max_games_per_elo_range: 20
	batch_size: 2048
	first_n_moves: 10
	last_n_moves: 10
	dim_cnn: 256
	dim_vit: 1024
	num_blocks_cnn: 5
	num_blocks_vit: 2
	input_channels: 18
	vit_length: 8
	elo_dim: 128
	side_info: True
	value: True
	max_depth: 1


In [1]:
# Load model from checkpoint
import torch
import torch.onnx
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

checkpoint = torch.load("tmp/0.0001_8192_1e-05_FT_All/epoch_2_2023-03.pgn.pt", map_location="cpu")
state_dict = {k.replace("module.", ""): v for k, v in checkpoint["model_state_dict"].items()}
model.load_state_dict(state_dict)
model.eval()

# Export model to ONNX
# (Pdb) boards.shape                                                                                                                               │
# torch.Size([2048, 18, 8, 8])                                                                                                                     │
# (Pdb) elos_self.shape                                                                                                                            │
# torch.Size([2048])                                                                                                                               │
# (Pdb) elos_oppo.shape                                                                                                                            │
# torch.Size([2048])
# Create dummy inputs for the model
dummy_input1 = torch.randn(1, 18, 8, 8)
dummy_input2 = torch.Tensor([0]).long()
dummy_input3 = torch.Tensor([0]).long()

# Define the file path for the ONNX model
onnx_model_path = "checkpoints/onnx/maia2_ft.onnx"
quantized_model_path = "checkpoints/onnx/maia2_ft_quant.onnx"

# Export the model
torch.onnx.export(
    model,                 # Model being run
    (dummy_input1, dummy_input2, dummy_input3),  # Model inputs as a tuple
    onnx_model_path,       # Where to save the model (can be a file or file-like object)
    export_params=True,    # Store the trained parameter weights inside the model file
    opset_version=11,      # The ONNX version to export the model to
    do_constant_folding=True,  # Whether to execute constant folding for optimization
    input_names=['boards', 'elo_self', 'elo_oppo'],  # The model's input names
    output_names=['logits_maia', 'logits_side_info', 'logits_value'],  # The model's output names
    dynamic_axes={
        'boards': {0: 'batch_size'}, 
        'elo_self': {0: 'batch_size'}, 
        'elo_oppo': {0: 'batch_size'}, 
        'logits_maia': {0: 'batch_size'},
        'logits_side_info': {0: 'batch_size'},
        'logits_value': {0: 'batch_size'},
    }  # Variable length axes
)

print(f"Model has been converted to ONNX and saved at {onnx_model_path}")

# Load the ONNX model
onnx_model = onnx.load(onnx_model_path)

# Apply dynamic quantization
quantized_model = quantize_dynamic(
    onnx_model_path,              # Path to the model to quantize
    quantized_model_path,         # Path to save the quantized model
    weight_type=QuantType.QUInt8  # Quantize weights to 8-bit unsigned integers
)

print(f"Quantized model saved at {quantized_model_path}")

  from .autonotebook import tqdm as notebook_tqdm


Quantized model saved at checkpoints/onnx/maia2_ft_quant.onnx


In [1]:
import onnx
import onnxruntime as ort
import numpy as np
import tqdm

# Load ONNX model
# onnx_model_path = "checkpoints/onnx/maia2_ft.onnx"
onnx_model_path = "checkpoints/onnx/maia2_ft_quant.onnx"
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)

# Initialize ONNX runtime session
ort_session = ort.InferenceSession(onnx_model_path)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def evaluate_onnx(model_path, dataloader):
    ort_session = ort.InferenceSession(model_path)
    
    counter = 0
    correct_move = 0

    maia_preds = []
    with torch.no_grad():
        for boards, labels, elos_self, elos_oppo, legal_moves, _, _ in tqdm.tqdm(dataloader):
            # Convert torch tensors to numpy arrays
            boards = to_numpy(boards)
            elos_self = to_numpy(elos_self)
            elos_oppo = to_numpy(elos_oppo)
            legal_moves = to_numpy(legal_moves)

            # Run the ONNX model
            ort_inputs = {
                'boards': boards,
                'elo_self': elos_self,
                'elo_oppo': elos_oppo
            }
            ort_outs = ort_session.run(None, ort_inputs)
            logits_maia_legal = ort_outs[0] * legal_moves
            preds = np.argmax(logits_maia_legal, axis=-1)
            correct_move += (preds == to_numpy(labels)).sum().item()

            counter += len(labels)
            maia_preds.append(preds)

            if counter > 10000:
                break

    return correct_move, counter, maia_preds


In [2]:
import pickle

def load_preprocessed_data(cache_path):
    with open(cache_path, "rb") as f:
        cache = pickle.load(f)
    return cache["data"], cache["game_count"], cache["chunk_count"]

In [3]:
import torch

from eval_maia2 import parse_args
from utils import seed_everything, get_all_possible_moves, create_elo_dict
from data import read_cache_data, MAIA2Dataset

cfg = parse_args(args="")
print("Configurations:", flush=True)
for arg in vars(cfg):
	print(f"\t{arg}: {getattr(cfg, arg)}", flush=True)
seed_everything(cfg.seed)

all_moves = get_all_possible_moves()
all_moves_dict = {move: i for i, move in enumerate(all_moves)}
elo_dict = create_elo_dict()

val_paths = read_cache_data(
	"/grace/u/dhkim2810/maia_gm/dataset/cache",
	"/grace/u/dhkim2810/maia_gm/file_list.pkl",
	mode="validation",
)

total = 0
maia_correct = 0
for val_file in val_paths[:1]:
	data, game_count, chunk_count = load_preprocessed_data(val_file)
	dset = MAIA2Dataset(data, all_moves_dict, cfg)
	loader = torch.utils.data.DataLoader(
		dset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers
	)
	correct, counter, preds = evaluate_onnx(onnx_model_path, loader)
	total += counter
	maia_correct += correct
maia2_acc = maia_correct / total
print("MAIA2 Accuracy: ", maia2_acc, flush=True)

  from .autonotebook import tqdm as notebook_tqdm


Configurations:
	seed: 42
	num_workers: 0
	verbose: 1
	max_ply: 300
	clock_threshold: 30
	use_clock_filter: True
	chunk_size: 20000
	from_checkpoint: None
	max_games_per_elo_range: 20
	batch_size: 2048
	first_n_moves: 10
	last_n_moves: 10
	dim_cnn: 256
	dim_vit: 1024
	num_blocks_cnn: 5
	num_blocks_vit: 2
	input_channels: 18
	vit_length: 8
	elo_dim: 128
	side_info: True
	value: True
	max_depth: 1


  0%|          | 4/6816 [00:51<24:18:47, 12.85s/it]

MAIA2 Accuracy:  0.54521484375





In [4]:
print(f"ONNX Model Performance : 55.29, 3.99s/batch, 2048 moves per batch")
print(f"UINT8 ONNX Model Performance : 54.52, 12.85s/batch, 2048 moves per batch")

ONNX Model Performance : 55.29, 3.99s/batch, 2048 per batch
UINT8 ONNX Model Performance : 54.52, 12.85s/batch, 2048 per batch


In [None]:
- UINT8 is not ideal, INT8 is ideal but not supported
- ONNX model can be improved using TensorRT(optimization kit)
- MCTS search requires 75700+ inference for each move search

 - Leela is faster due to optimized neural network (c++)
 - When using GPU, we have around 1.2s/batch on Grace(single GPU)

- hard to estimate but they say it can reach 20x speed 
- with onnx model, it's quite straightforward, but
need additional implementation on data loading and preprocessing

- Thats all