# Quantization-aware training (GPU)

This notebook:
- builds a MobileNetV2 classifier with full-layer QAT
- trains on Kaggle WHO6 using random-frame sampling from videos
- evaluates the quantized model
- exports an INT8 TFLite model

Note: QAT expects a GPU for reasonable speed, but it also runs on CPU.


In [None]:
%pip install -q numpy pandas opencv-python tensorflow tensorflow-model-optimization tqdm requests


In [None]:
import os
import sys
import time
import random
import tarfile
from pathlib import Path
import importlib.util

import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import requests


def find_repo_root(start=None):
    start = Path.cwd() if start is None else Path(start)
    for parent in [start] + list(start.parents):
        if (parent / "inference" / "config.py").exists() or (parent / "training" / "config.py").exists():
            return parent
    return start


def _load_module(path: Path, name: str):
    spec = importlib.util.spec_from_file_location(name, path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot load module from {path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


REPO_ROOT = find_repo_root()
INFERENCE_DIR = REPO_ROOT / "inference"
TRAINING_DIR = REPO_ROOT / "training"

if (INFERENCE_DIR / "config.py").exists():
    cfg = _load_module(INFERENCE_DIR / "config.py", "cfg")
elif (TRAINING_DIR / "config.py").exists():
    cfg = _load_module(TRAINING_DIR / "config.py", "cfg")
else:
    raise FileNotFoundError("config.py not found in inference/ or training/")

np.random.seed(cfg.RANDOM_SEED)
random.seed(cfg.RANDOM_SEED)
try:
    tf.random.set_seed(cfg.RANDOM_SEED)
except Exception:
    pass

RAW_DIR = REPO_ROOT / "datasets" / "raw"
OUTPUT_DIR = TRAINING_DIR / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print("Repo root:", REPO_ROOT)
print("Output dir:", OUTPUT_DIR)


## GPU check


In [None]:
print("TF version:", tf.__version__)
print("GPUs:", tf.config.list_physical_devices("GPU"))
for gpu in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except Exception:
        pass


## Download Kaggle WHO6 dataset (if needed)


In [None]:
KAGGLE_URL = cfg.DATASETS["kaggle"]["url"]
KAGGLE_DIR = RAW_DIR / "kaggle"
KAGGLE_TAR = KAGGLE_DIR / "kaggle-dataset-6classes.tar"
KAGGLE_EXTRACTED = KAGGLE_DIR / "kaggle-dataset-6classes"


def download_with_progress(url, dest: Path):
    dest.parent.mkdir(parents=True, exist_ok=True)
    if dest.exists():
        print("skip", dest)
        return
    with requests.get(url, stream=True, timeout=30) as r:
        r.raise_for_status()
        total = int(r.headers.get("content-length", 0))
        with open(dest, "wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=dest.name) as pbar:
            for chunk in r.iter_content(chunk_size=1024 * 1024):
                if not chunk:
                    continue
                f.write(chunk)
                pbar.update(len(chunk))


def extract_tar(tar_path: Path, target_dir: Path):
    if target_dir.exists():
        print("skip", target_dir)
        return
    target_dir.mkdir(parents=True, exist_ok=True)
    with tarfile.open(tar_path, "r") as tar:
        tar.extractall(path=target_dir)


if not KAGGLE_EXTRACTED.exists():
    print("Downloading Kaggle WHO6...")
    download_with_progress(KAGGLE_URL, KAGGLE_TAR)
    print("Extracting...")
    extract_tar(KAGGLE_TAR, KAGGLE_DIR)
else:
    print("Kaggle dataset already extracted:", KAGGLE_EXTRACTED)


## Index videos and create splits


In [None]:
VIDEO_EXTS = (".mp4", ".avi", ".mov", ".mkv")


def kaggle_class_id_from_folder(name: str) -> int:
    name_lower = name.lower()
    if name_lower in cfg.KAGGLE_CLASS_MAPPING:
        return int(cfg.KAGGLE_CLASS_MAPPING[name_lower])
    digits = "".join(ch for ch in name_lower if ch.isdigit())
    if digits:
        class_id = int(digits)
        if 0 <= class_id < len(cfg.CLASS_NAMES):
            return class_id
    raise ValueError(f"Unknown Kaggle class folder: {name}")


def collect_videos(dataset_root: Path) -> pd.DataFrame:
    rows = []
    for class_dir in sorted(dataset_root.iterdir()):
        if not class_dir.is_dir():
            continue
        class_id = kaggle_class_id_from_folder(class_dir.name)
        for vid in class_dir.iterdir():
            if vid.suffix.lower() not in VIDEO_EXTS:
                continue
            rows.append({"video_path": str(vid), "class_id": class_id})
    df = pd.DataFrame(rows)
    if df.empty:
        raise RuntimeError(f"No videos found in {dataset_root}")
    return df


videos_df = collect_videos(KAGGLE_EXTRACTED)
print("Total videos:", len(videos_df))

train_df, temp_df = train_test_split(
    videos_df,
    test_size=(cfg.VAL_RATIO + cfg.TEST_RATIO),
    stratify=videos_df["class_id"],
    random_state=cfg.RANDOM_SEED,
)
val_size = cfg.VAL_RATIO / (cfg.VAL_RATIO + cfg.TEST_RATIO)
val_df, test_df = train_test_split(
    temp_df,
    test_size=(1.0 - val_size),
    stratify=temp_df["class_id"],
    random_state=cfg.RANDOM_SEED,
)

print("Train:", len(train_df), "Val:", len(val_df), "Test:", len(test_df))


## tf.data pipeline (random frames per video)


In [None]:
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_v2_preprocess

IMG_SIZE = cfg.IMG_SIZE
NUM_CLASSES = cfg.NUM_CLASSES


def _load_random_frame_py(video_path_bytes, label):
    video_path = video_path_bytes.decode("utf-8")
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if frame_count <= 0:
        cap.release()
        raise RuntimeError(f"No frames in {video_path}")
    target_idx = np.random.randint(0, frame_count)
    cap.set(cv2.CAP_PROP_POS_FRAMES, target_idx)
    ok, frame = cap.read()
    cap.release()
    if not ok:
        raise RuntimeError(f"Failed reading frame from {video_path}")
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = cv2.resize(frame, IMG_SIZE)
    frame = frame.astype(np.float32)
    frame = mobilenet_v2_preprocess(frame)
    return frame, np.int32(label)


def _load_random_frame(video_path, label):
    frame, label = tf.py_function(
        _load_random_frame_py,
        inp=[video_path, label],
        Tout=[tf.float32, tf.int32],
    )
    frame.set_shape((*IMG_SIZE, 3))
    label.set_shape(())
    return frame, label


def make_dataset(df, batch_size=32, shuffle=True):
    ds = tf.data.Dataset.from_tensor_slices((df["video_path"].values, df["class_id"].values))
    if shuffle:
        ds = ds.shuffle(len(df), seed=cfg.RANDOM_SEED, reshuffle_each_iteration=True)
    ds = ds.map(_load_random_frame, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds


BATCH_SIZE = 32
train_ds = make_dataset(train_df, batch_size=BATCH_SIZE, shuffle=True)
val_ds = make_dataset(val_df, batch_size=BATCH_SIZE, shuffle=False)
test_ds = make_dataset(test_df, batch_size=BATCH_SIZE, shuffle=False)


## Build MobileNetV2 + QAT


In [None]:
from tensorflow.keras import layers, models


def build_model():
    base = tf.keras.applications.MobileNetV2(
        input_shape=(*IMG_SIZE, 3),
        include_top=False,
        weights=None,
    )
    base.trainable = True
    inputs = layers.Input(shape=(*IMG_SIZE, 3))
    x = base(inputs, training=True)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
    model = models.Model(inputs, outputs)
    return model


float_model = build_model()
quantize_model = tfmot.quantization.keras.quantize_model
qat_model = quantize_model(float_model)

qat_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)

qat_model.summary()


## Train


In [None]:
EPOCHS = 5

history = qat_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
)


## Evaluate


In [None]:
qat_eval = qat_model.evaluate(test_ds, verbose=1)
print("QAT test metrics:", dict(zip(qat_model.metrics_names, qat_eval)))


## Save QAT model and export INT8 TFLite

Axis DLPU devices require INT8 TFLite with built-in ops.


In [None]:
from typing import Iterable

qat_model_path = OUTPUT_DIR / "mobilenetv2_qat.keras"
qat_model.save(qat_model_path)
print("Saved QAT model:", qat_model_path)

DISABLE_PER_CHANNEL = False


def representative_dataset_from_ds(ds, max_batches=10) -> Iterable[list[np.ndarray]]:
    count = 0
    for batch, _ in ds:
        batch_np = batch.numpy().astype(np.float32)
        for i in range(batch_np.shape[0]):
            yield [batch_np[i : i + 1]]
        count += 1
        if count >= max_batches:
            break


tflite_path = OUTPUT_DIR / "mobilenetv2_qat_int8.tflite"
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = lambda: representative_dataset_from_ds(train_ds)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
if DISABLE_PER_CHANNEL:
    converter._experimental_disable_per_channel = True

try:
    tflite_model = converter.convert()
    tflite_path.write_bytes(tflite_model)
    print("Saved TFLite:", tflite_path)
except Exception as exc:
    print("INT8 conversion failed:", exc)


In [None]:
if tflite_path.exists():
    interpreter = tf.lite.Interpreter(model_path=str(tflite_path))
    interpreter.allocate_tensors()
    input_detail = interpreter.get_input_details()[0]
    output_detail = interpreter.get_output_details()[0]
    print("Input dtype:", input_detail["dtype"], "quant:", input_detail.get("quantization"))
    print("Output dtype:", output_detail["dtype"], "quant:", output_detail.get("quantization"))
