# Transfer Learning from Custom CNN

This notebook implements transfer learning using our best custom CNN as the base model.

## Objectives:
- Load the best trained custom CNN model
- Freeze convolutional layers (keep learned features)
- Replace dense layers with new classifier
- Fine-tune on the same dataset for improved performance
- Evaluate and compare results

## 1. Setup and Data Loading

In [None]:
# Configure environment for Apple Silicon optimization
import warnings
warnings.filterwarnings('ignore')
import os
import shutil
import random
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Suppress TensorFlow warnings

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

# Configure TensorFlow for Apple Silicon
try:
    # Enable memory growth for GPU (if available)
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"✅ GPU configured: {len(gpus)} device(s) found")
    else:
        print("ℹ️  No GPU found, using CPU")
except Exception as e:
    print(f"⚠️  GPU configuration warning: {e}")

print(f"TensorFlow version: {tf.__version__}")
print(f"Available devices: {[device.name for device in tf.config.list_physical_devices()]}")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)
print("✅ Random seeds set for reproducibility")

## 2. Data Loading and Preprocessing

Using the same data setup as the original training notebook.

In [None]:
path = '/Users/smithn5/.cache/kagglehub/datasets/alessiocorrado99/animals10/versions/2'
data_dir = os.path.join(path, 'raw-img')

In [None]:
# Configuration - Keep same as original training
FIRST_TIME_SETUP = False  # Data should already be split
base_dir = '../data/'
batch_size = 32
img_height = 224  # Match original training
img_width = 224   # Match original training

# Paths to best models from previous training
BEST_MODEL_PATH = '../models/custom_costum_animals10_acc_0_60.h5'  # Update this to your best model
# BEST_MODEL_PATH = '../models/custom_simple_animals10_best.h5'  # Alternative path

print(f"🔍 Looking for best model at: {BEST_MODEL_PATH}")
print(f"📂 Data directory: {base_dir}")
print(f"🖼️  Image size: {img_height}x{img_width}")
print(f"📦 Batch size: {batch_size}")

In [None]:
# Load the dataset using tf.keras.preprocessing.image.ImageDataGenerator
# Keep same data preprocessing as original training

print("🔧 Creating data generators...")

# Create training data augmentation (same as original)
train_datagen = ImageDataGenerator(
    rescale=1./255,  # Normalize to [0,1]
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Create validation/test generator without augmentation
test_val_datagen = ImageDataGenerator(
    rescale=1./255,  # Only rescaling for validation
)

# Create training dataset with augmentation
train_ds = train_datagen.flow_from_directory(
    os.path.join(base_dir, 'train'),
    shuffle=True,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)

# Create validation dataset WITHOUT augmentation
val_ds = test_val_datagen.flow_from_directory(
    os.path.join(base_dir, 'val'),
    shuffle=False,  # Don't shuffle validation
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)

# Create test dataset WITHOUT augmentation
test_ds = test_val_datagen.flow_from_directory(
    os.path.join(base_dir, 'test'),
    shuffle=False,  # Don't shuffle test
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)

# Get class names from the dataset
class_names = list(train_ds.class_indices.keys())
print(f"Found {len(class_names)} classes: {class_names}")

# Calculate dataset sizes
print(f"Training samples: {train_ds.samples}")
print(f"Validation samples: {val_ds.samples}")
print(f"Test samples: {test_ds.samples}")
print(f"Training batches per epoch: {len(train_ds)}")
print(f"Validation batches per epoch: {len(val_ds)}")
print(f"Test batches per epoch: {len(test_ds)}")

print(f"\n✅ Data generators created successfully!")

## 3. Load Best Trained Model

In [None]:
# Load the best trained model
print("🔄 Loading best trained model...")
print(f"Model path: {BEST_MODEL_PATH}")

if not os.path.exists(BEST_MODEL_PATH):
    print(f"❌ Model file not found: {BEST_MODEL_PATH}")
    print("\n📋 Available models in ../models/:")
    models_dir = '../models/'
    if os.path.exists(models_dir):
        for file in os.listdir(models_dir):
            if file.endswith('.h5'):
                print(f"   - {file}")
    print("\n⚠️  Please update BEST_MODEL_PATH to point to your best model")
    raise FileNotFoundError(f"Model not found: {BEST_MODEL_PATH}")

try:
    # Load the complete trained model
    base_model = tf.keras.models.load_model(BEST_MODEL_PATH)
    print("✅ Model loaded successfully!")
    
    # Display model info
    print(f"\n📊 Loaded Model Information:")
    print(f"   Input shape: {base_model.input_shape}")
    print(f"   Output shape: {base_model.output_shape}")
    print(f"   Total parameters: {base_model.count_params():,}")
    print(f"   Trainable parameters: {sum([tf.keras.backend.count_params(w) for w in base_model.trainable_weights]):,}")
    
    # Evaluate the loaded model to see current performance
    print(f"\n🧪 Evaluating loaded model on validation set...")
    val_ds.reset()
    base_val_loss, base_val_acc = base_model.evaluate(val_ds, verbose=0)
    print(f"   Base model validation accuracy: {base_val_acc:.4f}")
    print(f"   Base model validation loss: {base_val_loss:.4f}")
    val_ds.reset()
    
except Exception as e:
    print(f"❌ Error loading model: {e}")
    raise

# Display the architecture
print(f"\n🏗️  Model Architecture:")
base_model.summary()

## 4. Create Transfer Learning Model

Extract convolutional layers and create new classifier.

In [None]:
# Find the flatten layer to separate conv layers from dense layers
print("🔍 Analyzing model architecture...")
flatten_layer_index = None
for i, layer in enumerate(base_model.layers):
    print(f"   Layer {i}: {layer.name} ({type(layer).__name__})")
    if isinstance(layer, tf.keras.layers.Flatten):
        flatten_layer_index = i
        print(f"   ✅ Found Flatten layer at index {i}")
        break

if flatten_layer_index is None:
    print("❌ No Flatten layer found! Model structure might be different.")
    print("   Looking for GlobalAveragePooling2D or other pooling layers...")
    for i, layer in enumerate(base_model.layers):
        if isinstance(layer, (tf.keras.layers.GlobalAveragePooling2D, tf.keras.layers.GlobalMaxPooling2D)):
            flatten_layer_index = i
            print(f"   ✅ Found {type(layer).__name__} layer at index {i}")
            break

if flatten_layer_index is None:
    raise ValueError("Could not find a suitable layer to separate conv and dense parts")

print(f"\n🎯 Will keep layers 0 to {flatten_layer_index} (convolutional part)")
print(f"   Will replace layers {flatten_layer_index+1} onwards (dense part)")

In [None]:
# Create the feature extractor (convolutional part)
print("🔧 Creating feature extractor from convolutional layers...")

# Extract the convolutional base (up to and including flatten/pooling)
conv_base = tf.keras.Model(
    inputs=base_model.input,
    outputs=base_model.layers[flatten_layer_index].output
)

print(f"✅ Feature extractor created:")
print(f"   Input shape: {conv_base.input_shape}")
print(f"   Output shape: {conv_base.output_shape}")
print(f"   Parameters: {conv_base.count_params():,}")

# Freeze the convolutional base
conv_base.trainable = False
print(f"\n❄️  Convolutional base frozen (trainable=False)")
print(f"   Frozen parameters: {conv_base.count_params():,}")

# Display the feature extractor architecture
print(f"\n🏗️  Feature Extractor Architecture:")
conv_base.summary()

In [None]:
# Create the new transfer learning model
print("🚀 Building transfer learning model with new classifier...")

num_classes = len(class_names)
input_shape = (img_height, img_width, 3)

# Build the new model
transfer_model = tf.keras.Sequential([
    conv_base,  # Frozen convolutional base
    
    # New classifier head - you can experiment with different architectures
    tf.keras.layers.GlobalAveragePooling2D() if not isinstance(base_model.layers[flatten_layer_index], tf.keras.layers.GlobalAveragePooling2D) else tf.keras.layers.Lambda(lambda x: x),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(num_classes, activation='softmax', name='predictions')
])

# Build the model by calling it with a sample input
transfer_model.build(input_shape=(None, img_height, img_width, 3))

print(f"✅ Transfer learning model created!")
print(f"\n📊 Transfer Learning Model Information:")
print(f"   Input shape: {transfer_model.input_shape}")
print(f"   Output shape: {transfer_model.output_shape}")
print(f"   Total parameters: {transfer_model.count_params():,}")
print(f"   Trainable parameters: {sum([tf.keras.backend.count_params(w) for w in transfer_model.trainable_weights]):,}")
print(f"   Frozen parameters: {transfer_model.count_params() - sum([tf.keras.backend.count_params(w) for w in transfer_model.trainable_weights]):,}")

# Display the full architecture
print(f"\n🏗️  Complete Transfer Learning Model:")
transfer_model.summary()

## 5. Model Compilation and Training Setup

In [None]:
# Compile the transfer learning model
print("⚙️  Compiling transfer learning model...")

# Use a lower learning rate for transfer learning
LEARNING_RATE = 0.0001  # Lower LR since we're fine-tuning

transfer_model.compile(
    optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=LEARNING_RATE),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print(f"✅ Model compiled with learning rate: {LEARNING_RATE}")

# Setup callbacks for training
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=15,
        restore_best_weights=True,
        verbose=1
    ),
    ModelCheckpoint(
        '../models/transfer_learning_best.h5',
        monitor='val_loss',
        save_best_only=True,
        save_weights_only=False,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    )
]

print("✅ Training callbacks configured:")
print("   - EarlyStopping: patience=15")
print("   - ModelCheckpoint: saves best model")
print("   - ReduceLROnPlateau: factor=0.5, patience=5")

# Test the model with a sample batch to ensure everything works
print("\n🧪 Testing model with sample batch...")
sample_batch_x, sample_batch_y = next(train_ds)
test_prediction = transfer_model.predict(sample_batch_x[:1], verbose=0)
print(f"   Sample prediction shape: {test_prediction.shape}")
print(f"   Sample prediction sum: {test_prediction.sum():.4f} (should be ~1.0)")
train_ds.reset()
print("✅ Model ready for training!")

## 6. Transfer Learning Training

In [None]:
# Training parameters
EPOCHS = 30  # Usually fewer epochs needed for transfer learning

print(f"🚀 Starting transfer learning training...")
print(f"="*60)
print(f"Training Configuration:")
print(f"- Learning rate: {LEARNING_RATE}")
print(f"- Max epochs: {EPOCHS}")
print(f"- Batch size: {batch_size}")
print(f"- Frozen layers: {len(conv_base.layers)} convolutional layers")
print(f"- Trainable parameters: {sum([tf.keras.backend.count_params(w) for w in transfer_model.trainable_weights]):,}")
print(f"- Total parameters: {transfer_model.count_params():,}")
print(f"="*60)

# Start training
history = transfer_model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=callbacks,
    verbose=1
)

print("\n✅ Transfer learning training completed!")
if len(history.history['accuracy']) > 0:
    print(f"Total epochs trained: {len(history.history['accuracy'])}")
    print(f"Best validation accuracy: {max(history.history['val_accuracy']):.4f}")
    print(f"Final training accuracy: {history.history['accuracy'][-1]:.4f}")
    print(f"Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")
    
    # Compare with base model
    improvement = max(history.history['val_accuracy']) - base_val_acc
    print(f"\n📈 Performance Improvement:")
    print(f"   Base model validation accuracy: {base_val_acc:.4f}")
    print(f"   Transfer model best validation accuracy: {max(history.history['val_accuracy']):.4f}")
    print(f"   Improvement: {improvement:+.4f} ({improvement*100:+.2f}%)")

# Reset generators for future use
train_ds.reset()
val_ds.reset()

## 7. Training History Visualization

In [None]:
def plot_transfer_learning_history(history, base_val_acc):
    """
    Plot training and validation metrics for transfer learning with comparison to base model.
    """
    # Create a larger figure with subplots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # Plot accuracy
    epochs = range(1, len(history.history['accuracy']) + 1)
    ax1.plot(epochs, history.history['accuracy'], 'b-', label='Training Accuracy', linewidth=2, marker='o', markersize=4)
    ax1.plot(epochs, history.history['val_accuracy'], 'r-', label='Validation Accuracy', linewidth=2, marker='s', markersize=4)
    
    # Add base model performance line
    ax1.axhline(y=base_val_acc, color='orange', linestyle='--', linewidth=2, label=f'Base Model Val Acc ({base_val_acc:.3f})')
    
    ax1.set_title('Transfer Learning - Model Accuracy Over Time', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([0, 1])
    
    # Add best accuracy annotation
    best_val_acc = max(history.history['val_accuracy'])
    best_val_acc_epoch = history.history['val_accuracy'].index(best_val_acc) + 1
    ax1.annotate(f'Best: {best_val_acc:.3f}', 
                xy=(best_val_acc_epoch, best_val_acc), 
                xytext=(best_val_acc_epoch + 2, best_val_acc - 0.05),
                arrowprops=dict(arrowstyle='->', color='red', alpha=0.7))
    
    # Plot loss
    ax2.plot(epochs, history.history['loss'], 'b-', label='Training Loss', linewidth=2, marker='o', markersize=4)
    ax2.plot(epochs, history.history['val_loss'], 'r-', label='Validation Loss', linewidth=2, marker='s', markersize=4)
    ax2.set_title('Transfer Learning - Model Loss Over Time', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Plot learning rate (if available)
    if 'lr' in history.history:
        ax3.plot(epochs, history.history['lr'], 'g-', label='Learning Rate', linewidth=2)
        ax3.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('Learning Rate')
        ax3.set_yscale('log')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
    else:
        # Plot improvement over base model
        improvement = np.array(history.history['val_accuracy']) - base_val_acc
        ax3.plot(epochs, improvement, 'green', label='Improvement over Base Model', linewidth=2)
        ax3.axhline(y=0, color='black', linestyle='--', alpha=0.5, label='Base Model Level')
        ax3.set_title('Transfer Learning Improvement', fontsize=14, fontweight='bold')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('Accuracy Improvement')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
    
    # Plot comparison summary
    ax4.text(0.1, 0.9, 'Transfer Learning Summary', fontsize=16, fontweight='bold', transform=ax4.transAxes)
    ax4.text(0.1, 0.8, f'Base Model Val Acc: {base_val_acc:.4f}', fontsize=12, transform=ax4.transAxes)
    ax4.text(0.1, 0.7, f'Best Transfer Val Acc: {best_val_acc:.4f}', fontsize=12, transform=ax4.transAxes)
    ax4.text(0.1, 0.6, f'Improvement: {best_val_acc - base_val_acc:+.4f}', fontsize=12, transform=ax4.transAxes, 
             color='green' if best_val_acc > base_val_acc else 'red')
    ax4.text(0.1, 0.5, f'Epochs Trained: {len(epochs)}', fontsize=12, transform=ax4.transAxes)
    ax4.text(0.1, 0.4, f'Final Train Acc: {history.history["accuracy"][-1]:.4f}', fontsize=12, transform=ax4.transAxes)
    ax4.text(0.1, 0.3, f'Final Val Acc: {history.history["val_accuracy"][-1]:.4f}', fontsize=12, transform=ax4.transAxes)
    
    # Add performance assessment
    if best_val_acc > base_val_acc + 0.01:
        ax4.text(0.1, 0.2, '✅ Transfer Learning Successful!', fontsize=12, transform=ax4.transAxes, color='green')
    elif best_val_acc > base_val_acc:
        ax4.text(0.1, 0.2, '⚡ Marginal Improvement', fontsize=12, transform=ax4.transAxes, color='orange')
    else:
        ax4.text(0.1, 0.2, '⚠️  No Improvement', fontsize=12, transform=ax4.transAxes, color='red')
    
    ax4.set_xlim(0, 1)
    ax4.set_ylim(0, 1)
    ax4.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print comprehensive metrics
    print("\n" + "="*60)
    print("TRANSFER LEARNING RESULTS ANALYSIS")
    print("="*60)
    print(f"Base model validation accuracy: {base_val_acc:.4f}")
    print(f"Transfer learning best accuracy: {best_val_acc:.4f}")
    print(f"Improvement: {best_val_acc - base_val_acc:+.4f} ({(best_val_acc - base_val_acc)*100:+.2f}%)")
    print(f"Total epochs trained: {len(epochs)}")
    print(f"Best epoch: {best_val_acc_epoch}")
    print(f"Final training accuracy: {history.history['accuracy'][-1]:.4f}")
    print(f"Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")
    
    # Overfitting analysis
    final_gap = history.history['accuracy'][-1] - history.history['val_accuracy'][-1]
    print(f"\nOverfitting Analysis:")
    print(f"Final accuracy gap: {final_gap:.4f}")
    if final_gap > 0.1:
        print("⚠️  Potential overfitting detected")
    elif final_gap > 0.05:
        print("⚡ Mild overfitting")
    else:
        print("✅ Good generalization")
    
    print("="*60)

# Plot the training history
plot_transfer_learning_history(history, base_val_acc)

## 8. Model Evaluation on Test Set

In [None]:
# Evaluate on validation set
val_ds.reset()
val_loss, val_accuracy = transfer_model.evaluate(val_ds, verbose=0)
print(f"Transfer Learning Model - Validation Accuracy: {val_accuracy:.4f}")
print(f"Transfer Learning Model - Validation Loss: {val_loss:.4f}")

# Generate predictions on validation set
print("\nGenerating predictions...")
val_ds.reset()
y_pred = transfer_model.predict(val_ds, verbose=1)
y_pred_classes = np.argmax(y_pred, axis=1)

# Get true labels from validation generator
val_ds.reset()
y_true_classes = []
for i in range(len(val_ds)):
    batch_images, batch_labels = next(val_ds)
    batch_true_classes = np.argmax(batch_labels, axis=1)
    y_true_classes.extend(batch_true_classes)

y_true_classes = np.array(y_true_classes)

# Ensure we have the same number of predictions and true labels
min_length = min(len(y_pred_classes), len(y_true_classes))
y_pred_classes = y_pred_classes[:min_length]
y_true_classes = y_true_classes[:min_length]

print(f"Number of samples evaluated: {min_length}")

# Classification report
print("\nClassification Report - Transfer Learning Model:")
print(classification_report(y_true_classes, y_pred_classes, target_names=class_names))

# Reset generator
val_ds.reset()

## 9. Confusion Matrix Visualization

In [None]:
# Generate and plot confusion matrix
cm = confusion_matrix(y_true_classes, y_pred_classes)

plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
           xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix - Transfer Learning Model', fontsize=16, fontweight='bold')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# Calculate per-class accuracy
class_accuracy = cm.diagonal() / cm.sum(axis=1)
print("\nPer-class Accuracy - Transfer Learning Model:")
for i, class_name in enumerate(class_names):
    print(f"{class_name}: {class_accuracy[i]:.4f}")

# Calculate and display macro and weighted averages
macro_avg = np.mean(class_accuracy)
weighted_avg = np.average(class_accuracy, weights=cm.sum(axis=1))
print(f"\nOverall Performance:")
print(f"Macro average accuracy: {macro_avg:.4f}")
print(f"Weighted average accuracy: {weighted_avg:.4f}")
print(f"Overall accuracy: {val_accuracy:.4f}")

## 10. Model Comparison and Saving

In [None]:
# Compare base model vs transfer learning model
print("📊 MODEL COMPARISON SUMMARY")
print("="*60)
print(f"Base Model (Original CNN):")
print(f"   Validation Accuracy: {base_val_acc:.4f}")
print(f"   Total Parameters: {base_model.count_params():,}")
print(f"   All parameters were trainable")

print(f"\nTransfer Learning Model:")
print(f"   Validation Accuracy: {val_accuracy:.4f}")
print(f"   Total Parameters: {transfer_model.count_params():,}")
print(f"   Trainable Parameters: {sum([tf.keras.backend.count_params(w) for w in transfer_model.trainable_weights]):,}")
print(f"   Frozen Parameters: {transfer_model.count_params() - sum([tf.keras.backend.count_params(w) for w in transfer_model.trainable_weights]):,}")

improvement = val_accuracy - base_val_acc
print(f"\nPerformance Improvement: {improvement:+.4f} ({improvement*100:+.2f}%)")

if improvement > 0.01:
    print("✅ Transfer learning was successful!")
elif improvement > 0:
    print("⚡ Marginal improvement achieved")
else:
    print("⚠️  No significant improvement - consider different approach")

print("="*60)

# Save the transfer learning model
val_accuracy_str = f"{val_accuracy:.2f}".replace('.', '_')
model_filename = f'../models/transfer_learning_animals10_acc_{val_accuracy_str}.h5'
transfer_model.save(model_filename)
print(f"\n💾 Transfer learning model saved as: {model_filename}")

# Save training history
import pickle
with open(f'../models/transfer_learning_animals10_acc_{val_accuracy_str}_history.pkl', 'wb') as f:
    pickle.dump(history.history, f)
print(f"📈 Training history saved")

# Save model configuration
model_config = {
    'model_name': 'Transfer Learning from Custom CNN',
    'base_model_path': BEST_MODEL_PATH,
    'base_model_accuracy': float(base_val_acc),
    'dataset': 'Animals10',
    'input_shape': list(input_shape),
    'num_classes': num_classes,
    'batch_size': batch_size,
    'learning_rate': LEARNING_RATE,
    'epochs_trained': len(history.history['accuracy']),
    'best_val_accuracy': float(max(history.history['val_accuracy'])),
    'final_val_accuracy': float(val_accuracy),
    'improvement_over_base': float(improvement),
    'total_parameters': int(transfer_model.count_params()),
    'trainable_parameters': int(sum([tf.keras.backend.count_params(w) for w in transfer_model.trainable_weights])),
    'frozen_parameters': int(transfer_model.count_params() - sum([tf.keras.backend.count_params(w) for w in transfer_model.trainable_weights])),
    'architecture': 'Custom CNN base + new classifier head'
}

import json
with open(f'../models/transfer_learning_animals10_acc_{val_accuracy_str}_config.json', 'w') as f:
    json.dump(model_config, f, indent=2)

print(f"⚙️  Model configuration saved")
print(f"\n🎉 Transfer learning experiment completed successfully!")

## 11. Optional: Fine-tuning the Entire Model

Uncomment and run this section if you want to unfreeze some layers and fine-tune the entire model.

In [None]:
# # Optional: Fine-tune the entire model
# print("🔓 Unfreezing convolutional base for fine-tuning...")

# # Unfreeze the convolutional base
# conv_base.trainable = True

# # Use a much lower learning rate for fine-tuning
# FINE_TUNE_LR = 0.00001

# transfer_model.compile(
#     optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=FINE_TUNE_LR),
#     loss='categorical_crossentropy',
#     metrics=['accuracy']
# )

# print(f"✅ Model recompiled for fine-tuning with LR: {FINE_TUNE_LR}")
# print(f"   Now trainable parameters: {sum([tf.keras.backend.count_params(w) for w in transfer_model.trainable_weights]):,}")

# # Fine-tune for a few more epochs
# FINE_TUNE_EPOCHS = 10

# print(f"\n🎯 Starting fine-tuning for {FINE_TUNE_EPOCHS} epochs...")
# fine_tune_history = transfer_model.fit(
#     train_ds,
#     epochs=FINE_TUNE_EPOCHS,
#     validation_data=val_ds,
#     callbacks=[
#         EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
#         ModelCheckpoint('../models/fine_tuned_best.h5', save_best_only=True)
#     ],
#     verbose=1
# )

# print("✅ Fine-tuning completed!")
# final_val_loss, final_val_acc = transfer_model.evaluate(val_ds, verbose=0)
# print(f"Final fine-tuned accuracy: {final_val_acc:.4f}")
# print(f"Total improvement from base: {final_val_acc - base_val_acc:+.4f}")