In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, applications
import matplotlib.pyplot as plt
import numpy as np

# Define image dimensions
img_height, img_width = 224, 224  # Adjust for AlexNet and ResNet input size

# 1. Load the MNIST dataset
def load_mnist_data():
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

    # Preprocess images (resize and normalize)
    train_images = np.expand_dims(train_images, -1).astype("float32") / 255.0  # Normalize to [0, 1]
    test_images = np.expand_dims(test_images, -1).astype("float32") / 255.0  # Normalize to [0, 1]

    # Resize images to fit the model's input size (224x224 for AlexNet and ResNet)
    train_images = tf.image.resize(train_images, [img_height, img_width]).numpy()
    test_images = tf.image.resize(test_images, [img_height, img_width]).numpy()

    return (train_images, train_labels), (test_images, test_labels)

# 2. Build and Train a Simple CNN Model
def build_simple_cnn():
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 1)),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(10, activation='softmax')  # 10 classes for digits 0-9
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

# 3. Load Pre-trained AlexNet (without fine-tuning)
def build_alexnet(pretrained=False):
    base_model = applications.VGG16(weights=None if not pretrained else 'imagenet',
                                     include_top=False, input_shape=(img_height, img_width, 1))
    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dense(1024, activation='relu'),
        layers.Dense(10, activation='softmax')  # 10 classes for digits 0-9
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

# 4. Load Pre-trained ResNet (without fine-tuning)
def build_resnet(pretrained=False):
    base_model = applications.ResNet50(weights=None if not pretrained else 'imagenet',
                                       include_top=False, input_shape=(img_height, img_width, 1))
    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dense(1024, activation='relu'),
        layers.Dense(10, activation='softmax')  # 10 classes for digits 0-9
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

# 5. Fine-Tuning Pre-trained Models (AlexNet and ResNet)
def fine_tune_model(base_model, base_learning_rate=1e-5):
    for layer in base_model.layers:
        layer.trainable = True  # Unfreeze all layers

    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dense(1024, activation='relu'),
        layers.Dense(10, activation='softmax')  # 10 classes for digits 0-9
    ])
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

# Load the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = load_mnist_data()

# 1. Train Simple CNN
simple_cnn_model = build_simple_cnn()
history_simple_cnn = simple_cnn_model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))

# 2. Train AlexNet (without fine-tuning)
alexnet_model = build_alexnet(pretrained=False)
history_alexnet = alexnet_model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))

# 3. Train ResNet (without fine-tuning)
resnet_model = build_resnet(pretrained=False)
history_resnet = resnet_model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))

# 4. Fine-tune AlexNet
alexnet_model_ft = build_alexnet(pretrained=True)
alexnet_model_ft = fine_tune_model(alexnet_model_ft)
history_alexnet_ft = alexnet_model_ft.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))

# 5. Fine-tune ResNet
resnet_model_ft = build_resnet(pretrained=True)
resnet_model_ft = fine_tune_model(resnet_model_ft)
history_resnet_ft = resnet_model_ft.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))

# 6. Compare the Results
def plot_history(history, model_name):
    plt.plot(history.history['accuracy'], label=f'{model_name} Training Accuracy')
    plt.plot(history.history['val_accuracy'], label=f'{model_name} Validation Accuracy')
    plt.title(f'{model_name} Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

# Plot comparison
plt.figure(figsize=(12, 8))
plot_history(history_simple_cnn, 'Simple CNN')
plot_history(history_alexnet, 'AlexNet (No Fine-tune)')
plot_history(history_resnet, 'ResNet (No Fine-tune)')
plot_history(history_alexnet_ft, 'AlexNet (Fine-tune)')
plot_history(history_resnet_ft, 'ResNet (Fine-tune)')
plt.show()

# Evaluate the models on the test set
print("Evaluating Simple CNN:")
print(simple_cnn_model.evaluate(test_images, test_labels))

print("Evaluating AlexNet (No Fine-tune):")
print(alexnet_model.evaluate(test_images, test_labels))

print("Evaluating ResNet (No Fine-tune):")
print(resnet_model.evaluate(test_images, test_labels))

print("Evaluating AlexNet (Fine-tune):")
print(alexnet_model_ft.evaluate(test_images, test_labels))

print("Evaluating ResNet (Fine-tune):")
print(resnet_model_ft.evaluate(test_images, test_labels))