In [38]:
'''Trains an LSTM model on the IMDB sentiment classification task.
The dataset is actually too small for LSTM to be of any advantage
compared to simpler, much faster methods such as TF-IDF + LogReg.
# Notes
- RNNs are tricky. Choice of batch size is important,
choice of loss and optimizer is critical, etc.
Some configurations won't converge.
- LSTM loss decrease patterns during training can be quite different
from what you see with CNNs/MLPs/etc.
'''
from __future__ import print_function

from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers import LSTM, Dense
import numpy as np



data_dim = 7 # Features
timesteps = 3 
num_classes = 9
batch_size = 32

# Expected input batch shape: (batch_size, timesteps, data_dim)
# Note that we have to provide the full batch_input_shape since the network is stateful.
# the sample of index i in batch k is the follow-up for the sample i in batch k-1.
model = Sequential()

model.add(LSTM(32, stateful=True,
               batch_input_shape=(batch_size, timesteps, data_dim)))
# model.add(LSTM(32, return_sequences=True, stateful=True))

# model.add(LSTM(32, stateful=True))
model.add(Dense(9, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

# Generate dummy training data
x_train = np.random.random((batch_size * 10, timesteps, data_dim))
y_train = np.random.random((batch_size * 10, num_classes))

# Generate dummy validation data
x_val = np.random.random((batch_size * 2, timesteps, data_dim))
y_val = np.random.random((batch_size * 2, num_classes))

model.fit(x_train, y_train,
          batch_size=batch_size, epochs=5, shuffle=False,
          validation_data=(x_val, y_val))

Train on 320 samples, validate on 64 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7fc145376c18>

In [39]:
y_val.shape

(64, 9)

In [40]:
x_train.shape

(320, 3, 7)

In [41]:
y_train.shape

(320, 9)

In [42]:
y_train.shape

(320, 9)

In [43]:
x_val.shape

(64, 3, 7)