# Transfer Learning with ImageNet Pre-trained Models

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

## What is Transfer Learning?

Transfer learning leverages pre-trained models trained on large datasets (like ImageNet) to solve related tasks with limited data.

In [None]:
# Load a pre-trained ResNet50
model_pretrained = models.resnet50(pretrained=True)
print(f'ResNet50 loaded with pretrained ImageNet weights')

# Count parameters
total_params = sum(p.numel() for p in model_pretrained.parameters())
trainable_params = sum(p.numel() for p in model_pretrained.parameters() if p.requires_grad)
print(f'Total parameters: {total_params / 1e6:.1f}M')
print(f'Trainable parameters: {trainable_params / 1e6:.1f}M')

## Fine-tuning Strategy

In [None]:
# Freeze all layers except the final classifier
for param in model_pretrained.parameters():
    param.requires_grad = False

# Replace the final layer
num_classes = 10  # e.g., CIFAR-10
num_ftrs = model_pretrained.fc.in_features
model_pretrained.fc = nn.Linear(num_ftrs, num_classes)

# Only the new classifier layer will be trained
trainable_params = sum(p.numel() for p in model_pretrained.parameters() if p.requires_grad)
print(f'Trainable parameters after modification: {trainable_params / 1e6:.1f}M')

## Prepare Data

In [None]:
# Prepare data with ImageNet normalization
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, 
                                  download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, 
                                 download=True, transform=transform)

# Use a subset for quick training
indices = torch.randperm(len(train_dataset))[:1000]
train_dataset = torch.utils.data.Subset(train_dataset, indices)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

## Train the Fine-tuned Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_pretrained = model_pretrained.to(device)

loss_fn = nn.CrossEntropyLoss()
# Only optimize the parameters that require gradients
optimizer = optim.Adam([p for p in model_pretrained.parameters() if p.requires_grad], lr=0.001)

epochs = 5
train_losses = []
val_accs = []

for epoch in range(epochs):
    # Training
    model_pretrained.train()
    train_loss = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model_pretrained(images)
        loss = loss_fn(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    
    # Validation
    model_pretrained.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model_pretrained(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_acc = 100 * correct / total
    val_accs.append(val_acc)
    
    print(f'Epoch {epoch+1}/{epochs}, Loss: {train_loss:.4f}, Val Acc: {val_acc:.2f}%')

## Results

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(train_losses)
ax1.set_title('Training Loss')
ax1.set_ylabel('Loss')
ax1.set_xlabel('Epoch')
ax1.grid(True)

ax2.plot(val_accs)
ax2.set_title('Validation Accuracy')
ax2.set_ylabel('Accuracy (%)')
ax2.set_xlabel('Epoch')
ax2.grid(True)

plt.tight_layout()
plt.show()

print(f'\nTransfer learning achieved {val_accs[-1]:.2f}% accuracy with minimal training!')