In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
# import hijax as hx

import sys
sys.path.append('../')

In [2]:
from neural_chess.models.policy_net import build_policy_net

In [3]:
# model config
model_config = {
    'vocab': 13,  # i.e) 12 pieces + empty square
    'embedding_dim': 256,
    'hidden_dim': 256,
    'head_dim': 64,
    'nb_layers': 1,
    'nb_heads': 4,
    'output_dim': 4096,
    'dropout': 0.1
}

In [4]:
# sample data
def get_dummy_batch(batch=8):
    """
    Generate a batch of dummy data, for the purpose of initialising the network parameter
    dict (note: this is not a legal board position!)
    """
    board_state = np.random.randint(0, 13, (batch, 64)).astype(np.int32)
    turn = np.random.binomial(p=0.5, n=1, size=(batch,)).astype(np.int32)
    castling_rights = np.random.binomial(p=0.5, n=1, size=(batch,)).astype(np.int32)
    en_passant = np.random.randint(0, 65, (batch,)).astype(np.int32)
    elo = np.random.random((batch,)).astype(np.float32)
    return board_state, turn, castling_rights, en_passant, elo

board_state, turn, castling_rights, en_passant, elo = get_dummy_batch(3)

print(board_state.shape)
print(elo.shape)
print(turn.shape)
print(castling_rights.shape)
print(en_passant.shape)

(3, 64)
(3,)
(3,)
(3,)
(3,)


In [5]:
# get a random key
key = jax.random.PRNGKey(42)

# initialise the network!
forward_fn = build_policy_net(**model_config)
init, apply = hk.transform(forward_fn)

params = init(key, is_training=True, **{
    'board_state': board_state,
    'turn': turn,
    'castling_rights': castling_rights,
    'en_passant': en_passant,
    'elo': elo
})

In [6]:
# view the parameter tree
shapes = jax.tree_map(lambda x: x.shape, params)
print(shapes)

FlatMap({
  'board_emb': FlatMap({'embeddings': (13, 256)}),
  'board_pos_emb': FlatMap({'embeddings': (64, 256)}),
  'castle_emb': FlatMap({'embeddings': (2, 256)}),
  'elo_emb': FlatMap({'b': (256,), 'w': (1, 256)}),
  'en_passant_emb': FlatMap({'embeddings': (65, 256)}),
  'set_transformer/layer_norm': FlatMap({'offset': (256,), 'scale': (256,)}),
  'set_transformer/mlp/linear': FlatMap({'b': (512,), 'w': (256, 512)}),
  'set_transformer/mlp/linear_1': FlatMap({'b': (4096,), 'w': (512, 4096)}),
  'set_transformer/~_init_modules_for_layer/layer_norm': FlatMap({'offset': (256,), 'scale': (256,)}),
  'set_transformer/~_init_modules_for_layer/layer_norm_1': FlatMap({'offset': (256,), 'scale': (256,)}),
  'set_transformer/~_init_modules_for_layer/mlp/linear': FlatMap({'b': (512,), 'w': (256, 512)}),
  'set_transformer/~_init_modules_for_layer/mlp/linear_1': FlatMap({'b': (256,), 'w': (512, 256)}),
  'set_transformer/~_init_modules_for_layer/multi_head_attention/key': FlatMap({'b': (256,)

In [7]:
# perform a forward pass
output = apply(params, key, board_state, turn, castling_rights, en_passant, elo, is_training=True)

In [8]:
print(output.device())

gpu:0
