# Handwash quantization + evaluation (MobileNetV2 + LSTM)

This notebook:
- loads trained MobileNetV2 and LSTM/GRU handwash models
- exports post-training quantized TFLite models
- evaluates float vs quantized models on Kaggle WHO6
- saves TFLite outputs
- summarizes Axis Model Zoo compatibility notes


In [None]:
%pip install -q numpy pandas opencv-python tensorflow matplotlib scikit-learn tqdm requests seaborn


In [None]:
import os
import sys
import math
import time
import random
import tarfile
from pathlib import Path
import importlib.util

import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import requests


def find_repo_root(start=None):
    start = Path.cwd() if start is None else Path(start)
    for parent in [start] + list(start.parents):
        if (parent / "inference" / "config.py").exists() or (parent / "training" / "config.py").exists():
            return parent
    return start


def _load_module(path: Path, name: str):
    spec = importlib.util.spec_from_file_location(name, path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot load module from {path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


REPO_ROOT = find_repo_root()
INFERENCE_DIR = REPO_ROOT / "inference"
TRAINING_DIR = REPO_ROOT / "training"

if (INFERENCE_DIR / "config.py").exists():
    cfg = _load_module(INFERENCE_DIR / "config.py", "cfg")
elif (TRAINING_DIR / "config.py").exists():
    cfg = _load_module(TRAINING_DIR / "config.py", "cfg")
else:
    raise FileNotFoundError("config.py not found in inference/ or training/")

np.random.seed(cfg.RANDOM_SEED)
random.seed(cfg.RANDOM_SEED)

def _seed_tf(seed):
    try:
        tf.random.set_seed(seed)
    except Exception:
        pass

_seed_tf(cfg.RANDOM_SEED)

RAW_DIR = REPO_ROOT / "datasets" / "raw"
PROCESSED_DIR = REPO_ROOT / "datasets" / "processed"
OUTPUT_DIR = INFERENCE_DIR / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print("Repo root:", REPO_ROOT)
print("Output dir:", OUTPUT_DIR)


## Download Kaggle WHO6 dataset (if needed)


In [None]:
KAGGLE_URL = cfg.DATASETS["kaggle"]["url"]
KAGGLE_DIR = RAW_DIR / "kaggle"
KAGGLE_TAR = KAGGLE_DIR / "kaggle-dataset-6classes.tar"
KAGGLE_EXTRACTED = KAGGLE_DIR / "kaggle-dataset-6classes"


def download_with_progress(url, dest: Path):
    dest.parent.mkdir(parents=True, exist_ok=True)
    if dest.exists():
        print("skip", dest)
        return
    with requests.get(url, stream=True, timeout=30) as r:
        r.raise_for_status()
        total = int(r.headers.get("content-length", 0))
        with open(dest, "wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=dest.name) as pbar:
            for chunk in r.iter_content(chunk_size=1024 * 1024):
                if not chunk:
                    continue
                f.write(chunk)
                pbar.update(len(chunk))


def extract_tar(tar_path: Path, extract_root: Path):
    extract_root.mkdir(parents=True, exist_ok=True)
    with tarfile.open(tar_path, "r") as tar:
        tar.extractall(path=extract_root)


if not KAGGLE_EXTRACTED.exists():
    print("Downloading Kaggle WHO6...")
    download_with_progress(KAGGLE_URL, KAGGLE_TAR)
    print("Extracting...")
    extract_tar(KAGGLE_TAR, KAGGLE_DIR)
else:
    print("Kaggle dataset already extracted:", KAGGLE_EXTRACTED)


## Index videos and create splits


In [None]:
VIDEO_EXTS = (".mp4", ".avi", ".mov", ".mkv")


def kaggle_class_id_from_folder(name: str) -> int:
    name_lower = name.lower()
    if name_lower in cfg.KAGGLE_CLASS_MAPPING:
        return int(cfg.KAGGLE_CLASS_MAPPING[name_lower])
    digits = "".join(ch for ch in name_lower if ch.isdigit())
    if digits:
        class_id = int(digits)
        if 0 <= class_id < len(cfg.CLASS_NAMES):
            return class_id
    raise ValueError(f"Unknown Kaggle class folder: {name}")


def collect_videos(dataset_root: Path) -> pd.DataFrame:
    rows = []
    for class_dir in sorted(dataset_root.iterdir()):
        if not class_dir.is_dir():
            continue
        class_id = kaggle_class_id_from_folder(class_dir.name)
        for vid in class_dir.iterdir():
            if vid.suffix.lower() not in VIDEO_EXTS:
                continue
            rows.append({"video_path": str(vid), "class_id": class_id})
    df = pd.DataFrame(rows)
    if df.empty:
        raise RuntimeError(f"No videos found in {dataset_root}")
    return df


videos_df = collect_videos(KAGGLE_EXTRACTED)
print("Total videos:", len(videos_df))

train_df, temp_df = train_test_split(
    videos_df,
    test_size=(cfg.VAL_RATIO + cfg.TEST_RATIO),
    stratify=videos_df["class_id"],
    random_state=cfg.RANDOM_SEED,
)
val_size = cfg.VAL_RATIO / (cfg.VAL_RATIO + cfg.TEST_RATIO)
val_df, test_df = train_test_split(
    temp_df,
    test_size=(1.0 - val_size),
    stratify=temp_df["class_id"],
    random_state=cfg.RANDOM_SEED,
)

print("Train:", len(train_df), "Val:", len(val_df), "Test:", len(test_df))


## Frame and sequence sampling helpers


In [None]:
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_v2_preprocess


class MobileNetPreprocessingLayer(tf.keras.layers.Layer):
    def call(self, x):
        return (x / 127.5) - 1.0


CUSTOM_OBJECTS = {
    "MobileNetPreprocessingLayer": MobileNetPreprocessingLayer,
    "preprocess_input": mobilenet_v2_preprocess,
}


def iter_video_frames(video_path: str, stride: int = 2, max_frames=None):
    cap = cv2.VideoCapture(video_path)
    frames = []
    idx = 0
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        if idx % stride == 0:
            frames.append(frame)
            if max_frames is not None and len(frames) >= max_frames:
                break
        idx += 1
    cap.release()
    return frames


def resize_and_preprocess(frame_bgr, size):
    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    frame_rgb = cv2.resize(frame_rgb, size)
    frame_rgb = frame_rgb.astype(np.float32)
    return mobilenet_v2_preprocess(frame_rgb)


def build_frame_dataset(df, size, frame_stride=2, max_frames_per_video=8, max_videos=None):
    xs, ys = [], []
    for i, row in enumerate(df.itertuples(index=False)):
        if max_videos is not None and i >= max_videos:
            break
        frames = iter_video_frames(row.video_path, stride=frame_stride, max_frames=max_frames_per_video)
        for frame in frames:
            xs.append(resize_and_preprocess(frame, size))
            ys.append(row.class_id)
    if not xs:
        raise RuntimeError("No frames collected for evaluation")
    return np.stack(xs), np.array(ys, dtype=np.int64)


def build_sequence_dataset(df, size, sequence_length=16, frame_stride=2, max_sequences_per_video=4, max_videos=None):
    xs, ys = [], []
    for i, row in enumerate(df.itertuples(index=False)):
        if max_videos is not None and i >= max_videos:
            break
        max_frames = sequence_length * max_sequences_per_video
        frames = iter_video_frames(row.video_path, stride=frame_stride, max_frames=max_frames)
        if len(frames) < sequence_length:
            continue
        seq_count = 0
        for start in range(0, len(frames) - sequence_length + 1, sequence_length):
            seq_frames = frames[start : start + sequence_length]
            seq = np.stack([resize_and_preprocess(f, size) for f in seq_frames])
            xs.append(seq)
            ys.append(row.class_id)
            seq_count += 1
            if seq_count >= max_sequences_per_video:
                break
    if not xs:
        raise RuntimeError("No sequences collected for evaluation")
    return np.stack(xs), np.array(ys, dtype=np.int64)


## MobileNetV2: load model + evaluate float


In [None]:
MOBILENET_MODEL_NAME = "mobilenetv2_final.keras"


def resolve_model_path(name: str) -> Path:
    candidates = [
        INFERENCE_DIR / name,
        REPO_ROOT / "models" / name,
        REPO_ROOT / name,
    ]
    runs_dir = REPO_ROOT / "Runs"
    if runs_dir.exists():
        candidates.extend(sorted(runs_dir.rglob(name), key=lambda p: p.stat().st_mtime, reverse=True))
    for cand in candidates:
        if cand.exists():
            return cand
    return Path(name)


mobilenet_path = resolve_model_path(MOBILENET_MODEL_NAME)
print("MobileNet model path:", mobilenet_path)

mobilenet_model = tf.keras.models.load_model(
    mobilenet_path,
    custom_objects=CUSTOM_OBJECTS,
    compile=False,
    safe_mode=False,
)

input_shape = mobilenet_model.input_shape
if isinstance(input_shape, list):
    input_shape = input_shape[0]

if input_shape[1] is None or input_shape[2] is None:
    input_size = cfg.IMG_SIZE
else:
    input_size = (input_shape[2], input_shape[1])
print("MobileNet input size:", input_size)


In [None]:
FRAME_STRIDE = 2
MAX_FRAMES_PER_VIDEO = 8
MAX_TEST_VIDEOS = 60

X_frames, y_frames = build_frame_dataset(
    test_df,
    size=input_size,
    frame_stride=FRAME_STRIDE,
    max_frames_per_video=MAX_FRAMES_PER_VIDEO,
    max_videos=MAX_TEST_VIDEOS,
)

print("Frame samples:", X_frames.shape)


In [None]:
from sklearn.metrics import accuracy_score


def evaluate_keras(model, xs, ys, batch_size=32):
    start = time.perf_counter()
    preds = model.predict(xs, batch_size=batch_size, verbose=0)
    elapsed = time.perf_counter() - start
    y_pred = np.argmax(preds, axis=1)
    acc = accuracy_score(ys, y_pred)
    ms_per = (elapsed / len(ys)) * 1000.0
    return {"accuracy": acc, "ms_per_sample": ms_per}


mobilenet_float_metrics = evaluate_keras(mobilenet_model, X_frames, y_frames)
print("MobileNet float metrics:", mobilenet_float_metrics)


## MobileNetV2: post-training int8 quantization

Set `DISABLE_PER_CHANNEL=True` for ARTPEC-8 style per-tensor quantization.


In [None]:
from typing import Iterable

DISABLE_PER_CHANNEL = False


def representative_dataset(xs, max_samples=200) -> Iterable[list[np.ndarray]]:
    count = min(len(xs), max_samples)
    for i in range(count):
        yield [xs[i : i + 1].astype(np.float32)]


def convert_int8_tflite(model, xs, output_path: Path, allow_fallback=False, disable_per_channel=False):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = lambda: representative_dataset(xs)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    if disable_per_channel:
        converter._experimental_disable_per_channel = True
    try:
        tflite_model = converter.convert()
        output_path.write_bytes(tflite_model)
        return output_path, "int8"
    except Exception as exc:
        if not allow_fallback:
            raise
        print("Int8 conversion failed, falling back to dynamic range:", exc)
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        tflite_model = converter.convert()
        output_path.write_bytes(tflite_model)
        return output_path, "dynamic"


mobilenet_tflite_path = OUTPUT_DIR / "mobilenetv2_int8.tflite"
mobile_path, mobile_quant_type = convert_int8_tflite(
    mobilenet_model,
    X_frames,
    mobilenet_tflite_path,
    allow_fallback=True,
    disable_per_channel=DISABLE_PER_CHANNEL,
)
print("Saved TFLite:", mobile_path, "quant:", mobile_quant_type)


In [None]:
import numpy as np


def _quantize_input(x, input_detail):
    dtype = input_detail["dtype"]
    if dtype == np.float32:
        return x.astype(np.float32)
    scale, zero_point = input_detail.get("quantization", (0.0, 0))
    if scale == 0:
        return x.astype(dtype)
    q = x / scale + zero_point
    if dtype == np.int8:
        q = np.clip(np.round(q), -128, 127).astype(np.int8)
    elif dtype == np.uint8:
        q = np.clip(np.round(q), 0, 255).astype(np.uint8)
    else:
        q = q.astype(dtype)
    return q


def _dequantize_output(y, output_detail):
    dtype = output_detail["dtype"]
    if dtype == np.float32:
        return y.astype(np.float32)
    scale, zero_point = output_detail.get("quantization", (0.0, 0))
    if scale == 0:
        return y.astype(np.float32)
    return (y.astype(np.float32) - zero_point) * scale


def evaluate_tflite(tflite_path: Path, xs, ys):
    interpreter = tf.lite.Interpreter(model_path=str(tflite_path))
    interpreter.allocate_tensors()
    input_detail = interpreter.get_input_details()[0]
    output_detail = interpreter.get_output_details()[0]

    preds = []
    start = time.perf_counter()
    for i in range(len(xs)):
        x = xs[i : i + 1]
        xq = _quantize_input(x, input_detail)
        interpreter.set_tensor(input_detail["index"], xq)
        interpreter.invoke()
        y = interpreter.get_tensor(output_detail["index"])
        y = _dequantize_output(y, output_detail)
        preds.append(y)
    elapsed = time.perf_counter() - start
    preds = np.concatenate(preds, axis=0)
    y_pred = np.argmax(preds, axis=1)
    acc = accuracy_score(ys, y_pred)
    ms_per = (elapsed / len(ys)) * 1000.0
    return {"accuracy": acc, "ms_per_sample": ms_per}


def summarize_tflite(path: Path):
    interpreter = tf.lite.Interpreter(model_path=str(path))
    interpreter.allocate_tensors()
    input_detail = interpreter.get_input_details()[0]
    output_detail = interpreter.get_output_details()[0]
    print("Input dtype:", input_detail["dtype"], "quant:", input_detail.get("quantization"))
    print("Output dtype:", output_detail["dtype"], "quant:", output_detail.get("quantization"))
    try:
        from tensorflow.lite.experimental import Analyzer
        print(Analyzer.analyze(model_path=str(path), show_details=False))
    except Exception as exc:
        print("Analyzer unavailable:", exc)


mobilenet_tflite_metrics = evaluate_tflite(mobilenet_tflite_path, X_frames, y_frames)
print("MobileNet TFLite metrics:", mobilenet_tflite_metrics)

summarize_tflite(mobilenet_tflite_path)


## Axis Model Zoo compatibility notes

Summary from Axis Model Zoo documentation and model lists:

- The Axis Model Zoo repository includes multiple quantized TFLite models (for example, MobileNetV2, SSD MobileNet, and QAT SSDLite MobileDet) targeting ARTPEC-7/8/9 chips, which indicates that INT8 TFLite is supported on those cameras.
- The README lists performance benchmarks per chip and links to TFLite models (many are quantized or QAT), implying the expected on-device format.
- The larod-client workflow described in the repo uses .tflite (and some .bin) files and selects chip targets such as A9-DLPU, A8-DLPU, A7-GPU, or CPU, which is how Axis validates model compatibility.

Links used:
- https://github.com/AxisCommunications/axis-model-zoo
- https://github.com/AxisCommunications/axis-model-zoo/blob/main/README.md
- https://github.com/AxisCommunications/axis-model-zoo/tree/main/scripts/auto-test-framework/larod-test

Practical implication for this project:
- An INT8 TFLite model composed of built-in ops has the best chance of running on Axis cameras.
- If conversion requires SELECT_TF_OPS, the model may not be compatible with the LAROD runtime on camera. Validate with larod-client on the target chip.


## LSTM/GRU: load model + evaluate float


In [None]:
LSTM_MODEL_NAME = "lstm_final.keras"

lstm_path = resolve_model_path(LSTM_MODEL_NAME)
print("LSTM model path:", lstm_path)

lstm_model = tf.keras.models.load_model(
    lstm_path,
    custom_objects=CUSTOM_OBJECTS,
    compile=False,
    safe_mode=False,
)

lstm_input_shape = lstm_model.input_shape
if isinstance(lstm_input_shape, list):
    lstm_input_shape = lstm_input_shape[0]

if len(lstm_input_shape) < 5:
    raise ValueError(f"Expected sequence input shape, got {lstm_input_shape}")

sequence_length = lstm_input_shape[1] or cfg.SEQUENCE_LENGTH
if sequence_length is None:
    sequence_length = cfg.SEQUENCE_LENGTH

if lstm_input_shape[2] is None or lstm_input_shape[3] is None:
    seq_input_size = cfg.IMG_SIZE
else:
    seq_input_size = (lstm_input_shape[3], lstm_input_shape[2])

print("LSTM sequence length:", sequence_length)
print("LSTM input size:", seq_input_size)


In [None]:
SEQ_FRAME_STRIDE = 2
MAX_SEQUENCES_PER_VIDEO = 3
MAX_TEST_VIDEOS_LSTM = 40

X_seq, y_seq = build_sequence_dataset(
    test_df,
    size=seq_input_size,
    sequence_length=sequence_length,
    frame_stride=SEQ_FRAME_STRIDE,
    max_sequences_per_video=MAX_SEQUENCES_PER_VIDEO,
    max_videos=MAX_TEST_VIDEOS_LSTM,
)

print("Sequence samples:", X_seq.shape)


In [None]:
lstm_float_metrics = evaluate_keras(lstm_model, X_seq, y_seq, batch_size=8)
print("LSTM float metrics:", lstm_float_metrics)


## LSTM/GRU: post-training quantization

Note: Sequence models often rely on ops that are not supported on the DLPU.
If full INT8 conversion fails, the notebook falls back to dynamic range quantization.


In [None]:
lstm_tflite_path = OUTPUT_DIR / "lstm_int8.tflite"

lstm_path_out, lstm_quant_type = convert_int8_tflite(
    lstm_model,
    X_seq,
    lstm_tflite_path,
    allow_fallback=True,
)
print("Saved LSTM TFLite:", lstm_path_out, "quant:", lstm_quant_type)


In [None]:
lstm_tflite_metrics = evaluate_tflite(lstm_tflite_path, X_seq, y_seq)
print("LSTM TFLite metrics:", lstm_tflite_metrics)

summarize_tflite(lstm_tflite_path)


## Comparison summary


In [None]:
summary = pd.DataFrame([
    {"model": "MobileNetV2", "type": "float", **mobilenet_float_metrics},
    {"model": "MobileNetV2", "type": f"tflite_{mobile_quant_type}", **mobilenet_tflite_metrics},
    {"model": "LSTM/GRU", "type": "float", **lstm_float_metrics},
    {"model": "LSTM/GRU", "type": f"tflite_{lstm_quant_type}", **lstm_tflite_metrics},
])
summary
