Imports

In [44]:
import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing

import numpy as np
import os
import time

Get the dataset

In [45]:
# path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
path_to_file = "discord_data.txt"

Check out the dataset

In [46]:
# Read, then decode for py2 compat.
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# length of text is the number of characters in it
print(f'Length of text: {len(text)} characters')

# Take a look at the first 250 characters in text
print(text[:250])

# The unique characters in the file
vocab = sorted(set(text))
print(f'{len(vocab)} unique characters')

Length of text: 4387105 characters
no clue AND NOW IT'S FUCKING BROKE AGAIN What is life I want to shoot it  I saw that Also, why is it saved as a .png.jpg? I... I don't know  Nice Neat So how is everyone Just got done with a concert. Nice. Who'd you see? Trans Siberian Orchestra Fuck
907 unique characters


Vectorize the dataset

In [47]:
ids_from_chars = preprocessing.StringLookup(
    vocabulary=list(vocab), mask_token=None)
chars_from_ids = tf.keras.layers.experimental.preprocessing.StringLookup(
    vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)

all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)

for ids in ids_dataset.take(10):
    print(chars_from_ids(ids).numpy().decode('utf-8'))

def text_from_ids(ids):
  return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

n
o
 
c
l
u
e
 
A
N


In [48]:
seq_length = 100
examples_per_epoch = len(text)//(seq_length+1)

sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)

for seq in sequences.take(1):
  print(chars_from_ids(seq))

for seq in sequences.take(5):
  print(text_from_ids(seq).numpy())

tf.Tensor(
[b'n' b'o' b' ' b'c' b'l' b'u' b'e' b' ' b'A' b'N' b'D' b' ' b'N' b'O'
 b'W' b' ' b'I' b'T' b"'" b'S' b' ' b'F' b'U' b'C' b'K' b'I' b'N' b'G'
 b' ' b'B' b'R' b'O' b'K' b'E' b' ' b'A' b'G' b'A' b'I' b'N' b' ' b'W'
 b'h' b'a' b't' b' ' b'i' b's' b' ' b'l' b'i' b'f' b'e' b' ' b'I' b' '
 b'w' b'a' b'n' b't' b' ' b't' b'o' b' ' b's' b'h' b'o' b'o' b't' b' '
 b'i' b't' b' ' b' ' b'I' b' ' b's' b'a' b'w' b' ' b't' b'h' b'a' b't'
 b' ' b'A' b'l' b's' b'o' b',' b' ' b'w' b'h' b'y' b' ' b'i' b's' b' '
 b'i' b't' b' '], shape=(101,), dtype=string)
b"no clue AND NOW IT'S FUCKING BROKE AGAIN What is life I want to shoot it  I saw that Also, why is it "
b"saved as a .png.jpg? I... I don't know  Nice Neat So how is everyone Just got done with a concert. Ni"
b"ce. Who'd you see? Trans Siberian Orchestra Fucking amazing Partially deaf now Lol, that\xe2\x80\x99s not good  "
b'Oh well. Fucking worth it That\xe2\x80\x99s awesome \xf0\x9f\x91\x8d Indeed TSO is awesome! Kinda jealous. Did you add

In [49]:
def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

In [50]:
dataset = sequences.map(split_input_target)

for input_example, target_example in dataset.take(1):
    print("Input :", text_from_ids(input_example).numpy())
    print("Target:", text_from_ids(target_example).numpy())

Input : b"no clue AND NOW IT'S FUCKING BROKE AGAIN What is life I want to shoot it  I saw that Also, why is it"
Target: b"o clue AND NOW IT'S FUCKING BROKE AGAIN What is life I want to shoot it  I saw that Also, why is it "


It's Model time baby

In [51]:
# Batch size
BATCH_SIZE = 64

# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE))

dataset

<PrefetchDataset shapes: ((64, 100), (64, 100)), types: (tf.int64, tf.int64)>

In [52]:
# Length of the vocabulary in chars
vocab_size = len(vocab)

# The embedding dimension
embedding_dim = 256

# Number of RNN units
rnn_units = 1024

class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   return_sequences=True,
                                   return_state=True)
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

model = MyModel(
    # Be sure the vocabulary size matches the `StringLookup` layers.
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units)

In [53]:
for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

(64, 100, 908) # (batch_size, sequence_length, vocab_size)


In [54]:
model.summary()

Model: "my_model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      multiple                  232448    
_________________________________________________________________
gru_1 (GRU)                  multiple                  3938304   
_________________________________________________________________
dense_1 (Dense)              multiple                  930700    
Total params: 5,101,452
Trainable params: 5,101,452
Non-trainable params: 0
_________________________________________________________________


In [55]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

In [56]:
sampled_indices

array([498,  69, 154, 326, 850,  26, 336, 671, 697, 651, 763,  18, 268,
       198, 171, 368, 795, 414, 601, 403, 826, 554, 679, 729, 776, 115,
       334, 429, 350, 335, 398, 554, 539, 353, 468, 209,  33, 267, 116,
       564, 409, 457,  52, 185, 798, 440, 325, 338, 454, 647, 203,  66,
       613, 717, 901, 841, 344, 649, 328, 281, 563, 719, 138, 629, 461,
       200,  94, 127, 237, 630, 613, 860, 675, 296, 702, 552, 421, 404,
        92, 363, 232, 173, 463, 858,  28, 338, 803, 864, 526, 869, 288,
       123,  78, 209, 772,  34, 805, 161, 512, 197], dtype=int64)

In [57]:
print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())

Input:
 b", what frustrates me the most is that villains in stories are relatable. I get where they're coming "

Next Char Predictions:
 b'\xe2\x92\xba_\xc6\x90\xcf\x83\xf0\x9f\x98\x964\xd0\xb5\xef\xbd\x83\xf0\x9d\x90\xa8\xed\x95\x9c\xf0\x9f\x87\xa9,\xcd\x84\xcb\x80\xc8\xbd\xe0\xb4\x82\xf0\x9f\x8e\xb2\xe1\x9a\xa0\xe3\x83\x89\xe1\x8a\xa0\xf0\x9f\x95\x88\xe3\x80\x87\xef\xbd\x96\xf0\x9d\x94\x82\xf0\x9f\x87\xb7\xc3\x98\xd0\xb0\xe1\x9b\x8f\xd7\xa5\xd0\xb4\xe0\xb7\xb4\xe3\x80\x87\xe2\x9a\xbe\xd8\xad\xe2\x80\x99\xcc\x85;\xcd\x83\xc3\xa0\xe3\x81\x93\xe1\x97\x9d\xe1\xb6\xa4N\xc9\xb9\xf0\x9f\x90\x8a\xe1\xb4\xba\xcf\x81\xd0\xbd\xe1\xb5\x9b\xea\xb8\xb0\xcb\xa2\\\xe4\xbd\x95\xf0\x9d\x93\xa2\xf0\x9f\xa5\x9a\xf0\x9f\x98\x88\xd2\xaf\xec\x9e\xac\xcf\x85\xcd\x92\xe3\x81\x8c\xf0\x9d\x93\xaa\xc4\x8f\xe6\x8c\x81\xe1\xb9\x87\xcb\x8cx\xc3\xb1\xcc\xa2\xe6\x93\x8d\xe4\xbd\x95\xf0\x9f\x98\xa9\xef\xbd\x8d\xcd\xa2\xf0\x9d\x92\xbd\xe3\x80\x80\xe1\x9a\xb4\xe1\x8b\xb3v\xe0\xa8\xae\xcc\x9c\xc9\x8e\xe2\x80\x8d\xf0\x9f

It training time now

In [58]:
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

example_batch_loss = loss(target_example_batch, example_batch_predictions)
mean_loss = example_batch_loss.numpy().mean()
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Mean loss:        ", mean_loss)

Prediction shape:  (64, 100, 908)  # (batch_size, sequence_length, vocab_size)
Mean loss:         6.8115444


In [59]:
tf.exp(mean_loss).numpy()

908.27246

In [60]:
model.compile(optimizer='adam', loss=loss)

In [61]:
# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints_discord'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.ckpt")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True,
    verbose=1)

In [62]:
'''
restore_from = 18

print(checkpoint_prefix.format(epoch=str(restore_from)))

model.load_weights(checkpoint_prefix.format(epoch=str(restore_from)))
'''

'\nrestore_from = 18\n\nprint(checkpoint_prefix.format(epoch=str(restore_from)))\n\nmodel.load_weights(checkpoint_prefix.format(epoch=str(restore_from)))\n'

In [63]:
EPOCHS = 20

history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

Epoch 1/20

Epoch 00001: saving model to ./training_checkpoints_discord\ckpt_1.ckpt
Epoch 2/20

Epoch 00002: saving model to ./training_checkpoints_discord\ckpt_2.ckpt
Epoch 3/20

Epoch 00003: saving model to ./training_checkpoints_discord\ckpt_3.ckpt
Epoch 4/20

Epoch 00004: saving model to ./training_checkpoints_discord\ckpt_4.ckpt
Epoch 5/20

Epoch 00005: saving model to ./training_checkpoints_discord\ckpt_5.ckpt
Epoch 6/20

Epoch 00006: saving model to ./training_checkpoints_discord\ckpt_6.ckpt
Epoch 7/20

Epoch 00007: saving model to ./training_checkpoints_discord\ckpt_7.ckpt
Epoch 8/20

Epoch 00008: saving model to ./training_checkpoints_discord\ckpt_8.ckpt
Epoch 9/20

KeyboardInterrupt: 

Aight, let's see how you did

In [None]:
class OneStep(tf.keras.Model):
  def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
    super().__init__()
    self.temperature = temperature
    self.model = model
    self.chars_from_ids = chars_from_ids
    self.ids_from_chars = ids_from_chars

    # Create a mask to prevent "[UNK]" from being generated.
    skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
    sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index.
        values=[-float('inf')]*len(skip_ids),
        indices=skip_ids,
        # Match the shape to the vocabulary
        dense_shape=[len(ids_from_chars.get_vocabulary())])
    self.prediction_mask = tf.sparse.to_dense(sparse_mask)

  @tf.function
  def generate_one_step(self, inputs, states=None):
    # Convert strings to token IDs.
    input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
    input_ids = self.ids_from_chars(input_chars).to_tensor()

    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits]
    predicted_logits, states = self.model(inputs=input_ids, states=states,
                                          return_state=True)
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    predicted_logits = predicted_logits/self.temperature
    # Apply the prediction mask: prevent "[UNK]" from being generated.
    predicted_logits = predicted_logits + self.prediction_mask

    # Sample the output logits to generate token IDs.
    predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
    predicted_ids = tf.squeeze(predicted_ids, axis=-1)

    # Convert from token ids to characters
    predicted_chars = self.chars_from_ids(predicted_ids)

    # Return the characters and model state.
    return predicted_chars, states

one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

In [None]:
start = time.time()
states = None
next_char = tf.constant(['How are you?'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)

result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)
print('\nRun time:', end - start)