# **CIFAR-10 Training Notebook**
Complete preprocessing, visualization, model training, evaluation, and saving.

In [None]:

import tensorflow as tf
from tensorflow.keras.datasets import cifar10
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Load dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

# Normalize
x_train = x_train / 255.0
x_test = x_test / 255.0

print("Training shape:", x_train.shape)
print("Test shape:", x_test.shape)


## **Sample Visualizations**

In [None]:

plt.figure(figsize=(10,3))
for i in range(10):
    plt.subplot(2,5,i+1)
    plt.imshow(x_train[i])
    plt.title(class_names[y_train[i][0]])
    plt.axis('off')
plt.show()


## **Model Creation (CNN)**

In [None]:

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32,(3,3),activation='relu',input_shape=(32,32,3)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64,(3,3),activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128,activation='relu'),
    tf.keras.layers.Dense(10,activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(x_train, y_train, epochs=5, validation_split=0.2)


## **Training Curves**

In [None]:

plt.plot(history.history['accuracy'],label='train')
plt.plot(history.history['val_accuracy'],label='val')
plt.legend(); plt.title("Accuracy"); plt.show()

plt.plot(history.history['loss'],label='train')
plt.plot(history.history['val_loss'],label='val')
plt.legend(); plt.title("Loss"); plt.show()


## **Evaluation**

In [None]:

y_pred = np.argmax(model.predict(x_test), axis=1)
y_true = y_test.reshape(-1)

print(classification_report(y_true, y_pred, target_names=class_names))


## **Confusion Matrix**

In [None]:

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10,7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names,
            yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()


## **Save Model Class Names & Weights**

In [None]:

import json

model.save("cifar10_model.h5")

with open("class_names.json", "w") as f:
    json.dump(class_names, f)

print("Model and class names saved.")
