# Bidirectional LSTM on IMDB

In [5]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

max_features = 20000
maxlen = 200
batch_size = 32
epochs = 2

# 数据加载
(x_train, y_train), (x_val, y_val) = tf.keras.datasets.imdb.load_data(num_words = max_features)
print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")
x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen = maxlen)
x_val = tf.keras.preprocessing.sequence.pad_sequences(x_val, maxlen = maxlen)


# 模型构建
inputs = tf.keras.Input(shape = (None,), dtype = tf.int32)
x = layers.Embedding(max_features, 128)(inputs)
x = layers.Bidirectional(layers.LSTM(64, return_sequences = True))(x)
x = layers.Bidirectional(layers.LSTM(64))(x)
outputs = layers.Dense(1, activation = "sigmoid")(x)

model = tf.keras.Model(inputs, outputs)
model.summary()

model.compile(
    optimizer = "adam",
    loss = "binary_crossentropy",
    metrics = ["accuracy"],
)
model.fit(
    x_train,
    y_train, 
    batch_size = batch_size, 
    epochs = epochs, 
    validation_data = (x_val, y_val)
)

25000 Training sequences
25000 Validation sequences
Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, None)]            0         
_________________________________________________________________
embedding_4 (Embedding)      (None, None, 128)         2560000   
_________________________________________________________________
bidirectional_8 (Bidirection (None, None, 128)         98816     
_________________________________________________________________
bidirectional_9 (Bidirection (None, 128)               98816     
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 129       
Total params: 2,757,761
Trainable params: 2,757,761
Non-trainable params: 0
_________________________________________________________________
Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x19941d450>