# 03. Transfer Learning (ResNet50)

## Introduction
This notebook implements Transfer Learning using a pre-trained ResNet50 model.
It is fully self-contained.

## Setup

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input

tf.random.set_seed(42)
np.random.seed(42)

## 1. Data Loading
Using the same strategy as the baseline: training on the multi-class data from the test folder for demonstration.

In [None]:
IMG_SIZE = (256, 256)
BATCH_SIZE = 32
DATA_DIR = "../data/raw"
TARGET_CATEGORY = 'bottle'
TEST_DIR = os.path.join(DATA_DIR, TARGET_CATEGORY, 'test')

print(f"Target Category: {TARGET_CATEGORY}")

full_ds = tf.keras.utils.image_dataset_from_directory(
    TEST_DIR,
    seed=123,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)

class_names = full_ds.class_names
num_classes = len(class_names)

# Split
train_size = int(0.8 * len(full_ds))
train_ds = full_ds.take(train_size)
val_ds = full_ds.skip(train_size)

# Preprocessing for ResNet
def preprocess(image, label):
    return preprocess_input(image), label

train_ds = train_ds.map(preprocess).cache().shuffle(1000).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.map(preprocess).cache().prefetch(tf.data.AUTOTUNE)

## 2. Model Definition
We load ResNet50 (ImageNet weights), freeze the base, and add a custom classification head.

In [None]:
def create_resnet_model(input_shape, num_classes):
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
    
    # Freeze base model
    base_model.trainable = False
    
    inputs = tf.keras.Input(shape=input_shape)
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = models.Model(inputs, outputs, name="resnet50_transfer")
    return model

model = create_resnet_model(IMG_SIZE + (3,), num_classes)
model.summary()

## 3. Training
Training the top layers.

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)]
)

## 4. Fine-tuning (Optional)
Unfreezing the last few layers of ResNet for better performance.

In [None]:
base_model = model.layers[1]
base_model.trainable = True

# Freeze all except last 20 layers
for layer in base_model.layers[:-20]:
    layer.trainable = False
    
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), # Lower LR
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history_fine = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)]
)