In [2]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
import numpy as np

# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize the pixel values
x_train = x_train / 255.0
x_test = x_test / 255.0

# Reshape the input data
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

# One-hot encode the labels
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)

# Define the model architecture
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.2),
    Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(x_train, y_train, epochs=10, batch_size=128, validation_data=(x_test, y_test))

# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_acc * 100:.2f}%')

# Predict label for a random handwritten image from the MNIST dataset
while True:
    digit = int(input("Enter a digit between 0 and 9 to predict a random handwritten image (Enter -1 to exit): "))
    if digit == -1:
        break
    elif digit < 0 or digit > 9:
        print("Invalid input. Please enter an integer between 0 and 9.")
    else:
        # Find indices of images corresponding to the given digit
        image_indices = np.where(y_test[:, digit] == 1)[0]
        if len(image_indices) == 0:
            print(f"No images found for digit {digit} in the test set.")
        else:
            # Select a random image index
            random_index = np.random.choice(image_indices)
            sample_image = x_test[random_index].reshape(1, 28, 28, 1)
            prediction = model.predict(sample_image)
            predicted_label = np.argmax(prediction)
            print(f"Predicted label for image {random_index}: {predicted_label}")

            # Display the image and its predicted label
            plt.imshow(sample_image.squeeze(), cmap='gray')
            plt.title(f'Predicted label: {predicted_label}')
            plt.axis('off')
            plt.show()

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Test accuracy: 99.29%
Enter a digit between 0 and 9 to predict a random handwritten image (Enter -1 to exit): -1
