# Transfer Learning with ResNet50 on CIFAR-10

In this notebook, we'll perform **transfer learning** using a **ResNet50** model pretrained on **ImageNet**, and fine-tune it on the **CIFAR-10** dataset.

Transfer learning allows us to reuse features learned by powerful models trained on large datasets — saving time and improving accuracy on smaller tasks.

### Steps we'll cover:
1. Load and preprocess CIFAR-10
2. Load pretrained ResNet50 model
3. Freeze base layers & add custom head
4. Compile and train model
5. Evaluate and visualize results

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt

## 1️⃣ Load and Preprocess CIFAR-10

In [None]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()

class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

# Resize images to 224x224 (ResNet input size)
X_train_resized = tf.image.resize(X_train, (224, 224))
X_test_resized = tf.image.resize(X_test, (224, 224))

# Preprocess using ResNet's preprocessing
X_train_prep = preprocess_input(X_train_resized)
X_test_prep = preprocess_input(X_test_resized)

## 2️⃣ Load Pretrained ResNet50 Model
- We'll use `include_top=False` to remove the fully connected layers.
- Add our custom classifier on top.

In [None]:
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze base layers
for layer in base_model.layers:
    layer.trainable = False

model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')
])

model.summary()

## 3️⃣ Compile and Train Model

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(X_train_prep, y_train, validation_data=(X_test_prep, y_test), epochs=5, batch_size=64)

## 4️⃣ Fine-Tuning (Optional)
Let's unfreeze the last few layers of ResNet to improve performance through fine-tuning.

In [None]:
for layer in base_model.layers[-10:]:
    layer.trainable = True

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

fine_tune_history = model.fit(X_train_prep, y_train, validation_data=(X_test_prep, y_test), epochs=3, batch_size=64)

## 5️⃣ Evaluate Model Performance

In [None]:
test_loss, test_acc = model.evaluate(X_test_prep, y_test, verbose=2)
print(f"\n✅ Test Accuracy after fine-tuning: {test_acc*100:.2f}%")

## 6️⃣ Visualize Accuracy and Loss Curves

In [None]:
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'] + fine_tune_history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'] + fine_tune_history.history['val_accuracy'], label='Val Accuracy')
plt.legend()
plt.title('Accuracy over Epochs')

plt.subplot(1,2,2)
plt.plot(history.history['loss'] + fine_tune_history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'] + fine_tune_history.history['val_loss'], label='Val Loss')
plt.legend()
plt.title('Loss over Epochs')
plt.show()

## 7️⃣ Visualize Predictions

In [None]:
pred_probs = model.predict(X_test_prep[:9])
pred_classes = np.argmax(pred_probs, axis=1)

plt.figure(figsize=(10,10))
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.imshow(X_test[i])
    plt.title(f"True: {class_names[y_test[i][0]]}\nPred: {class_names[pred_classes[i]]}")
    plt.axis('off')
plt.show()

## ✅ Summary
- Used **ResNet50** pretrained on ImageNet.
- Added a custom dense head for CIFAR-10 classification.
- Achieved good accuracy in few epochs with minimal training.
- Fine-tuning improved performance further.

**Next steps:** Try other architectures like **VGG16**, **MobileNetV2**, or **EfficientNet** for comparison.