In [5]:
import tensorflow as tf

In [6]:
with open('moves.txt', 'r') as infile:
    moves = infile.read()

vocab = sorted(set(moves))

ids_from_chars = tf.keras.layers.experimental.preprocessing.StringLookup(
    vocabulary=list(vocab), mask_token=None
)

chars_from_ids = tf.keras.layers.experimental.preprocessing.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None
)

print(f'{len(vocab)} unique characters')

33 unique characters


In [7]:
class OneStep(tf.keras.Model):
  def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
    super().__init__()
    self.temperature = temperature
    self.model = model
    self.chars_from_ids = chars_from_ids
    self.ids_from_chars = ids_from_chars

    # Create a mask to prevent "[UNK]" from being generated.
    skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
    sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index.
        values=[-float('inf')]*len(skip_ids),
        indices=skip_ids,
        # Match the shape to the vocabulary
        dense_shape=[len(ids_from_chars.get_vocabulary())])
    self.prediction_mask = tf.sparse.to_dense(sparse_mask)

  @tf.function
  def generate_one_step(self, inputs, states=None):
    # Convert strings to token IDs.
    input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
    input_ids = self.ids_from_chars(input_chars).to_tensor()

    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits]
    predicted_logits, states = self.model(inputs=input_ids, states=states,
                                          return_state=True)
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    predicted_logits = predicted_logits/self.temperature
    # Apply the prediction mask: prevent "[UNK]" from being generated.
    predicted_logits = predicted_logits + self.prediction_mask

    # Sample the output logits to generate token IDs.
    predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
    predicted_ids = tf.squeeze(predicted_ids, axis=-1)

    # Convert from token ids to characters
    predicted_chars = self.chars_from_ids(predicted_ids)

    # Return the characters and model state.
    return predicted_chars, states

In [8]:
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024

class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units, return_sequences=True, return_state=True)
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

model = MyModel(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
)

model.load_weights(filepath='training_checkpoints/ckpt_40')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x28147cf40>

In [9]:
import numpy as np

all_ids = ids_from_chars(tf.strings.unicode_split(moves, 'UTF-8'))
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)
seq_length = 100
examples_per_epoch = len(moves)
sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)

@tf.autograph.experimental.do_not_convert
def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

dataset = sequences.map(split_input_target)

BATCH_SIZE = 64
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE))

for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)

2021-10-25 11:28:07.894086: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-10-25 11:28:07.894583: W tensorflow/core/platform/profile_utils/cpu_utils.cc:126] Failed to get CPU frequency: 0 Hz


In [10]:
from tensorflow.keras.layers.experimental import preprocessing

one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

In [11]:
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024

class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units, return_sequences=True, return_state=True)
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

model = MyModel(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
)

In [40]:
states = None
next_char = tf.constant(['W1.e4'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)

result = tf.strings.join(result)
print(result.numpy())

[b'W1.e4 B1.c6 W2.d4 B2.d5 W3.exd5 B3.cxd5 W4.c4 B4.Nf6 W5.Nc3 B5.g6 W6.cxd5 B6.Bg7 W7.Nf3 B7.O-O W8.Be2 B8.c5 W9.Nf3 B9.Qc7 W10.bxc3 B10.bxc5 W11.Bd3 B11.Bc4 W12.O-O B12.Be7 W13.Bf4 B13.Rfe8 W14.Qc1 B14.Nd6 W15.Nxe6 B15.fxe6 W16.Rhe1 B16.Be6 W17.Nc3 B17.Qa5 W18.Bxe4 B18.Rxe4 W19.Qf2 B19.Kg7 W20.Bf4 B20.Rxd2 W21.Qxe6+ B21.Qc7 W22.Qa4+ B22.Ka6 W23.Rf6 B23.Kc6 W24.Qh6+ B24.Kg6 W25.Bd3 B25.Qxh4+  \n4.Nd2 B14.Nxf3 W15.gxf3 B15.Bxf3 W16.h4 B16.Qh4 W17.Qf3 B17.Bd7 W18.Bxd7+ B18.Kxd7 W19.Qf4+ B19.Kg6 W20.Qxh8 B20.Qxh4+ W21.Kg2 B21.Rf8+ W22.Kf3 B22.Bxh3 W23.gxh3 B23.Kxg5 W24.Qc2 B24.Rh5 W25.Qa4+ B25.Ke6  \n4.Qxb6 B14.gxh6 W15.Ne2 B15.Rhc8 W16.Nf3 B16.Be6 W17.Rhe1 B17.Rxc1+ W18.Bxc1 B18.Nh6 W19.Nf4 B19.Rc6 W20.Ndxc6+ B20.Kb7 W21.Nd8 B21.e5 W22.Be3 B22.g6 W23.Rc6+ B23.Kb6 W24.f3 B24.f6 W25.f4+ B25.Kf6  \n4.fxe6 B14.fxe6 W15.Nf3 B15.f6 W16.Ne2 B16.Rhe8 W17.Kf1 B17.Ke7 W18.Nc4 B18.Ke6 W19.Rd1 B19.Nbd5 W20.Rd6 B20.Ng4 W21.Nd5+ B21.Kf6 W22.Nxg6+ B22.fxg6 W23.Bxf6 B23.Rhg8 W24.Bd7+ B24.Kc6 W25.Be6+ B