# Handwash Full Pipeline (Kaggle) - QAT
Self contained notebook for Kaggle with quantization-aware training.


In [None]:
# Install dependencies
!pip install -q --no-cache-dir scikit-learn pandas numpy opencv-python-headless matplotlib seaborn tqdm requests tensorflow-model-optimization


In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
from tensorflow import keras

# =========================
# Standard library
# =========================
import math
import random
from pathlib import Path
from typing import List, Dict, Tuple

# =========================
# Third-party
# =========================
import numpy as np
import cv2
import pandas as pd
from tqdm import tqdm

# =========================
# TensorFlow / Keras
# =========================
import tensorflow as tf
import tensorflow_model_optimization as tfmot

RANDOM_SEED = 42
IMG_SIZE = (224, 224)
NUM_CLASSES = 7
CLASS_NAMES = ['Other', 'Step1_PalmToPalm', 'Step2_PalmOverDorsum', 'Step3_InterlacedFingers', 'Step4_BackOfFingers', 'Step5_ThumbRub', 'Step6_Fingertips']
KAGGLE_URL = 'https://github.com/atiselsts/data/raw/master/kaggle-dataset-6classes.tar'
KAGGLE_CLASS_MAPPING = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, 'step1': 1, 'step2': 2, 'step3': 3, 'step4': 4, 'step5': 5, 'step6': 6, 'other': 0}
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15

np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
try:
    tf.random.set_seed(RANDOM_SEED)
except Exception:
    pass


In [None]:
import os, sys, json, time, math, random, shutil, subprocess
from pathlib import Path

RUN_NAME = os.environ.get("RUN_NAME", "handwash_qat_run")
KAGGLE_WORKING = Path("/kaggle/working")
if not KAGGLE_WORKING.exists():
    KAGGLE_WORKING = Path.cwd()

WORK_DIR = KAGGLE_WORKING / "handwash_runs" / RUN_NAME
DATA_DIR = KAGGLE_WORKING / "handwash_data"

RAW_DIR = DATA_DIR / "raw"
PROCESSED_DIR = DATA_DIR / "processed"
MODELS_DIR = WORK_DIR / "models"
CKPT_DIR = WORK_DIR / "checkpoints"
LOGS_DIR = WORK_DIR / "logs"

for p in [WORK_DIR, DATA_DIR, RAW_DIR, PROCESSED_DIR, MODELS_DIR, CKPT_DIR, LOGS_DIR]:
    p.mkdir(parents=True, exist_ok=True)

print("WORK_DIR:", WORK_DIR)
print("DATA_DIR:", DATA_DIR)


## Configuration
All options are user editable.


In [None]:
# User config (edit these)
DATASETS = ["kaggle"]

# Training config
BATCH_SIZE = 32
EPOCHS_FLOAT = 1
EPOCHS_QAT = 5
LEARNING_RATE_FLOAT = 1e-4
LEARNING_RATE_QAT = 1e-5

# Quantization config
DISABLE_PER_CHANNEL = False

# Data sampling config
FRAME_STRIDE = 2
MAX_FRAMES_PER_VIDEO = 8

# Mixed precision
MIXED_PRECISION = False


In [None]:
# Auto-tune and mixed precision
if MIXED_PRECISION:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy("mixed_float16")
    print("Mixed precision enabled")

# Multi-GPU strategy
_gpus = tf.config.list_physical_devices('GPU')
if len(_gpus) > 1:
    STRATEGY = tf.distribute.MirroredStrategy()
else:
    STRATEGY = tf.distribute.get_strategy()
NUM_REPLICAS = STRATEGY.num_replicas_in_sync
print('Using strategy:', STRATEGY, 'replicas:', NUM_REPLICAS)


## Download and preprocess


In [None]:
import requests
from tqdm import tqdm
import tarfile
from IPython.display import Video, display

VIDEO_EXTS = (".mp4", ".avi", ".mov", ".mkv")

KAGGLE_DIR = RAW_DIR / "kaggle"
KAGGLE_TAR = KAGGLE_DIR / "kaggle-dataset-6classes.tar"
KAGGLE_EXTRACTED = KAGGLE_DIR / "kaggle-dataset-6classes"


def download_with_progress(url: str, dest: Path):
    dest.parent.mkdir(parents=True, exist_ok=True)
    if dest.exists():
        print("skip", dest)
        return
    with requests.get(url, stream=True, timeout=30) as r:
        r.raise_for_status()
        total = int(r.headers.get("content-length", 0))
        with open(dest, "wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=dest.name) as pbar:
            for chunk in r.iter_content(chunk_size=1024 * 1024):
                if not chunk:
                    continue
                f.write(chunk)
                pbar.update(len(chunk))


def extract_tar(tar_path: Path, extract_root: Path):
    extract_root.mkdir(parents=True, exist_ok=True)
    with tarfile.open(tar_path, "r") as tar:
        tar.extractall(path=extract_root)


def find_kaggle_input_dataset():
    root = Path("/kaggle/input")
    if not root.exists():
        return None
    for item in root.iterdir():
        if not item.is_dir():
            continue
        candidate = item / "kaggle-dataset-6classes"
        if candidate.exists():
            return candidate
        if all((item / str(i)).exists() for i in range(NUM_CLASSES)):
            return item
    return None


DATA_ROOT = find_kaggle_input_dataset()
if DATA_ROOT is not None:
    print("Using Kaggle input dataset:", DATA_ROOT)
else:
    print("Kaggle input dataset not found. Downloading...")
    if not KAGGLE_EXTRACTED.exists():
        download_with_progress(KAGGLE_URL, KAGGLE_TAR)
        extract_tar(KAGGLE_TAR, KAGGLE_DIR)
    DATA_ROOT = KAGGLE_EXTRACTED

print("Dataset root:", DATA_ROOT)


In [None]:
from sklearn.model_selection import train_test_split


def kaggle_class_id_from_folder(name: str) -> int:
    name_lower = name.lower()
    if name_lower in KAGGLE_CLASS_MAPPING:
        return int(KAGGLE_CLASS_MAPPING[name_lower])
    digits = "".join(ch for ch in name_lower if ch.isdigit())
    if digits:
        class_id = int(digits)
        if 0 <= class_id < len(CLASS_NAMES):
            return class_id
    raise ValueError(f"Unknown Kaggle class folder: {name}")


def collect_videos(dataset_root: Path) -> pd.DataFrame:
    rows = []
    for class_dir in sorted(dataset_root.iterdir()):
        if not class_dir.is_dir():
            continue
        class_id = kaggle_class_id_from_folder(class_dir.name)
        for vid in class_dir.iterdir():
            if vid.suffix.lower() not in VIDEO_EXTS:
                continue
            rows.append({"video_path": str(vid), "class_id": class_id})
    df = pd.DataFrame(rows)
    if df.empty:
        raise RuntimeError(f"No videos found in {dataset_root}")
    return df


videos_df = collect_videos(DATA_ROOT)
print("Total videos:", len(videos_df))

train_df, temp_df = train_test_split(
    videos_df,
    test_size=(VAL_RATIO + TEST_RATIO),
    stratify=videos_df["class_id"],
    random_state=RANDOM_SEED,
)
val_size = VAL_RATIO / (VAL_RATIO + TEST_RATIO)
val_df, test_df = train_test_split(
    temp_df,
    test_size=(1.0 - val_size),
    stratify=temp_df["class_id"],
    random_state=RANDOM_SEED,
)

print("Train:", len(train_df), "Val:", len(val_df), "Test:", len(test_df))


In [None]:
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_v2_preprocess


def _load_random_frame_py(video_path_bytes, label):
    video_path = video_path_bytes.decode("utf-8")
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if frame_count <= 0:
        cap.release()
        raise RuntimeError(f"No frames in {video_path}")
    target_idx = np.random.randint(0, frame_count)
    cap.set(cv2.CAP_PROP_POS_FRAMES, target_idx)
    ok, frame = cap.read()
    cap.release()
    if not ok:
        raise RuntimeError(f"Failed reading frame from {video_path}")
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = cv2.resize(frame, IMG_SIZE)
    frame = frame.astype(np.float32)
    frame = mobilenet_v2_preprocess(frame)
    return frame, np.int32(label)


def _load_random_frame(video_path, label):
    frame, label = tf.py_function(
        _load_random_frame_py,
        inp=[video_path, label],
        Tout=[tf.float32, tf.int32],
    )
    frame.set_shape((*IMG_SIZE, 3))
    label.set_shape(())
    return frame, label


def make_dataset(df, batch_size=32, shuffle=True):
    ds = tf.data.Dataset.from_tensor_slices((df["video_path"].values, df["class_id"].values))
    if shuffle:
        ds = ds.shuffle(len(df), seed=RANDOM_SEED, reshuffle_each_iteration=True)
    ds = ds.map(_load_random_frame, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds


train_ds = make_dataset(train_df, batch_size=BATCH_SIZE, shuffle=True)
val_ds = make_dataset(val_df, batch_size=BATCH_SIZE, shuffle=False)
test_ds = make_dataset(test_df, batch_size=BATCH_SIZE, shuffle=False)


## Build QAT model


In [None]:
from tensorflow.keras import layers, models


def build_float_model():
    base = tf.keras.applications.MobileNetV2(
        input_shape=(*IMG_SIZE, 3),
        include_top=False,
        weights="imagenet",
    )
    base.trainable = True
    inputs = layers.Input(shape=(*IMG_SIZE, 3))
    x = base(inputs, training=True)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
    model = models.Model(inputs, outputs)
    return model


with STRATEGY.scope():
    float_model = build_float_model()
    float_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_FLOAT),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=["accuracy"],
    )

float_model.summary()


## Warmup (float) and QAT fine-tune


In [None]:
history_float = float_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS_FLOAT,
)

with STRATEGY.scope():
    quantize_model = tfmot.quantization.keras.quantize_model
    qat_model = quantize_model(float_model)
    qat_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_QAT),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=["accuracy"],
    )

qat_model.summary()

history_qat = qat_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS_QAT,
)


## Evaluate


In [None]:
qat_eval = qat_model.evaluate(test_ds, verbose=1)
print("QAT test metrics:", dict(zip(qat_model.metrics_names, qat_eval)))


## Export INT8 TFLite


In [None]:
from typing import Iterable

qat_model_path = MODELS_DIR / "mobilenetv2_qat.keras"
qat_model.save(qat_model_path)
print("Saved QAT model:", qat_model_path)


def representative_dataset_from_ds(ds, max_batches=10) -> Iterable[list[np.ndarray]]:
    count = 0
    for batch, _ in ds:
        batch_np = batch.numpy().astype(np.float32)
        for i in range(batch_np.shape[0]):
            yield [batch_np[i : i + 1]]
        count += 1
        if count >= max_batches:
            break


tflite_path = MODELS_DIR / "mobilenetv2_qat_int8.tflite"
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = lambda: representative_dataset_from_ds(train_ds)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
if DISABLE_PER_CHANNEL:
    converter._experimental_disable_per_channel = True

try:
    tflite_model = converter.convert()
    tflite_path.write_bytes(tflite_model)
    print("Saved TFLite:", tflite_path)
except Exception as exc:
    print("INT8 conversion failed:", exc)


In [None]:
if tflite_path.exists():
    interpreter = tf.lite.Interpreter(model_path=str(tflite_path))
    interpreter.allocate_tensors()
    input_detail = interpreter.get_input_details()[0]
    output_detail = interpreter.get_output_details()[0]
    print("Input dtype:", input_detail["dtype"], "quant:", input_detail.get("quantization"))
    print("Output dtype:", output_detail["dtype"], "quant:", output_detail.get("quantization"))
