# Transfer Learning for Image Classification

This notebook demonstrates transfer learning using a pre-trained ResNet model for image classification on the CIFAR-10 dataset.

## What is Transfer Learning?

Transfer learning is a machine learning technique where a model trained on one task is repurposed for a second related task. It leverages knowledge gained from solving one problem and applies it to a different but related problem.

## Key Concepts

1. **Feature Extraction**: Use pre-trained model as fixed feature extractor
2. **Fine-tuning**: Unfreeze some layers and train them on new data
3. **Domain Adaptation**: Adapt model from one domain to another

## Benefits

- **Faster Training**: Start with learned features instead of random initialization
- **Less Data Required**: Pre-trained features work well even with limited data
- **Better Performance**: Often achieves higher accuracy than training from scratch

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import copy
import warnings
warnings.filterwarnings('ignore')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Load and Prepare CIFAR-10 Dataset

In [None]:
# CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

# Data augmentation and normalization for training
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])  # ImageNet statistics
])

# Just normalization for validation/test
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

# Load datasets
print('Loading CIFAR-10 dataset...')
train_dataset = datasets.CIFAR10(root='./data', train=True, 
                                 download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False,
                                download=True, transform=test_transform)

# Create a smaller training set for faster demo
train_size = 5000  # Use 5000 samples instead of full 50000
train_dataset, _ = random_split(train_dataset, [train_size, len(train_dataset) - train_size])

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                         shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=2)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')
print(f'Number of classes: {len(class_names)}')

## Visualize Sample Images

In [None]:
# Function to denormalize images for visualization
def denormalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

# Get a batch of training images
dataiter = iter(train_loader)
images, labels = next(dataiter)

# Display images
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

for i, ax in enumerate(axes.flat):
    img = images[i].clone()
    img = denormalize(img, mean, std)
    img = torch.clamp(img, 0, 1)
    ax.imshow(img.permute(1, 2, 0))
    ax.set_title(class_names[labels[i]])
    ax.axis('off')
plt.tight_layout()
plt.show()

## Load Pre-trained ResNet Model

We'll use ResNet-18, a popular convolutional neural network pre-trained on ImageNet.

In [None]:
# Load pre-trained ResNet-18
print('Loading pre-trained ResNet-18...')
model = models.resnet18(pretrained=True)

# Print model architecture
print('\nModel Architecture:')
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'\nTotal parameters: {total_params:,}')

## Transfer Learning Strategy 1: Feature Extraction

In this approach, we freeze all convolutional layers and only train the final classification layer.

In [None]:
# Create feature extraction model
def create_feature_extractor():
    model = models.resnet18(pretrained=True)
    
    # Freeze all layers
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace the final fully connected layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, len(class_names))
    
    return model

# Create model
feature_extractor = create_feature_extractor().to(device)

# Count trainable parameters
trainable_params = sum(p.numel() for p in feature_extractor.parameters() if p.requires_grad)
print(f'Trainable parameters: {trainable_params:,}')
print(f'Frozen parameters: {total_params - trainable_params:,}')

## Training Functions

In [None]:
def train_model(model, train_loader, test_loader, criterion, optimizer, num_epochs=5):
    """
    Train the model and track performance
    """
    history = {
        'train_loss': [],
        'train_acc': [],
        'test_acc': []
    }
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 50)
        
        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in tqdm(train_loader, desc='Training'):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.item())
        
        # Evaluation phase
        test_acc = evaluate_model(model, test_loader)
        history['test_acc'].append(test_acc)
        
        print(f'Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f}')
        print(f'Test Acc: {test_acc:.4f}\n')
        
        # Save best model
        if test_acc > best_acc:
            best_acc = test_acc
            best_model_wts = copy.deepcopy(model.state_dict())
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

def evaluate_model(model, dataloader):
    """
    Evaluate model accuracy
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return correct / total

## Train Feature Extractor Model

In [None]:
# Setup training
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(feature_extractor.fc.parameters(), lr=0.001)

# Train
print('Training Feature Extractor Model...\n')
feature_extractor, history_fe = train_model(
    feature_extractor, 
    train_loader, 
    test_loader,
    criterion, 
    optimizer, 
    num_epochs=5
)

print(f'Best Test Accuracy (Feature Extraction): {max(history_fe["test_acc"]):.4f}')

## Transfer Learning Strategy 2: Fine-tuning

In this approach, we unfreeze some layers and train them along with the classification layer.

In [None]:
# Create fine-tuning model
def create_finetuning_model():
    model = models.resnet18(pretrained=True)
    
    # Freeze early layers
    for name, param in model.named_parameters():
        if 'layer4' not in name and 'fc' not in name:
            param.requires_grad = False
    
    # Replace the final fully connected layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, len(class_names))
    
    return model

# Create model
finetuning_model = create_finetuning_model().to(device)

# Count trainable parameters
trainable_params = sum(p.numel() for p in finetuning_model.parameters() if p.requires_grad)
print(f'Trainable parameters: {trainable_params:,}')
print(f'Frozen parameters: {total_params - trainable_params:,}')

## Train Fine-tuning Model

In [None]:
# Setup training with different learning rates for different layers
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam([
    {'params': finetuning_model.layer4.parameters(), 'lr': 1e-4},
    {'params': finetuning_model.fc.parameters(), 'lr': 1e-3}
])

# Train
print('Training Fine-tuning Model...\n')
finetuning_model, history_ft = train_model(
    finetuning_model,
    train_loader,
    test_loader,
    criterion,
    optimizer,
    num_epochs=5
)

print(f'Best Test Accuracy (Fine-tuning): {max(history_ft["test_acc"]):.4f}')

## Compare Results

In [None]:
# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Training accuracy
axes[0].plot(history_fe['train_acc'], marker='o', label='Feature Extraction')
axes[0].plot(history_ft['train_acc'], marker='s', label='Fine-tuning')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Training Accuracy')
axes[0].legend()
axes[0].grid(True)

# Test accuracy
axes[1].plot(history_fe['test_acc'], marker='o', label='Feature Extraction')
axes[1].plot(history_ft['test_acc'], marker='s', label='Fine-tuning')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Test Accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

print('\nFinal Comparison:')
print(f'Feature Extraction - Best Test Acc: {max(history_fe["test_acc"]):.4f}')
print(f'Fine-tuning - Best Test Acc: {max(history_ft["test_acc"]):.4f}')

## Detailed Evaluation on Test Set

In [None]:
def get_predictions(model, dataloader):
    """
    Get all predictions and true labels
    """
    model.eval()
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc='Evaluating'):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    return predictions, true_labels

# Get predictions for fine-tuned model
predictions, true_labels = get_predictions(finetuning_model, test_loader)

# Classification report
print('Classification Report (Fine-tuned Model):\n')
print(classification_report(true_labels, predictions, target_names=class_names))

In [None]:
# Confusion matrix
cm = confusion_matrix(true_labels, predictions)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
           xticklabels=class_names, yticklabels=class_names)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix (Fine-tuned Model)')
plt.tight_layout()
plt.show()

## Visualize Predictions

In [None]:
# Get a batch of test images
dataiter = iter(test_loader)
images, labels = next(dataiter)
images = images.to(device)
labels = labels.to(device)

# Make predictions
finetuning_model.eval()
with torch.no_grad():
    outputs = finetuning_model(images)
    _, preds = torch.max(outputs, 1)

# Display predictions
fig, axes = plt.subplots(2, 4, figsize=(15, 8))
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

for i, ax in enumerate(axes.flat):
    img = images[i].cpu().clone()
    img = denormalize(img, mean, std)
    img = torch.clamp(img, 0, 1)
    
    ax.imshow(img.permute(1, 2, 0))
    
    true_label = class_names[labels[i]]
    pred_label = class_names[preds[i]]
    color = 'green' if labels[i] == preds[i] else 'red'
    
    ax.set_title(f'True: {true_label}\nPred: {pred_label}', color=color)
    ax.axis('off')

plt.tight_layout()
plt.show()

## Save the Fine-tuned Model

In [None]:
# Save model
model_path = './finetuned_resnet18_cifar10.pth'
torch.save(finetuning_model.state_dict(), model_path)
print(f'Model saved to {model_path}')

# To load the model later:
# model = models.resnet18()
# model.fc = nn.Linear(model.fc.in_features, len(class_names))
# model.load_state_dict(torch.load(model_path))
# model.eval()

## Key Takeaways

### Transfer Learning Strategies:

1. **Feature Extraction**:
   - Freeze all pre-trained layers
   - Only train new classification layer
   - Fast training, works well with limited data
   - Good when source and target domains are similar

2. **Fine-tuning**:
   - Unfreeze some layers (usually later layers)
   - Train unfrozen layers with smaller learning rate
   - Better performance but requires more data and time
   - Better when domains are somewhat different

### Best Practices:

1. **Learning Rates**: Use smaller learning rates for pre-trained layers
2. **Data Augmentation**: Important for preventing overfitting
3. **Gradual Unfreezing**: Start with frozen layers, gradually unfreeze
4. **Normalization**: Use same statistics as pre-training dataset

### When to Use Transfer Learning:

- Limited training data
- Similar tasks (e.g., both are image classification)
- Want faster convergence
- Need better performance with less computational cost

## Next Steps

- Try different pre-trained models (VGG, EfficientNet, etc.)
- Experiment with different freezing strategies
- Apply to your own custom datasets
- Compare with training from scratch
- Explore domain adaptation techniques