In [1]:
import os
import sys
import numpy as np
import tensorflow as tf
from tensorflow import keras

  from ._conv import register_converters as _register_converters


In [2]:
file = os.path.join('../data', 'haiku_reddit.tsv')
raw_text = open(file, 'r').read().lower()

In [3]:
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c, in enumerate(chars))

In [4]:
n_chars = len(raw_text)
n_words = len(chars)
print("Number of unique characters (vocabulary): ", n_words)

Number of unique characters (vocabulary):  66


In [5]:
seq_length = 30
dataX = []
dataY = []
lines = raw_text.split('\n')
total, skip = [0, 0]
for haiku in lines:
    if len(haiku) < seq_length:
        skip = skip + 1
        continue
    haiku = haiku.strip()
    for i in range(0, len(haiku) - 1 - seq_length):
        total = total + 1
        seq_in = haiku[i:i+seq_length]
        seq_out = haiku[i+seq_length]
        dataX.append([char_to_int[char] for char in seq_in])
        dataY.append(char_to_int[seq_out])

In [6]:
input_size = len(dataX)

In [7]:
X = np.reshape(dataX, (input_size, seq_length, 1))
X = X / float(n_words)
y = tf.keras.utils.to_categorical(dataY)

In [8]:
model = keras.models.Sequential([
    keras.layers.LSTM(256, input_shape=(X.shape[1], X.shape[2]), return_sequences=True),
    keras.layers.Dropout(rate=0.2),
    keras.layers.LSTM(256),
    keras.layers.Dropout(rate=0.2),
    keras.layers.Dense(y.shape[1], activation='softmax')
])

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


In [9]:
model.compile(loss='categorical_crossentropy', optimizer='adam')

In [10]:
path = os.path.join('./src', 'model-{epoch:02d}-{loss:.4f}=1.hdf5')
checkpoint_cb = keras.callbacks.ModelCheckpoint(filepath=path, monitor='loss', verbose=1, mode='min', save_best_only=False)
callbacks = [checkpoint_cb]

In [12]:
model.fit(X, y, epochs=20, batch_size=64, callbacks=callbacks)

Epoch 1/20
   1856/1550092 [..............................] - ETA: 7:34:04 - loss: 3.3064

KeyboardInterrupt: 