In [23]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import VGG16
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt

# ============================
# 1. Train CNN from Scratch
# ============================
def train_cnn_from_scratch():
    # Load and preprocess CIFAR-10 dataset
    (train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
    train_images, test_images = train_images / 255.0, test_images / 255.0  # Normalize
    train_labels, test_labels = to_categorical(train_labels), to_categorical(test_labels)  # One-hot encode

    # Define a CNN model
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(10, activation='softmax')  # 10 classes for CIFAR-10
    ])

    # Compile the model
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    model.summary()

    # Train the model
    history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels), verbose=1)

    # Evaluate the model
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print(f"CNN from Scratch - Test Accuracy: {test_acc:.4f}")

    # Plot training results
    plot_training_history(history, title="CNN from Scratch")


# ============================
# 2. Use Pretrained VGG16 Model
# ============================
def train_pretrained_vgg16():
    # Load and preprocess CIFAR-10 dataset
    (train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
    train_images, test_images = train_images / 255.0, test_images / 255.0  # Normalize
    train_labels, test_labels = to_categorical(train_labels), to_categorical(test_labels)  # One-hot encode

    # Load VGG16 pretrained model (without fully connected layers)
    vgg_base = VGG16(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
    vgg_base.trainable = False  # Freeze the pretrained layers

    # Define a new model with VGG16 as the base
    model = models.Sequential([
        vgg_base,
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(10, activation='softmax')  # 10 classes for CIFAR-10
    ])

    # Compile the model
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

    # Train the model
    history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels), verbose=1)
    model.summary()

    # Evaluate the model
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print(f"Pretrained VGG16 - Test Accuracy: {test_acc:.4f}")

    # Plot training results
    plot_training_history(history, title="Pretrained VGG16")


# ============================
# Utility: Plot Training History
# ============================
def plot_training_history(history, title="Training History"):
    # Plot accuracy
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title(f"{title} - Accuracy")
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    # Plot loss
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title(f"{title} - Loss")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.show()


# Call both functions
train_cnn_from_scratch()
train_pretrained_vgg16()


Epoch 1/5
[1m 415/1563[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m9s[0m 9ms/step - accuracy: 0.2357 - loss: 2.0320

KeyboardInterrupt: 