In [2]:
import tensorflow as tf
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense
from sklearn.metrics import classification_report

In [3]:
# Load IMDB dataset with the top 10,000 most frequent words
vocab_size = 10000
max_length = 200 # max words per review

(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=vocab_size)

# Pad sequences
x_train = pad_sequences(x_train, maxlen=max_length)
x_test = pad_sequences(x_test, maxlen=max_length)

model = Sequential([
 Embedding(input_dim=vocab_size, output_dim=128,
input_length=max_length),
 LSTM(64, dropout=0.2, recurrent_dropout=0.2),
 Dense(1, activation='sigmoid') # binary classification
])

model.compile(loss='binary_crossentropy', optimizer='adam',
metrics=['accuracy'])
model.summary()


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
[1m17464789/17464789[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step




In [4]:
model.fit(x_train, y_train, epochs=3, batch_size=64, validation_split=0.2)

Epoch 1/3
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m104s[0m 316ms/step - accuracy: 0.6960 - loss: 0.5659 - val_accuracy: 0.8240 - val_loss: 0.3916
Epoch 2/3
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m140s[0m 312ms/step - accuracy: 0.8621 - loss: 0.3287 - val_accuracy: 0.8378 - val_loss: 0.3845
Epoch 3/3
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m143s[0m 314ms/step - accuracy: 0.8700 - loss: 0.3171 - val_accuracy: 0.8110 - val_loss: 0.4190


<keras.src.callbacks.history.History at 0x7bcb7d2e3510>

In [5]:
y_pred_prob = model.predict(x_test)
y_pred = (y_pred_prob > 0.5).astype("int32")
print(classification_report(y_test, y_pred))

[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 42ms/step
              precision    recall  f1-score   support

           0       0.83      0.80      0.81     12500
           1       0.81      0.83      0.82     12500

    accuracy                           0.82     25000
   macro avg       0.82      0.82      0.82     25000
weighted avg       0.82      0.82      0.82     25000

