In [1]:
import pickle
import numpy as np
from keras.models import Model
from keras.layers import Input, Lambda, Reshape, LSTM, Dense
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras import backend as K

Using TensorFlow backend.


In [2]:
tokenized_text, char_to_idx, idx_to_char = pickle.load(open('../data/processed-alice.pickle', 'rb'))
X, Y = pickle.load(open('../data/alice-x-y-data.pickle', 'rb'))

Y = np.reshape(Y, (X.shape[1], X.shape[0], X.shape[2]))

# Global Variables
m, Tx, n_values = X.shape
n_a = 64 # for LSTM with 64-dimensional hidden states

# Train the Model
First we train the model to get the proper weights.
Remember, we must declare the layers in our model as global variables so they
can be reused in the Text Generation step

In [3]:
# Reused Layer for Text Generation: Reshape, LSTM, Dense
reshapor = Reshape((1, n_values)) # makes input 3-dimensional
LSTM_cell = LSTM(n_a, return_state=True)
densor = Dense(n_values, activation='softmax')

In [4]:
def learn_to_talk(Tx, n_a, n_values):
    X = Input(shape=(Tx, n_values))
    
    # Define initial hidden state for LSTM
    a0 = Input(shape=(n_a,), name='a0')
    c0 = Input(shape=(n_a,), name='c0')
    a = a0
    c = c0
    
    outputs = []
    for t in range(Tx):
        # Get t'th vector
        x = Lambda(lambda x: x[:,t,:])(X)
        # Reshape x to be (1,Tx)
        x = reshapor(x)
        # Perform 1 step of LSTM
        a, _, c = LSTM_cell(x)
        # apply densor to hidden state output of LSTM_cell
        out = densor(a)
        outputs.append(out)
    return Model([X, a0, c0], outputs)

In [5]:
model = learn_to_talk(Tx, n_a, n_values)

In [6]:
opt = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, decay=0.0005)
model.compile(optimizer=opt, loss='mse', metrics=['accuracy'])

In [7]:
# a0 = np.zeros((m, n_a))
# c0 = np.zeros((m, n_a))

a0 = np.random.rand(m, n_a)
c0 = np.random.rand(m, n_a)

In [8]:
model.fit([X, a0, c0], list(Y), epochs=30)

Epoch 1/30
Epoch 2/30
Epoch 3/30


  32/1031 [..............................] - ETA: 4s - loss: 2.9958 - dense_1_loss: 0.0215 - dense_1_acc: 0.2188 - dense_1_acc_1: 0.1562 - dense_1_acc_2: 0.1250 - dense_1_acc_3: 0.1875 - dense_1_acc_4: 0.1562 - dense_1_acc_5: 0.2812 - dense_1_acc_6: 0.2500 - dense_1_acc_7: 0.1875 - dense_1_acc_8: 0.1875 - dense_1_acc_9: 0.0938 - dense_1_acc_10: 0.2188 - dense_1_acc_11: 0.3438 - dense_1_acc_12: 0.2188 - dense_1_acc_13: 0.1562 - dense_1_acc_14: 0.2188 - dense_1_acc_15: 0.1250 - dense_1_acc_16: 0.1875 - dense_1_acc_17: 0.1562 - dense_1_acc_18: 0.1562 - dense_1_acc_19: 0.2188 - dense_1_acc_20: 0.2188 - dense_1_acc_21: 0.2188 - dense_1_acc_22: 0.0938 - dense_1_acc_23: 0.0312 - dense_1_acc_24: 0.1562 - dense_1_acc_25: 0.2188 - dense_1_acc_26: 0.2188 - dense_1_acc_27: 0.1875 - dense_1_acc_28: 0.5000 - dense_1_acc_29: 0.2500 - dense_1_acc_30: 0.1875 - dense_1_acc_31: 0.1250 - dense_1_acc_32: 0.2188 - dense_1_acc_33: 0.2500 - dense_1_acc_34: 0.1875 - dense_1_acc_35: 0.1875 - dense_1_acc_36: 0.2

Epoch 4/30
Epoch 5/30


  32/1031 [..............................] - ETA: 4s - loss: 2.8922 - dense_1_loss: 0.0210 - dense_1_acc: 0.3125 - dense_1_acc_1: 0.1875 - dense_1_acc_2: 0.0625 - dense_1_acc_3: 0.1875 - dense_1_acc_4: 0.2188 - dense_1_acc_5: 0.0938 - dense_1_acc_6: 0.2188 - dense_1_acc_7: 0.1875 - dense_1_acc_8: 0.1875 - dense_1_acc_9: 0.2188 - dense_1_acc_10: 0.2500 - dense_1_acc_11: 0.1562 - dense_1_acc_12: 0.2500 - dense_1_acc_13: 0.2188 - dense_1_acc_14: 0.2812 - dense_1_acc_15: 0.2812 - dense_1_acc_16: 0.2812 - dense_1_acc_17: 0.2500 - dense_1_acc_18: 0.1562 - dense_1_acc_19: 0.2188 - dense_1_acc_20: 0.3125 - dense_1_acc_21: 0.1562 - dense_1_acc_22: 0.1875 - dense_1_acc_23: 0.2188 - dense_1_acc_24: 0.2500 - dense_1_acc_25: 0.1562 - dense_1_acc_26: 0.1250 - dense_1_acc_27: 0.1875 - dense_1_acc_28: 0.4375 - dense_1_acc_29: 0.2188 - dense_1_acc_30: 0.1562 - dense_1_acc_31: 0.1562 - dense_1_acc_32: 0.0625 - dense_1_acc_33: 0.1562 - dense_1_acc_34: 0.1562 - dense_1_acc_35: 0.1250 - dense_1_acc_36: 0.1

Epoch 6/30
Epoch 7/30


  32/1031 [..............................] - ETA: 3s - loss: 2.8758 - dense_1_loss: 0.0207 - dense_1_acc: 0.2500 - dense_1_acc_1: 0.2500 - dense_1_acc_2: 0.2812 - dense_1_acc_3: 0.1562 - dense_1_acc_4: 0.2188 - dense_1_acc_5: 0.2500 - dense_1_acc_6: 0.1875 - dense_1_acc_7: 0.1250 - dense_1_acc_8: 0.3438 - dense_1_acc_9: 0.1562 - dense_1_acc_10: 0.1562 - dense_1_acc_11: 0.2812 - dense_1_acc_12: 0.2500 - dense_1_acc_13: 0.0625 - dense_1_acc_14: 0.2188 - dense_1_acc_15: 0.1250 - dense_1_acc_16: 0.0312 - dense_1_acc_17: 0.1562 - dense_1_acc_18: 0.1562 - dense_1_acc_19: 0.1250 - dense_1_acc_20: 0.2812 - dense_1_acc_21: 0.3125 - dense_1_acc_22: 0.1875 - dense_1_acc_23: 0.2812 - dense_1_acc_24: 0.3438 - dense_1_acc_25: 0.2188 - dense_1_acc_26: 0.3125 - dense_1_acc_27: 0.2500 - dense_1_acc_28: 0.3750 - dense_1_acc_29: 0.0938 - dense_1_acc_30: 0.1875 - dense_1_acc_31: 0.1562 - dense_1_acc_32: 0.1250 - dense_1_acc_33: 0.1250 - dense_1_acc_34: 0.2188 - dense_1_acc_35: 0.3125 - dense_1_acc_36: 0.1

Epoch 8/30
Epoch 9/30


  32/1031 [..............................] - ETA: 3s - loss: 2.8829 - dense_1_loss: 0.0213 - dense_1_acc: 0.1875 - dense_1_acc_1: 0.3438 - dense_1_acc_2: 0.1875 - dense_1_acc_3: 0.1875 - dense_1_acc_4: 0.3750 - dense_1_acc_5: 0.1562 - dense_1_acc_6: 0.2500 - dense_1_acc_7: 0.1250 - dense_1_acc_8: 0.2812 - dense_1_acc_9: 0.3125 - dense_1_acc_10: 0.1875 - dense_1_acc_11: 0.3438 - dense_1_acc_12: 0.1250 - dense_1_acc_13: 0.2188 - dense_1_acc_14: 0.1875 - dense_1_acc_15: 0.2188 - dense_1_acc_16: 0.1250 - dense_1_acc_17: 0.2188 - dense_1_acc_18: 0.0625 - dense_1_acc_19: 0.0625 - dense_1_acc_20: 0.1562 - dense_1_acc_21: 0.1562 - dense_1_acc_22: 0.1562 - dense_1_acc_23: 0.1250 - dense_1_acc_24: 0.1875 - dense_1_acc_25: 0.0938 - dense_1_acc_26: 0.2188 - dense_1_acc_27: 0.1562 - dense_1_acc_28: 0.3438 - dense_1_acc_29: 0.1562 - dense_1_acc_30: 0.1562 - dense_1_acc_31: 0.1562 - dense_1_acc_32: 0.1562 - dense_1_acc_33: 0.2500 - dense_1_acc_34: 0.2812 - dense_1_acc_35: 0.1562 - dense_1_acc_36: 0.2

Epoch 10/30
Epoch 11/30


  32/1031 [..............................] - ETA: 3s - loss: 2.8714 - dense_1_loss: 0.0203 - dense_1_acc: 0.2188 - dense_1_acc_1: 0.2812 - dense_1_acc_2: 0.2188 - dense_1_acc_3: 0.1875 - dense_1_acc_4: 0.2500 - dense_1_acc_5: 0.2500 - dense_1_acc_6: 0.1562 - dense_1_acc_7: 0.1562 - dense_1_acc_8: 0.2500 - dense_1_acc_9: 0.1875 - dense_1_acc_10: 0.1250 - dense_1_acc_11: 0.2188 - dense_1_acc_12: 0.2812 - dense_1_acc_13: 0.1562 - dense_1_acc_14: 0.2500 - dense_1_acc_15: 0.3438 - dense_1_acc_16: 0.1562 - dense_1_acc_17: 0.2188 - dense_1_acc_18: 0.3125 - dense_1_acc_19: 0.0938 - dense_1_acc_20: 0.2812 - dense_1_acc_21: 0.1562 - dense_1_acc_22: 0.3438 - dense_1_acc_23: 0.2188 - dense_1_acc_24: 0.1875 - dense_1_acc_25: 0.1562 - dense_1_acc_26: 0.2188 - dense_1_acc_27: 0.1875 - dense_1_acc_28: 0.2812 - dense_1_acc_29: 0.2188 - dense_1_acc_30: 0.1875 - dense_1_acc_31: 0.1562 - dense_1_acc_32: 0.1562 - dense_1_acc_33: 0.1875 - dense_1_acc_34: 0.1875 - dense_1_acc_35: 0.3438 - dense_1_acc_36: 0.2

Epoch 12/30
Epoch 13/30


  32/1031 [..............................] - ETA: 3s - loss: 2.8757 - dense_1_loss: 0.0207 - dense_1_acc: 0.2812 - dense_1_acc_1: 0.1875 - dense_1_acc_2: 0.1875 - dense_1_acc_3: 0.0625 - dense_1_acc_4: 0.3125 - dense_1_acc_5: 0.3125 - dense_1_acc_6: 0.2188 - dense_1_acc_7: 0.1562 - dense_1_acc_8: 0.1562 - dense_1_acc_9: 0.1562 - dense_1_acc_10: 0.2188 - dense_1_acc_11: 0.3750 - dense_1_acc_12: 0.4375 - dense_1_acc_13: 0.1250 - dense_1_acc_14: 0.1875 - dense_1_acc_15: 0.1875 - dense_1_acc_16: 0.1562 - dense_1_acc_17: 0.1562 - dense_1_acc_18: 0.2500 - dense_1_acc_19: 0.1562 - dense_1_acc_20: 0.2188 - dense_1_acc_21: 0.3125 - dense_1_acc_22: 0.2188 - dense_1_acc_23: 0.1250 - dense_1_acc_24: 0.0312 - dense_1_acc_25: 0.2500 - dense_1_acc_26: 0.1250 - dense_1_acc_27: 0.2500 - dense_1_acc_28: 0.3750 - dense_1_acc_29: 0.2500 - dense_1_acc_30: 0.1250 - dense_1_acc_31: 0.1562 - dense_1_acc_32: 0.2188 - dense_1_acc_33: 0.0938 - dense_1_acc_34: 0.2188 - dense_1_acc_35: 0.1562 - dense_1_acc_36: 0.1

Epoch 14/30
Epoch 15/30


  32/1031 [..............................] - ETA: 3s - loss: 2.8735 - dense_1_loss: 0.0203 - dense_1_acc: 0.2812 - dense_1_acc_1: 0.1250 - dense_1_acc_2: 0.1875 - dense_1_acc_3: 0.1875 - dense_1_acc_4: 0.2188 - dense_1_acc_5: 0.1875 - dense_1_acc_6: 0.2188 - dense_1_acc_7: 0.1875 - dense_1_acc_8: 0.2500 - dense_1_acc_9: 0.1875 - dense_1_acc_10: 0.1562 - dense_1_acc_11: 0.2188 - dense_1_acc_12: 0.2812 - dense_1_acc_13: 0.1250 - dense_1_acc_14: 0.1250 - dense_1_acc_15: 0.2500 - dense_1_acc_16: 0.1250 - dense_1_acc_17: 0.1875 - dense_1_acc_18: 0.1875 - dense_1_acc_19: 0.1562 - dense_1_acc_20: 0.0938 - dense_1_acc_21: 0.2188 - dense_1_acc_22: 0.1562 - dense_1_acc_23: 0.1250 - dense_1_acc_24: 0.0625 - dense_1_acc_25: 0.1875 - dense_1_acc_26: 0.2500 - dense_1_acc_27: 0.2812 - dense_1_acc_28: 0.2812 - dense_1_acc_29: 0.3125 - dense_1_acc_30: 0.2188 - dense_1_acc_31: 0.1562 - dense_1_acc_32: 0.2500 - dense_1_acc_33: 0.2188 - dense_1_acc_34: 0.1875 - dense_1_acc_35: 0.2812 - dense_1_acc_36: 0.1

Epoch 16/30
Epoch 17/30


  32/1031 [..............................] - ETA: 3s - loss: 2.8791 - dense_1_loss: 0.0201 - dense_1_acc: 0.2188 - dense_1_acc_1: 0.2812 - dense_1_acc_2: 0.1875 - dense_1_acc_3: 0.1875 - dense_1_acc_4: 0.2188 - dense_1_acc_5: 0.1875 - dense_1_acc_6: 0.2188 - dense_1_acc_7: 0.3125 - dense_1_acc_8: 0.2812 - dense_1_acc_9: 0.2188 - dense_1_acc_10: 0.2500 - dense_1_acc_11: 0.3125 - dense_1_acc_12: 0.2188 - dense_1_acc_13: 0.1250 - dense_1_acc_14: 0.1562 - dense_1_acc_15: 0.2500 - dense_1_acc_16: 0.2188 - dense_1_acc_17: 0.2500 - dense_1_acc_18: 0.0625 - dense_1_acc_19: 0.1562 - dense_1_acc_20: 0.1875 - dense_1_acc_21: 0.1875 - dense_1_acc_22: 0.0938 - dense_1_acc_23: 0.1250 - dense_1_acc_24: 0.1562 - dense_1_acc_25: 0.1562 - dense_1_acc_26: 0.1562 - dense_1_acc_27: 0.4688 - dense_1_acc_28: 0.3125 - dense_1_acc_29: 0.1250 - dense_1_acc_30: 0.1562 - dense_1_acc_31: 0.1875 - dense_1_acc_32: 0.0938 - dense_1_acc_33: 0.0312 - dense_1_acc_34: 0.1250 - dense_1_acc_35: 0.2188 - dense_1_acc_36: 0.3

Epoch 18/30
Epoch 19/30


  32/1031 [..............................] - ETA: 4s - loss: 2.8772 - dense_1_loss: 0.0215 - dense_1_acc: 0.2188 - dense_1_acc_1: 0.1250 - dense_1_acc_2: 0.1562 - dense_1_acc_3: 0.1875 - dense_1_acc_4: 0.2188 - dense_1_acc_5: 0.2188 - dense_1_acc_6: 0.3750 - dense_1_acc_7: 0.1875 - dense_1_acc_8: 0.1250 - dense_1_acc_9: 0.1875 - dense_1_acc_10: 0.1875 - dense_1_acc_11: 0.1250 - dense_1_acc_12: 0.2812 - dense_1_acc_13: 0.2812 - dense_1_acc_14: 0.1875 - dense_1_acc_15: 0.2188 - dense_1_acc_16: 0.2188 - dense_1_acc_17: 0.3125 - dense_1_acc_18: 0.0938 - dense_1_acc_19: 0.3438 - dense_1_acc_20: 0.1250 - dense_1_acc_21: 0.2812 - dense_1_acc_22: 0.1250 - dense_1_acc_23: 0.0938 - dense_1_acc_24: 0.1875 - dense_1_acc_25: 0.2812 - dense_1_acc_26: 0.1875 - dense_1_acc_27: 0.3125 - dense_1_acc_28: 0.5000 - dense_1_acc_29: 0.1562 - dense_1_acc_30: 0.0625 - dense_1_acc_31: 0.2500 - dense_1_acc_32: 0.0625 - dense_1_acc_33: 0.2188 - dense_1_acc_34: 0.3125 - dense_1_acc_35: 0.1875 - dense_1_acc_36: 0.3

Epoch 20/30
Epoch 21/30


  32/1031 [..............................] - ETA: 4s - loss: 2.8649 - dense_1_loss: 0.0205 - dense_1_acc: 0.1875 - dense_1_acc_1: 0.1875 - dense_1_acc_2: 0.2188 - dense_1_acc_3: 0.1250 - dense_1_acc_4: 0.3125 - dense_1_acc_5: 0.3438 - dense_1_acc_6: 0.3750 - dense_1_acc_7: 0.1562 - dense_1_acc_8: 0.2188 - dense_1_acc_9: 0.1875 - dense_1_acc_10: 0.2500 - dense_1_acc_11: 0.2812 - dense_1_acc_12: 0.4375 - dense_1_acc_13: 0.1250 - dense_1_acc_14: 0.1250 - dense_1_acc_15: 0.2188 - dense_1_acc_16: 0.1875 - dense_1_acc_17: 0.1875 - dense_1_acc_18: 0.2188 - dense_1_acc_19: 0.1562 - dense_1_acc_20: 0.1562 - dense_1_acc_21: 0.3125 - dense_1_acc_22: 0.1875 - dense_1_acc_23: 0.1875 - dense_1_acc_24: 0.2188 - dense_1_acc_25: 0.1562 - dense_1_acc_26: 0.1562 - dense_1_acc_27: 0.3438 - dense_1_acc_28: 0.3750 - dense_1_acc_29: 0.1875 - dense_1_acc_30: 0.2188 - dense_1_acc_31: 0.2500 - dense_1_acc_32: 0.2188 - dense_1_acc_33: 0.1875 - dense_1_acc_34: 0.1250 - dense_1_acc_35: 0.2500 - dense_1_acc_36: 0.1

Epoch 22/30
Epoch 23/30


  32/1031 [..............................] - ETA: 3s - loss: 2.8702 - dense_1_loss: 0.0206 - dense_1_acc: 0.2188 - dense_1_acc_1: 0.2500 - dense_1_acc_2: 0.2188 - dense_1_acc_3: 0.2500 - dense_1_acc_4: 0.2812 - dense_1_acc_5: 0.1562 - dense_1_acc_6: 0.1250 - dense_1_acc_7: 0.2500 - dense_1_acc_8: 0.3438 - dense_1_acc_9: 0.1250 - dense_1_acc_10: 0.1250 - dense_1_acc_11: 0.2500 - dense_1_acc_12: 0.2812 - dense_1_acc_13: 0.1250 - dense_1_acc_14: 0.1250 - dense_1_acc_15: 0.4688 - dense_1_acc_16: 0.2188 - dense_1_acc_17: 0.1562 - dense_1_acc_18: 0.2812 - dense_1_acc_19: 0.3438 - dense_1_acc_20: 0.1562 - dense_1_acc_21: 0.1250 - dense_1_acc_22: 0.1562 - dense_1_acc_23: 0.1250 - dense_1_acc_24: 0.1250 - dense_1_acc_25: 0.2500 - dense_1_acc_26: 0.1250 - dense_1_acc_27: 0.1250 - dense_1_acc_28: 0.4375 - dense_1_acc_29: 0.1250 - dense_1_acc_30: 0.2500 - dense_1_acc_31: 0.1562 - dense_1_acc_32: 0.1562 - dense_1_acc_33: 0.2812 - dense_1_acc_34: 0.2500 - dense_1_acc_35: 0.3438 - dense_1_acc_36: 0.1

Epoch 24/30
Epoch 25/30


  32/1031 [..............................] - ETA: 2s - loss: 2.8712 - dense_1_loss: 0.0199 - dense_1_acc: 0.1875 - dense_1_acc_1: 0.2500 - dense_1_acc_2: 0.2500 - dense_1_acc_3: 0.1250 - dense_1_acc_4: 0.1562 - dense_1_acc_5: 0.2812 - dense_1_acc_6: 0.2500 - dense_1_acc_7: 0.2500 - dense_1_acc_8: 0.1875 - dense_1_acc_9: 0.2812 - dense_1_acc_10: 0.2188 - dense_1_acc_11: 0.2500 - dense_1_acc_12: 0.3750 - dense_1_acc_13: 0.3125 - dense_1_acc_14: 0.0938 - dense_1_acc_15: 0.2500 - dense_1_acc_16: 0.2188 - dense_1_acc_17: 0.2188 - dense_1_acc_18: 0.0938 - dense_1_acc_19: 0.2812 - dense_1_acc_20: 0.1250 - dense_1_acc_21: 0.2500 - dense_1_acc_22: 0.1250 - dense_1_acc_23: 0.2500 - dense_1_acc_24: 0.1875 - dense_1_acc_25: 0.2500 - dense_1_acc_26: 0.1875 - dense_1_acc_27: 0.1875 - dense_1_acc_28: 0.5000 - dense_1_acc_29: 0.1875 - dense_1_acc_30: 0.2500 - dense_1_acc_31: 0.2188 - dense_1_acc_32: 0.2188 - dense_1_acc_33: 0.2188 - dense_1_acc_34: 0.2812 - dense_1_acc_35: 0.1250 - dense_1_acc_36: 0.0

Epoch 26/30
Epoch 27/30


  32/1031 [..............................] - ETA: 3s - loss: 2.8853 - dense_1_loss: 0.0214 - dense_1_acc: 0.3125 - dense_1_acc_1: 0.2500 - dense_1_acc_2: 0.1875 - dense_1_acc_3: 0.1250 - dense_1_acc_4: 0.1562 - dense_1_acc_5: 0.1562 - dense_1_acc_6: 0.2812 - dense_1_acc_7: 0.1875 - dense_1_acc_8: 0.3438 - dense_1_acc_9: 0.1250 - dense_1_acc_10: 0.2188 - dense_1_acc_11: 0.1875 - dense_1_acc_12: 0.3438 - dense_1_acc_13: 0.2188 - dense_1_acc_14: 0.1562 - dense_1_acc_15: 0.1562 - dense_1_acc_16: 0.2812 - dense_1_acc_17: 0.1250 - dense_1_acc_18: 0.1875 - dense_1_acc_19: 0.2500 - dense_1_acc_20: 0.1250 - dense_1_acc_21: 0.1250 - dense_1_acc_22: 0.1562 - dense_1_acc_23: 0.2188 - dense_1_acc_24: 0.1250 - dense_1_acc_25: 0.1875 - dense_1_acc_26: 0.1562 - dense_1_acc_27: 0.2188 - dense_1_acc_28: 0.4062 - dense_1_acc_29: 0.2188 - dense_1_acc_30: 0.1875 - dense_1_acc_31: 0.0625 - dense_1_acc_32: 0.0938 - dense_1_acc_33: 0.1875 - dense_1_acc_34: 0.1562 - dense_1_acc_35: 0.1250 - dense_1_acc_36: 0.2

Epoch 28/30
Epoch 29/30


  32/1031 [..............................] - ETA: 3s - loss: 2.8809 - dense_1_loss: 0.0206 - dense_1_acc: 0.1875 - dense_1_acc_1: 0.2188 - dense_1_acc_2: 0.1875 - dense_1_acc_3: 0.0938 - dense_1_acc_4: 0.1875 - dense_1_acc_5: 0.1875 - dense_1_acc_6: 0.1250 - dense_1_acc_7: 0.1875 - dense_1_acc_8: 0.2188 - dense_1_acc_9: 0.2188 - dense_1_acc_10: 0.1250 - dense_1_acc_11: 0.2812 - dense_1_acc_12: 0.0938 - dense_1_acc_13: 0.1562 - dense_1_acc_14: 0.0938 - dense_1_acc_15: 0.1562 - dense_1_acc_16: 0.1875 - dense_1_acc_17: 0.2188 - dense_1_acc_18: 0.2812 - dense_1_acc_19: 0.0938 - dense_1_acc_20: 0.1250 - dense_1_acc_21: 0.1562 - dense_1_acc_22: 0.0625 - dense_1_acc_23: 0.1875 - dense_1_acc_24: 0.1250 - dense_1_acc_25: 0.1562 - dense_1_acc_26: 0.1875 - dense_1_acc_27: 0.2500 - dense_1_acc_28: 0.3125 - dense_1_acc_29: 0.1875 - dense_1_acc_30: 0.1562 - dense_1_acc_31: 0.0938 - dense_1_acc_32: 0.2188 - dense_1_acc_33: 0.1562 - dense_1_acc_34: 0.3438 - dense_1_acc_35: 0.1250 - dense_1_acc_36: 0.1

Epoch 30/30


<keras.callbacks.History at 0xb40203d68>

# Use Trained Weights for Text Generation

In [9]:
def one_hot(softmax_v):
    best_idxs = K.argmax(softmax_v, axis=-1)
    return K.one_hot([best_idxs], num_classes=n_values)

def generate_text(Ty = 140):
    x0 = Input(shape=(1, n_values))
    a0 = Input(shape=(n_a,), name='a0')
    c0 = Input(shape=(n_a,), name='c0')
    a = a0
    c = c0
    x = x0
    outputs = []
    
    for t in range(Ty):
        # Get new activation and cell state
        a, _ , c = LSTM_cell(x, initial_state=[a, c])
        # Use Dense layer to get softmax output
        out = densor(a)
        outputs.append(out)
        
        x = Lambda(one_hot)(out)
    
    return Model([x0, a0, c0], outputs)
    

In [10]:
text_model = generate_text()

In [11]:
# x_initializer = np.zeros((1, 1, n_values))
# a_initializer = np.zeros((1, n_a))
# c_initializer = np.zeros((1, n_a))

x_initializer = np.random.rand(1, 1, n_values)
a_initializer = np.random.rand(1, n_a)
c_initializer = np.random.rand(1, n_a)

In [12]:
def predict_and_sample(x_initializer, a_initializer, c_initializer):
    pred = text_model.predict([x_initializer, a_initializer, c_initializer])
    indices = np.argmax(pred, axis=-1)
    results = to_categorical(indices, num_classes=n_values)
    return results, indices

In [13]:
results, indices = predict_and_sample(x_initializer, a_initializer, c_initializer)

In [14]:
new_chars = [idx_to_char[str(i)] for i in np.ndarray.flatten(indices)]

In [15]:
''.join(new_chars)

'                                                                                                                                            '

In [16]:
indices

array([[0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
    

In [17]:
pred = text_model.predict([x_initializer, a_initializer, c_initializer])