In [None]:
# Experiment-1: Design a Customize Convolutional Neural Network (CNN) for Handwritten Digit Classification
# Import necessary libraries
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

# Load and preprocess the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255

# Data augmentation
data_gen = ImageDataGenerator(rotation_range=10, width_shift_range=0.1, height_shift_range=0.1)
data_gen.fit(x_train)

# Define the CNN model
def create_cnn_model():
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        layers.MaxPooling2D((3, 3), strides=(1, 1)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((3, 3), strides=(1, 1)),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(10, activation='softmax')
    ])
    return model

# Compile and train the model on original and augmented data
original_model = create_cnn_model()
original_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
original_model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

augmented_model = create_cnn_model()
augmented_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
augmented_model.fit(data_gen.flow(x_train, y_train), epochs=5, validation_data=(x_test, y_test))

# Evaluate and compare the models
original_accuracy = original_model.evaluate(x_test, y_test, verbose=0)[1]
augmented_accuracy = augmented_model.evaluate(x_test, y_test, verbose=0)[1]

# Plot the comparison
plt.bar(['Original', 'Augmented'], [original_accuracy, augmented_accuracy])
plt.title('Accuracy Comparison')
plt.ylabel('Accuracy')
plt.show()