# CNN Image Classifier (CIFAR-10)

This notebook demonstrates image classification using a Convolutional Neural Network (CNN) on the CIFAR-10 dataset. It covers data loading, model training, data augmentation, transfer learning, and evaluation.

In [1]:
# Imports
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt

## 1. Load and Visualize CIFAR-10 Data

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
print('Train shape:', x_train.shape, y_train.shape)
print('Test shape:', x_test.shape, y_test.shape)

# Visualize a few images
plt.figure(figsize=(10,2))
for i in range(10):
    plt.subplot(1,10,i+1)
    plt.imshow(x_train[i])
    plt.axis('off')
    plt.title(class_names[y_train[i][0]], fontsize=8)
plt.show()

## 2. Normalize Data

In [3]:
# Normalize pixel values
x_train, x_test = x_train / 255.0, x_test / 255.0

## 3. Build and Train Baseline CNN

In [4]:
model = models.Sequential([
    layers.Conv2D(32, (3,3), activation='relu', input_shape=(32,32,3)),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()

In [5]:
# Train baseline model
history = model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))
model.save('models/baseline_cnn.h5')

## 4. Plot Training History

In [6]:
plt.figure()
plt.plot(history.history['accuracy'], label='train acc')
plt.plot(history.history['val_accuracy'], label='val acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Baseline CNN Accuracy')
plt.savefig('plots/baseline_cnn_accuracy.png')
plt.show()

## 5. Data Augmentation

In [7]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True
)
datagen.fit(x_train)

# Train model with augmentation
aug_model = models.clone_model(model)
aug_model.set_weights(model.get_weights())
aug_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

aug_history = aug_model.fit(datagen.flow(x_train, y_train, batch_size=64),
                           epochs=10, validation_data=(x_test, y_test))
aug_model.save('models/augmented_cnn.h5')

In [8]:
# Plot accuracy with augmentation
plt.figure()
plt.plot(aug_history.history['accuracy'], label='train acc (aug)')
plt.plot(aug_history.history['val_accuracy'], label='val acc (aug)')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Augmented CNN Accuracy')
plt.savefig('plots/augmented_cnn_accuracy.png')
plt.show()

## 6. Transfer Learning with ResNet50

In [9]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout

resnet = ResNet50(weights='imagenet', include_top=False, input_shape=(32,32,3))
resnet.trainable = False

tl_model = models.Sequential([
    resnet,
    GlobalAveragePooling2D(),
    Dense(128, activation='relu'),
    Dropout(0.3),
    Dense(10, activation='softmax')
])
tl_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

tl_history = tl_model.fit(datagen.flow(x_train, y_train, batch_size=64),
                         epochs=5, validation_data=(x_test, y_test))
tl_model.save('models/resnet_cnn.h5')

In [10]:
# Plot accuracy for transfer learning
plt.figure()
plt.plot(tl_history.history['accuracy'], label='train acc (ResNet)')
plt.plot(tl_history.history['val_accuracy'], label='val acc (ResNet)')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Transfer Learning (ResNet50) Accuracy')
plt.savefig('plots/resnet_cnn_accuracy.png')
plt.show()

## 7. Evaluate and Show Sample Predictions

In [11]:
preds = np.argmax(tl_model.predict(x_test[:10]), axis=1)
plt.figure(figsize=(10,2))
for i in range(10):
    plt.subplot(1,10,i+1)
    plt.imshow(x_test[i])
    plt.axis('off')
    plt.title(f"Pred: {class_names[preds[i]]}", fontsize=8)
plt.show()