# Notebook 6: Data Augmentation

**Course:** 21CSE558T - Deep Neural Network Architectures  
**Module 4:** CNNs - Practical Session  
**Date:** Monday, November 3, 2025  
**Duration:** 30 minutes  
**Objective:** Use data augmentation to improve model generalization and accuracy

---

## The Data Problem

**Scenario:** You have only 1,000 images but need 10,000 for good accuracy.

**Traditional solution:** Collect more data (expensive, time-consuming)

**Smart solution:** **Data Augmentation** - Create new training samples from existing ones!

---

## What is Data Augmentation?

**Definition:** Apply random transformations to training images:
- Rotation
- Shifting
- Flipping
- Zooming
- Brightness changes

**Result:** Same image becomes many different versions → More training data!

**Key:** Only augment **training** data, not test data!

---

## Benefits:

1. ✅ **Reduces overfitting** - Model sees more variations
2. ✅ **Improves generalization** - Better on unseen data
3. ✅ **Makes model robust** - Invariant to small changes
4. ✅ **No cost** - Free data from existing images!

In [None]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
    Conv2D, MaxPooling2D, GlobalAveragePooling2D,
    Dense, Dropout, BatchNormalization
)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
import warnings
warnings.filterwarnings('ignore')

print(f"✅ TensorFlow version: {tf.__version__}")

# Set seed
tf.random.set_seed(42)
np.random.seed(42)

---

## Part 1: Load Data

In [None]:
# Load Fashion-MNIST
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# Preprocessing
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
y_train_cat = to_categorical(y_train, 10)
y_test_cat = to_categorical(y_test, 10)

print(f"✅ Data loaded: {x_train.shape[0]:,} training samples")

---

## Part 2: Visualize Augmentation Techniques

Let's see what each augmentation does!

In [None]:
# Select one sample image
sample_img = x_train[0]
sample_label = class_names[y_train[0]]

# Prepare for augmentation (need batch dimension)
img_batch = sample_img.reshape((1, 28, 28, 1))

print(f"Original image: {sample_label}")
print(f"Shape: {sample_img.shape}")

### Augmentation 1: Rotation

In [None]:
# Rotation augmentation
rotation_gen = ImageDataGenerator(rotation_range=30)  # ±30 degrees

fig, axes = plt.subplots(2, 5, figsize=(15, 6))

# Original
axes[0, 0].imshow(sample_img.reshape(28, 28), cmap='gray')
axes[0, 0].set_title('Original', fontsize=12, fontweight='bold')
axes[0, 0].axis('off')

# Generate rotated versions
aug_iter = rotation_gen.flow(img_batch, batch_size=1)
for i in range(9):
    row = (i + 1) // 5
    col = (i + 1) % 5
    aug_img = next(aug_iter)[0]
    axes[row, col].imshow(aug_img.reshape(28, 28), cmap='gray')
    axes[row, col].set_title(f'Rotated {i+1}', fontsize=10)
    axes[row, col].axis('off')

plt.suptitle(f'Rotation Augmentation (±30°) - {sample_label}', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("💡 Each training iteration shows the model a different rotation!")

### Augmentation 2: Shifting

In [None]:
# Shifting augmentation
shift_gen = ImageDataGenerator(
    width_shift_range=0.2,   # Shift horizontally by ±20%
    height_shift_range=0.2   # Shift vertically by ±20%
)

fig, axes = plt.subplots(2, 5, figsize=(15, 6))

axes[0, 0].imshow(sample_img.reshape(28, 28), cmap='gray')
axes[0, 0].set_title('Original', fontsize=12, fontweight='bold')
axes[0, 0].axis('off')

aug_iter = shift_gen.flow(img_batch, batch_size=1)
for i in range(9):
    row = (i + 1) // 5
    col = (i + 1) % 5
    aug_img = next(aug_iter)[0]
    axes[row, col].imshow(aug_img.reshape(28, 28), cmap='gray')
    axes[row, col].set_title(f'Shifted {i+1}', fontsize=10)
    axes[row, col].axis('off')

plt.suptitle(f'Translation/Shifting Augmentation - {sample_label}', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("💡 Object can appear at different positions in the image!")

### Augmentation 3: Zooming

In [None]:
# Zoom augmentation
zoom_gen = ImageDataGenerator(zoom_range=0.2)  # Zoom in/out by ±20%

fig, axes = plt.subplots(2, 5, figsize=(15, 6))

axes[0, 0].imshow(sample_img.reshape(28, 28), cmap='gray')
axes[0, 0].set_title('Original', fontsize=12, fontweight='bold')
axes[0, 0].axis('off')

aug_iter = zoom_gen.flow(img_batch, batch_size=1)
for i in range(9):
    row = (i + 1) // 5
    col = (i + 1) % 5
    aug_img = next(aug_iter)[0]
    axes[row, col].imshow(aug_img.reshape(28, 28), cmap='gray')
    axes[row, col].set_title(f'Zoomed {i+1}', fontsize=10)
    axes[row, col].axis('off')

plt.suptitle(f'Zoom Augmentation - {sample_label}', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("💡 Objects can appear at different scales!")

### Augmentation 4: Horizontal Flip

In [None]:
# Flip augmentation
flip_gen = ImageDataGenerator(horizontal_flip=True)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

# Original
axes[0].imshow(sample_img.reshape(28, 28), cmap='gray')
axes[0].set_title('Original', fontsize=13, fontweight='bold')
axes[0].axis('off')

# Flipped versions
aug_iter = flip_gen.flow(img_batch, batch_size=1)
for i in range(2):
    aug_img = next(aug_iter)[0]
    axes[i+1].imshow(aug_img.reshape(28, 28), cmap='gray')
    axes[i+1].set_title(f'Flipped {i+1}', fontsize=13, fontweight='bold')
    axes[i+1].axis('off')

plt.suptitle(f'Horizontal Flip Augmentation - {sample_label}', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("💡 Note: Not all classes should be flipped (e.g., text, digits)")
print("   For Fashion-MNIST, flipping makes sense for most items!")

### Augmentation 5: Combined Transformations

In [None]:
# Combined augmentation (realistic scenario)
combined_gen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.15,
    height_shift_range=0.15,
    zoom_range=0.15,
    horizontal_flip=True,
    fill_mode='nearest'
)

fig, axes = plt.subplots(3, 5, figsize=(15, 9))

# Original
axes[0, 0].imshow(sample_img.reshape(28, 28), cmap='gray')
axes[0, 0].set_title('ORIGINAL', fontsize=12, fontweight='bold', color='red')
axes[0, 0].axis('off')

# Generate 14 augmented versions
aug_iter = combined_gen.flow(img_batch, batch_size=1)
for i in range(14):
    if i == 0:
        row, col = 0, 1
    else:
        row = (i + 1) // 5
        col = (i + 1) % 5
    
    aug_img = next(aug_iter)[0]
    axes[row, col].imshow(aug_img.reshape(28, 28), cmap='gray')
    axes[row, col].set_title(f'Augmented {i+1}', fontsize=10)
    axes[row, col].axis('off')

plt.suptitle(f'Combined Augmentation - {sample_label}\n(Rotation + Shift + Zoom + Flip)', 
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("💡 One image → Infinite variations!")
print("💡 Model never sees exact same image twice during training!")

---

## Part 3: Train Without Augmentation (Baseline)

In [None]:
# Build a good CNN model
def build_model():
    model = Sequential([
        Conv2D(64, (3, 3), padding='same', input_shape=(28, 28, 1)),
        BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        Conv2D(64, (3, 3), padding='same'),
        BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        MaxPooling2D((2, 2)),
        Dropout(0.25),
        
        Conv2D(128, (3, 3), padding='same'),
        BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        Conv2D(128, (3, 3), padding='same'),
        BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        MaxPooling2D((2, 2)),
        Dropout(0.25),
        
        GlobalAveragePooling2D(),
        Dense(256),
        BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        Dropout(0.5),
        Dense(10, activation='softmax')
    ])
    return model

# Build baseline model
baseline_model = build_model()
baseline_model.compile(optimizer='adam',
                       loss='categorical_crossentropy',
                       metrics=['accuracy'])

print("✅ Baseline model built")
print(f"Parameters: {baseline_model.count_params():,}")

In [None]:
# Train WITHOUT augmentation
print("🚀 Training WITHOUT augmentation...\n")

baseline_history = baseline_model.fit(
    x_train, y_train_cat,
    batch_size=128,
    epochs=20,
    validation_split=0.1,
    verbose=1
)

baseline_test_loss, baseline_test_acc = baseline_model.evaluate(x_test, y_test_cat, verbose=0)
print(f"\n📊 Baseline Test Accuracy: {baseline_test_acc:.2%}")

---

## Part 4: Train WITH Augmentation

In [None]:
# Create augmentation generator
train_datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)

# NO augmentation for validation/test (only normalize)
val_datagen = ImageDataGenerator()  # No augmentation!

print("✅ Data augmentation generators created")
print("\n📋 Augmentation parameters:")
print("  • Rotation: ±15°")
print("  • Width shift: ±10%")
print("  • Height shift: ±10%")
print("  • Zoom: ±10%")
print("  • Horizontal flip: Yes")

In [None]:
# Build new model for augmented training
augmented_model = build_model()
augmented_model.compile(optimizer='adam',
                        loss='categorical_crossentropy',
                        metrics=['accuracy'])

# Split training data for validation
split_idx = int(0.9 * len(x_train))
x_train_split = x_train[:split_idx]
y_train_split = y_train_cat[:split_idx]
x_val_split = x_train[split_idx:]
y_val_split = y_train_cat[split_idx:]

print(f"Training samples: {len(x_train_split):,}")
print(f"Validation samples: {len(x_val_split):,}")

In [None]:
# Train WITH augmentation
print("\n🚀 Training WITH augmentation...\n")

# Create generators
train_generator = train_datagen.flow(x_train_split, y_train_split, batch_size=128)
val_generator = val_datagen.flow(x_val_split, y_val_split, batch_size=128)

augmented_history = augmented_model.fit(
    train_generator,
    steps_per_epoch=len(x_train_split) // 128,
    epochs=20,
    validation_data=val_generator,
    validation_steps=len(x_val_split) // 128,
    verbose=1
)

augmented_test_loss, augmented_test_acc = augmented_model.evaluate(x_test, y_test_cat, verbose=0)
print(f"\n📊 Augmented Model Test Accuracy: {augmented_test_acc:.2%}")

---

## Part 5: Compare Results

In [None]:
# Comparison plots
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Accuracy comparison
axes[0].plot(baseline_history.history['accuracy'], 'b-', label='Train (No Aug)', linewidth=2)
axes[0].plot(baseline_history.history['val_accuracy'], 'b--', label='Val (No Aug)', linewidth=2)
axes[0].plot(augmented_history.history['accuracy'], 'g-', label='Train (With Aug)', linewidth=2)
axes[0].plot(augmented_history.history['val_accuracy'], 'g--', label='Val (With Aug)', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Accuracy', fontsize=12)
axes[0].set_title('Accuracy: With vs Without Augmentation', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Loss comparison
axes[1].plot(baseline_history.history['loss'], 'b-', label='Train (No Aug)', linewidth=2)
axes[1].plot(baseline_history.history['val_loss'], 'b--', label='Val (No Aug)', linewidth=2)
axes[1].plot(augmented_history.history['loss'], 'g-', label='Train (With Aug)', linewidth=2)
axes[1].plot(augmented_history.history['val_loss'], 'g--', label='Val (With Aug)', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Loss', fontsize=12)
axes[1].set_title('Loss: With vs Without Augmentation', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

plt.suptitle('Data Augmentation Impact', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Final comparison
import pandas as pd

comparison_data = {
    'Model': ['Without Augmentation', 'With Augmentation'],
    'Test Accuracy': [f"{baseline_test_acc:.2%}", f"{augmented_test_acc:.2%}"],
    'Test Loss': [f"{baseline_test_loss:.4f}", f"{augmented_test_loss:.4f}"],
    'Final Train Acc': [
        f"{baseline_history.history['accuracy'][-1]:.2%}",
        f"{augmented_history.history['accuracy'][-1]:.2%}"
    ],
    'Final Val Acc': [
        f"{baseline_history.history['val_accuracy'][-1]:.2%}",
        f"{augmented_history.history['val_accuracy'][-1]:.2%}"
    ],
    'Overfitting Gap': [
        f"{baseline_history.history['accuracy'][-1] - baseline_history.history['val_accuracy'][-1]:.2%}",
        f"{augmented_history.history['accuracy'][-1] - augmented_history.history['val_accuracy'][-1]:.2%}"
    ]
}

df_comparison = pd.DataFrame(comparison_data)

print("\n" + "="*90)
print("DATA AUGMENTATION - FINAL COMPARISON")
print("="*90)
print(df_comparison.to_string(index=False))
print("="*90)

# Calculate improvement
improvement = augmented_test_acc - baseline_test_acc
print(f"\n🎯 Test Accuracy Improvement: {improvement:+.2%}")

if improvement > 0:
    print("\n✅ SUCCESS! Augmentation improved test accuracy!")
    print("✅ Overfitting gap reduced")
    print("✅ Model generalizes better to unseen data")
else:
    print("\n⚠️ Note: Augmentation helps most when you have limited data")

---

## Part 6: When to Use Which Augmentation?

Not all augmentations work for all datasets!

In [None]:
# Create visualization of good vs bad augmentations
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

# Select different types of images
indices = [0, 1, 2]  # T-shirt, Trouser, Pullover
aug_types = [
    ('Original', None),
    ('Rotation ✅', ImageDataGenerator(rotation_range=30)),
    ('Shift ✅', ImageDataGenerator(width_shift_range=0.2, height_shift_range=0.2)),
    ('Flip ✅', ImageDataGenerator(horizontal_flip=True))
]

for i, idx in enumerate(indices):
    img = x_train[idx:idx+1]
    label = class_names[y_train[idx]]
    
    for j, (aug_name, aug_gen) in enumerate(aug_types):
        if aug_gen is None:
            axes[i, j].imshow(img[0].reshape(28, 28), cmap='gray')
        else:
            aug_img = next(aug_gen.flow(img, batch_size=1))[0]
            axes[i, j].imshow(aug_img.reshape(28, 28), cmap='gray')
        
        if i == 0:
            axes[i, j].set_title(aug_name, fontsize=12, fontweight='bold')
        
        if j == 0:
            axes[i, j].set_ylabel(label, fontsize=11, fontweight='bold')
        
        axes[i, j].axis('off')

plt.suptitle('Data Augmentation Examples on Fashion-MNIST', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n📋 Augmentation Guidelines:")
print("\n✅ GOOD for Fashion-MNIST:")
print("  • Rotation (±15-20°) - clothes can be at angles")
print("  • Shifting (±10-15%) - items not always centered")
print("  • Zoom (±10-15%) - different sizes")
print("  • Horizontal flip - most items symmetric")
print("\n❌ BAD for Fashion-MNIST:")
print("  • Vertical flip - upside-down clothes don't make sense")
print("  • Extreme rotation (±45°+) - unrealistic")
print("  • Color jitter - already grayscale")

print("\n📋 For other datasets:")
print("\n• MNIST Digits:")
print("  ✅ Small rotation (±10°)")
print("  ✅ Small shift")
print("  ❌ NO flipping (6 becomes 9!)")
print("\n• Natural Images (cats, dogs):")
print("  ✅ All transformations")
print("  ✅ Brightness, contrast changes")
print("  ✅ Color jitter")
print("\n• Medical Images (X-rays):")
print("  ⚠️ Careful with augmentation")
print("  ✅ Small rotation, shift")
print("  ❌ Usually NO flip (left/right matters)")

---

## Summary: Key Takeaways 🎯

### What is Data Augmentation?

- **Transform training images** on-the-fly
- **Infinite variations** from limited data
- **Only apply to training** (not validation/test)
- **Free regularization** - no additional data collection

### Common Augmentation Techniques:

1. **✅ Rotation** - Rotate image by random angle
2. **✅ Translation** - Shift horizontally/vertically
3. **✅ Zoom** - Scale in/out
4. **✅ Flip** - Horizontal/vertical mirroring
5. **✅ Brightness** - Adjust lighting (for color images)
6. **✅ Contrast** - Change contrast (for color images)

### ImageDataGenerator Parameters:

```python
ImageDataGenerator(
    rotation_range=20,        # ±20 degrees
    width_shift_range=0.1,    # ±10% width
    height_shift_range=0.1,   # ±10% height
    zoom_range=0.1,           # ±10% zoom
    horizontal_flip=True,     # 50% chance flip
    fill_mode='nearest'       # How to fill new pixels
)
```

### Benefits:

1. **✅ Reduces overfitting** - More diverse training data
2. **✅ Improves test accuracy** - Better generalization
3. **✅ Makes model robust** - Invariant to transformations
4. **✅ No cost** - Generated from existing data

### Best Practices:

- **Start conservative** - Small augmentation values
- **Check visualizations** - Ensure augmented images make sense
- **Domain-specific** - Different datasets need different augmentations
- **Monitor validation** - Too much augmentation can hurt
- **Combine with other regularization** - Dropout, BatchNorm, etc.

### When to Use:

| Scenario | Augmentation Benefit |
|----------|---------------------|
| Small dataset (<10K images) | ⭐⭐⭐⭐⭐ Very High |
| Medium dataset (10K-100K) | ⭐⭐⭐⭐ High |
| Large dataset (>100K) | ⭐⭐⭐ Moderate |
| Overfitting observed | ⭐⭐⭐⭐⭐ Very High |
| Balanced classes | ⭐⭐⭐⭐ High |
| Imbalanced classes | ⭐⭐⭐⭐⭐ Very High |

---

## Practice Exercises 📝

1. **Experiment:** Try extreme augmentation (rotation=45, zoom=0.5). What happens?

2. **Challenge:** Augment only the minority classes (imbalanced dataset simulation)

3. **Analysis:** How many training epochs do you need with vs without augmentation?

4. **Custom:** Create your own augmentation pipeline for a specific use case

---

## Next: Notebook 7 - Final Challenge! 🏆

**Coming up:** Combine everything you learned to build the best possible CNN!

---

*⏱️ Time spent: ~30 minutes*  
*💪 Difficulty: Intermediate*  
*🎓 Mastery: Data augmentation techniques*