# Transfer Learning for Image Classification

This notebook implements transfer learning using pre-trained models (VGG16, ResNet50, EfficientNet).

## Objectives:
- Compare multiple pre-trained architectures
- Implement transfer learning with frozen base layers
- Fine-tune the best performing model
- Compare results with custom CNN
- Select the best overall model

## 1. Setup and Data Loading

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.applications import VGG16, ResNet50, EfficientNetB0
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import pickle

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

In [None]:
# Load and preprocess data (same as previous notebooks)
DATASET_CHOICE = "cifar10"  # or "animals10"

if DATASET_CHOICE == "cifar10":
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                   'dog', 'frog', 'horse', 'ship', 'truck']
    num_classes = 10

# Normalize and prepare data
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
y_train_cat = to_categorical(y_train, num_classes)
y_test_cat = to_categorical(y_test, num_classes)

# Train/validation split
from sklearn.model_selection import train_test_split
x_train_split, x_val_split, y_train_split, y_val_split = train_test_split(
    x_train, y_train_cat, test_size=0.2, random_state=42, stratify=y_train
)

print(f"Training set: {x_train_split.shape}")
print(f"Validation set: {x_val_split.shape}")
print(f"Test set: {x_test.shape}")

## 2. Data Preprocessing for Transfer Learning

In [None]:
# Resize images for pre-trained models (they expect larger input sizes)
def resize_images(images, target_size=(224, 224)):
    """
    Resize images to target size for pre-trained models.
    """
    resized = tf.image.resize(images, target_size)
    return resized.numpy()

# Resize all datasets
x_train_resized = resize_images(x_train_split)
x_val_resized = resize_images(x_val_split)
x_test_resized = resize_images(x_test)

print(f"Resized training set: {x_train_resized.shape}")
print(f"Resized validation set: {x_val_resized.shape}")
print(f"Resized test set: {x_test_resized.shape}")

In [None]:
# Data augmentation for transfer learning
transfer_datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.15,
    height_shift_range=0.15,
    horizontal_flip=True,
    zoom_range=0.15,
    fill_mode='nearest'
)

transfer_datagen.fit(x_train_resized)
print("Data augmentation configured for transfer learning")

## 3. Transfer Learning Model Creation Functions

In [None]:
def create_transfer_model(base_model_name, input_shape, num_classes, trainable=False):
    """
    Create a transfer learning model with specified base architecture.
    
    Args:
        base_model_name: 'vgg16', 'resnet50', or 'efficientnet'
        input_shape: Input image shape
        num_classes: Number of output classes
        trainable: Whether to make base model trainable (for fine-tuning)
    """
    
    # Select base model
    if base_model_name == 'vgg16':
        base_model = VGG16(
            weights='imagenet',
            include_top=False,
            input_shape=input_shape
        )
    elif base_model_name == 'resnet50':
        base_model = ResNet50(
            weights='imagenet',
            include_top=False,
            input_shape=input_shape
        )
    elif base_model_name == 'efficientnet':
        base_model = EfficientNetB0(
            weights='imagenet',
            include_top=False,
            input_shape=input_shape
        )
    else:
        raise ValueError(f"Unknown base model: {base_model_name}")
    
    # Freeze or unfreeze base model
    base_model.trainable = trainable
    
    # Add custom classification head
    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        layers.Dense(128, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(num_classes, activation='softmax')
    ])
    
    return model, base_model

print("Transfer learning model creation functions defined")

## 4. Compare Multiple Pre-trained Architectures

In [None]:
# Model configurations to compare
model_configs = [
    {'name': 'vgg16', 'display_name': 'VGG16'},
    {'name': 'resnet50', 'display_name': 'ResNet50'},
    {'name': 'efficientnet', 'display_name': 'EfficientNetB0'}
]

input_shape = x_train_resized.shape[1:]  # (224, 224, 3)
results = {}

print(f"Input shape for transfer learning: {input_shape}")
print(f"Will compare {len(model_configs)} architectures...")

In [None]:
# Training parameters
BATCH_SIZE = 32
INITIAL_EPOCHS = 20  # Fewer epochs for initial comparison

for config in model_configs:
    model_name = config['name']
    display_name = config['display_name']
    
    print(f"\n{'='*60}")
    print(f"Training {display_name}...")
    print(f"{'='*60}")
    
    # Create model
    model, base_model = create_transfer_model(
        model_name, input_shape, num_classes, trainable=False
    )
    
    # Compile model
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Setup callbacks
    callbacks = [
        EarlyStopping(
            monitor='val_accuracy',
            patience=8,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            f'../models/{model_name}_transfer_best.h5',
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        )
    ]
    
    print(f"Model parameters: {model.count_params():,}")
    print(f"Trainable parameters: {sum(p.numel() for p in model.trainable_variables):,}")
    
    # Train model
    history = model.fit(
        transfer_datagen.flow(x_train_resized, y_train_split, batch_size=BATCH_SIZE),
        steps_per_epoch=len(x_train_resized) // BATCH_SIZE,
        epochs=INITIAL_EPOCHS,
        validation_data=(x_val_resized, y_val_split),
        callbacks=callbacks,
        verbose=1
    )
    
    # Evaluate on test set
    test_loss, test_accuracy = model.evaluate(x_test_resized, y_test_cat, verbose=0)
    
    # Store results
    results[model_name] = {
        'model': model,
        'history': history.history,
        'test_accuracy': test_accuracy,
        'test_loss': test_loss,
        'best_val_accuracy': max(history.history['val_accuracy']),
        'display_name': display_name
    }
    
    print(f"\n{display_name} Results:")
    print(f"Best validation accuracy: {results[model_name]['best_val_accuracy']:.4f}")
    print(f"Test accuracy: {test_accuracy:.4f}")

print("\nAll models trained and evaluated!")

## 5. Compare Model Performance

In [None]:
# Compare all models
print("\n" + "="*80)
print("TRANSFER LEARNING MODELS COMPARISON")
print("="*80)

comparison_data = []
for model_name, result in results.items():
    comparison_data.append({
        'Model': result['display_name'],
        'Best Val Accuracy': f"{result['best_val_accuracy']:.4f}",
        'Test Accuracy': f"{result['test_accuracy']:.4f}",
        'Test Loss': f"{result['test_loss']:.4f}"
    })

import pandas as pd
comparison_df = pd.DataFrame(comparison_data)
print(comparison_df.to_string(index=False))

# Find best model
best_model_name = max(results.keys(), key=lambda k: results[k]['test_accuracy'])
best_result = results[best_model_name]

print(f"\nBest performing model: {best_result['display_name']}")
print(f"Test accuracy: {best_result['test_accuracy']:.4f}")

In [None]:
# Plot comparison of training histories
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot validation accuracy comparison
for model_name, result in results.items():
    axes[0, 0].plot(result['history']['val_accuracy'], 
                   label=result['display_name'], linewidth=2)
axes[0, 0].set_title('Validation Accuracy Comparison')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Accuracy')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot validation loss comparison
for model_name, result in results.items():
    axes[0, 1].plot(result['history']['val_loss'], 
                   label=result['display_name'], linewidth=2)
axes[0, 1].set_title('Validation Loss Comparison')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Bar plot of final test accuracies
model_names = [results[k]['display_name'] for k in results.keys()]
test_accuracies = [results[k]['test_accuracy'] for k in results.keys()]

bars = axes[1, 0].bar(model_names, test_accuracies, 
                     color=['skyblue', 'lightcoral', 'lightgreen'])
axes[1, 0].set_title('Test Accuracy Comparison')
axes[1, 0].set_ylabel('Test Accuracy')
axes[1, 0].tick_params(axis='x', rotation=45)

# Add value labels on bars
for bar, acc in zip(bars, test_accuracies):
    axes[1, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
                   f'{acc:.3f}', ha='center', va='bottom')

# Training vs validation accuracy for best model
best_history = best_result['history']
axes[1, 1].plot(best_history['accuracy'], label='Training', linewidth=2)
axes[1, 1].plot(best_history['val_accuracy'], label='Validation', linewidth=2)
axes[1, 1].set_title(f'{best_result["display_name"]} - Best Model Training')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Fine-tuning the Best Model

In [None]:
# Fine-tune the best performing model
print(f"\nFine-tuning {best_result['display_name']}...")

# Create a new model for fine-tuning
finetune_model, base_model = create_transfer_model(
    best_model_name, input_shape, num_classes, trainable=True
)

# Load the best weights from initial training
finetune_model.load_weights(f'../models/{best_model_name}_transfer_best.h5')

# Use a lower learning rate for fine-tuning
finetune_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5),  # Lower LR
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print(f"Fine-tuning model with {finetune_model.count_params():,} total parameters")
print(f"Trainable parameters: {sum(p.numel() for p in finetune_model.trainable_variables):,}")

In [None]:
# Fine-tuning callbacks
finetune_callbacks = [
    EarlyStopping(
        monitor='val_accuracy',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    ModelCheckpoint(
        f'../models/{best_model_name}_finetuned_best.h5',
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        min_lr=1e-8,
        verbose=1
    )
]

# Fine-tune the model
FINETUNE_EPOCHS = 30

finetune_history = finetune_model.fit(
    transfer_datagen.flow(x_train_resized, y_train_split, batch_size=BATCH_SIZE),
    steps_per_epoch=len(x_train_resized) // BATCH_SIZE,
    epochs=FINETUNE_EPOCHS,
    validation_data=(x_val_resized, y_val_split),
    callbacks=finetune_callbacks,
    verbose=1
)

print("Fine-tuning completed!")

## 7. Final Model Evaluation

In [None]:
# Evaluate fine-tuned model
finetune_test_loss, finetune_test_accuracy = finetune_model.evaluate(
    x_test_resized, y_test_cat, verbose=0
)

print(f"\nFinal Results Comparison:")
print(f"{'='*50}")
print(f"Initial {best_result['display_name']}: {best_result['test_accuracy']:.4f}")
print(f"Fine-tuned {best_result['display_name']}: {finetune_test_accuracy:.4f}")
print(f"Improvement: {finetune_test_accuracy - best_result['test_accuracy']:+.4f}")
print(f"{'='*50}")

# Generate detailed predictions for fine-tuned model
y_pred_finetune = finetune_model.predict(x_test_resized)
y_pred_classes_finetune = np.argmax(y_pred_finetune, axis=1)
y_true_classes = np.argmax(y_test_cat, axis=1)

# Classification report
print("\nDetailed Classification Report (Fine-tuned Model):")
print(classification_report(y_true_classes, y_pred_classes_finetune, 
                          target_names=class_names))

In [None]:
# Confusion matrix for fine-tuned model
cm_finetune = confusion_matrix(y_true_classes, y_pred_classes_finetune)

plt.figure(figsize=(10, 8))
sns.heatmap(cm_finetune, annot=True, fmt='d', cmap='Blues', 
           xticklabels=class_names, yticklabels=class_names)
plt.title(f'Confusion Matrix - Fine-tuned {best_result["display_name"]}')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# Per-class accuracy
class_accuracy_finetune = cm_finetune.diagonal() / cm_finetune.sum(axis=1)
print("\nPer-class Accuracy (Fine-tuned):")
for i, class_name in enumerate(class_names):
    print(f"{class_name}: {class_accuracy_finetune[i]:.4f}")

## 8. Model Comparison with Custom CNN

In [None]:
# Load custom CNN results for comparison (if available)
try:
    with open(f'../models/custom_cnn_{DATASET_CHOICE}_history.pkl', 'rb') as f:
        custom_cnn_history = pickle.load(f)
    
    # Load custom CNN model to get test accuracy
    custom_cnn_model = keras.models.load_model(f'../models/custom_cnn_{DATASET_CHOICE}_final.h5')
    custom_cnn_test_loss, custom_cnn_test_accuracy = custom_cnn_model.evaluate(
        x_test, y_test_cat, verbose=0
    )
    
    print("\nFINAL MODEL COMPARISON")
    print("="*60)
    print(f"Custom CNN Test Accuracy: {custom_cnn_test_accuracy:.4f}")
    print(f"Transfer Learning ({best_result['display_name']}) Test Accuracy: {finetune_test_accuracy:.4f}")
    
    if finetune_test_accuracy > custom_cnn_test_accuracy:
        print(f"\n🏆 WINNER: Transfer Learning ({best_result['display_name']})")
        print(f"Improvement over Custom CNN: {finetune_test_accuracy - custom_cnn_test_accuracy:+.4f}")
    else:
        print(f"\n🏆 WINNER: Custom CNN")
        print(f"Advantage over Transfer Learning: {custom_cnn_test_accuracy - finetune_test_accuracy:+.4f}")
    
    print("="*60)
    
except FileNotFoundError:
    print("\nCustom CNN results not found. Run notebook 02_custom_cnn.ipynb first for comparison.")

## 9. Save Final Results

In [None]:
# Save the best transfer learning model
final_model_path = f'../models/best_transfer_model_{DATASET_CHOICE}.h5'
finetune_model.save(final_model_path)
print(f"Best transfer learning model saved: {final_model_path}")

# Save all results
all_results = {
    'initial_comparison': results,
    'best_model_name': best_model_name,
    'finetune_history': finetune_history.history,
    'finetune_test_accuracy': finetune_test_accuracy,
    'finetune_test_loss': finetune_test_loss
}

with open(f'../models/transfer_learning_results_{DATASET_CHOICE}.pkl', 'wb') as f:
    pickle.dump(all_results, f)
print(f"Transfer learning results saved")

# Summary
print("\n" + "="*70)
print("TRANSFER LEARNING SUMMARY")
print("="*70)
print(f"Dataset: {DATASET_CHOICE.upper()}")
print(f"Best architecture: {best_result['display_name']}")
print(f"Initial transfer learning accuracy: {best_result['test_accuracy']:.4f}")
print(f"Fine-tuned accuracy: {finetune_test_accuracy:.4f}")
print(f"Model parameters: {finetune_model.count_params():,}")
print(f"Final model saved: {final_model_path}")
print("="*70)