In [None]:
# Import the necessary libraries
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import fashion_mnist
import matplotlib.pyplot as plt

# Load and preprocess the Fashion MNIST dataset
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0

# Reshape the data
train_images = train_images[..., np.newaxis]
test_images = test_images[..., np.newaxis]

# Class names for the dataset
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# Build the CNN model
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.BatchNormalization(),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train the model with early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
model.fit(train_images, train_labels, epochs=50, validation_split=0.2, batch_size=64, callbacks=[early_stopping])

# Evaluate the model
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f'Test accuracy: {test_acc:.3f}')

# Predict on the first two test images
predictions = model.predict(test_images[:2])

# Function to plot the predictions
def plot_predictions(images, labels, predictions):
    plt.figure(figsize=(8, 4))
    for i in range(len(images)):
        plt.subplot(1, 2, i + 1)
        plt.imshow(images[i].reshape(28, 28), cmap=plt.cm.binary)
        plt.title(f"Predicted: {class_names[np.argmax(predictions[i])]} (True: {class_names[labels[i]]})")
        plt.axis('off')
    plt.show()

# Plot predictions for the first two images
plot_predictions(test_images[:2], test_labels[:2], predictions)