# Transfer Learning on Flower Photos (Keras)

This notebook mirrors train_transfer.py with clear steps:
- Load config
- Build datasets with augmentation
- Build transfer model head
- Train (feature extraction) and optionally fine-tune
- Evaluate and optionally export the model

In [1]:
# Imports, environment checks, and config loading
import os, json
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from pathlib import Path
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score

print('TensorFlow:', tf.__version__)
print('NumPy:', np.__version__)
# Check for _ARRAY_API presence
has_array_api = hasattr(np.core.multiarray, '_ARRAY_API') if hasattr(np, 'core') and hasattr(np.core, 'multiarray') else False
print('Has _ARRAY_API:', has_array_api)
try:
    import ml_dtypes
    print('ml_dtypes:', getattr(ml_dtypes, '__version__', 'unknown'))
except Exception as e:
    print('ml_dtypes import error:', e)

# Load config
with open('config_transfer.json', 'r', encoding='utf-8') as f:
    cfg = json.load(f)
print('Loaded config for task:', cfg['task'])

TensorFlow: 2.20.0
NumPy: 1.26.4
Has _ARRAY_API: True
ml_dtypes: 0.5.3
Loaded config for task: image_classification_transfer_learning


In [2]:
# Dataset builder

def build_datasets(cfg):
    ds_cfg = cfg['dataset']
    img_size = tuple(ds_cfg['image_size'])
    train_ds = keras.preprocessing.image_dataset_from_directory(
        ds_cfg['path'], validation_split=ds_cfg['validation_split'], subset='training',
        seed=ds_cfg['seed'], image_size=img_size, batch_size=cfg['training']['batch_size'], label_mode=ds_cfg['class_mode'])
    val_ds = keras.preprocessing.image_dataset_from_directory(
        ds_cfg['path'], validation_split=ds_cfg['validation_split'], subset='validation',
        seed=ds_cfg['seed'], image_size=img_size, batch_size=cfg['training']['batch_size'], label_mode=ds_cfg['class_mode'])
    class_names = train_ds.class_names
    AUTOTUNE = tf.data.AUTOTUNE
    return train_ds.prefetch(AUTOTUNE), val_ds.prefetch(AUTOTUNE), class_names

train_ds, val_ds, class_names = build_datasets(cfg)
print('Classes:', class_names)


Found 0 files belonging to 0 classes.
Using 0 files for training.


ValueError: No images found in directory data/flowers/. Allowed formats: ('.bmp', '.gif', '.jpeg', '.jpg', '.png')

In [None]:
# Augmentation and model builder

def build_augmentation(cfg):
    aug_cfg = cfg['augmentation']
    if not aug_cfg.get('enabled', False):
        return keras.Sequential(name='no_aug')
    layers_list = []
    if aug_cfg.get('horizontal_flip'): layers_list.append(layers.RandomFlip('horizontal'))
    if aug_cfg.get('rotation_range'): layers_list.append(layers.RandomRotation(aug_cfg['rotation_range']/360.0))
    if aug_cfg.get('zoom_range'): layers_list.append(layers.RandomZoom(aug_cfg['zoom_range']))
    if aug_cfg.get('width_shift_range') or aug_cfg.get('height_shift_range'):
        layers_list.append(layers.RandomTranslation(
            height_factor=aug_cfg.get('height_shift_range', 0),
            width_factor=aug_cfg.get('width_shift_range', 0)))
    return keras.Sequential(layers_list, name='augmentation')


def build_model(cfg, num_classes):
    model_cfg = cfg['model']; head_cfg = cfg['head']
    base = getattr(keras.applications, model_cfg['base_architecture'])(
        include_top=model_cfg['include_top'], weights=model_cfg['weights'],
        input_shape=(*cfg['dataset']['image_size'], 3)
    )
    base.trainable = False if model_cfg['trainable_strategy'] in ('freeze_then_finetune','freeze') else True

    inputs = keras.Input(shape=(*cfg['dataset']['image_size'], 3))
    x = build_augmentation(cfg)(inputs)
    prep_fn = getattr(keras.applications, model_cfg['base_architecture']).preprocess_input
    x = layers.Lambda(prep_fn)(x)
    x = base(x, training=False)

    if head_cfg['global_pool'] == 'avg': x = layers.GlobalAveragePooling2D()(x)
    elif head_cfg['global_pool'] == 'max': x = layers.GlobalMaxPooling2D()(x)
    else: x = layers.Flatten()(x)

    for units in head_cfg['dense_units']:
        x = layers.Dense(units, activation=head_cfg['activation'])(x)
        if head_cfg.get('dropout', 0) > 0:
            x = layers.Dropout(head_cfg['dropout'])(x)
    outputs = layers.Dense(head_cfg['output_classes'], activation=head_cfg['output_activation'])(x)
    model = keras.Model(inputs, outputs)
    return model, base

model, base = build_model(cfg, len(class_names))
model.summary(line_length=120)


In [None]:
# Compile and feature extraction training

from pathlib import Path

def compile_model(model, lr, optimizer_name):
    if optimizer_name == 'adam':
        opt = keras.optimizers.Adam(learning_rate=lr)
    elif optimizer_name == 'sgd':
        opt = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
    else:
        raise ValueError('Unsupported optimizer')
    model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

tr_cfg = cfg['training']
compile_model(model, tr_cfg['learning_rates']['feature_extraction'], tr_cfg['optimizer'])

callbacks = []
if tr_cfg.get('early_stopping'):
    es = tr_cfg['early_stopping']
    callbacks.append(keras.callbacks.EarlyStopping(monitor=es['monitor'], patience=es['patience'], restore_best_weights=True))
if tr_cfg.get('checkpoint'):
    ck = tr_cfg['checkpoint']
    Path('checkpoints').mkdir(exist_ok=True)
    callbacks.append(keras.callbacks.ModelCheckpoint('checkpoints/best.keras', monitor=ck['monitor'], save_best_only=ck['save_best_only']))

history_fe = model.fit(train_ds, validation_data=val_ds, epochs=tr_cfg['feature_extraction_epochs'], callbacks=callbacks)


In [None]:
# Fine-tuning phase (optional)

ft_cfg = tr_cfg.get('fine_tune', {"enabled": False})
if ft_cfg.get('enabled', False):
    # Find the base model by name or type
    base_layer_name = 'base_model'
    try:
        base_model = model.get_layer(base_layer_name)
    except ValueError:
        # fallback: first layer with trainable set to False typically is the base
        base_model = None
        for layer in model.layers:
            if hasattr(layer, 'layers') and any(getattr(l, 'trainable', True) is False for l in layer.layers):
                base_model = layer
                break
        if base_model is None:
            # last resort: treat the first layer as base
            base_model = model.layers[0]
    unfreeze_from = ft_cfg.get('unfreeze_from', 0)
    for i, layer in enumerate(base_model.layers):
        layer.trainable = (i >= unfreeze_from)
    print(f"Unfroze layers from index {unfreeze_from} (total {len(base_model.layers)})")

    compile_model(model, tr_cfg['learning_rates']['fine_tuning'], tr_cfg['optimizer'])
    history_ft = model.fit(train_ds, validation_data=val_ds, epochs=ft_cfg.get('epochs', 5), callbacks=callbacks)
else:
    history_ft = None


In [None]:
# Evaluation and reporting

from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

val_batches = list(val_ds)
X_val = np.concatenate([x.numpy() for x, _ in val_batches], axis=0)
y_val = np.concatenate([y.numpy() for _, y in val_batches], axis=0)

pred_probs = model.predict(X_val, batch_size=32)
pred_labels = pred_probs.argmax(axis=1)
true_labels = y_val.argmax(axis=1)

print('Classification Report:')
print(classification_report(true_labels, pred_labels, target_names=class_names))

print('Confusion Matrix:')
print(confusion_matrix(true_labels, pred_labels))


In [None]:
# Plot training curves
import matplotlib.pyplot as plt

histories = [h for h in [history_fe, history_ft] if h is not None]
history = {}
for h in histories:
    for k, v in h.history.items():
        history.setdefault(k, []).extend(v)

plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(history.get('loss', []), label='train')
plt.plot(history.get('val_loss', []), label='val')
plt.title('Loss'); plt.legend(); plt.grid(True)

plt.subplot(1,2,2)
plt.plot(history.get('accuracy', []), label='train')
plt.plot(history.get('val_accuracy', []), label='val')
plt.title('Accuracy'); plt.legend(); plt.grid(True)
plt.tight_layout()
plt.show()
