image-classification 

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
import matplotlib.pyplot as plt

# 1. Load CIFAR-10 Dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalize pixel values to be between 0 and 1
x_train, x_test = x_train / 255.0, x_test / 255.0

# 2. Define a Simple CNN Model
def create_model():
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(10)
    ])
    return model

# 3. Compile and Train Function for Experimentation with Different Optimizers
def compile_and_train(model, optimizer):
    model.compile(optimizer=optimizer,
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    
    # Train the model
    history = model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test), batch_size=64)
    return history

# 4. Experimenting with Different Optimizers
optimizers = ['sgd', 'adam', 'rmsprop', 'adagrad']
histories = {}

for opt in optimizers:
    print(f"Training with {opt} optimizer:")
    model = create_model()
    history = compile_and_train(model, opt)
    histories[opt] = history

# 5. Plotting the Results
plt.figure(figsize=(12, 8))

for opt in optimizers:
    history = histories[opt]
    
    # Plot Accuracy
    plt.subplot(2, 2, 1)
    plt.plot(history.history['accuracy'], label=f'{opt} train accuracy')
    plt.plot(history.history['val_accuracy'], label=f'{opt} val accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    # Plot Loss
    plt.subplot(2, 2, 2)
    plt.plot(history.history['loss'], label=f'{opt} train loss')
    plt.plot(history.history['val_loss'], label=f'{opt} val loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

plt.tight_layout()
plt.show()
