In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical

# Step 1: Load and Preprocess Data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = np.expand_dims(x_train, -1) / 255.0, np.expand_dims(x_test, -1) / 255.0  # Normalize and expand dims
y_train, y_test = to_categorical(y_train, 10), to_categorical(y_test, 10)  # One-hot encode labels

# Step 2: Define CNN Model
def create_model():
    model = Sequential([
        Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        MaxPooling2D((2, 2)),
        Conv2D(64, (3, 3), activation='relu'),
        MaxPooling2D((2, 2)),
        Flatten(),
        Dense(128, activation='relu'),
        Dense(10, activation='softmax')  # Output layer for 10 classes
    ])
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

# Step 3: Data Augmentation
datagen = ImageDataGenerator(
    rotation_range=20,  # Random rotation
    width_shift_range=0.2,  # Random width shift
    height_shift_range=0.2,  # Random height shift
    shear_range=0.2,  # Shearing transformation
    zoom_range=0.2,  # Random zoom
    horizontal_flip=True,  # Horizontal flip
    fill_mode='nearest'  # Filling the newly created pixels
)

# Fit on training data
datagen.fit(x_train)

# Step 4: Train Model with Augmented Data
model_aug = create_model()
history_aug = model_aug.fit(datagen.flow(x_train, y_train, batch_size=64), epochs=10, validation_data=(x_test, y_test))

# Step 5: Train Model on Original Data (Without Augmentation)
model_no_aug = create_model()
history_no_aug = model_no_aug.fit(x_train, y_train, epochs=10, batch_size=64, validation_data=(x_test, y_test))

# Step 6: Plot Training and Validation Accuracy for Both Models
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(history_aug.history['accuracy'], label='Train Accuracy (Augmented)')
plt.plot(history_aug.history['val_accuracy'], label='Validation Accuracy (Augmented)')
plt.title('CNN with Data Augmentation')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history_no_aug.history['accuracy'], label='Train Accuracy (Original)')
plt.plot(history_no_aug.history['val_accuracy'], label='Validation Accuracy (Original)')
plt.title('CNN without Data Augmentation')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()