In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk

import sys
sys.path.append('../')

In [2]:
from neural_chess.models.policy_net import build_policy_net

In [3]:
# model config
model_config = {
    'vocab': 13,  # i.e) 12 pieces + empty square
    'embedding_dim': 128,
    'hidden_dim': 128,
    'head_dim': 32,
    'nb_layers': 4,
    'nb_heads': 4,
    'output_dim': 4096,
    'dropout': 0.1
}

In [4]:
# sample data
batch = 3
board_state = np.random.randint(0, 13, (batch, 64)).astype(np.int32)
elo = np.random.random((batch,)).astype(np.float32)
turn = np.random.binomial(p=0.5, n=1, size=(batch,)).astype(np.int32)

print(board_state.shape)
print(elo.shape)
print(turn.shape)

(3, 64)
(3,)
(3,)


In [14]:
# get a random key
key = jax.random.PRNGKey(42)

# initialise the network!
forward_fn = build_policy_net(model_config)
init, apply = hk.transform(forward_fn)

params = init(key, board_state, turn, elo, is_training=True)

In [15]:
# view the parameter tree
shapes = jax.tree_map(lambda x: x.shape, params)
print(shapes)

FlatMap({
  'embed': FlatMap({'embeddings': (13, 128)}),
  'embed_1': FlatMap({'embeddings': (64, 128)}),
  'embed_2': FlatMap({'embeddings': (2, 128)}),
  'linear': FlatMap({'b': (128,), 'w': (1, 128)}),
  'set_transformer/layer_norm': FlatMap({'offset': (128,), 'scale': (128,)}),
  'set_transformer/mlp/linear': FlatMap({'b': (256,), 'w': (128, 256)}),
  'set_transformer/mlp/linear_1': FlatMap({'b': (4096,), 'w': (256, 4096)}),
  'set_transformer/~_init_modules_for_layer/layer_norm': FlatMap({'offset': (128,), 'scale': (128,)}),
  'set_transformer/~_init_modules_for_layer/layer_norm_1': FlatMap({'offset': (128,), 'scale': (128,)}),
  'set_transformer/~_init_modules_for_layer/mlp/linear': FlatMap({'b': (256,), 'w': (128, 256)}),
  'set_transformer/~_init_modules_for_layer/mlp/linear_1': FlatMap({'b': (128,), 'w': (256, 128)}),
  'set_transformer/~_init_modules_for_layer/multi_head_attention/key': FlatMap({'b': (128,), 'w': (128, 128)}),
  'set_transformer/~_init_modules_for_layer/multi

In [16]:
# perform a forward pass
output = apply(params, key, board_state, turn, elo, is_training=True)

In [19]:
print(output.shape)

(3, 4096)
