In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding,SimpleRNN,Dense

In [2]:
max_features =10000
(X_train,y_train),(X_test,y_test)=imdb.load_data(num_words=max_features)

print(X_train.shape)
print(X_test.shape)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
[1m17464789/17464789[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
(25000,)
(25000,)


In [3]:
sample_review = X_train[0]
sample_label = y_train[0]

In [4]:
word_index = imdb.get_word_index()
reverse_word_index = dict([(value,key) for (key,value) in word_index.items()])

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb_word_index.json
[1m1641221/1641221[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [5]:
decoded_review = " ".join([reverse_word_index.get(word-3,'?') for word in sample_review])

In [6]:
max_len = 500
X_train = sequence.pad_sequences(X_train,maxlen=max_len)
X_test = sequence.pad_sequences(X_test,maxlen=max_len)

In [7]:
model = Sequential()
model.add(Embedding(max_features,128,input_length=max_len))
model.add(SimpleRNN(128,activation='relu'))
model.add(Dense(1,activation='sigmoid'))



In [8]:
model.summary

In [9]:
'''
If the loss doesn't improve for 5 more epochs (because patience=5). In this case, early stopping will stop the training at epoch 15 and restore the weights from epoch 10,
which had the best performance.
'''
from tensorflow.keras.callbacks import EarlyStopping

early_stopping = EarlyStopping(monitor='val_loss',patience=5,restore_best_weights=True)

In [10]:
model.compile('adam','binary_crossentropy',metrics=['accuracy'])

In [11]:
model.fit(
    X_train,y_train,
    epochs=10,
    batch_size=32,
    validation_split=0.2,
    callbacks=[early_stopping])

Epoch 1/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 67ms/step - accuracy: 0.5659 - loss: 9672.1641 - val_accuracy: 0.6310 - val_loss: 0.6235
Epoch 2/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 56ms/step - accuracy: 0.6961 - loss: 0.5934 - val_accuracy: 0.6698 - val_loss: 0.5835
Epoch 3/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 57ms/step - accuracy: 0.7629 - loss: 0.4773 - val_accuracy: 0.7458 - val_loss: 0.6722
Epoch 4/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 56ms/step - accuracy: 0.8513 - loss: 0.3434 - val_accuracy: 0.7804 - val_loss: 0.4775
Epoch 5/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 56ms/step - accuracy: 0.9041 - loss: 0.3149 - val_accuracy: 0.6228 - val_loss: 0.6594
Epoch 6/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 55ms/step - accuracy: 0.8004 - loss: 0.4138 - val_accuracy: 0.7710 - val_loss: 0.6072
Epoch 7/10
[

<keras.src.callbacks.history.History at 0x7cfbbb7a3cd0>

In [14]:
model.save('simple_rnn_model.h5')

