# 02: Feature Extraction - The Fast Solution

**Course:** 21CSE558T - Deep Neural Network Architectures  
**Module 4:** CNNs & Transfer Learning (Week 12)  
**Estimated Time:** 8-10 minutes  
**Prerequisites:** Notebook 01  
**Goal:** Use transfer learning to achieve 88-92% accuracy

---

## 📚 What You'll Learn

In this notebook, you will:
1. Load ResNet50 pre-trained on ImageNet (1.2M images)
2. Freeze the base model (25M parameters)
3. Train only the final classifier layer
4. Achieve **88-92% accuracy** with same 3,000 images!

**Key Message:** _"Don't reinvent the wheel - borrow knowledge from ImageNet!"_

---

In [None]:
# Setup
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.applications import ResNet50
from tensorflow.keras import Sequential
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense

print(f"✅ TensorFlow: {tf.__version__}")
print("✅ Ready to see transfer learning magic!\n")

## Step 1: Load Same Dataset (TF Flowers)

We'll use the EXACT same dataset as Notebook 01 to prove transfer learning works!

In [None]:
# Load flowers dataset
(train_ds, val_ds), info = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:]'],
    as_supervised=True,
    with_info=True
)

num_classes = info.features['label'].num_classes
class_names = info.features['label'].names

# Preprocess for ResNet50 (224×224 input required)
IMG_SIZE = 224  # ResNet50 default
BATCH_SIZE = 32

def preprocess(image, label):
    # Resize to 224x224
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    
    # IMPORTANT: Use ResNet50's preprocessing function!
    # This applies ImageNet-specific preprocessing (RGB->BGR, zero-centering)
    image = tf.keras.applications.resnet50.preprocess_input(image)
    
    return image, label

train_ds = train_ds.map(preprocess).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.map(preprocess).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print("✅ Same 3,000 flowers dataset loaded!")
print(f"   Classes: {class_names}")
print("\n🔑 Using ResNet50-specific preprocessing!")

## Step 2: Load Pre-trained ResNet50

**The Magic Step!**

We'll load ResNet50 that was trained on ImageNet:
- **Training data:** 1.2 million images, 1,000 categories
- **Training time:** 2 weeks on 8 GPUs
- **Training cost:** ~$15,000
- **Our cost:** FREE! (We download pre-trained weights)

Watch for the download message...

In [None]:
# Load ResNet50 pre-trained on ImageNet
print("📥 Loading ResNet50 pre-trained weights...")
print("   (This downloads 98 MB - first time only)\n")

base_model = ResNet50(
    weights='imagenet',        # 🔑 KEY: Use ImageNet pre-trained weights!
    include_top=False,         # Remove ImageNet classifier (1000 classes)
    input_shape=(IMG_SIZE, IMG_SIZE, 3)
)

print("\n✅ ResNet50 loaded with ImageNet knowledge!")
print(f"   Total parameters: {base_model.count_params():,}")
print("\n   This model already learned:")
print("   🔸 Edges, textures, colors (early layers)")
print("   🔸 Shapes, patterns (middle layers)")
print("   🔸 Object parts (deep layers)")
print("\n   These features work on ANY image - including flowers!")

## Step 3: FREEZE the Base Model

**THE MOST IMPORTANT LINE IN TRANSFER LEARNING:**

```python
base_model.trainable = False
```

This freezes all 25M parameters - we won't update them during training!

**Why freeze?**
- These features are already perfect (learned from 1.2M images)
- We only have 3,000 images - not enough to improve them
- Freezing = 10× faster training, prevents overfitting

In [None]:
# FREEZE the base model
base_model.trainable = False  # ❄️ THE KEY LINE!

print("❄️  Base model FROZEN!")
print(f"   Frozen parameters: {base_model.count_params():,}")
print("   These won't update during training.")
print("\n   ✅ We're borrowing ImageNet knowledge, not reinventing it!")

## Step 4: Add Custom Classifier

We'll add our own classifier for 5 flower classes:
```
ResNet50 (frozen 25M params)
    ↓
GlobalAveragePooling2D
    ↓
Dense(5, softmax)  ← ONLY THIS TRAINS!
```

In [None]:
# Build complete model
model = Sequential([
    base_model,                              # Frozen ResNet50
    GlobalAveragePooling2D(),                # Reduce spatial dimensions
    Dense(num_classes, activation='softmax') # Our classifier (5 classes)
], name='ResNet50_FeatureExtraction')

# Compile
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Print summary
model.summary()

# Count trainable vs non-trainable
trainable = sum([tf.size(w).numpy() for w in model.trainable_weights])
non_trainable = sum([tf.size(w).numpy() for w in model.non_trainable_weights])

print("\n📊 PARAMETER COUNT:")
print("="*50)
print(f"❄️  Frozen (non-trainable): {non_trainable:,}")
print(f"🔥 Training (trainable):   {trainable:,}")
print(f"📉 We're training only {trainable/non_trainable*100:.3f}% of params!")
print("="*50)
print("\n✅ Model ready! Let's train ONLY the classifier.")

## Step 5: Train (Watch the Magic!)

**Prediction:** 
- Notebook 01 (from scratch): 45-55% accuracy ❌
- Notebook 02 (transfer learning): 88-92% accuracy ✅

**Watch for:**
- Epoch 1: Already ~75% accuracy (much higher than scratch!)
- Epoch 5: Reaches 88-92% accuracy
- Training time: ~3 minutes (vs 2-3 min for scratch, but MUCH better!)

In [None]:
# Train!
print("🚀 Training with transfer learning...\n")
print("🎯 Watch the validation accuracy - should start HIGH!\n")

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=5,  # Only need 5 epochs!
    verbose=1
)

print("\n✅ Training complete!")
print("\n🎉 THE MAGIC HAPPENED! Let's analyze...")

## Step 6: Compare Results

Let's compare Notebook 01 (scratch) vs Notebook 02 (transfer learning)

In [None]:
# Plot comparison
plt.figure(figsize=(14, 5))

# Accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Acc', marker='o', linewidth=2)
plt.plot(history.history['val_accuracy'], label='Validation Acc', marker='s', linewidth=2)
plt.axhline(y=0.50, color='red', linestyle='--', label='Notebook 01 Result (50%)', alpha=0.7)
plt.title('Transfer Learning: Much Better Accuracy!', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

# Loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss', marker='o', linewidth=2)
plt.plot(history.history['val_loss'], label='Validation Loss', marker='s', linewidth=2)
plt.title('Transfer Learning: Lower Loss', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print results
final_val_acc = history.history['val_accuracy'][-1]

print("\n" + "="*70)
print("🎯 FINAL COMPARISON")
print("="*70)
print(f"\n{'Method':<30} {'Val Accuracy':<20} {'Improvement'}")
print("-"*70)
print(f"{'Notebook 01 (Scratch)':<30} {'~50%':<20} {'Baseline'}")
print(f"{'Notebook 02 (Transfer)':<30} {f'{final_val_acc:.1%}':<20} {'+' + f'{(final_val_acc - 0.50)*100:.0f}'} percentage points! ✨")
print("-"*70)
print(f"\n🎉 We achieved {final_val_acc:.1%} accuracy with SAME 3,000 images!")
print(f"   That's a {(final_val_acc/0.50 - 1)*100:.0f}% relative improvement!")
print("\n✨ THAT is the power of transfer learning!")
print("="*70)

## Step 7: Visualize Predictions

Let's see our model in action!

In [None]:
# Get predictions on validation set
plt.figure(figsize=(15, 10))
for i, (image_batch, label_batch) in enumerate(val_ds.take(1)):
    predictions = model.predict(image_batch)
    
    for j in range(min(9, len(image_batch))):
        plt.subplot(3, 3, j + 1)
        plt.imshow(image_batch[j].numpy())
        
        true_label = class_names[label_batch[j].numpy()]
        pred_label = class_names[np.argmax(predictions[j])]
        confidence = np.max(predictions[j]) * 100
        
        color = 'green' if true_label == pred_label else 'red'
        plt.title(f"True: {true_label}\nPred: {pred_label} ({confidence:.0f}%)", 
                 color=color, fontsize=10)
        plt.axis('off')

plt.suptitle('Sample Predictions (Green = Correct, Red = Wrong)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n✅ Model making confident, accurate predictions!")

---

## 🎓 Summary: What You Learned

### The Solution:
1. ✅ **Load pre-trained model:** ResNet50 with ImageNet weights
2. ✅ **Freeze base model:** Don't update 25M parameters
3. ✅ **Add custom classifier:** Train only final layer (~20K params)
4. ✅ **Train fast:** 5 epochs, ~3 minutes
5. ✅ **Achieve great accuracy:** 88-92% vs 45-55% from scratch!

### The Impact:
| Metric | Scratch (Notebook 01) | Transfer (Notebook 02) | Improvement |
|--------|---------------------|---------------------|-------------|
| Validation Accuracy | 45-55% | 88-92% | **+40%** |
| Overfitting | Severe | Minimal | Much better |
| Training Time | 3 min | 3 min | Same |
| Data Required | 100K+ images | 3K images | **30× less** |

### The Key Insight:
**"ImageNet knowledge transfers to flowers!"**

ResNet50 learned universal features from 1.2M images:
- Edges → Work on flower petals ✅
- Textures → Work on flower patterns ✅
- Shapes → Work on flower structure ✅

We just taught it: "These features = daisy, these = rose, etc."

---

## ✅ Key Takeaways

Before moving to Notebook 03, make sure you understand:

- ✅ **Feature extraction strategy:** Freeze base, train classifier only
- ✅ **Why it works:** Universal features transfer across domains
- ✅ **When to use:** Small datasets (<5K images), limited compute
- ✅ **Code pattern:** `base.trainable = False` is the key line!

**Question:** Can we do EVEN BETTER than 90%?

**Answer:** YES! → Notebook 03: Fine-Tuning

---

## 🚀 Next Steps

**Two options:**

**Option 1: Learn Fine-Tuning (Advanced)**
👉 Open Notebook 03 to learn how to unfreeze top layers and reach 92-95% accuracy

**Option 2: Compare Models**
👉 Open Notebook 04 to compare VGG16, ResNet50, MobileNetV2

**Recommended:** Do both! But Notebook 03 teaches the most valuable skill.

---

**End of Notebook 02**

**Status:** ✅ Feature extraction mastered!

**Achievement Unlocked:** 🏆 90% accuracy with small dataset

**Time spent:** ~8-10 minutes

**Next:** Notebook 03 - Fine-Tuning (optional but powerful!) 🔥