# Transfer Learning vs. Training From Scratch

This notebook compares the performance of:
1. **Transfer Learning**: Using pre-trained ResNet50 weights
2. **Training From Scratch**: Randomly initialized ResNet50

We'll evaluate both approaches on the same dataset to understand the benefits of transfer learning.

## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import datasets, models, transforms

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import time

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Configuration

In [None]:
CONFIG = {
    'batch_size': 32,
    'num_epochs': 15,  # More epochs to see convergence difference
    'learning_rate': 0.001,
    'num_classes': 5,
    'train_split': 0.8,
    'image_size': 224,
    'num_workers': 2,
}

data_dir = './data'

## 3. Data Preparation

In [None]:
# Data transforms
train_transforms = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
try:
    full_dataset = datasets.ImageFolder(data_dir, transform=train_transforms)
    num_classes = len(full_dataset.classes)
    class_names = full_dataset.classes
    print(f"Loaded custom dataset with {len(full_dataset)} images")
except:
    print("Using Flowers102 for demonstration...")
    full_dataset = datasets.Flowers102(root='./data', split='train', download=True, transform=train_transforms)
    num_classes = 102
    class_names = [f'class_{i}' for i in range(num_classes)]

CONFIG['num_classes'] = num_classes

# Split dataset
train_size = int(CONFIG['train_split'] * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
val_dataset.dataset.transform = val_transforms

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, 
                         num_workers=CONFIG['num_workers'], pin_memory=True if torch.cuda.is_available() else False)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, 
                       num_workers=CONFIG['num_workers'], pin_memory=True if torch.cuda.is_available() else False)

print(f"Train set: {len(train_dataset)} | Val set: {len(val_dataset)}")

## 4. Training and Evaluation Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(dataloader, desc='Training', leave=False):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    return running_loss / total, 100 * correct / total


def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc='Validating', leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return running_loss / total, 100 * correct / total


def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, model_name):
    """Train model and return history."""
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'epoch_times': []}
    best_val_acc = 0.0
    
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}\n")
    
    total_start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
        
        epoch_time = time.time() - epoch_start_time
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['epoch_times'].append(epoch_time)
        
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
        print(f"  Time: {epoch_time:.2f}s\n")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
    
    total_time = time.time() - total_start_time
    history['total_time'] = total_time
    history['best_val_acc'] = best_val_acc
    
    print(f"\nTotal training time: {total_time:.2f}s ({total_time/60:.2f} min)")
    print(f"Best validation accuracy: {best_val_acc:.2f}%\n")
    
    return history

## 5. Model 1: Transfer Learning (Pre-trained)

In [None]:
# Load pre-trained ResNet50
model_transfer = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

# Freeze convolutional layers
for param in model_transfer.parameters():
    param.requires_grad = False

# Replace final layer
num_features = model_transfer.fc.in_features
model_transfer.fc = nn.Sequential(
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, CONFIG['num_classes'])
)

model_transfer = model_transfer.to(device)

criterion_transfer = nn.CrossEntropyLoss()
optimizer_transfer = optim.Adam(model_transfer.fc.parameters(), lr=CONFIG['learning_rate'])

trainable_params = sum(p.numel() for p in model_transfer.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model_transfer.parameters())

print(f"Transfer Learning Model:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")

In [None]:
# Train transfer learning model
history_transfer = train_model(
    model_transfer,
    train_loader,
    val_loader,
    criterion_transfer,
    optimizer_transfer,
    CONFIG['num_epochs'],
    device,
    "Transfer Learning (Pre-trained ResNet50)"
)

## 6. Model 2: Training From Scratch

In [None]:
# Create ResNet50 with random weights
model_scratch = models.resnet50(weights=None)  # No pre-trained weights

# Replace final layer
num_features = model_scratch.fc.in_features
model_scratch.fc = nn.Sequential(
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, CONFIG['num_classes'])
)

model_scratch = model_scratch.to(device)

criterion_scratch = nn.CrossEntropyLoss()
# Train all parameters from scratch
optimizer_scratch = optim.Adam(model_scratch.parameters(), lr=CONFIG['learning_rate'])

trainable_params_scratch = sum(p.numel() for p in model_scratch.parameters() if p.requires_grad)
total_params_scratch = sum(p.numel() for p in model_scratch.parameters())

print(f"From-Scratch Model:")
print(f"  Total parameters: {total_params_scratch:,}")
print(f"  Trainable parameters: {trainable_params_scratch:,} ({100*trainable_params_scratch/total_params_scratch:.2f}%)")

In [None]:
# Train from-scratch model
history_scratch = train_model(
    model_scratch,
    train_loader,
    val_loader,
    criterion_scratch,
    optimizer_scratch,
    CONFIG['num_epochs'],
    device,
    "Training From Scratch (Random Init ResNet50)"
)

## 7. Comparison Visualizations

In [None]:
# Plot training curves comparison
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

epochs = range(1, CONFIG['num_epochs'] + 1)

# Training Loss
axes[0, 0].plot(epochs, history_transfer['train_loss'], 'b-o', label='Transfer Learning', linewidth=2)
axes[0, 0].plot(epochs, history_scratch['train_loss'], 'r-s', label='From Scratch', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].set_title('Training Loss Comparison', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=11)
axes[0, 0].grid(True, alpha=0.3)

# Validation Loss
axes[0, 1].plot(epochs, history_transfer['val_loss'], 'b-o', label='Transfer Learning', linewidth=2)
axes[0, 1].plot(epochs, history_scratch['val_loss'], 'r-s', label='From Scratch', linewidth=2)
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Loss', fontsize=12)
axes[0, 1].set_title('Validation Loss Comparison', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=11)
axes[0, 1].grid(True, alpha=0.3)

# Training Accuracy
axes[1, 0].plot(epochs, history_transfer['train_acc'], 'b-o', label='Transfer Learning', linewidth=2)
axes[1, 0].plot(epochs, history_scratch['train_acc'], 'r-s', label='From Scratch', linewidth=2)
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Accuracy (%)', fontsize=12)
axes[1, 0].set_title('Training Accuracy Comparison', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=11)
axes[1, 0].grid(True, alpha=0.3)

# Validation Accuracy
axes[1, 1].plot(epochs, history_transfer['val_acc'], 'b-o', label='Transfer Learning', linewidth=2)
axes[1, 1].plot(epochs, history_scratch['val_acc'], 'r-s', label='From Scratch', linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1, 1].set_title('Validation Accuracy Comparison', fontsize=14, fontweight='bold')
axes[1, 1].legend(fontsize=11)
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('transfer_vs_scratch_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 8. Performance Metrics Comparison

In [None]:
# Create comparison table
comparison_data = {
    'Metric': [
        'Best Validation Accuracy (%)',
        'Final Validation Accuracy (%)',
        'Final Training Accuracy (%)',
        'Total Training Time (min)',
        'Avg Time per Epoch (s)',
        'Trainable Parameters',
        'Convergence Speed'
    ],
    'Transfer Learning': [
        f"{history_transfer['best_val_acc']:.2f}",
        f"{history_transfer['val_acc'][-1]:.2f}",
        f"{history_transfer['train_acc'][-1]:.2f}",
        f"{history_transfer['total_time']/60:.2f}",
        f"{np.mean(history_transfer['epoch_times']):.2f}",
        f"{trainable_params:,}",
        'Fast'
    ],
    'From Scratch': [
        f"{history_scratch['best_val_acc']:.2f}",
        f"{history_scratch['val_acc'][-1]:.2f}",
        f"{history_scratch['train_acc'][-1]:.2f}",
        f"{history_scratch['total_time']/60:.2f}",
        f"{np.mean(history_scratch['epoch_times']):.2f}",
        f"{trainable_params_scratch:,}",
        'Slow'
    ]
}

print("\n" + "="*80)
print("PERFORMANCE COMPARISON: TRANSFER LEARNING VS. TRAINING FROM SCRATCH")
print("="*80)
print(f"{'Metric':<35} {'Transfer Learning':<20} {'From Scratch':<20}")
print("-"*80)
for i, metric in enumerate(comparison_data['Metric']):
    print(f"{metric:<35} {comparison_data['Transfer Learning'][i]:<20} {comparison_data['From Scratch'][i]:<20}")
print("="*80)

## 9. Bar Chart Comparison

In [None]:
# Create bar chart comparison
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Accuracy comparison
metrics = ['Best Val Acc', 'Final Val Acc', 'Final Train Acc']
transfer_values = [
    history_transfer['best_val_acc'],
    history_transfer['val_acc'][-1],
    history_transfer['train_acc'][-1]
]
scratch_values = [
    history_scratch['best_val_acc'],
    history_scratch['val_acc'][-1],
    history_scratch['train_acc'][-1]
]

x = np.arange(len(metrics))
width = 0.35

axes[0].bar(x - width/2, transfer_values, width, label='Transfer Learning', color='steelblue')
axes[0].bar(x + width/2, scratch_values, width, label='From Scratch', color='coral')
axes[0].set_ylabel('Accuracy (%)', fontsize=12)
axes[0].set_title('Accuracy Comparison', fontsize=14, fontweight='bold')
axes[0].set_xticks(x)
axes[0].set_xticklabels(metrics, rotation=15, ha='right')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3, axis='y')

# Training time comparison
time_metrics = ['Total Time (min)', 'Avg Epoch Time (s)']
transfer_times = [
    history_transfer['total_time']/60,
    np.mean(history_transfer['epoch_times'])
]
scratch_times = [
    history_scratch['total_time']/60,
    np.mean(history_scratch['epoch_times'])
]

x2 = np.arange(len(time_metrics))
axes[1].bar(x2 - width/2, transfer_times, width, label='Transfer Learning', color='steelblue')
axes[1].bar(x2 + width/2, scratch_times, width, label='From Scratch', color='coral')
axes[1].set_ylabel('Time', fontsize=12)
axes[1].set_title('Training Time Comparison', fontsize=14, fontweight='bold')
axes[1].set_xticks(x2)
axes[1].set_xticklabels(time_metrics)
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('comparison_bar_charts.png', dpi=300, bbox_inches='tight')
plt.show()

## 10. Learning Curve Analysis

In [None]:
# Calculate improvement rate
def calculate_improvement(history, start_epoch=0, end_epoch=5):
    """Calculate accuracy improvement rate."""
    start_acc = history['val_acc'][start_epoch]
    end_acc = history['val_acc'][min(end_epoch, len(history['val_acc'])-1)]
    return end_acc - start_acc

transfer_early_improvement = calculate_improvement(history_transfer, 0, 5)
scratch_early_improvement = calculate_improvement(history_scratch, 0, 5)

print("\nLearning Speed Analysis (First 5 epochs):")
print("-" * 60)
print(f"Transfer Learning improvement: {transfer_early_improvement:.2f}%")
print(f"From Scratch improvement: {scratch_early_improvement:.2f}%")
print(f"\nTransfer learning converges {transfer_early_improvement/scratch_early_improvement:.2f}x faster!")

## 11. Key Insights and Recommendations

In [None]:
print("\n" + "="*80)
print("KEY INSIGHTS: TRANSFER LEARNING VS. TRAINING FROM SCRATCH")
print("="*80)

print("\n1. ACCURACY:")
acc_diff = history_transfer['best_val_acc'] - history_scratch['best_val_acc']
if acc_diff > 0:
    print(f"   ✓ Transfer learning achieved {acc_diff:.2f}% higher accuracy")
else:
    print(f"   ⚠ From-scratch training achieved {abs(acc_diff):.2f}% higher accuracy")

print("\n2. CONVERGENCE SPEED:")
print(f"   ✓ Transfer learning: {transfer_early_improvement:.2f}% improvement in 5 epochs")
print(f"   ✓ From scratch: {scratch_early_improvement:.2f}% improvement in 5 epochs")
print(f"   → Transfer learning converges ~{transfer_early_improvement/max(scratch_early_improvement, 0.1):.1f}x faster")

print("\n3. COMPUTATIONAL EFFICIENCY:")
time_ratio = history_scratch['total_time'] / history_transfer['total_time']
print(f"   ✓ Transfer learning trains {trainable_params/trainable_params_scratch*100:.1f}% fewer parameters")
if time_ratio > 1:
    print(f"   ✓ Transfer learning is {time_ratio:.2f}x faster per epoch")
else:
    print(f"   ⚠ From-scratch is {1/time_ratio:.2f}x faster per epoch (fewer frozen layers)")

print("\n4. OVERFITTING:")
transfer_gap = history_transfer['train_acc'][-1] - history_transfer['val_acc'][-1]
scratch_gap = history_scratch['train_acc'][-1] - history_scratch['val_acc'][-1]
print(f"   ✓ Transfer learning train-val gap: {transfer_gap:.2f}%")
print(f"   ✓ From scratch train-val gap: {scratch_gap:.2f}%")
if transfer_gap < scratch_gap:
    print(f"   → Transfer learning shows better generalization")
else:
    print(f"   → From-scratch shows better generalization")

print("\n" + "="*80)
print("RECOMMENDATIONS:")
print("="*80)
print("\n✓ Use TRANSFER LEARNING when:")
print("  • Working with small datasets (< 10k images)")
print("  • Limited computational resources")
print("  • Quick prototyping needed")
print("  • Task is similar to ImageNet (natural images)")
print("\n✓ Train FROM SCRATCH when:")
print("  • Large dataset available (> 100k images)")
print("  • Domain is very different from ImageNet (medical, satellite, etc.)")
print("  • Sufficient computational resources")
print("  • Need complete control over feature learning")
print("="*80)

## Summary

This comparison demonstrates the significant advantages of transfer learning:

### Transfer Learning Benefits:
1. **Faster convergence**: Achieves good performance in fewer epochs
2. **Better accuracy**: Leverages pre-trained features from ImageNet
3. **Less overfitting**: Pre-trained features provide good regularization
4. **Efficient training**: Only trains final layers, reducing computation
5. **Small data friendly**: Works well with limited training samples

### When Transfer Learning Works Best:
- Small to medium datasets (≤ 10k images)
- Similar domain to pre-training dataset (natural images)
- Limited computational resources
- Time-constrained projects

### Training From Scratch:
- May eventually match or exceed transfer learning with enough data and epochs
- Better for domains very different from ImageNet
- Requires significantly more training time and data
- Higher risk of overfitting on small datasets