# Handwash Full Pipeline (Kaggle) - QAT
Self contained notebook for Kaggle with quantization-aware training.


In [None]:
# Install dependencies
!pip install -q --no-cache-dir scikit-learn pandas numpy opencv-python-headless matplotlib seaborn tqdm requests tensorflow-model-optimization


In [None]:
import os
os.environ.setdefault("TF_USE_LEGACY_KERAS", "1")
print("TF_USE_LEGACY_KERAS=", os.environ.get("TF_USE_LEGACY_KERAS"))


In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
from tensorflow import keras

# =========================
# Standard library
# =========================
import math
import random
from pathlib import Path
from typing import List, Dict, Tuple

# =========================
# Third-party
# =========================
import numpy as np
import cv2
import pandas as pd
from tqdm import tqdm

# =========================
# TensorFlow / Keras
# =========================
import tensorflow as tf

RANDOM_SEED = 42
IMG_SIZE = (224, 224)
NUM_CLASSES = 7
CLASS_NAMES = ['Unused', 'Step1_PalmToPalm', 'Step2_PalmOverDorsum', 'Step3_InterlacedFingers', 'Step4_BackOfFingers', 'Step5_ThumbRub', 'Step6_Fingertips']
KAGGLE_URL = 'https://github.com/atiselsts/data/raw/master/kaggle-dataset-6classes.tar'
KAGGLE_CLASS_MAPPING = {'0': -1, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, 'step1': 1, 'step2': 2, 'step3': 3, 'step4': 4, 'step5': 5, 'step6': 6, 'other': -1}
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15

np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
try:
    tf.random.set_seed(RANDOM_SEED)
except Exception:
    pass


In [None]:
import sys
import subprocess


def _pip_install(pkg):
    return subprocess.call([sys.executable, "-m", "pip", "install", "-q", "--no-cache-dir", pkg])

try:
    import tensorflow_model_optimization as tfmot
except ModuleNotFoundError:
    for candidate in (
        "tensorflow-model-optimization==0.8.0",
        "tensorflow-model-optimization==0.7.5",
        "tensorflow-model-optimization",
    ):
        code = _pip_install(candidate)
        if code == 0:
            break
    import tensorflow_model_optimization as tfmot


In [None]:
import os, sys, json, time, math, random, shutil, subprocess
from pathlib import Path

RUN_NAME = os.environ.get("RUN_NAME", "handwash_qat_run")
KAGGLE_WORKING = Path("/kaggle/working")
if not KAGGLE_WORKING.exists():
    KAGGLE_WORKING = Path.cwd()

WORK_DIR = KAGGLE_WORKING / "handwash_runs" / RUN_NAME
DATA_DIR = KAGGLE_WORKING / "handwash_data"

RAW_DIR = DATA_DIR / "raw"
PROCESSED_DIR = DATA_DIR / "processed"
MODELS_DIR = WORK_DIR / "models"
CKPT_DIR = WORK_DIR / "checkpoints"
LOGS_DIR = WORK_DIR / "logs"
SAMPLES_DIR = WORK_DIR / "samples"

for p in [WORK_DIR, DATA_DIR, RAW_DIR, PROCESSED_DIR, MODELS_DIR, CKPT_DIR, LOGS_DIR, SAMPLES_DIR]:
    p.mkdir(parents=True, exist_ok=True)

print("WORK_DIR:", WORK_DIR)
print("DATA_DIR:", DATA_DIR)


## Configuration
All options are user editable.


In [None]:
# User config (edit these)
DATASETS = ["kaggle"]

# Training config
BATCH_SIZE = 32
EPOCHS_FLOAT = 2
EPOCHS_QAT = 50
LEARNING_RATE_FLOAT = 1e-4
LEARNING_RATE_QAT = 1e-5

# Quantization config
DISABLE_PER_CHANNEL = False

# Data sampling config
FRAME_STRIDE = 2
MAX_FRAMES_PER_VIDEO = 8

# Augmentation config
AUGMENT_TRAIN = True
AUGMENT_CONFIG = {
    "hflip": True,
    "rotation": 15,
    "zoom": 0.1,
    "shear": 0.1,
    "shift": 0.1,
    "brightness": (0.9, 1.1),
    "contrast": (0.9, 1.1),
    "gamma": (0.9, 1.1),
    "shadow": True,
}

# Mixed precision
MIXED_PRECISION = False


In [None]:
# Auto-tune and mixed precision
if MIXED_PRECISION:
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy("mixed_float16")
    print("Mixed precision enabled")

# Multi-GPU strategy
_gpus = tf.config.list_physical_devices('GPU')
if len(_gpus) > 1:
    STRATEGY = tf.distribute.MirroredStrategy()
else:
    STRATEGY = tf.distribute.get_strategy()
NUM_REPLICAS = STRATEGY.num_replicas_in_sync
print('Using strategy:', STRATEGY, 'replicas:', NUM_REPLICAS)


## Download and preprocess


In [None]:
import requests
from tqdm import tqdm
import tarfile
from IPython.display import Video, display

VIDEO_EXTS = (".mp4", ".avi", ".mov", ".mkv")

KAGGLE_DIR = RAW_DIR / "kaggle"
KAGGLE_TAR = KAGGLE_DIR / "kaggle-dataset-6classes.tar"
KAGGLE_EXTRACTED = KAGGLE_DIR / "kaggle-dataset-6classes"


def download_with_progress(url: str, dest: Path):
    dest.parent.mkdir(parents=True, exist_ok=True)
    if dest.exists():
        print("skip", dest)
        return
    with requests.get(url, stream=True, timeout=30) as r:
        r.raise_for_status()
        total = int(r.headers.get("content-length", 0))
        with open(dest, "wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=dest.name) as pbar:
            for chunk in r.iter_content(chunk_size=1024 * 1024):
                if not chunk:
                    continue
                f.write(chunk)
                pbar.update(len(chunk))


def extract_tar(tar_path: Path, extract_root: Path):
    extract_root.mkdir(parents=True, exist_ok=True)
    with tarfile.open(tar_path, "r") as tar:
        tar.extractall(path=extract_root)


def find_kaggle_input_dataset():
    root = Path("/kaggle/input")
    if not root.exists():
        return None
    for item in root.iterdir():
        if not item.is_dir():
            continue
        candidate = item / "kaggle-dataset-6classes"
        if candidate.exists():
            return candidate
        if all((item / str(i)).exists() for i in range(NUM_CLASSES)):
            return item
    return None


DATA_ROOT = find_kaggle_input_dataset()
if DATA_ROOT is not None:
    print("Using Kaggle input dataset:", DATA_ROOT)
else:
    print("Kaggle input dataset not found. Downloading...")
    if not KAGGLE_EXTRACTED.exists():
        download_with_progress(KAGGLE_URL, KAGGLE_TAR)
        extract_tar(KAGGLE_TAR, KAGGLE_DIR)
    DATA_ROOT = KAGGLE_EXTRACTED

print("Dataset root:", DATA_ROOT)


In [None]:
from sklearn.model_selection import train_test_split


def kaggle_class_id_from_folder(name: str):
    name_lower = name.lower()
    if name_lower in KAGGLE_CLASS_MAPPING:
        class_id = KAGGLE_CLASS_MAPPING[name_lower]
        return class_id if class_id and class_id > 0 else None
    digits = "".join(ch for ch in name_lower if ch.isdigit())
    if digits:
        class_id = int(digits)
        if class_id > 0 and class_id < len(CLASS_NAMES):
            return class_id
    return None


def collect_videos(dataset_root: Path) -> pd.DataFrame:
    rows = []
    skipped = 0
    for class_dir in sorted(dataset_root.iterdir()):
        if not class_dir.is_dir():
            continue
        class_id = kaggle_class_id_from_folder(class_dir.name)
        if class_id is None:
            skipped += 1
            continue
        for vid in class_dir.iterdir():
            if vid.suffix.lower() not in VIDEO_EXTS:
                continue
            rows.append({"video_path": str(vid), "class_id": class_id})
    df = pd.DataFrame(rows)
    if df.empty:
        raise RuntimeError(f"No videos found in {dataset_root}")
    print("Skipped class folders:", skipped)
    return df


videos_df = collect_videos(DATA_ROOT)
print("Total videos:", len(videos_df))

train_df, temp_df = train_test_split(
    videos_df,
    test_size=(VAL_RATIO + TEST_RATIO),
    stratify=videos_df["class_id"],
    random_state=RANDOM_SEED,
)
val_size = VAL_RATIO / (VAL_RATIO + TEST_RATIO)
val_df, test_df = train_test_split(
    temp_df,
    test_size=(1.0 - val_size),
    stratify=temp_df["class_id"],
    random_state=RANDOM_SEED,
)

print("Train:", len(train_df), "Val:", len(val_df), "Test:", len(test_df))


In [None]:
from collections import Counter


def print_dataset_info(df, title):
    counts = Counter(df["class_id"].tolist())
    total = sum(counts.values())
    print(f"{title} total videos: {total}")
    for class_id in sorted(counts):
        name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else str(class_id)
        pct = counts[class_id] / total * 100 if total else 0
        print(f"  {class_id} {name}: {counts[class_id]} ({pct:.2f}%)")


def plot_class_distribution(df, title):
    counts = df["class_id"].value_counts().sort_index()
    labels = [CLASS_NAMES[i] if i < len(CLASS_NAMES) else str(i) for i in counts.index]
    plt.figure(figsize=(8, 4))
    plt.bar(labels, counts.values)
    plt.title(title)
    plt.xlabel("Class")
    plt.ylabel("Count")
    plt.xticks(rotation=30, ha="right")
    plt.tight_layout()
    plt.show()


print_dataset_info(train_df, "Train")
print_dataset_info(val_df, "Val")
print_dataset_info(test_df, "Test")

plot_class_distribution(train_df, "Train Class Distribution")
plot_class_distribution(val_df, "Val Class Distribution")
plot_class_distribution(test_df, "Test Class Distribution")


In [None]:
import math
import random
from IPython.display import clear_output


def random_shadow(img):
    h, w = img.shape[:2]
    x1, y1 = np.random.randint(0, w), 0
    x2, y2 = np.random.randint(0, w), h
    mask = np.zeros((h, w), dtype=np.uint8)
    cv2.fillPoly(mask, [np.array([[x1, y1], [x2, y2], [0, h], [w, h]])], 255)
    shadow = np.stack([mask] * 3, axis=-1)
    alpha = np.random.uniform(0.5, 0.9)
    return np.where(shadow > 0, (img * alpha).astype(np.uint8), img)


def sample_aug_params():
    hflip_enabled = any(
        AUGMENT_CONFIG.get(k, False)
        for k in ("hflip", "mid_flip", "horizontal_flip")
    )
    params = {
        "hflip": hflip_enabled and random.random() < 0.5,
        "angle": 0.0,
        "zoom": 1.0,
        "shear": 0.0,
        "tx": 0,
        "ty": 0,
        "brightness": None,
        "contrast": None,
        "gamma": None,
        "shadow": False,
    }

    if AUGMENT_CONFIG.get("rotation", 0) > 0:
        params["angle"] = random.uniform(-AUGMENT_CONFIG["rotation"], AUGMENT_CONFIG["rotation"])

    if AUGMENT_CONFIG.get("zoom", 0) > 0:
        params["zoom"] = random.uniform(1 - AUGMENT_CONFIG["zoom"], 1 + AUGMENT_CONFIG["zoom"])

    if AUGMENT_CONFIG.get("shear", 0) > 0:
        params["shear"] = random.uniform(-AUGMENT_CONFIG["shear"], AUGMENT_CONFIG["shear"])

    if AUGMENT_CONFIG.get("shift", 0) > 0:
        params["tx"] = int(random.uniform(-AUGMENT_CONFIG["shift"], AUGMENT_CONFIG["shift"]) * IMG_SIZE[0])
        params["ty"] = int(random.uniform(-AUGMENT_CONFIG["shift"], AUGMENT_CONFIG["shift"]) * IMG_SIZE[1])

    if AUGMENT_CONFIG.get("brightness"):
        params["brightness"] = random.uniform(*AUGMENT_CONFIG["brightness"])

    if AUGMENT_CONFIG.get("contrast"):
        params["contrast"] = random.uniform(*AUGMENT_CONFIG["contrast"])

    if AUGMENT_CONFIG.get("gamma"):
        params["gamma"] = random.uniform(*AUGMENT_CONFIG["gamma"])

    if AUGMENT_CONFIG.get("shadow") and random.random() < 0.5:
        params["shadow"] = True

    return params


def apply_aug(img, params):
    if params.get("hflip"):
        img = cv2.flip(img, 1)

    angle = params.get("angle", 0.0)
    if angle:
        M = cv2.getRotationMatrix2D((IMG_SIZE[0] / 2, IMG_SIZE[1] / 2), angle, 1.0)
        img = cv2.warpAffine(img, M, IMG_SIZE, borderMode=cv2.BORDER_REFLECT)

    zoom = params.get("zoom", 1.0)
    if zoom != 1.0:
        h, w = IMG_SIZE
        img_resized = cv2.resize(img, (int(w * zoom), int(h * zoom)))
        if zoom > 1:
            startx = (img_resized.shape[1] - w) // 2
            starty = (img_resized.shape[0] - h) // 2
            img = img_resized[starty : starty + h, startx : startx + w]
        else:
            pad_x = (w - img_resized.shape[1]) // 2
            pad_y = (h - img_resized.shape[0]) // 2
            img = cv2.copyMakeBorder(
                img_resized,
                pad_y,
                h - img_resized.shape[0] - pad_y,
                pad_x,
                w - img_resized.shape[1] - pad_x,
                cv2.BORDER_REFLECT,
            )

    tx, ty = params.get("tx", 0), params.get("ty", 0)
    if tx or ty:
        M = np.float32([[1, 0, tx], [0, 1, ty]])
        img = cv2.warpAffine(img, M, IMG_SIZE, borderMode=cv2.BORDER_REFLECT)

    shear = params.get("shear", 0.0)
    if shear:
        M = np.float32([[1, shear, 0], [0, 1, 0]])
        img = cv2.warpAffine(img, M, IMG_SIZE, borderMode=cv2.BORDER_REFLECT)

    brightness = params.get("brightness")
    if brightness is not None:
        img = np.clip(img.astype(np.float32) * brightness, 0, 255).astype(np.uint8)

    contrast = params.get("contrast")
    if contrast is not None:
        img = np.clip(128 + contrast * (img.astype(np.float32) - 128), 0, 255).astype(np.uint8)

    gamma = params.get("gamma")
    if gamma is not None:
        img = np.clip(((img / 255.0) ** gamma) * 255.0, 0, 255).astype(np.uint8)

    if params.get("shadow"):
        img = random_shadow(img)

    return img


In [None]:
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_v2_preprocess


def _load_random_frame_py(video_path_bytes, label, augment):
    if hasattr(video_path_bytes, "numpy"):
        video_path_bytes = video_path_bytes.numpy()
    if hasattr(label, "numpy"):
        label = int(label.numpy())
    if hasattr(augment, "numpy"):
        augment = bool(augment.numpy())
    video_path = video_path_bytes.decode("utf-8")
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if frame_count <= 0:
        cap.release()
        raise RuntimeError(f"No frames in {video_path}")
    target_idx = np.random.randint(0, frame_count)
    cap.set(cv2.CAP_PROP_POS_FRAMES, target_idx)
    ok, frame = cap.read()
    cap.release()
    if not ok:
        raise RuntimeError(f"Failed reading frame from {video_path}")
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = cv2.resize(frame, IMG_SIZE)
    if augment:
        params = sample_aug_params()
        frame = apply_aug(frame, params)
    frame = frame.astype(np.float32)
    frame = mobilenet_v2_preprocess(frame)
    return frame, np.int32(label)


def _load_random_frame(video_path, label, augment):
    frame, label = tf.py_function(
        _load_random_frame_py,
        inp=[video_path, label, augment],
        Tout=[tf.float32, tf.int32],
    )
    frame.set_shape((*IMG_SIZE, 3))
    label.set_shape(())
    return frame, label


def make_dataset(df, batch_size=32, shuffle=True, augment=False):
    ds = tf.data.Dataset.from_tensor_slices((df["video_path"].values, df["class_id"].values))
    if shuffle:
        ds = ds.shuffle(len(df), seed=RANDOM_SEED, reshuffle_each_iteration=True)
    aug_tensor = tf.constant(augment)
    ds = ds.map(lambda p, y: _load_random_frame(p, y, aug_tensor), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds


train_ds = make_dataset(train_df, batch_size=BATCH_SIZE, shuffle=True, augment=AUGMENT_TRAIN)
val_ds = make_dataset(val_df, batch_size=BATCH_SIZE, shuffle=False, augment=False)
test_ds = make_dataset(test_df, batch_size=BATCH_SIZE, shuffle=False, augment=False)


## Build QAT model


In [None]:
from tensorflow.keras import layers, models


def build_float_model():
    base = tf.keras.applications.MobileNetV2(
        input_shape=(*IMG_SIZE, 3),
        include_top=False,
        weights="imagenet",
    )
    base.trainable = True
    inputs = layers.Input(shape=(*IMG_SIZE, 3))
    x = base(inputs, training=True)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
    model = models.Model(inputs, outputs)
    return model


with STRATEGY.scope():
    float_model = build_float_model()
    float_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_FLOAT),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=["accuracy"],
    )

float_model.summary()


## Warmup (float) and QAT fine-tune


In [None]:
history_float = float_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS_FLOAT,
)

with STRATEGY.scope():
    quantize_model = tfmot.quantization.keras.quantize_model
    qat_model = quantize_model(float_model)
    qat_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_QAT),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=["accuracy"],
    )

qat_model.summary()

history_qat = qat_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS_QAT,
)


## Evaluate


In [None]:
qat_eval = qat_model.evaluate(test_ds, verbose=1)
print("QAT test metrics:", dict(zip(qat_model.metrics_names, qat_eval)))


## Sample inference on test videos


In [None]:
import numpy as np


def iter_video_frames(video_path: str, stride: int = 2, max_frames=24):
    cap = cv2.VideoCapture(video_path)
    frames = []
    idx = 0
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        if idx % stride == 0:
            frames.append(frame)
            if max_frames is not None and len(frames) >= max_frames:
                break
        idx += 1
    cap.release()
    return frames


def run_sample_inference(model, df, num_videos=3, frame_stride=2, max_frames=12):
    sample_df = df.sample(min(num_videos, len(df)), random_state=RANDOM_SEED)
    for row in sample_df.itertuples(index=False):
        frames = iter_video_frames(row.video_path, stride=frame_stride, max_frames=max_frames)
        if not frames:
            continue
        inputs = []
        for frame in frames:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame_rgb = cv2.resize(frame_rgb, IMG_SIZE)
            frame_rgb = frame_rgb.astype(np.float32)
            frame_rgb = tf.keras.applications.mobilenet_v2.preprocess_input(frame_rgb)
            inputs.append(frame_rgb)
        inputs = np.stack(inputs)
        preds = model.predict(inputs, verbose=0)
        pred_ids = np.argmax(preds, axis=1)
        majority = int(np.bincount(pred_ids).argmax())

        label_name = CLASS_NAMES[row.class_id]
        pred_name = CLASS_NAMES[majority]
        print(f"Video: {Path(row.video_path).name} | GT: {row.class_id} {label_name} | Pred: {majority} {pred_name}")

        frame0 = frames[0].copy()
        cv2.putText(frame0, f"GT: {label_name}", (10, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
        cv2.putText(frame0, f"Pred: {pred_name}", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
        out_path = SAMPLES_DIR / f"sample_{Path(row.video_path).stem}.jpg"
        cv2.imwrite(str(out_path), frame0)

        plt.figure(figsize=(4, 3))
        plt.imshow(cv2.cvtColor(frame0, cv2.COLOR_BGR2RGB))
        plt.title(Path(row.video_path).name)
        plt.axis("off")
        plt.show()


run_sample_inference(qat_model, test_df)


## Export INT8 TFLite


In [None]:
from typing import Iterable

qat_model_path = MODELS_DIR / "mobilenetv2_qat.keras"
qat_model.save(qat_model_path)
print("Saved QAT model:", qat_model_path)


def representative_dataset_from_ds(ds, max_batches=10) -> Iterable[list[np.ndarray]]:
    count = 0
    for batch, _ in ds:
        batch_np = batch.numpy().astype(np.float32)
        for i in range(batch_np.shape[0]):
            yield [batch_np[i : i + 1]]
        count += 1
        if count >= max_batches:
            break


tflite_path = MODELS_DIR / "mobilenetv2_qat_int8.tflite"
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = lambda: representative_dataset_from_ds(train_ds)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
if DISABLE_PER_CHANNEL:
    converter._experimental_disable_per_channel = True

try:
    tflite_model = converter.convert()
    tflite_path.write_bytes(tflite_model)
    print("Saved TFLite:", tflite_path)
except Exception as exc:
    print("INT8 conversion failed:", exc)


In [None]:
if tflite_path.exists():
    interpreter = tf.lite.Interpreter(model_path=str(tflite_path))
    interpreter.allocate_tensors()
    input_detail = interpreter.get_input_details()[0]
    output_detail = interpreter.get_output_details()[0]
    print("Input dtype:", input_detail["dtype"], "quant:", input_detail.get("quantization"))
    print("Output dtype:", output_detail["dtype"], "quant:", output_detail.get("quantization"))
