In [None]:
# Import required libraries
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential, save_model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout
import numpy as np
import matplotlib.pyplot as plt


In [None]:
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalize pixel values to be between 0 and 1
x_train, x_test = x_train / 255.0, x_test / 255.0

# Convert labels to categorical (one-hot encoding)
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

In [None]:
# Define the model
model = Sequential([
    Input(shape=(32, 32, 3)),  # Explicitly define the input layer
    Conv2D(32, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    Flatten(),
    Dense(64, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')  # 10 output classes for CIFAR-10
])

# Compile the model
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
# Train the model with validation
try:
    history = model.fit(
        x_train, y_train,
        epochs=10,
        validation_data=(x_test, y_test),
        batch_size=64
    )
except Exception as e:
    print(f"Error during training: {e}")
    raise

In [None]:
# Evaluate the model
try:
    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
    print("Test accuracy:", test_acc)
except Exception as e:
    print(f"Error during evaluation: {e}")
    raise

In [None]:
# Save the model
try:
    save_model(model, 'image_classification_model.keras')
    print("Model saved successfully.")
except Exception as e:
    print(f"Error saving the model: {e}")
    raise

In [None]:
# Make predictions and visualize
try:
    predictions = model.predict(x_test)

    # Display the first 10 images with predictions and actual labels
    plt.figure(figsize=(10, 10))
    for i in range(10):
        # De-normalize the image for display
        img_array = (x_test[i] * 255).astype(np.uint8)

        # Plot each image in a 2x5 grid
        plt.subplot(2, 5, i + 1)
        plt.imshow(img_array)
        plt.axis('off')

        # Display the predicted and actual class
        plt.title(f"Pred: {predictions[i].argmax()}\nActual: {y_test[i].argmax()}")

    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Error during prediction or visualization: {e}")
    raise