In [1]:
from typing import List, Dict
from collections import Counter

import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import Sequential, Model
from keras.layers import LSTM, StackedRNNCells, Dense, Flatten, Input, Dropout

In [2]:
def read_content(file_path: str = "tinyshakespeare.txt"):
    content = ''
    with open(file_path, 'r') as in_file:
        lines = in_file.readlines()
        for line in lines:
            content += line
    return content


def index_vucabulary(content: str):
    char_freq = Counter(content)
    char_freq = char_freq.most_common(len(char_freq))
    voc_index = [(i, char_freq[i][0]) for i in range(len(char_freq))]

    vocab = dict()
    reverse_vocab = dict()
    for item in voc_index:
        vocab[item[0]] = item[1]
        reverse_vocab[item[1]] = item[0]

    return vocab, reverse_vocab

In [3]:
vocab, reverse_vocab = index_vucabulary(read_content())

In [4]:
def generate_dataset(
        text: List[str], reverse_vocab: Dict[str, int],
        sequence_length: int=50, batch_size: int=64) -> tf.data.Dataset:

    text_indices = np.array(list(map(lambda x: reverse_vocab[x], text)))
    dataset = tf.data.Dataset.from_tensor_slices(tensors=(text_indices))
    #dataset = dataset.map(lambda x: tf.one_hot(x, depth=len(reverse_vocab), dtype=tf.float32))

    dataset = dataset.batch(sequence_length, drop_remainder=True)
    dataset = dataset.map(lambda x: (x[:-1], tf.one_hot(x[1:], depth=len(reverse_vocab), dtype=tf.float32)))
    dataset = dataset.shuffle(buffer_size=100)

    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset

In [5]:
dataset = generate_dataset(read_content(), reverse_vocab)
batch = next(iter(dataset))
print(batch[0].shape, batch[1].shape)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
(64, 49) (64, 49, 65)


In [6]:
def sample_text(model, sample_len: int=200) -> str:
    # TODO: fix this logic to sample properly
    context = next(iter(dataset))[0][:1]
    h1 = tf.zeros((1, 512), dtype=tf.float32)
    h2 = tf.zeros((1, 512), dtype=tf.float32)
    sampled_text = ""

    for t in range(context.shape[1]):
        char_id = int(tf.argmax(tf.squeeze(context[:, t])))
        sampled_text += vocab[char_id]
    sampled_text += "|"

    for i in range(sample_len):
        new_context, h1, h2 = model((context, h1, h2))
        new_context = np.concatenate((context[:, :-1], new_context[:, -1:]))
        char_id = int(tf.argmax(tf.squeeze(context[:, -1])))
        sampled_text += vocab[char_id]

    return sampled_text


class TextSampleCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f"\nsampled text:\n{sample_text(model)}")

In [7]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Embedding(len(vocab), 512, input_length=49),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(512, return_sequences=True)),
    Dropout(rate=0.2),
    tf.keras.layers.Dense(len(vocab), activation='softmax')
])

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
model.fit(epochs=30, x=dataset) #, callbacks=[TextSampleCallback()])

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       (None, 49, 512)           33280     
                                                                 
 bidirectional (Bidirectiona  (None, 49, 1024)         4198400   
 l)                                                              
                                                                 
 dropout (Dropout)           (None, 49, 1024)          0         
                                                                 
 dense (Dense)               (None, 49, 65)            66625     
                                                                 
Total params: 4,298,305
Trainable params: 4,298,305
Non-trainable params: 0
_________________________________________________________________
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
E

<keras.callbacks.History at 0x7f47f1795f40>