In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import VGG16
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt

In [None]:

# Load MNIST dataset
(input_train, target_train), (input_test, target_test) = mnist.load_data()

# Add a channel dimension to the grayscale images (from 28x28 to 28x28x1)
input_train = input_train.reshape(-1, 28, 28, 1)
input_test = input_test.reshape(-1, 28, 28, 1)

# Preprocess data: Resize MNIST images from 28x28x1 to 32x32x3
input_train_resized = np.array([tf.image.resize(img, (32, 32)) for img in input_train])
input_test_resized = np.array([tf.image.resize(img, (32, 32)) for img in input_test])

# Normalize pixel values to be between 0 and 1
input_train_resized = input_train_resized.astype('float32') / 255
input_test_resized = input_test_resized.astype('float32') / 255

In [None]:
# Convert labels to one-hot encoded vectors
target_train = to_categorical(target_train, 10)
target_test = to_categorical(target_test, 10)

# Data Augmentation
datagen = ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1
)
datagen.fit(input_train_resized)

In [None]:
# Define a function to create the customized VGG16 model
def create_model():
    # Load VGG16 model with pre-trained weights from ImageNet and exclude top layers
    vgg16 = VGG16(weights='imagenet', include_top=False, input_shape=(32, 32, 3))

    # Freeze fully-connected (FC) and output layers of VGG16
    for layer in vgg16.layers:
        layer.trainable = False

    # Add custom layers for MNIST classification
    model = models.Sequential()
    model.add(layers.Conv2D(3, (3, 3), padding='same', input_shape=(32, 32, 1)))  # Convert grayscale to 3 channels
    model.add(vgg16)
    model.add(layers.Flatten())
    model.add(layers.Dense(512, activation='relu'))  # New FC layer
    model.add(layers.Dense(10, activation='softmax'))  # Output layer

    # Compile model
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model


In [None]:

# Create the original VGG16 model
model_original = create_model()

# Train the model with original data
history_original = model_original.fit(input_train_resized, target_train, epochs=5, 
                                       batch_size=64, validation_data=(input_test_resized, target_test))

# Create the augmented data generator
train_generator = datagen.flow(input_train_resized, target_train, batch_size=64)

# Create the augmented VGG16 model
model_augmented = create_model()

# Train the model with augmented data
history_augmented = model_augmented.fit(train_generator, epochs=5, 
                                         validation_data=(input_test_resized, target_test))

# Evaluate the models on the test set
score_original = model_original.evaluate(input_test_resized, target_test)
score_augmented = model_augmented.evaluate(datagen.flow(input_test_resized, target_test, batch_size=64))

# Print evaluation scores
print(f"Original Model - Test loss: {score_original[0]}, Test accuracy: {score_original[1]}")
print(f"Augmented Model - Test loss: {score_augmented[0]}, Test accuracy: {score_augmented[1]}")

# Plot the accuracy of both models
plt.figure(figsize=(12, 6))
plt.plot(history_original.history['accuracy'], label='Original Train Accuracy', color='blue')
plt.plot(history_original.history['val_accuracy'], label='Original Validation Accuracy', color='cyan')
plt.plot(history_augmented.history['accuracy'], label='Augmented Train Accuracy', color='orange')
plt.plot(history_augmented.history['val_accuracy'], label='Augmented Validation Accuracy', color='red')
plt.title('Model Accuracy Comparison')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.show()

# Display model summary and number of parameters
print("Original Model Summary:")
model_original.summary()
print("\nAugmented Model Summary:")
model_augmented.summary()
