In [1]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import SimpleRNN, Dense
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD

# Load MNIST data
(X_train, y_train), (X_valid, y_valid) = mnist.load_data()

# Reshape the data to fit RNN input requirements (num_timesteps, input_dim)
X_train = X_train.reshape(60000, 28, 28).astype('float32') / 255
X_valid = X_valid.reshape(10000, 28, 28).astype('float32') / 255

# One-hot encode the labels
n_classes = 10
y_train = to_categorical(y_train, n_classes)
y_valid = to_categorical(y_valid, n_classes)

# Define the RNN model
model = Sequential([
    SimpleRNN(128, activation='relu', input_shape=(28, 28)),
    Dense(64, activation='relu'),
    Dense(10, activation='softmax')
])

# Compile the model
model.compile(loss='categorical_crossentropy', optimizer=SGD(learning_rate=0.1), metrics=['accuracy'])

# Train the model
history = model.fit(X_train, y_train, batch_size=128, epochs=5, verbose=1, validation_data=(X_valid, y_valid))

# Evaluate the model
score = model.evaluate(X_valid, y_valid, verbose=0)
print(f"Test loss: {score[0]}")
print(f"Test accuracy: {score[1]}")


  super().__init__(**kwargs)


Epoch 1/5
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 26ms/step - accuracy: 0.3420 - loss: 1.8083 - val_accuracy: 0.7488 - val_loss: 0.7520
Epoch 2/5
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 26ms/step - accuracy: 0.6865 - loss: 0.9129 - val_accuracy: 0.9088 - val_loss: 0.3159
Epoch 3/5
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 26ms/step - accuracy: 0.8867 - loss: 0.3805 - val_accuracy: 0.8994 - val_loss: 0.3378
Epoch 4/5
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 26ms/step - accuracy: 0.9250 - loss: 0.2623 - val_accuracy: 0.8961 - val_loss: 0.3324
Epoch 5/5
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 27ms/step - accuracy: 0.9276 - loss: 0.2434 - val_accuracy: 0.9483 - val_loss: 0.1719
Test loss: 0.17194698750972748
Test accuracy: 0.9483000040054321
