# Lesson 3: ResNet18 Transfer Learning for Flower Classification

## Overview
Learn transfer learning with ResNet18 on the Flowers102 dataset. This lesson demonstrates how pre-trained models can be adapted for new classification tasks with minimal training time.

### Learning Objectives
- Understand ResNet18 architecture and residual connections
- Implement transfer learning with pre-trained weights
- Use progressive training strategy (freeze → fine-tune)
- Evaluate model performance and analyze results

### Model Quick Facts
- **Architecture**: ResNet18 (18 layers, 11.7M parameters)
- **Pre-training**: ImageNet dataset (1.2M images, 1000 classes)
- **Key Innovation**: Residual connections for deep network training
- **Transfer Method**: Feature extraction + fine-tuning
- **Expected Performance**: ~85%+ accuracy on Flowers102


## Step 1: Environment Setup and Library Imports

### Why This Step Matters
Setting up the environment correctly is crucial for:
- **Reproducibility**: Ensuring consistent results across different runs
- **Performance**: Optimizing GPU usage and memory management
- **Debugging**: Clean output without unnecessary warnings

### Key Libraries Explained
- **torch**: Core PyTorch library (tensors, automatic differentiation, neural networks)
- **torchvision**: Computer vision utilities (datasets, transforms, pre-trained models)
- **models**: Pre-trained model architectures (ResNet18, VGG, etc.)
- **optim**: Optimization algorithms (SGD, Adam, AdamW)
- **DataLoader**: Efficient batch processing and parallel data loading
- **tqdm**: Progress bars for training loops
- **matplotlib**: Data visualization and plotting
- **sklearn**: Machine learning utilities (metrics, confusion matrix)

### Configuration Settings
We configure matplotlib for high-quality visualizations and set up proper warning filters for cleaner output during training.


In [2]:
# Core PyTorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Computer vision utilities
import torchvision
import torchvision.transforms as transforms
from torchvision import models

# Data handling and visualization
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import copy

# Machine learning utilities
from sklearn.metrics import confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

# Configure matplotlib for high-quality plots
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10
plt.style.use('default')

print("✅ Libraries imported successfully!")
print(f"📦 PyTorch version: {torch.__version__}")
print(f"🖼️ Torchvision version: {torchvision.__version__}")
print(f"🔥 CUDA available: {torch.cuda.is_available()}")
print(f"🍎 MPS available: {torch.backends.mps.is_available()}")


✅ Libraries imported successfully!
📦 PyTorch version: 2.2.2
🖼️ Torchvision version: 0.17.2
🔥 CUDA available: False
🍎 MPS available: True


## Step 2: Device Detection and Configuration

### Device Selection Strategy
Transfer learning benefits significantly from GPU acceleration. Our device detection follows this priority:

1. **CUDA GPU** (NVIDIA): Optimal for deep learning training
   - Parallel processing with thousands of cores
   - Large memory capacity for batch processing
   - Highly optimized for matrix operations

2. **MPS (Apple Silicon)**: Apple's Metal Performance Shaders
   - Efficient on M1/M2 chips
   - Unified memory architecture
   - Good performance for development and medium-scale training

3. **CPU**: Universal fallback
   - Works on any system
   - Slower but sufficient for learning purposes

### Training Configuration
We use standardized parameters for consistent comparison across all lessons:
- **Batch Size**: 32 (balances memory usage and gradient quality)
- **Learning Rate**: 0.001 (standard for AdamW optimizer)
- **Epochs**: 50 total (20 frozen + 30 fine-tuning)
- **Optimizer**: AdamW with weight decay


In [3]:
# Device detection with fallback hierarchy
print("🔍 Detecting optimal compute device...")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"🚀 Using NVIDIA GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("🍎 Using Apple Silicon GPU (MPS)")
    print("   Optimized for M1/M2 chips")
else:
    device = torch.device("cpu")
    print("💻 Using CPU (consider GPU for faster training)")

# Set training configuration
print("\n⚙️ Setting up training configuration...")
config = {
    'batch_size': 32,
    'learning_rate': 0.001,
    'epochs': 50,
    'freeze_epochs': 20,
    'finetune_epochs': 30,
    'num_workers': 2,
    'weight_decay': 0.01
}

print(f"   📦 Batch size: {config['batch_size']}")
print(f"   🎯 Learning rate: {config['learning_rate']}")
print(f"   🔄 Total epochs: {config['epochs']} (freeze: {config['freeze_epochs']}, fine-tune: {config['finetune_epochs']})")
print(f"   👥 Workers: {config['num_workers']}")
print(f"   ⚖️ Weight decay: {config['weight_decay']}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True

print("\n✅ Configuration complete!")


🔍 Detecting optimal compute device...
🍎 Using Apple Silicon GPU (MPS)
   Optimized for M1/M2 chips

⚙️ Setting up training configuration...
   📦 Batch size: 32
   🎯 Learning rate: 0.001
   🔄 Total epochs: 50 (freeze: 20, fine-tune: 30)
   👥 Workers: 2
   ⚖️ Weight decay: 0.01

✅ Configuration complete!


## Step 3: Data Preprocessing and DataLoader Setup

### Data Augmentation Strategy

**Why Augmentation is Critical:**
- **Increases Effective Dataset Size**: Transforms create new variations of existing images
- **Improves Generalization**: Model learns to handle variations in real-world data
- **Reduces Overfitting**: Prevents memorization of specific image characteristics
- **Handles Data Scarcity**: Particularly important for smaller datasets like Flowers102

**Training vs. Validation Transforms:**
- **Training**: Aggressive augmentation for maximum variety and robustness
- **Validation/Test**: Minimal transforms for consistent, reproducible evaluation

### ImageNet Normalization
Pre-trained models require ImageNet statistics for optimal performance:
- **Mean**: [0.485, 0.456, 0.406] for RGB channels
- **Std**: [0.229, 0.224, 0.225] for RGB channels


In [4]:
print("🔧 Creating data preprocessing pipeline...")

# Training transforms with augmentation
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation transforms (no augmentation)
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("   ✓ Training transforms: 5 augmentations + normalization")
print("   ✓ Validation transforms: resize + normalization only")

# Create datasets
print("\n📦 Loading Flowers102 dataset...")
train_dataset = torchvision.datasets.Flowers102(
    root='./data', split='train', transform=train_transforms, download=True)
val_dataset = torchvision.datasets.Flowers102(
    root='./data', split='val', transform=val_transforms, download=True)
test_dataset = torchvision.datasets.Flowers102(
    root='./data', split='test', transform=val_transforms, download=True)

print(f"   🏋️ Training samples: {len(train_dataset):,}")
print(f"   🔍 Validation samples: {len(val_dataset):,}")
print(f"   📝 Test samples: {len(test_dataset):,}")

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], 
                         shuffle=True, num_workers=config['num_workers'], pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], 
                       shuffle=False, num_workers=config['num_workers'], pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], 
                        shuffle=False, num_workers=config['num_workers'], pin_memory=True)

print(f"\n📊 DataLoader batches: {len(train_loader)} train, {len(val_loader)} val, {len(test_loader)} test")
print("✅ Data pipeline ready!")


🔧 Creating data preprocessing pipeline...
   ✓ Training transforms: 5 augmentations + normalization
   ✓ Validation transforms: resize + normalization only

📦 Loading Flowers102 dataset...
Downloading https://thor.robots.ox.ac.uk/flowers/102/102flowers.tgz to data/flowers-102/102flowers.tgz


100%|██████████| 344862509/344862509 [00:19<00:00, 17583027.18it/s]


Extracting data/flowers-102/102flowers.tgz to data/flowers-102
Downloading https://thor.robots.ox.ac.uk/flowers/102/imagelabels.mat to data/flowers-102/imagelabels.mat


100%|██████████| 502/502 [00:00<00:00, 177877.89it/s]


Downloading https://thor.robots.ox.ac.uk/flowers/102/setid.mat to data/flowers-102/setid.mat


100%|██████████| 14989/14989 [00:00<00:00, 3870255.03it/s]

   🏋️ Training samples: 1,020
   🔍 Validation samples: 1,020
   📝 Test samples: 6,149

📊 DataLoader batches: 32 train, 32 val, 193 test
✅ Data pipeline ready!





## Step 4: ResNet18 Model Setup and Transfer Learning

### ResNet18 Architecture Overview
ResNet18 is a 18-layer deep convolutional neural network that introduced residual connections:

**Key Features:**
- **Residual Blocks**: Skip connections that help training very deep networks
- **Batch Normalization**: Normalizes inputs to each layer for stable training
- **ReLU Activation**: Non-linear activation function
- **Global Average Pooling**: Reduces spatial dimensions before classification

### Transfer Learning Strategy

**Phase 1: Feature Extraction (Freeze Backbone)**
- Keep pre-trained weights frozen
- Only train the new classification head
- Fast training, good for small datasets
- Expected: ~75% accuracy after 20 epochs

**Phase 2: Fine-tuning (Unfreeze All Layers)**
- Gradually unfreeze all layers
- Train entire network with lower learning rate
- Better adaptation to target domain
- Expected: ~85%+ accuracy after 30 more epochs

### Model Modifications
- Replace final layer: 1000 classes → 102 classes
- Keep all other layers with ImageNet weights
- Use appropriate learning rates for each phase


In [10]:
print("🏗️ Setting up ResNet18 model...")

# Load pre-trained ResNet18
model = models.resnet18(pretrained=True)
print(f"   ✓ Loaded pre-trained ResNet18")
print(f"   📊 Original classes: {model.fc.in_features} → 1000")

# Modify final layer for Flowers102 (102 classes)
num_classes = 102
model.fc = nn.Linear(model.fc.in_features, num_classes)
print(f"   🎯 Modified final layer: {model.fc.in_features} → {num_classes}")

# Move model to device
model = model.to(device)
print(f"   🚀 Model moved to {device}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"   📈 Total parameters: {total_params:,}")
print(f"   🎯 Trainable parameters: {trainable_params:,}")

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

print(f"\n⚙️ Training setup:")
print(f"   🎯 Loss function: CrossEntropyLoss")
print(f"   🚀 Optimizer: AdamW (lr={config['learning_rate']}, weight_decay={config['weight_decay']})")

# Function to freeze/unfreeze model parameters
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
        # Only train the classifier
        for param in model.fc.parameters():
            param.requires_grad = True
    else:
        for param in model.parameters():
            param.requires_grad = True

print("✅ Model setup complete!")


🏗️ Setting up ResNet18 model...
   ✓ Loaded pre-trained ResNet18
   📊 Original classes: 512 → 1000
   🎯 Modified final layer: 512 → 102
   🚀 Model moved to mps
   📈 Total parameters: 11,228,838
   🎯 Trainable parameters: 11,228,838

⚙️ Training setup:
   🎯 Loss function: CrossEntropyLoss
   🚀 Optimizer: AdamW (lr=0.001, weight_decay=0.01)
✅ Model setup complete!


## Step 5: Training Functions and Evaluation

### Training Function Design
Our training function implements:
- **Batch Processing**: Efficient mini-batch gradient descent
- **Progress Tracking**: Real-time loss and accuracy monitoring
- **Memory Management**: Proper GPU memory cleanup
- **Gradient Accumulation**: Stable training with consistent updates

### Evaluation Metrics
We track multiple metrics for comprehensive evaluation:
- **Loss**: CrossEntropyLoss for optimization
- **Accuracy**: Top-1 classification accuracy
- **Progress**: Real-time training progress with tqdm
- **Timing**: Training time per epoch for performance analysis

### Two-Phase Training Process
1. **Phase 1 (Epochs 1-20)**: Feature extraction with frozen backbone
2. **Phase 2 (Epochs 21-50)**: Fine-tuning with unfrozen layers


In [11]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train model for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(train_loader, desc="Training", leave=False)
    
    for batch_idx, (data, targets) in enumerate(progress_bar):
        data, targets = data.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.3f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    return running_loss / len(train_loader), 100. * correct / total

def evaluate(model, val_loader, criterion, device):
    """Evaluate model on validation set"""
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc="Evaluating", leave=False)
        
        for batch_idx, (data, targets) in enumerate(progress_bar):
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            loss = criterion(outputs, targets)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            progress_bar.set_postfix({
                'Loss': f'{val_loss/(batch_idx+1):.3f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    return val_loss / len(val_loader), 100. * correct / total

print("✅ Training and evaluation functions defined!")


✅ Training and evaluation functions defined!


## Step 6: Phase 1 - Feature Extraction Training

### Feature Extraction Strategy
In the first phase, we freeze the pre-trained backbone and only train the classification head:

**Why Feature Extraction First?**
- **Preserves Pre-trained Features**: Keeps valuable ImageNet features intact
- **Faster Training**: Only ~50K parameters to train vs. 11.7M total
- **Stable Learning**: Prevents catastrophic forgetting of pre-trained weights
- **Good Baseline**: Achieves decent performance quickly

**Training Details:**
- **Frozen Layers**: All convolutional layers and batch normalization
- **Trainable Layers**: Only the final classification layer (fc)
- **Learning Rate**: 0.001 (standard for new layers)
- **Duration**: 20 epochs (sufficient for classifier convergence)


In [12]:
print("🎯 Phase 1: Feature Extraction Training")
print("="*50)

# Freeze backbone, only train classifier
set_parameter_requires_grad(model, feature_extracting=True)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"   🔒 Frozen parameters: {total_params - trainable_params:,}")
print(f"   🎯 Trainable parameters: {trainable_params:,}")

# Training tracking
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

print(f"\n🚀 Starting Phase 1 training ({config['freeze_epochs']} epochs)...")
phase1_start = time.time()

best_val_acc = 0.0
best_model_wts = copy.deepcopy(model.state_dict())

for epoch in range(config['freeze_epochs']):
    epoch_start = time.time()
    
    # Training
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validation
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())
    
    # Record metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    epoch_time = time.time() - epoch_start
    
    print(f"Epoch {epoch+1:2d}/{config['freeze_epochs']} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | "
          f"Time: {epoch_time:.1f}s")

phase1_time = time.time() - phase1_start

print(f"\n📊 Phase 1 Results:")
print(f"   ⏱️  Training time: {phase1_time:.1f}s ({phase1_time/60:.1f}m)")
print(f"   🎯 Best validation accuracy: {best_val_acc:.2f}%")
print(f"   📈 Final training accuracy: {train_accuracies[-1]:.2f}%")
print(f"   📉 Final validation loss: {val_losses[-1]:.4f}")

# Load best model weights
model.load_state_dict(best_model_wts)
print("✅ Phase 1 complete! Best model weights loaded.")


🎯 Phase 1: Feature Extraction Training
   🔒 Frozen parameters: 11,176,512
   🎯 Trainable parameters: 52,326

🚀 Starting Phase 1 training (20 epochs)...


                                                                                   

Epoch  1/20 | Train Loss: 4.7374 | Train Acc: 3.14% | Val Loss: 4.0240 | Val Acc: 13.04% | Time: 34.6s


                                                                                   

Epoch  2/20 | Train Loss: 3.6848 | Train Acc: 24.51% | Val Loss: 3.2799 | Val Acc: 40.00% | Time: 34.5s


                                                                                   

Epoch  3/20 | Train Loss: 2.9507 | Train Acc: 50.10% | Val Loss: 2.7430 | Val Acc: 52.06% | Time: 36.2s


                                                                                   

Epoch  4/20 | Train Loss: 2.3572 | Train Acc: 64.71% | Val Loss: 2.3180 | Val Acc: 57.16% | Time: 34.3s


                                                                                   

Epoch  5/20 | Train Loss: 1.9245 | Train Acc: 76.76% | Val Loss: 1.9695 | Val Acc: 65.10% | Time: 35.4s


                                                                                   

Epoch  6/20 | Train Loss: 1.6206 | Train Acc: 80.59% | Val Loss: 1.7696 | Val Acc: 67.25% | Time: 34.8s


                                                                                   

Epoch  7/20 | Train Loss: 1.3733 | Train Acc: 81.86% | Val Loss: 1.5568 | Val Acc: 71.37% | Time: 34.1s


                                                                                   

Epoch  8/20 | Train Loss: 1.1968 | Train Acc: 84.71% | Val Loss: 1.4594 | Val Acc: 71.47% | Time: 34.6s


                                                                                   

Epoch  9/20 | Train Loss: 1.0417 | Train Acc: 87.65% | Val Loss: 1.3446 | Val Acc: 73.92% | Time: 35.0s


                                                                                   

Epoch 10/20 | Train Loss: 0.9244 | Train Acc: 87.94% | Val Loss: 1.2627 | Val Acc: 74.12% | Time: 35.0s


                                                                                   

Epoch 11/20 | Train Loss: 0.8165 | Train Acc: 91.18% | Val Loss: 1.2062 | Val Acc: 75.10% | Time: 34.6s


                                                                                   

Epoch 12/20 | Train Loss: 0.7405 | Train Acc: 91.76% | Val Loss: 1.1422 | Val Acc: 75.69% | Time: 34.7s


                                                                                   

Epoch 13/20 | Train Loss: 0.6826 | Train Acc: 93.24% | Val Loss: 1.0926 | Val Acc: 76.47% | Time: 34.9s


                                                                                   

Epoch 14/20 | Train Loss: 0.6244 | Train Acc: 93.14% | Val Loss: 1.0665 | Val Acc: 77.35% | Time: 34.7s


                                                                                   

Epoch 15/20 | Train Loss: 0.5694 | Train Acc: 94.80% | Val Loss: 1.0419 | Val Acc: 77.55% | Time: 34.5s


                                                                                   

Epoch 16/20 | Train Loss: 0.5238 | Train Acc: 94.41% | Val Loss: 1.0391 | Val Acc: 76.27% | Time: 33.6s


                                                                                   

Epoch 17/20 | Train Loss: 0.4786 | Train Acc: 95.88% | Val Loss: 1.0039 | Val Acc: 77.84% | Time: 34.1s


                                                                                   

Epoch 18/20 | Train Loss: 0.4695 | Train Acc: 94.22% | Val Loss: 0.9707 | Val Acc: 77.65% | Time: 33.9s


                                                                                   

Epoch 19/20 | Train Loss: 0.4249 | Train Acc: 95.59% | Val Loss: 0.9550 | Val Acc: 77.75% | Time: 34.5s


                                                                                   

Epoch 20/20 | Train Loss: 0.3979 | Train Acc: 96.67% | Val Loss: 0.9296 | Val Acc: 77.65% | Time: 35.7s

📊 Phase 1 Results:
   ⏱️  Training time: 693.8s (11.6m)
   🎯 Best validation accuracy: 77.84%
   📈 Final training accuracy: 96.67%
   📉 Final validation loss: 0.9296
✅ Phase 1 complete! Best model weights loaded.




## Step 7: Phase 2 - Fine-tuning Training

### Fine-tuning Strategy
In the second phase, we unfreeze all layers and train the entire network:

**Why Fine-tuning After Feature Extraction?**
- **Better Adaptation**: Allows all layers to adapt to the new flower domain
- **Higher Performance**: Typically achieves 5-15% better accuracy
- **Stable Foundation**: Phase 1 provides a good starting point
- **Controlled Learning**: Lower learning rate prevents catastrophic forgetting

**Training Details:**
- **Unfrozen Layers**: All 11.7M parameters are now trainable
- **Learning Rate**: 0.001 (same as Phase 1, but now for entire network)
- **Duration**: 30 epochs (longer for full network convergence)
- **Expected Improvement**: ~75% → ~85%+ validation accuracy

### Learning Rate Considerations
- **Same LR**: We keep the same learning rate since the model is already partially trained
- **Lower LR Option**: Could use 0.0001 for more conservative fine-tuning
- **Scheduler**: Could add learning rate scheduling for better convergence


In [13]:
print("🔥 Phase 2: Fine-tuning Training")
print("="*50)

# Unfreeze all layers
set_parameter_requires_grad(model, feature_extracting=False)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"   🔓 All parameters unfrozen")
print(f"   🎯 Trainable parameters: {trainable_params:,}")

# Create new optimizer for fine-tuning
optimizer_ft = optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

print(f"\n🚀 Starting Phase 2 training ({config['finetune_epochs']} epochs)...")
phase2_start = time.time()

# Continue from Phase 1 metrics
phase1_epochs = len(train_losses)
best_val_acc = max(val_accuracies)

for epoch in range(config['finetune_epochs']):
    epoch_start = time.time()
    
    # Training
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer_ft, device)
    
    # Validation
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())
    
    # Record metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    epoch_time = time.time() - epoch_start
    
    print(f"Epoch {epoch+1:2d}/{config['finetune_epochs']} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | "
          f"Time: {epoch_time:.1f}s")

phase2_time = time.time() - phase2_start
total_time = phase1_time + phase2_time

print(f"\n📊 Phase 2 Results:")
print(f"   ⏱️  Training time: {phase2_time:.1f}s ({phase2_time/60:.1f}m)")
print(f"   🎯 Best validation accuracy: {best_val_acc:.2f}%")

print(f"\n🎉 Complete Training Summary:")
print(f"   ⏱️  Total time: {total_time:.1f}s ({total_time/60:.1f}m)")
print(f"   📊 Phase 1 → Phase 2 improvement: {val_accuracies[phase1_epochs-1]:.2f}% → {best_val_acc:.2f}%")

# Load best model weights
model.load_state_dict(best_model_wts)
print("✅ Phase 2 complete! Best model weights loaded.")


🔥 Phase 2: Fine-tuning Training
   🔓 All parameters unfrozen
   🎯 Trainable parameters: 11,228,838

🚀 Starting Phase 2 training (30 epochs)...


                                                                                   

Epoch  1/30 | Train Loss: 1.6831 | Train Acc: 53.82% | Val Loss: 3.2129 | Val Acc: 38.14% | Time: 44.8s


                                                                                   

Epoch  2/30 | Train Loss: 0.8670 | Train Acc: 75.39% | Val Loss: 1.4263 | Val Acc: 65.20% | Time: 38.6s


                                                                                   

Epoch  3/30 | Train Loss: 0.4729 | Train Acc: 86.57% | Val Loss: 1.8036 | Val Acc: 55.20% | Time: 39.7s


                                                                                   

Epoch  4/30 | Train Loss: 0.4005 | Train Acc: 89.02% | Val Loss: 1.0879 | Val Acc: 71.27% | Time: 37.7s


                                                                                   

Epoch  5/30 | Train Loss: 0.2731 | Train Acc: 92.94% | Val Loss: 1.2216 | Val Acc: 70.39% | Time: 38.1s


                                                                                   

Epoch  6/30 | Train Loss: 0.1838 | Train Acc: 94.61% | Val Loss: 0.9953 | Val Acc: 75.39% | Time: 39.6s


                                                                                   

Epoch  7/30 | Train Loss: 0.1191 | Train Acc: 97.16% | Val Loss: 0.9962 | Val Acc: 76.27% | Time: 39.7s


                                                                                   

Epoch  8/30 | Train Loss: 0.1079 | Train Acc: 98.14% | Val Loss: 0.9595 | Val Acc: 75.78% | Time: 37.9s


                                                                                   

Epoch  9/30 | Train Loss: 0.0991 | Train Acc: 97.75% | Val Loss: 1.4122 | Val Acc: 66.96% | Time: 38.0s


                                                                                   

Epoch 10/30 | Train Loss: 0.1452 | Train Acc: 95.20% | Val Loss: 1.4275 | Val Acc: 68.04% | Time: 36.3s


                                                                                   

Epoch 11/30 | Train Loss: 0.2046 | Train Acc: 94.22% | Val Loss: 1.2325 | Val Acc: 69.61% | Time: 35.9s


                                                                                   

Epoch 12/30 | Train Loss: 0.2130 | Train Acc: 93.82% | Val Loss: 1.1445 | Val Acc: 69.90% | Time: 35.3s


                                                                                   

Epoch 13/30 | Train Loss: 0.2279 | Train Acc: 94.02% | Val Loss: 1.5748 | Val Acc: 64.02% | Time: 36.5s


                                                                                   

Epoch 14/30 | Train Loss: 0.2489 | Train Acc: 92.25% | Val Loss: 1.3972 | Val Acc: 67.25% | Time: 37.5s


                                                                                   

Epoch 15/30 | Train Loss: 0.1713 | Train Acc: 94.90% | Val Loss: 1.2595 | Val Acc: 70.69% | Time: 38.2s


                                                                                   

Epoch 16/30 | Train Loss: 0.0981 | Train Acc: 97.25% | Val Loss: 1.1442 | Val Acc: 74.31% | Time: 36.9s


                                                                                   

Epoch 17/30 | Train Loss: 0.0934 | Train Acc: 97.75% | Val Loss: 1.0902 | Val Acc: 74.31% | Time: 37.5s


                                                                                   

Epoch 18/30 | Train Loss: 0.1076 | Train Acc: 97.35% | Val Loss: 1.3564 | Val Acc: 70.10% | Time: 36.3s


                                                                                   

Epoch 19/30 | Train Loss: 0.0988 | Train Acc: 97.25% | Val Loss: 1.0851 | Val Acc: 75.29% | Time: 37.6s


                                                                                   

Epoch 20/30 | Train Loss: 0.1062 | Train Acc: 97.25% | Val Loss: 1.1164 | Val Acc: 74.12% | Time: 38.5s


                                                                                   

Epoch 21/30 | Train Loss: 0.0996 | Train Acc: 96.86% | Val Loss: 1.1690 | Val Acc: 73.43% | Time: 37.4s


                                                                                   

Epoch 22/30 | Train Loss: 0.0796 | Train Acc: 98.04% | Val Loss: 1.1268 | Val Acc: 74.22% | Time: 38.5s


                                                                                   

Epoch 23/30 | Train Loss: 0.0870 | Train Acc: 97.55% | Val Loss: 0.8630 | Val Acc: 79.90% | Time: 35.3s


                                                                                   

Epoch 24/30 | Train Loss: 0.0557 | Train Acc: 98.73% | Val Loss: 0.9525 | Val Acc: 76.86% | Time: 38.4s


                                                                                   

Epoch 25/30 | Train Loss: 0.0703 | Train Acc: 97.94% | Val Loss: 1.1954 | Val Acc: 72.65% | Time: 38.7s


                                                                                   

Epoch 26/30 | Train Loss: 0.0909 | Train Acc: 96.96% | Val Loss: 1.3484 | Val Acc: 69.22% | Time: 38.6s


                                                                                   

Epoch 27/30 | Train Loss: 0.0653 | Train Acc: 97.94% | Val Loss: 1.1309 | Val Acc: 75.00% | Time: 39.0s


                                                                                   

Epoch 28/30 | Train Loss: 0.0440 | Train Acc: 98.73% | Val Loss: 0.8884 | Val Acc: 78.24% | Time: 39.1s


                                                                                   

Epoch 29/30 | Train Loss: 0.0551 | Train Acc: 98.63% | Val Loss: 0.8927 | Val Acc: 79.80% | Time: 38.3s


                                                                                   

Epoch 30/30 | Train Loss: 0.0542 | Train Acc: 98.53% | Val Loss: 1.1363 | Val Acc: 75.39% | Time: 78.9s

📊 Phase 2 Results:
   ⏱️  Training time: 1182.5s (19.7m)
   🎯 Best validation accuracy: 79.90%

🎉 Complete Training Summary:
   ⏱️  Total time: 1876.3s (31.3m)
   📊 Phase 1 → Phase 2 improvement: 77.65% → 79.90%
✅ Phase 2 complete! Best model weights loaded.


