# Character-based LSTM

In [116]:
import sys
sys.path.append('..')

import numpy as np

from keras.callbacks import EarlyStopping
from keras.layers import (
    Dense,
    Lambda,
    LSTM,
)
from keras.models import Sequential

from src import utils

### Load data

In [2]:
sonnets = utils.load_shakespeare()

### Pre-processing

Since we are training a character-based LSTM, we just need to map each character to a dimension.

In [52]:
# Set of all characters.
chars = set()
for sonnet in sonnets:
    for line in sonnet:
        chars |= set(line)
chars.add('\n')
n_chars = len(chars)

char_to_dim = {char: i for i, char in enumerate(chars)}
dim_to_char = {i: char for i, char in enumerate(chars)}

In [73]:
# Convert each sonnet to a list of integers corresponding to each character.
converted_sonnets = []
for sonnet in sonnets:
    converted = []
    for i, line in enumerate(sonnet):
        converted.extend([char_to_dim[char] for char in line])
        if i + 1 < len(sonnet):
            converted.append(char_to_dim['\n'])
    converted_sonnets.append(converted)

In [108]:
# Generate training data, where each character is an `n_chars`-dimensional vector.
X = []
Y = []
window_size = 40  # number of characters
for sonnet in converted_sonnets:
    x = np.zeros((len(sonnet) - window_size, window_size, n_chars))
    y = np.zeros((len(sonnet) - window_size, n_chars))
    
    for i in range(len(sonnet) - window_size):
        indices = np.vstack((np.arange(window_size), sonnet[i:i+window_size]))
        x[i][tuple(indices)] = 1
        y[i][sonnet[i+window_size]] = 1
    X.append(x)
    Y.append(y)
X = np.vstack(X)
Y = np.vstack(Y)

### Define model

In [117]:
model = Sequential()
model.add(LSTM(
    200,
    input_shape=(window_size, n_chars)
))
model.add(Dense(n_chars, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

In [119]:
es = EarlyStopping(monitor='loss', mode='min', verbose=1, patience=50)
model.fit(X, Y, epochs=1000, batch_size=128, callbacks=[es])

Epoch 1/1000

KeyboardInterrupt: 