# Introduction to MNIST Dataset

## ðŸ“š Learning Objectives

By the end of this notebook, you will:
- Understand the MNIST dataset structure
- Load and explore the data
- Visualize sample images
- Analyze class distribution
- Understand image representation as tensors

## 1. What is MNIST?

**MNIST** (Modified National Institute of Standards and Technology) is a classic dataset in machine learning:

- **Purpose**: Handwritten digit recognition (0-9)
- **Training samples**: 60,000 images
- **Test samples**: 10,000 images
- **Image size**: 28 Ã— 28 pixels
- **Color**: Grayscale (1 channel)
- **Classes**: 10 (digits 0-9)

### Why MNIST?
- **Benchmark dataset**: Standard for testing ML algorithms
- **Manageable size**: Fast to train
- **Real-world application**: OCR (Optical Character Recognition)
- **Educational**: Perfect for learning CNNs

In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.datasets import mnist

# Set random seed for reproducibility
np.random.seed(42)

# Configure matplotlib
plt.style.use('default')
%matplotlib inline

## 2. Loading the Dataset

Keras provides a convenient function to load MNIST directly.

In [None]:
# Load MNIST dataset
print("Loading MNIST dataset...")
(x_train, y_train), (x_test, y_test) = mnist.load_data()

print("\nâœ“ Dataset loaded successfully!")
print(f"\nTraining set: {x_train.shape[0]:,} images")
print(f"Test set: {x_test.shape[0]:,} images")

## 3. Exploring Data Shapes

Understanding data shapes is crucial in deep learning.

In [None]:
print("Data Shapes:")
print("="*50)
print(f"x_train shape: {x_train.shape}")
print(f"  â†’ {x_train.shape[0]} images")
print(f"  â†’ {x_train.shape[1]} Ã— {x_train.shape[2]} pixels")
print()
print(f"y_train shape: {y_train.shape}")
print(f"  â†’ {y_train.shape[0]} labels")
print()
print(f"x_test shape: {x_test.shape}")
print(f"y_test shape: {y_test.shape}")
print("="*50)

## 4. Understanding Image Data

Each image is a 28Ã—28 matrix of pixel values.

In [None]:
# Look at a single image
sample_image = x_train[0]
sample_label = y_train[0]

print(f"Sample image shape: {sample_image.shape}")
print(f"Sample label: {sample_label}")
print(f"\nPixel value range: [{sample_image.min()}, {sample_image.max()}]")
print(f"Data type: {sample_image.dtype}")

# Display the pixel matrix (first 10x10 corner)
print("\nFirst 10Ã—10 pixels:")
print(sample_image[:10, :10])

## 5. Visualizing Sample Images

Let's see what these digits actually look like!

In [None]:
# Display first 25 images
fig, axes = plt.subplots(5, 5, figsize=(10, 10))
axes = axes.flatten()

for i in range(25):
    axes[i].imshow(x_train[i], cmap='gray')
    axes[i].set_title(f'Label: {y_train[i]}', fontsize=12)
    axes[i].axis('off')

plt.suptitle('First 25 MNIST Images', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

## 6. Class Distribution Analysis

Let's check if the dataset is balanced across all digit classes.

In [None]:
# Count samples per class
unique, counts = np.unique(y_train, return_counts=True)

# Create bar plot
plt.figure(figsize=(10, 6))
bars = plt.bar(unique, counts, color='steelblue', alpha=0.7, edgecolor='black')

# Add count labels
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
            f'{int(height):,}',
            ha='center', va='bottom', fontsize=10)

plt.xlabel('Digit Class', fontsize=12)
plt.ylabel('Number of Samples', fontsize=12)
plt.title('Class Distribution in Training Set', fontsize=14, fontweight='bold')
plt.xticks(unique)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

# Print statistics
print("\nClass Distribution:")
for digit, count in zip(unique, counts):
    percentage = (count / len(y_train)) * 100
    print(f"Digit {digit}: {count:,} samples ({percentage:.2f}%)")

## 7. Pixel Value Distribution

Understanding pixel value distribution helps with preprocessing.

In [None]:
# Flatten all images and plot histogram
all_pixels = x_train.flatten()

plt.figure(figsize=(10, 6))
plt.hist(all_pixels, bins=50, color='steelblue', alpha=0.7, edgecolor='black')
plt.xlabel('Pixel Value', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Pixel Value Distribution', fontsize=14, fontweight='bold')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nPixel Statistics:")
print(f"Min value: {all_pixels.min()}")
print(f"Max value: {all_pixels.max()}")
print(f"Mean value: {all_pixels.mean():.2f}")
print(f"Std deviation: {all_pixels.std():.2f}")

## 8. Visualizing One Digit from Each Class

Let's see one example of each digit (0-9).

In [None]:
# Find first occurrence of each digit
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
axes = axes.flatten()

for digit in range(10):
    # Find first image of this digit
    idx = np.where(y_train == digit)[0][0]
    
    axes[digit].imshow(x_train[idx], cmap='gray')
    axes[digit].set_title(f'Digit: {digit}', fontsize=12, fontweight='bold')
    axes[digit].axis('off')

plt.suptitle('One Sample from Each Class', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 9. Exploring Digit Variations

Let's see multiple examples of the same digit to understand variation.

In [None]:
# Show 10 different examples of digit '5'
target_digit = 5
indices = np.where(y_train == target_digit)[0][:10]

fig, axes = plt.subplots(2, 5, figsize=(12, 6))
axes = axes.flatten()

for i, idx in enumerate(indices):
    axes[i].imshow(x_train[idx], cmap='gray')
    axes[i].axis('off')

plt.suptitle(f'10 Different Examples of Digit {target_digit}', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 10. Key Takeaways

### What We Learned:

1. **Dataset Structure**:
   - 60,000 training images, 10,000 test images
   - Each image is 28Ã—28 pixels
   - Pixel values range from 0 (black) to 255 (white)

2. **Data Representation**:
   - Images are stored as NumPy arrays
   - Shape: (num_samples, height, width)
   - Labels are integers from 0 to 9

3. **Dataset Characteristics**:
   - Relatively balanced across classes
   - Significant variation within each class
   - Grayscale images (simpler than color)

4. **Preprocessing Needs**:
   - Normalization (0-255 â†’ 0-1)
   - Reshaping for CNN input
   - Label encoding (if needed)

### Next Steps:
- **Notebook 02**: Data preprocessing and normalization
- **Notebook 03**: CNN architecture design
- **Notebook 04**: Model training
- **Notebook 05**: Evaluation and analysis

## ðŸŽ¯ Practice Exercises

Try these exercises to reinforce your understanding:

1. **Find the darkest and brightest images** in the training set
2. **Calculate the average image** for each digit class
3. **Identify which digit has the most variation** in pixel values
4. **Create a collage** showing 100 random images
5. **Analyze the test set** - is it similar to the training set?

Good luck! ðŸš€