# Transfer Learning with Pre-trained ViT on CIFAR-10

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/juho127/ClassificationTest/blob/main/pretrained_vit.ipynb)

In this notebook, we'll use a **pre-trained Vision Transformer (ViT)** model from ImageNet and fine-tune it on CIFAR-10.

## What is Transfer Learning?

Transfer learning uses knowledge learned from one task (ImageNet classification) and applies it to another task (CIFAR-10 classification).

### Benefits:
- ‚úì Much better performance (can reach 90%+ accuracy!)
- ‚úì Faster training (fewer epochs needed)
- ‚úì Works well with small datasets
- ‚úì Learns better features

### Comparison:
- **Training from scratch**: ~65-70% (from previous notebook)
- **Transfer learning**: ~85-95% (this notebook)

## Learning Goals:
1. Load pre-trained models using `timm` library
2. Understand fine-tuning strategies
3. Compare different ViT model sizes
4. Achieve state-of-the-art results on CIFAR-10


## 0. Environment Setup


In [None]:
# Check if running on Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("‚úì Running on Google Colab")
    print("üìå Tip: Runtime > Change runtime type > GPU for faster training!")
except:
    IN_COLAB = False
    print("‚úì Running on local environment")

# Install required packages on Colab
if IN_COLAB:
    print("\nInstalling packages...")
    import sys
    # timm: PyTorch Image Models (for pre-trained models)
    !{sys.executable} -m pip install -q torch torchvision tqdm matplotlib timm
    print("‚úì Packages installed!")
else:
    print("\nMake sure you have installed: torch torchvision tqdm matplotlib timm")


## 1. Import Libraries


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import time
import timm  # PyTorch Image Models

print(f"PyTorch version: {torch.__version__}")
print(f"timm version: {timm.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print("üéâ You can use GPU for faster training!")


## 2. Hyperparameters and Data Loading


In [None]:
# Hyperparameters
BATCH_SIZE = 128
LEARNING_RATE = 1e-4  # Lower learning rate for fine-tuning
NUM_EPOCHS = 10  # Fewer epochs needed with pre-trained model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# CIFAR-10 classes
CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Device: {DEVICE}")
if DEVICE.type == 'cuda':
    print("‚úì Using GPU!")
else:
    print("‚Ñπ Using CPU (Colab: Runtime > Change runtime type > GPU)")


In [None]:
# Data preprocessing with strong augmentation
# Note: Pre-trained models expect 224x224 images
transform_train = transforms.Compose([
    transforms.Resize(224),  # Resize CIFAR-10 from 32x32 to 224x224
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    # ImageNet normalization (important for pre-trained models!)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 dataset
print("Loading dataset...")
train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_train
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_test
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"\n‚ö†Ô∏è Note: Images are resized from 32x32 to 224x224 for pre-trained models")


In [None]:
# Visualize sample images
def show_images(loader, num_images=10):
    dataiter = iter(loader)
    images, labels = next(dataiter)
    
    # Denormalize for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    images = images * std + mean
    images = torch.clamp(images, 0, 1)
    
    fig, axes = plt.subplots(2, 5, figsize=(12, 5))
    fig.suptitle('CIFAR-10 Sample Images (Resized to 224x224)', fontsize=16, fontweight='bold')
    
    for idx, ax in enumerate(axes.flat):
        if idx < num_images:
            img = images[idx].numpy().transpose((1, 2, 0))
            ax.imshow(img)
            ax.set_title(f'{CLASSES[labels[idx]]}', fontsize=10)
            ax.axis('off')
    
    plt.tight_layout()
    plt.show()

show_images(train_loader)


## 3. Explore Available Pre-trained Models

Let's check what ViT models are available in the `timm` library.


In [None]:
# List available ViT models
vit_models = timm.list_models('vit*', pretrained=True)
print(f"Available pre-trained ViT models: {len(vit_models)}")
print("\nSome popular models:")
for model in vit_models[:10]:
    print(f"  - {model}")
    
print("\nüí° We'll use 'vit_tiny_patch16_224' (smallest, fastest for practice)")


## 4. Load Pre-trained Model

We'll load a pre-trained ViT model and modify it for CIFAR-10 (10 classes).


In [None]:
def create_pretrained_vit(model_name='vit_tiny_patch16_224', num_classes=10):
    """
    Create a pre-trained ViT model and modify the classifier head
    
    Args:
        model_name: Name of the pre-trained model
        num_classes: Number of output classes (10 for CIFAR-10)
    """
    print(f"Loading pre-trained model: {model_name}")
    
    # Load pre-trained model (trained on ImageNet with 1000 classes)
    model = timm.create_model(model_name, pretrained=True)
    
    # Get the number of features in the classifier
    num_features = model.head.in_features
    
    # Replace the classifier head for CIFAR-10 (10 classes)
    model.head = nn.Linear(num_features, num_classes)
    
    print(f"‚úì Model loaded successfully!")
    print(f"  - Original task: ImageNet (1000 classes)")
    print(f"  - New task: CIFAR-10 ({num_classes} classes)")
    print(f"  - Classifier head replaced: {num_features} -> {num_classes}")
    
    return model

# Create model
model = create_pretrained_vit().to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")


## 5. Fine-tuning Strategies

There are different ways to fine-tune a pre-trained model:

### Strategy 1: Fine-tune all layers
- Train all parameters
- More flexible but slower
- Risk of overfitting on small datasets

### Strategy 2: Feature extraction (freeze backbone)
- Only train the new classifier head
- Faster training
- Good for very small datasets

### Strategy 3: Gradual unfreezing
- Start with frozen backbone, then gradually unfreeze layers
- Best of both worlds

**We'll use Strategy 1** for simplicity and good performance.


In [None]:
# Optional: Freeze backbone for feature extraction only
# Uncomment the lines below to try Strategy 2

# for name, param in model.named_parameters():
#     if 'head' not in name:  # Freeze all except classifier head
#         param.requires_grad = False
# 
# print("Backbone frozen! Only classifier head will be trained.")
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f"Trainable parameters: {trainable_params:,}")

print("Using Strategy 1: Fine-tune all layers")
print(f"All {trainable_params:,} parameters will be trained")


## 6. Training Functions


In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, epoch):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}')
    for images, labels in pbar:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        pbar.set_postfix({
            'loss': f'{running_loss/total:.4f}',
            'acc': f'{100*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc


def evaluate(model, test_loader, criterion):
    """Evaluate model"""
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Evaluating'):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    test_loss = test_loss / total
    test_acc = 100 * correct / total
    return test_loss, test_acc


print("Training functions defined!")


## 7. Train the Model

Now let's fine-tune the pre-trained ViT model on CIFAR-10!


In [None]:
# Setup training
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

# Learning rate scheduler (optional but recommended)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# Training loop
train_losses = []
train_accs = []
test_losses = []
test_accs = []

print("Starting training...")
print("=" * 60)

start_time = time.time()
best_acc = 0.0

for epoch in range(NUM_EPOCHS):
    # Train
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, epoch)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # Evaluate
    test_loss, test_acc = evaluate(model, test_loader, criterion)
    test_losses.append(test_loss)
    test_accs.append(test_acc)
    
    # Update learning rate
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
    print(f"  Test  - Loss: {test_loss:.4f}, Acc: {test_acc:.2f}%")
    print(f"  LR: {current_lr:.6f}")
    
    # Save best model
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'pretrained_vit_best.pth')
        print(f"  ‚úì Best model saved (Accuracy: {best_acc:.2f}%)")
    
    print("-" * 60)

training_time = time.time() - start_time

print(f"\n{'='*60}")
print("Training Complete!")
print(f"Total training time: {training_time/60:.2f} minutes")
print(f"Best test accuracy: {best_acc:.2f}%")
print(f"{'='*60}")


## 8. Visualize Results


In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

epochs = range(1, NUM_EPOCHS + 1)

# Loss graph
ax1.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2, marker='o')
ax1.plot(epochs, test_losses, 'r-', label='Test Loss', linewidth=2, marker='s')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training History: Loss', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy graph
ax2.plot(epochs, train_accs, 'b-', label='Train Accuracy', linewidth=2, marker='o')
ax2.plot(epochs, test_accs, 'r-', label='Test Accuracy', linewidth=2, marker='s')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('Training History: Accuracy', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('pretrained_vit_training.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Training history saved as 'pretrained_vit_training.png'")
print(f"\nFinal Results:")
print(f"  Best Test Accuracy: {best_acc:.2f}%")
print(f"  Final Train Accuracy: {train_accs[-1]:.2f}%")
print(f"  Final Test Accuracy: {test_accs[-1]:.2f}%")


## 9. Per-Class Accuracy Analysis


In [None]:
# Calculate per-class accuracy
class_correct = [0] * 10
class_total = [0] * 10

model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

# Calculate and plot
class_acc = [100 * class_correct[i] / class_total[i] for i in range(10)]

fig, ax = plt.subplots(figsize=(12, 6))
bars = ax.bar(CLASSES, class_acc, color='steelblue', alpha=0.8)
ax.set_xlabel('Class', fontsize=12, fontweight='bold')
ax.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
ax.set_title('Per-Class Accuracy (Pre-trained ViT)', fontsize=14, fontweight='bold')
ax.set_ylim([0, 100])
ax.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, acc in zip(bars, class_acc):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{acc:.1f}%', ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig('pretrained_vit_per_class.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nPer-Class Accuracy:")
print("=" * 40)
for i, cls in enumerate(CLASSES):
    print(f"  {cls:10s}: {class_acc[i]:.2f}%")
print("=" * 40)


## 10. Visualize Predictions


In [None]:
# Visualize predictions
def visualize_predictions(model, loader, num_images=10):
    model.eval()
    dataiter = iter(loader)
    images, labels = next(dataiter)
    
    # Denormalize for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    images_display = images * std + mean
    images_display = torch.clamp(images_display, 0, 1)
    
    images = images.to(DEVICE)
    with torch.no_grad():
        outputs = model(images)
        probabilities = torch.softmax(outputs, dim=1)
        confidences, predicted = torch.max(probabilities, 1)
    
    fig, axes = plt.subplots(2, 5, figsize=(14, 6))
    fig.suptitle('Pre-trained ViT Predictions', fontsize=16, fontweight='bold')
    
    for idx, ax in enumerate(axes.flat):
        if idx < num_images:
            img = images_display[idx].numpy().transpose((1, 2, 0))
            ax.imshow(img)
            
            pred_label = CLASSES[predicted[idx]]
            true_label = CLASSES[labels[idx]]
            conf = confidences[idx].item()
            
            color = 'green' if predicted[idx] == labels[idx] else 'red'
            ax.set_title(f'Pred: {pred_label} ({conf:.2%})\nTrue: {true_label}', 
                        color=color, fontsize=9, fontweight='bold')
            ax.axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_predictions(model, test_loader)


## 11. Compare with Training from Scratch

Let's compare the performance of pre-trained vs. from-scratch models.


In [None]:
# Comparison table
comparison_data = {
    'Method': ['Training from Scratch', 'Transfer Learning (Pre-trained)'],
    'Accuracy': ['~65-70%', f'{best_acc:.2f}%'],
    'Training Time': ['~20-30 min (20 epochs)', f'{training_time/60:.1f} min ({NUM_EPOCHS} epochs)'],
    'Convergence': ['Slower', 'Faster'],
    'Data Efficiency': ['Needs more data', 'Works with less data']
}

print("=" * 80)
print("COMPARISON: Training from Scratch vs. Transfer Learning")
print("=" * 80)
print(f"{'Method':<35} {'Accuracy':<15} {'Training Time':<20}")
print("-" * 80)
for i in range(len(comparison_data['Method'])):
    print(f"{comparison_data['Method'][i]:<35} "
          f"{comparison_data['Accuracy'][i]:<15} "
          f"{comparison_data['Training Time'][i]:<20}")
print("=" * 80)

print("\nüéØ Key Takeaways:")
print("  1. Pre-trained models achieve MUCH better accuracy (+20-25%)")
print("  2. Faster convergence (fewer epochs needed)")
print("  3. More stable training (less overfitting)")
print("  4. Better feature extraction from ImageNet knowledge")


## 12. Try Different Model Sizes (Optional)

The `timm` library provides various ViT model sizes. Let's compare them!


In [None]:
# Compare different ViT model sizes
model_variants = [
    'vit_tiny_patch16_224',   # Smallest, fastest
    'vit_small_patch16_224',  # Medium size
    'vit_base_patch16_224',   # Larger, better accuracy but slower
]

print("Available ViT Model Variants:")
print("=" * 70)
print(f"{'Model Name':<30} {'Parameters':<20} {'Speed':<20}")
print("-" * 70)

for variant in model_variants:
    try:
        temp_model = timm.create_model(variant, pretrained=False)
        params = sum(p.numel() for p in temp_model.parameters()) / 1e6
        
        if 'tiny' in variant:
            speed = "‚ö° Fast"
        elif 'small' in variant:
            speed = "‚Üí Medium"
        else:
            speed = "üêå Slower"
        
        print(f"{variant:<30} {params:.1f}M parameters    {speed:<20}")
        del temp_model
    except:
        print(f"{variant:<30} Not available")

print("=" * 70)
print("\nüí° Recommendation:")
print("  - For practice: vit_tiny_patch16_224 (fast, good accuracy)")
print("  - For best results: vit_base_patch16_224 (slower, better accuracy)")
print("\nTo try a different model, change the model_name in Section 4!")


## 13. Key Concepts Summary

### What is Transfer Learning?
Transfer learning uses knowledge from a **source task** (ImageNet) to improve performance on a **target task** (CIFAR-10).

### Why does it work?
1. **Low-level features are universal**: Edge detectors, color filters work across datasets
2. **High-level features transfer**: Object parts, shapes are similar
3. **Pre-trained weights are better initialization**: Better than random initialization

### When to use Transfer Learning?
‚úÖ **Use when:**
- Limited training data
- Similar domain (images ‚Üí images)
- Want faster training
- Want better performance

‚ùå **Don't use when:**
- Very different domains (images ‚Üí text)
- Huge amount of target data
- Very specific task that's very different from source

### Fine-tuning Strategies:
1. **Feature Extraction**: Freeze backbone, train only classifier
2. **Fine-tune all**: Train all layers (we used this)
3. **Gradual unfreezing**: Start frozen, gradually unfreeze layers
4. **Discriminative learning rates**: Different learning rates for different layers


## 14. Exercises

Try these experiments to deepen your understanding:

### Easy:
1. **Change learning rate**: Try `1e-3`, `1e-5` and compare results
2. **Change number of epochs**: Try 5, 15, 20 epochs
3. **Try different model**: Use `vit_small_patch16_224` instead

### Medium:
4. **Feature extraction**: Uncomment the freeze code in Section 5
   - Compare training time and accuracy
5. **Data augmentation**: Remove some augmentations from transform_train
   - See how it affects overfitting

### Hard:
6. **Implement gradual unfreezing**:
   - Freeze all layers initially
   - Unfreeze one block at a time every few epochs
7. **Try other architectures**:
   - ResNet: `resnet50`, `resnet101`
   - EfficientNet: `efficientnet_b0`, `efficientnet_b3`
8. **Ensemble methods**: Train multiple models and combine predictions


## 15. Save Model (Optional)


In [None]:
# Save the final model
torch.save(model.state_dict(), 'pretrained_vit_final.pth')
print("‚úì Model saved as 'pretrained_vit_final.pth'")

# To load the model later:
# model = create_pretrained_vit()
# model.load_state_dict(torch.load('pretrained_vit_final.pth'))
# model.to(DEVICE)
# model.eval()

print("\n" + "=" * 70)
print("CONGRATULATIONS! üéâ")
print("=" * 70)
print(f"You successfully fine-tuned a pre-trained ViT model!")
print(f"Final test accuracy: {best_acc:.2f}%")
print(f"\nThis is much better than training from scratch (~65-70%)!")
print("=" * 70)
