In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk

import sys
sys.path.append('../')

In [2]:
from neural_chess.models.policy_net import build_policy_net

In [3]:
# model config
model_config = {
    'vocab': 13,  # i.e) 12 pieces + empty square
    'embedding_dim': 128,
    'hidden_dim': 128,
    'head_dim': 32,
    'nb_layers': 4,
    'nb_heads': 4,
    'output_dim': 4096,
    'dropout': 0.1
}

In [60]:
# sample data
def get_dummy_batch(batch=8):
    """
    Generate a batch of dummy data, for the purpose of initialising the network parameter
    dict (note: this is not a legal board position!)
    """
    board_state = np.random.randint(0, 13, (batch, 64)).astype(np.int32)
    turn = np.random.binomial(p=0.5, n=1, size=(batch,)).astype(np.int32)
    castling_rights = np.random.binomial(p=0.5, n=1, size=(batch,)).astype(np.int32)
    en_passant = np.random.randint(0, 65, (batch,)).astype(np.int32)
    elo = np.random.random((batch,)).astype(np.float32)
    return board_state, turn, castling_rights, en_passant, elo

board_state, turn, castling_rights, en_passant, elo = get_dummy_batch(3)

print(board_state.shape)
print(elo.shape)
print(turn.shape)
print(castling_rights.shape)
print(en_passant.shape)

(3, 64)
(3,)
(3,)
(3,)
(3,)


In [5]:
# get a random key
key = jax.random.PRNGKey(42)

# initialise the network!
forward_fn = build_policy_net(model_config)
init, apply = hk.transform(forward_fn)

params = init(key, board_state, turn, castling_rights, en_passant, elo, is_training=True)



In [6]:
# view the parameter tree
shapes = jax.tree_map(lambda x: x.shape, params)
print(shapes)

FlatMap({
  'board_emb': FlatMap({'embeddings': (13, 128)}),
  'board_pos_emb': FlatMap({'embeddings': (64, 128)}),
  'castle_emb': FlatMap({'embeddings': (2, 128)}),
  'elo_emb': FlatMap({'b': (128,), 'w': (1, 128)}),
  'en_passant_emb': FlatMap({'embeddings': (65, 128)}),
  'set_transformer/layer_norm': FlatMap({'offset': (128,), 'scale': (128,)}),
  'set_transformer/mlp/linear': FlatMap({'b': (256,), 'w': (128, 256)}),
  'set_transformer/mlp/linear_1': FlatMap({'b': (4096,), 'w': (256, 4096)}),
  'set_transformer/~_init_modules_for_layer/layer_norm': FlatMap({'offset': (128,), 'scale': (128,)}),
  'set_transformer/~_init_modules_for_layer/layer_norm_1': FlatMap({'offset': (128,), 'scale': (128,)}),
  'set_transformer/~_init_modules_for_layer/mlp/linear': FlatMap({'b': (256,), 'w': (128, 256)}),
  'set_transformer/~_init_modules_for_layer/mlp/linear_1': FlatMap({'b': (128,), 'w': (256, 128)}),
  'set_transformer/~_init_modules_for_layer/multi_head_attention/key': FlatMap({'b': (128,)

In [8]:
# perform a forward pass
output = apply(params, key, board_state, turn, castling_rights, en_passant, elo, is_training=True)

In [59]:
type(jnp.ones([10, 10]))

jaxlib.xla_extension.DeviceArray

In [55]:
params['board_emb']

FlatMap({
  'embeddings': DeviceArray([[-0.0178461 ,  0.00952877,  0.02462453, ...,  0.01094404,
                               0.01913168, -0.00570193],
                             [-0.02583723,  0.01335358,  0.02944723, ..., -0.00906039,
                               0.01591307,  0.00239701],
                             [-0.04328944,  0.01514734, -0.01018374, ..., -0.01745041,
                               0.01425417,  0.0224924 ],
                             ...,
                             [-0.03935785,  0.01419419, -0.01098523, ..., -0.01804546,
                               0.03128698,  0.01153706],
                             [ 0.00549005, -0.00130867, -0.03411865, ..., -0.00316766,
                               0.01009901, -0.0195477 ],
                             [ 0.00896102, -0.01897887, -0.00986449, ..., -0.01925972,
                               0.00177846, -0.04149825]], dtype=float32),
})

In [54]:
a = {
    'foo': 'bar'
}

set(a, 'foo', 'baz')
a

TypeError: set expected at most 1 argument, got 3

In [36]:
x = np.random.randn(13, 128)
print(x)

[[-0.47527872 -2.11855561 -1.01028343 ...  0.68336207  0.62772787
   0.40962262]
 [-2.30653632 -0.27302224 -2.53096186 ... -0.6518099   1.09306558
   1.11955486]
 [-0.3245646   0.69697357 -1.79566429 ...  1.3283232   1.7157535
   0.91107243]
 ...
 [-2.0314028  -1.28422664  1.63000891 ...  2.16287812 -0.78944481
  -0.60569176]
 [ 0.13995263  0.02055319  0.40103754 ... -1.04085873  1.95326781
   1.56164687]
 [-1.73069098  1.54018322  0.42834058 ...  0.0446253   0.45858568
  -2.26032743]]


In [37]:
params['board_emb']['embeddings'].at[:].set(x)

DeviceArray([[-0.47527874, -2.1185555 , -1.0102835 , ...,  0.68336207,
               0.62772787,  0.4096226 ],
             [-2.3065364 , -0.27302223, -2.5309618 , ..., -0.6518099 ,
               1.0930656 ,  1.1195549 ],
             [-0.3245646 ,  0.69697356, -1.7956643 , ...,  1.3283232 ,
               1.7157536 ,  0.91107243],
             ...,
             [-2.0314028 , -1.2842267 ,  1.6300089 , ...,  2.162878  ,
              -0.7894448 , -0.60569173],
             [ 0.13995263,  0.02055319,  0.40103754, ..., -1.0408587 ,
               1.9532678 ,  1.5616468 ],
             [-1.730691  ,  1.5401832 ,  0.42834058, ...,  0.0446253 ,
               0.45858568, -2.2603273 ]], dtype=float32)

In [9]:
print(output.shape)

(3, 4096)


In [10]:
import torch

In [14]:
x = torch.tensor([-float('inf'), -5, -4])
torch.nn.functional.log_softmax(x)

  torch.nn.functional.log_softmax(x)


tensor([   -inf, -1.3133, -0.3133])

In [19]:
x = np.array([
    [1, 2, 3, 10],
    [4, 5, 6, 10],
    [7, 8, 9, 10]
])

mask = np.array([
    [1, 0, 1, 0],
    [1, 1, 0, 0],
    [0, 1, 1, 0]
])

w = np.where(mask)
print(w)

(array([0, 0, 1, 1, 2, 2]), array([0, 2, 0, 1, 1, 2]))
