
# TensorFlow Training Pipeline (MobileNetV2 / EfficientNetV2) — Mirror of PyTorch Notebook

This notebook mirrors the structure of your original PyTorch pipeline but implements it with **TensorFlow/Keras**:
- GPU setup & memory
- Data input (tf.data) with **Keras preprocessing** augmentations
- Transfer learning with **MobileNetV2** and **EfficientNetV2**
- Freeze → Gradual unfreeze (incremental sessions)
- Training loops with callbacks (checkpointing best model)
- Metrics, confusion matrix, classification report
- Single-image inference utility
- History logging to CSV


## 1) Environment & GPU

In [None]:

import os, json, time, sys, platform
import tensorflow as tf

print("Python:", sys.version)
print("TensorFlow:", tf.__version__)
print("Platform:", platform.platform())

gpus = tf.config.list_physical_devices('GPU')
print("GPUs:", gpus)
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical = tf.config.list_logical_devices('GPU')
        print(f"Enabled memory growth. Logical GPUs: {logical}")
    except Exception as e:
        print("Could not set memory growth:", e)
else:
    print("No GPU detected — training will run on CPU.")


## 2) Config

In [None]:

from dataclasses import dataclass

@dataclass
class Config:
    train_dir: str = "datasets/train"
    val_dir: str = "datasets/val"
    test_dir: str = "datasets/test"
    image_size: tuple = (224, 224)
    batch_train: int = 128
    batch_val: int = 64
    batch_test: int = 32
    num_classes: int = 6  # adjust to your dataset
    seed: int = 42
    base_learning_rate: float = 1e-4
    mixed_precision: bool = True
    model_dir: str = "models_tf"
    best_model_name: str = "best_model.keras"
    history_csv: str = "history_log.csv"

cfg = Config()

# Enable mixed precision if GPU is present
if tf.config.list_physical_devices('GPU') and cfg.mixed_precision:
    try:
        from tensorflow.keras import mixed_precision as mp
        mp.set_global_policy('mixed_float16')
        print("Mixed precision enabled.")
    except Exception as e:
        print("Mixed precision not available:", e)


## 3) Data: tf.data Datasets + Augmentations

In [None]:

from tensorflow import keras
from tensorflow.keras import layers

# Training-time augmentation (applies only during training)
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
    layers.RandomTranslation(0.05, 0.05),
    layers.RandomContrast(0.1),
    layers.GaussianNoise(0.05),
], name="data_augmentation")

def build_datasets(cfg: Config):
    train_ds = keras.utils.image_dataset_from_directory(
        cfg.train_dir,
        labels="inferred",
        label_mode="int",
        image_size=cfg.image_size,
        batch_size=cfg.batch_train,
        shuffle=True,
        seed=cfg.seed,
    )
    val_ds = keras.utils.image_dataset_from_directory(
        cfg.val_dir,
        labels="inferred",
        label_mode="int",
        image_size=cfg.image_size,
        batch_size=cfg.batch_val,
        shuffle=False,
        seed=cfg.seed,
    )
    test_ds = keras.utils.image_dataset_from_directory(
        cfg.test_dir,
        labels="inferred",
        label_mode="int",
        image_size=cfg.image_size,
        batch_size=cfg.batch_test,
        shuffle=False,
        seed=cfg.seed,
    )

    class_names = train_ds.class_names
    AUTOTUNE = tf.data.AUTOTUNE

    def preprocess(images, labels):
        images = tf.cast(images, tf.float32) / 255.0
        return images, labels

    train_ds = (train_ds
                .map(preprocess, num_parallel_calls=AUTOTUNE)
                .map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
                .prefetch(AUTOTUNE))

    val_ds = val_ds.map(preprocess, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
    test_ds = test_ds.map(preprocess, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

    return train_ds, val_ds, test_ds, class_names

train_ds, val_ds, test_ds, class_names = build_datasets(cfg)
cfg.num_classes = len(class_names)
print("Classes:", class_names)


## 4) Models: MobileNetV2 and EfficientNetV2 (Transfer Learning)

In [None]:

from tensorflow.keras import Model

def build_backbone(name: str, input_shape):
    inputs = keras.Input(shape=input_shape + (3,))
    x = layers.Resizing(input_shape[0], input_shape[1])(inputs)
    if name.lower() in ["mobilenetv2", "mobilenet_v2"]:
        base = keras.applications.MobileNetV2(include_top=False, weights="imagenet", input_tensor=x, pooling="avg")
    elif name.lower() in ["efficientnetv2m", "efficientnetv2-m", "efficientnet_v2_m"]:
        base = keras.applications.EfficientNetV2M(include_top=False, weights="imagenet", input_tensor=x, pooling="avg")
    elif name.lower() in ["efficientnetb0", "efficientnet-b0", "efficientnet_b0"]:
        base = keras.applications.EfficientNetB0(include_top=False, weights="imagenet", input_tensor=x, pooling="avg")
    else:
        raise ValueError("Unknown backbone: " + name)
    return inputs, base

def build_classifier(backbone_name="efficientnetv2m", num_classes=6, input_shape=(224,224), dropout=0.5):
    inputs, base = build_backbone(backbone_name, input_shape)
    base.trainable = False
    y = base.output
    y = layers.Dropout(dropout)(y)
    outputs = layers.Dense(num_classes, activation="softmax", dtype="float32")(y)
    model = Model(inputs, outputs, name=f"{backbone_name}_classifier")
    return model

model = build_classifier("efficientnetb0", num_classes=cfg.num_classes, input_shape=cfg.image_size)
model.summary()


## 5) Gradual Unfreezing Utilities

In [None]:

from tensorflow.keras import layers as L

def unfreeze_last_n_layers(model: keras.Model, n: int):
    # Attempt to find the backbone (the largest sub-model)
    submodels = [l for l in model.layers if isinstance(l, keras.Model)]
    backbone = max(submodels, key=lambda m: len(m.layers)) if submodels else model

    total = len(backbone.layers)
    to_unfreeze = max(0, min(n, total))
    for layer in backbone.layers[-to_unfreeze:]:
        if isinstance(layer, L.BatchNormalization):
            layer.trainable = False
        else:
            layer.trainable = True
    print(f"Unfroze last {to_unfreeze}/{total} layers of backbone: {backbone.name}")


## 6) Compile & Callbacks

In [None]:

import os
os.makedirs(cfg.model_dir, exist_ok=True)

def compile_model(model, lr):
    opt = keras.optimizers.Adam(learning_rate=lr)
    model.compile(optimizer=opt, loss="sparse_categorical_crossentropy", metrics=["accuracy"])

checkpoint_cb = keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(cfg.model_dir, cfg.best_model_name),
    monitor="val_accuracy",
    save_best_only=True,
    verbose=1
)
earlystop_cb = keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=8, restore_best_weights=True, verbose=1)
reducelr_cb = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=4, verbose=1)

compile_model(model, cfg.base_learning_rate)


## 7) Train & Validate

In [None]:

import pandas as pd
import matplotlib.pyplot as plt

def train_model(model, train_ds, val_ds, epochs, callbacks):
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )
    return history

def plot_history(history):
    plt.figure()
    plt.plot(history.history['accuracy'], label='train_acc')
    plt.plot(history.history['val_accuracy'], label='val_acc')
    plt.title('Accuracy')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(); plt.grid(True)
    plt.show()

    plt.figure()
    plt.plot(history.history['loss'], label='train_loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.title('Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
    plt.show()

def save_history_csv(history, path):
    df = pd.DataFrame(history.history)
    if os.path.exists(path):
        old = pd.read_csv(path)
        df = pd.concat([old, df], ignore_index=True)
    df.to_csv(path, index=False)
    print("Saved history to", path)

history = train_model(model, train_ds, val_ds, epochs=15, callbacks=[checkpoint_cb, earlystop_cb, reducelr_cb])
plot_history(history)
save_history_csv(history, os.path.join(cfg.model_dir, cfg.history_csv))


## 8) Incremental Training Sessions (Freeze → Unfreeze)

In [None]:

def incremental_training(model, sessions, train_ds, val_ds):
    session_histories = []
    for i, sess in enumerate(sessions, start=1):
        print(f"\n=== Session {i}: {sess} ===")
        unfreeze = sess.get("unfreeze_last_n", 0)
        if unfreeze > 0:
            unfreeze_last_n_layers(model, unfreeze)
        lr = sess.get("lr", cfg.base_learning_rate)
        compile_model(model, lr)
        epochs = sess.get("epochs", 10)
        hist = model.fit(train_ds, validation_data=val_ds, epochs=epochs,
                         callbacks=[checkpoint_cb, earlystop_cb, reducelr_cb], verbose=1)
        session_histories.append(hist)
    return session_histories

sessions = [
    {"epochs": 10, "lr": 1e-4, "unfreeze_last_n": 0},
    {"epochs": 15, "lr": 5e-5, "unfreeze_last_n": 20},
    {"epochs": 20, "lr": 1e-5, "unfreeze_last_n": 60},
]

session_histories = incremental_training(model, sessions, train_ds, val_ds)


## 9) Evaluation: Test Set, Classification Report & Confusion Matrix

In [None]:

import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import itertools
import matplotlib.pyplot as plt

best_model_path = os.path.join(cfg.model_dir, cfg.best_model_name)
print("Loading best model from:", best_model_path)
best_model = keras.models.load_model(best_model_path)

test_images, test_labels = [], []
for batch_x, batch_y in test_ds:
    test_images.append(batch_x.numpy())
    test_labels.append(batch_y.numpy())
test_images = np.concatenate(test_images, axis=0)
test_labels = np.concatenate(test_labels, axis=0)

probs = best_model.predict(test_images, batch_size=cfg.batch_test, verbose=1)
preds = probs.argmax(axis=1)

print(classification_report(test_labels, preds, target_names=class_names))

cm = confusion_matrix(test_labels, preds)

def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix'):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.figure()
    plt.imshow(cm, interpolation='nearest')
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right')
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

plot_confusion_matrix(cm, class_names, normalize=False, title='Confusion Matrix')
plot_confusion_matrix(cm, class_names, normalize=True, title='Confusion Matrix (Normalized)')


## 10) Single-Image Inference Utility

In [None]:

from PIL import Image
import numpy as np

def load_and_preprocess_image(path, target_size):
    img = Image.open(path).convert("RGB").resize(target_size)
    arr = np.array(img).astype("float32") / 255.0
    return arr

def predict_single_image(model, image_path, class_names, target_size=(224,224)):
    arr = load_and_preprocess_image(image_path, target_size)
    x = np.expand_dims(arr, axis=0)
    probs = model.predict(x, verbose=0)[0]
    pred_idx = int(np.argmax(probs))
    pred_class = class_names[pred_idx]
    confidence = float(probs[pred_idx])
    return pred_class, confidence, probs

# Example:
# image_path = "datasets/val/<class>/<file>.jpg"
# pred_class, conf, prob = predict_single_image(best_model, image_path, class_names, target_size=cfg.image_size)
# print(pred_class, conf)


## 11) Switching Backbone (MobileNetV2 / EfficientNetB0 / EfficientNetV2M)

In [None]:

# To switch backbone, rebuild and retrain:
# model = build_classifier("mobilenetv2", num_classes=cfg.num_classes, input_shape=cfg.image_size)
# model = build_classifier("efficientnetb0", num_classes=cfg.num_classes, input_shape=cfg.image_size)
# model = build_classifier("efficientnetv2m", num_classes=cfg.num_classes, input_shape=cfg.image_size)
# compile_model(model, cfg.base_learning_rate)
# history = train_model(model, train_ds, val_ds, epochs=15, callbacks=[checkpoint_cb, earlystop_cb, reducelr_cb])
