# Fashion MNIST Classification with CNN
This Colab notebook implements a Convolutional Neural Network (CNN) to classify images from the Fashion MNIST dataset.

In [None]:
# 1. (Optional) Mount Google Drive for saving models
from google.colab import drive
# drive.mount('/content/drive')

## 2. Install and Import Dependencies

In [None]:
!pip install tensorflow matplotlib

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

## 3. Load and Preprocess Data

In [None]:
from tensorflow.keras.datasets import fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

# Reshape and normalize
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
print(f"Training samples: {len(x_train)}, Test samples: {len(x_test)}")

## 4. Build the CNN Model

In [None]:
model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
    MaxPooling2D((2,2)),
    Conv2D(64, (3,3), activation='relu'),
    MaxPooling2D((2,2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.summary()

## 5. Train the Model

In [None]:
history = model.fit(
    x_train, y_train,
    epochs=10,
    batch_size=64,
    validation_split=0.2,
    verbose=1
)

## 6. Evaluate and Visualize

In [None]:
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"Test accuracy: {test_acc:.3f}")

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(); plt.title('Accuracy over Epochs')
plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.title('Loss over Epochs')
plt.show()

## 7. Visualize Sample Predictions

In [None]:
preds = np.argmax(model.predict(x_test), axis=1)
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([]); plt.yticks([]); plt.grid(False)
    plt.imshow(x_test[i].reshape(28,28), cmap='gray')
    col = 'green' if preds[i]==y_test[i] else 'red'
    plt.xlabel(f"{class_names[preds[i]]}\n({class_names[y_test[i]]})", color=col)
plt.tight_layout(); plt.show()

## 8. Extensions
- Use `ImageDataGenerator` for data augmentation
- Try deeper or pretrained models (e.g., MobileNet)
- Tune hyperparameters: learning rate, dropout rate, batch size