# Lesson 5: EfficientNet-B0 Transfer Learning for Flower Classification

## Overview
Learn transfer learning with EfficientNet-B0, a highly efficient neural architecture that achieves excellent performance with fewer parameters than traditional CNNs. This lesson demonstrates how modern efficient architectures revolutionize the accuracy-vs-efficiency trade-off.

### Learning Objectives
- Understand EfficientNet architecture and compound scaling
- Implement transfer learning with efficient architectures  
- Compare efficiency metrics: accuracy vs parameters vs speed
- Analyze the trade-offs between model complexity and performance

### Model Quick Facts
- **Architecture**: EfficientNet-B0 (efficient CNN with 5.3M parameters)
- **Pre-training**: ImageNet dataset (1.2M images, 1000 classes)
- **Key Innovation**: Compound scaling + MBConv blocks + SE attention
- **Transfer Method**: Feature extraction + fine-tuning
- **Expected Performance**: ~90%+ accuracy on Flowers102 (best yet!)
- **Efficiency**: 2.2× fewer parameters than ResNet18, ~5% better accuracy


## Step 1: Environment Setup and Library Imports

### Why This Step Matters
Setting up the environment correctly is crucial for:
- **Reproducibility**: Ensuring consistent results across different runs
- **Performance**: Optimizing GPU usage and memory management  
- **Compatibility**: EfficientNet requires specific PyTorch versions

### Key Libraries Explained
- **torch**: Core PyTorch library (tensors, automatic differentiation, neural networks)
- **torchvision**: Computer vision utilities (datasets, transforms, pre-trained models)
- **models**: Pre-trained model architectures (EfficientNet, ResNet, etc.)
- **optim**: Optimization algorithms (SGD, Adam, AdamW)
- **DataLoader**: Efficient batch processing and parallel data loading
- **tqdm**: Progress bars for training loops
- **matplotlib**: Data visualization and plotting
- **sklearn**: Machine learning utilities (metrics, confusion matrix)

### EfficientNet Compatibility
EfficientNet requires PyTorch 1.6+ for proper Swish/SiLU activation support.


In [None]:
# Core PyTorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Computer vision utilities
import torchvision
import torchvision.transforms as transforms
from torchvision import models

# Data handling and visualization
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import copy

# Machine learning utilities
from sklearn.metrics import confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

# Configure matplotlib for high-quality plots
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10
plt.style.use('default')

print("✅ Libraries imported successfully!")
print(f"📦 PyTorch version: {torch.__version__}")
print(f"🖼️ Torchvision version: {torchvision.__version__}")
print(f"🔥 CUDA available: {torch.cuda.is_available()}")
print(f"🍎 MPS available: {torch.backends.mps.is_available()}")

# Check EfficientNet availability
try:
    test_model = models.efficientnet_b0(pretrained=False)
    print("✅ EfficientNet-B0 available!")
    del test_model  # Clean up
except Exception as e:
    print(f"❌ EfficientNet not available: {e}")
    print("💡 Please update PyTorch/torchvision to latest version")


In [None]:
# Define data transformations
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load datasets
print("📁 Loading Flowers102 dataset...")
try:
    train_dataset = torchvision.datasets.Flowers102(
        root='./data', 
        split='train',
        transform=train_transforms,
        download=True
    )
    
    val_dataset = torchvision.datasets.Flowers102(
        root='./data', 
        split='val',
        transform=val_transforms,
        download=False
    )
    
    test_dataset = torchvision.datasets.Flowers102(
        root='./data', 
        split='test',
        transform=val_transforms,
        download=False
    )
    
    print(f"✅ Dataset loaded successfully!")
    print(f"📊 Training images: {len(train_dataset)}")
    print(f"📊 Validation images: {len(val_dataset)}")
    print(f"📊 Test images: {len(test_dataset)}")
    
except Exception as e:
    print(f"❌ Error loading dataset: {e}")
    print("💡 Make sure you have internet connection for first download")

# Create data loaders
BATCH_SIZE = 32

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=2,
    pin_memory=torch.cuda.is_available()
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=2,
    pin_memory=torch.cuda.is_available()
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=2,
    pin_memory=torch.cuda.is_available()
)

print(f"🔄 Data loaders created with batch size: {BATCH_SIZE}")
print(f"📦 Training batches: {len(train_loader)}")
print(f"📦 Validation batches: {len(val_loader)}")
print(f"📦 Test batches: {len(test_loader)}")


In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"🔧 Using device: {device}")

# Initialize EfficientNet-B0 model
print("🏗️ Initializing EfficientNet-B0 model...")
model = models.efficientnet_b0(pretrained=True)

# Print model information
print(f"📊 Model loaded with pre-trained ImageNet weights")
print(f"🔢 Original classifier input features: {model.classifier[1].in_features}")
print(f"🔢 Original classifier output classes: {model.classifier[1].out_features}")

# Replace classifier for 102 flower classes
model.classifier = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(model.classifier[1].in_features, 102)
)

# Move model to device
model = model.to(device)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"\n📈 Model Architecture Summary:")
print(f"   Total parameters: {total_params:,}")
print(f"   Model size: ~{total_params * 4 / 1024 / 1024:.1f}MB")

# Display model structure
print(f"\n🏗️ Model Structure:")
print(f"   Features: {sum(p.numel() for p in model.features.parameters()):,} parameters")
print(f"   Classifier: {sum(p.numel() for p in model.classifier.parameters()):,} parameters")

# Print expected efficiency
print(f"\n⚡ Efficiency Comparison:")
print(f"   ResNet18: ~11.7M params, ~85% expected accuracy")
print(f"   ResNet50: ~25.6M params, ~88% expected accuracy")
print(f"   EfficientNet-B0: ~5.3M params, ~90% expected accuracy")
print(f"   Efficiency ratio: {90/5.3:.1f}% per M params (EfficientNet-B0)")


In [None]:
# Training configuration
NUM_EPOCHS = 50
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.01

# Phase 1: Feature Extraction (freeze backbone)
print("🔒 Phase 1: Feature Extraction Setup")
print("   Freezing feature layers...")

# Freeze all feature layers
for param in model.features.parameters():
    param.requires_grad = False

# Keep classifier trainable
for param in model.classifier.parameters():
    param.requires_grad = True

# Count trainable parameters for Phase 1
trainable_params_phase1 = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"   Trainable parameters in Phase 1: {trainable_params_phase1:,}")

# Loss function and optimizer for Phase 1
criterion = nn.CrossEntropyLoss()
optimizer_phase1 = optim.AdamW(model.classifier.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

print(f"✅ Phase 1 configuration complete")
print(f"   Loss function: CrossEntropyLoss")
print(f"   Optimizer: AdamW (lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY})")
print(f"   Training epochs: 1-20")

# Prepare for Phase 2 setup (will be used later)
print(f"\n🔓 Phase 2: Fine-tuning Setup (will be activated at epoch 21)")
print(f"   Will unfreeze all layers")
print(f"   Will train entire model with same learning rate")
print(f"   Training epochs: 21-50")

# Training tracking variables
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
epoch_times = []

print(f"\n📊 Training tracking initialized")
print(f"   Metrics: loss, accuracy, training time")
print(f"   Total epochs: {NUM_EPOCHS}")
print(f"   Expected training time: ~10-15 minutes")


In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train model for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Progress bar
    progress_bar = tqdm(train_loader, desc="Training", leave=False)
    
    for batch_idx, (inputs, labels) in enumerate(progress_bar):
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def evaluate(model, val_loader, criterion, device):
    """Evaluate model on validation set"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc="Validation", leave=False)
        
        for batch_idx, (inputs, labels) in enumerate(progress_bar):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

print("✅ Training and evaluation functions defined")
print("   train_epoch(): Trains model for one epoch with progress tracking")
print("   evaluate(): Evaluates model on validation set with accuracy calculation")


In [None]:
# Core PyTorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Computer vision utilities
import torchvision
import torchvision.transforms as transforms
from torchvision import models

# Data handling and visualization
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import copy

# Machine learning utilities
from sklearn.metrics import confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

# Configure matplotlib for high-quality plots
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10
plt.style.use('default')

print("✅ Libraries imported successfully!")
print(f"📦 PyTorch version: {torch.__version__}")
print(f"🖼️ Torchvision version: {torchvision.__version__}")
print(f"🔥 CUDA available: {torch.cuda.is_available()}")
print(f"🍎 MPS available: {torch.backends.mps.is_available()}")

# Check EfficientNet availability
try:
    test_model = models.efficientnet_b0(pretrained=False)
    print("✅ EfficientNet-B0 available!")
except Exception as e:
    print(f"❌ EfficientNet not available: {e}")
    print("💡 Please update PyTorch/torchvision to latest version")
