In [5]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import numpy as np
import os
import datetime

In [6]:
EPOCHS = 20
SEQ_LENGTH = 200
BATCH_SIZE = 64
EMBEDDING_DIM = 20
RNN_UNITS = 1024 

In [7]:
text = open("thoughts.txt", 'rb').read().decode(encoding='utf-8')

In [8]:
vocab = sorted(set(text))
vocab_size = len(vocab)

In [9]:
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

textInt = np.array([char2idx[c] for c in text])

In [10]:
textTensor = tf.data.Dataset.from_tensor_slices(textInt)

In [11]:
sequences = textTensor.batch(SEQ_LENGTH+1, drop_remainder=True)

In [12]:
def split_text(seq):
    inSeq = seq[:-1]
    tSeq = seq[1:]
    return inSeq, tSeq

In [13]:
data = sequences.map(split_text)
data = data.shuffle(1000).batch(BATCH_SIZE, drop_remainder= True)

In [22]:
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim,
                              batch_input_shape=[batch_size, None]),
    tf.keras.layers.LSTM(rnn_units,
                        return_sequences=True,
                        stateful=True),
    tf.keras.layers.LSTM(rnn_units,
                        return_sequences=True,
                        stateful=True),
    tf.keras.layers.Dense(vocab_size)
  ])
    return model

In [23]:
model = build_model(
  vocab_size = len(vocab),
  embedding_dim=EMBEDDING_DIM,
  rnn_units=RNN_UNITS,
  batch_size=BATCH_SIZE)

In [24]:
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x263d5e38910>

In [25]:
model.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_3 (Embedding)      (64, None, 20)            1900      
_________________________________________________________________
lstm_6 (LSTM)                (64, None, 1024)          4280320   
_________________________________________________________________
lstm_7 (LSTM)                (64, None, 1024)          8392704   
_________________________________________________________________
dense_3 (Dense)              (64, None, 95)            97375     
Total params: 12,772,299
Trainable params: 12,772,299
Non-trainable params: 0
_________________________________________________________________


In [26]:
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

In [27]:
model.compile(optimizer='adam', loss=loss, metrics=["accuracy"])

In [28]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [None]:
history = model.fit(data, epochs=EPOCHS, callbacks=[checkpoint_callback, tensorboard_callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20

In [4]:
def generate_text(model, startString, generateNum= 1000):
    inputE = [char2idx[s] for s in startString]
    inputE = tf.expand_dims(inputE, 0)
    generated = []

 
    temperature = 0.7

    model.reset_states()
    for i in range(generateNum):
        out = model(inputE)
        out = tf.squeeze(out, 0)

        out = out / temperature
        out_id = tf.random.categorical(out, num_samples=1)[-1,0].numpy()

        inputE = tf.expand_dims([out_id], 0)

        generated.append(idx2char[out_id])

    return (startString + ''.join(generated))

In [20]:
modelGen = build_model(vocab_size, EMBEDDING_DIM, RNN_UNITS, batch_size=1)

modelGen.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

modelGen.build(tf.TensorShape([1, None]))

In [21]:
print(generate_text(modelGen, startString=u"I"))



InvalidArgumentError: logits should be a matrix, got shape [95] [Op:Multinomial]