# Comparing Searchless Chess Policies
### Alex Kim, CSCI 381

searchless_chess.src is copyrighted by DeepMind.

Here, I use their library to build simple models.
Then, I vary the policy and temperature to do some quick hypothetical ablations.

In [None]:
import os
import chess
import chess.svg
from jax import random as jrandom
import numpy as np

In [None]:
from searchless_chess.src import tokenizer
from searchless_chess.src import training_utils
from searchless_chess.src import transformer
from searchless_chess.src import utils
from searchless_chess.src.engines import engine
from searchless_chess.src.engines import neural_engines

In [None]:
# Create predictor (this cell written by DeepMind)

policy = 'action_value'
num_return_buckets = 128

match policy:
  case 'action_value':
    output_size = num_return_buckets
  case 'behavioral_cloning':
    output_size = utils.NUM_ACTIONS
  case 'state_value':
    output_size = num_return_buckets
  case _:
    raise ValueError(f'Unknown policy {policy}')

predictor_config = transformer.TransformerConfig(
    vocab_size=utils.NUM_ACTIONS,
    output_size=output_size,
    pos_encodings=transformer.PositionalEncodings.LEARNED,
    max_sequence_length=tokenizer.SEQUENCE_LENGTH + 2,
    num_heads=4,
    num_layers=4,
    embedding_dim=64,
    apply_post_ln=True,
    apply_qk_layernorm=False,
    use_causal_mask=False,
)

predictor = transformer.build_transformer_predictor(config=predictor_config)

In [None]:
# Load dummy params

params = predictor.initial_params(
    rng=jrandom.PRNGKey(0),
    targets=np.zeros((1, 1), dtype=np.uint32),
)

In [None]:
# Initialize engine

def create_engine(policy, num_return_buckets, temperature):
    predict_fn = neural_engines.wrap_predict_fn(predictor, params, batch_size=1)
    _, return_buckets_values = utils.get_uniform_buckets_edges_values(num_return_buckets)
    neural_engine = neural_engines.ENGINE_FROM_POLICY[policy](
        return_buckets_values=return_buckets_values,
        predict_fn=predict_fn,        temperature=temperature,
    )
    
    return neural_engine

In [None]:
def compute_win_percentages(neural_engine, board):
    results = neural_engine.analyse(board)
    buckets_log_probs = results['log_probs']
    win_probs = np.inner(np.exp(buckets_log_probs), return_buckets_values)
    sorted_legal_moves = engine.get_ordered_legal_moves(board)
    return win_probs, sorted_legal_moves

In [None]:
# Vary policy
policies = ['action_value', 'behavioral_cloning', 'state_value']
num_return_buckets = 128
board = chess.Board()

win_percentages = {}
for policy in policies:
    neural_engine = create_engine(policy, num_return_buckets, temperature = 0.005)
    win_probs, sorted_legal_moves = compute_win_percentages(neural_engine, board)
    win_percentages[policy] = win_probs

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

# Plot win percentages
fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(len(sorted_legal_moves))

ax.bar(x - 0.25, win_percentages['action_value'] * 100, label='Action Value')
ax.bar(x, win_percentages['behavioral_cloning'] * 100, label='Behavioral Cloning')
ax.bar(x + 0.25, win_percentages['state_value'] * 100, label='State Value')

ax.set_xlabel('Legal Moves')
ax.set_ylabel('Win Percentage')
ax.set_title('Win Percentages for Different Policies')
ax.set_xticks(x)
ax.set_xticklabels([move.uci() for move in sorted_legal_moves], rotation=45)
ax.legend()

plt.tight_layout()
plt.show()

In [None]:
# This time, fix policy and vary temperature
policy = 'action_value'
num_return_buckets = 128
temperatures = [0.005, 0.1, 1.0]
board = chess.Board()

win_percentages = {}
for temp in temperatures:
    neural_engine = create_engine(policy, num_return_buckets, temp)
    win_probs, sorted_legal_moves = compute_win_percentages(neural_engine, board)
    win_percentages[temp] = win_probs

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(len(sorted_legal_moves))

for i, temp in enumerate(temperatures):
    ax.bar(x + (i - 1) * 0.25, win_percentages[temp] * 100, label=f'Temperature {temp}')

ax.set_xlabel('Legal Moves')
ax.set_ylabel('Win Percentage')
ax.set_title('Win Percentages for Different Temperatures')
ax.set_xticks(x)
ax.set_xticklabels([move.uci() for move in sorted_legal_moves], rotation=45)
ax.legend()

plt.tight_layout()
plt.show()