# Handwash inference + post-training scaling (Kaggle)


In [None]:
# Kaggle runtimes usually include the required packages.
# If something is missing, uncomment and run:
# !pip -q install opencv-python-headless tqdm requests


In [None]:
import os
import sys
import math
import random
from pathlib import Path
import importlib.util
from types import SimpleNamespace

import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from IPython.display import Video, display

IS_KAGGLE = bool(os.environ.get("KAGGLE_KERNEL_RUN_TYPE"))

def find_repo_root(start=None):
    start = Path.cwd() if start is None else Path(start)
    for parent in [start] + list(start.parents):
        if (parent / "inference" / "config.py").exists() or (parent / "training" / "config.py").exists():
            return parent
    if IS_KAGGLE:
        inputs_dir = Path("/kaggle/input")
        if inputs_dir.exists():
            for candidate in inputs_dir.iterdir():
                if (candidate / "inference" / "config.py").exists() or (candidate / "training" / "config.py").exists():
                    return candidate
    return start


def _load_module(path: Path, name: str):
    spec = importlib.util.spec_from_file_location(name, path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot load module from {path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


def load_cfg(repo_root: Path):
    inference_cfg = repo_root / "inference" / "config.py"
    training_cfg = repo_root / "training" / "config.py"
    if inference_cfg.exists():
        return _load_module(inference_cfg, "cfg")
    if training_cfg.exists():
        return _load_module(training_cfg, "cfg")
    # Fallback for standalone Kaggle notebook
    return SimpleNamespace(
        RANDOM_SEED=42,
        IMG_SIZE=(224, 224),
        VAL_RATIO=0.15,
        TEST_RATIO=0.15,
        ENABLE_SHADOW_AUG=True,
        AUGMENTATION_CONFIG={
            "rotation_range": 15,
            "width_shift_range": 0.1,
            "height_shift_range": 0.1,
            "shear_range": 0.1,
            "zoom_range": 0.1,
            "horizontal_flip": True,
            "mid_flip": True,
            "brightness_range": (0.9, 1.1),
            "reverse_sequence": True,
            "fill_mode": "nearest",
        },
        CLASS_NAMES=[
            "Other",
            "Step1_PalmToPalm",
            "Step2_PalmOverDorsum",
            "Step3_InterlacedFingers",
            "Step4_BackOfFingers",
            "Step5_ThumbRub",
            "Step6_Fingertips",
        ],
        KAGGLE_CLASS_MAPPING={
            "0": 0,
            "1": 1,
            "2": 2,
            "3": 3,
            "4": 4,
            "5": 5,
            "6": 6,
            "step1": 1,
            "step2": 2,
            "step3": 3,
            "step4": 4,
            "step5": 5,
            "step6": 6,
            "other": 0,
        },
        SEQUENCE_LENGTH=16,
        BATCH_SIZE=32,
    )


REPO_ROOT = find_repo_root()
cfg = load_cfg(REPO_ROOT)

np.random.seed(cfg.RANDOM_SEED)
random.seed(cfg.RANDOM_SEED)

WORK_DIR = Path("/kaggle/working") if IS_KAGGLE else REPO_ROOT
DATA_ROOT = WORK_DIR / "handwash_data"
RAW_DIR = DATA_ROOT / "raw"
PROCESSED_DIR = DATA_ROOT / "processed"
RAW_DIR.mkdir(parents=True, exist_ok=True)
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)

OUTPUT_DIR = WORK_DIR / "handwash_outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# User config
MODEL_NAME = "lstm_final.keras"  # change to "mobilenetv2_final.keras" if needed
MODEL_PATH = None  # set to explicit path if needed


def resolve_model_path(model_name, explicit_path=None):
    if explicit_path:
        path = Path(explicit_path)
        if path.exists():
            return path

    search_roots = [WORK_DIR, REPO_ROOT, REPO_ROOT / "models", REPO_ROOT / "Runs"]
    if IS_KAGGLE:
        search_roots.append(Path("/kaggle/input"))

    candidates = []
    for base in search_roots:
        if base.exists():
            candidates.extend(base.rglob(model_name))

    if candidates:
        candidates = sorted(candidates, key=lambda p: p.stat().st_mtime, reverse=True)
        return candidates[0]
    return Path(model_name)


MODEL_PATH = resolve_model_path(MODEL_NAME, MODEL_PATH)

print("Kaggle:", IS_KAGGLE)
print("Repo root:", REPO_ROOT)
print("Data root:", DATA_ROOT)
print("Output dir:", OUTPUT_DIR)
print("Model name:", MODEL_NAME)
print("Model path:", MODEL_PATH)
if not MODEL_PATH.exists():
    print("Warning: model path does not exist. Set MODEL_PATH or upload the model to /kaggle/input.")


In [None]:
def _sample_aug_params() -> dict:
    hflip_enabled = any(
        cfg.AUGMENTATION_CONFIG.get(key, False)
        for key in ("horizontal_flip", "mid_flip", "hflip")
    )
    params = {
        "hflip": hflip_enabled and np.random.rand() > 0.5,
        "angle": 0.0,
        "zoom": 1.0,
        "shear": 0.0,
        "tx": 0,
        "ty": 0,
        "brightness": None,
        "contrast": None,
        "gamma": None,
        "shadow": False,
        "reverse_sequence": cfg.AUGMENTATION_CONFIG.get("reverse_sequence", False) and np.random.rand() > 0.5,
    }

    if cfg.AUGMENTATION_CONFIG.get("rotation_range", 0) > 0:
        params["angle"] = np.random.uniform(
            -cfg.AUGMENTATION_CONFIG["rotation_range"],
            cfg.AUGMENTATION_CONFIG["rotation_range"],
        )

    if cfg.AUGMENTATION_CONFIG.get("zoom_range", 0) > 0:
        params["zoom"] = np.random.uniform(
            1 - cfg.AUGMENTATION_CONFIG["zoom_range"],
            1 + cfg.AUGMENTATION_CONFIG["zoom_range"],
        )

    if cfg.AUGMENTATION_CONFIG.get("shear_range", 0) > 0:
        params["shear"] = np.random.uniform(
            -cfg.AUGMENTATION_CONFIG["shear_range"],
            cfg.AUGMENTATION_CONFIG["shear_range"],
        )

    if cfg.AUGMENTATION_CONFIG.get("width_shift_range", 0) > 0 or cfg.AUGMENTATION_CONFIG.get("height_shift_range", 0) > 0:
        params["tx"] = int(np.random.uniform(
            -cfg.AUGMENTATION_CONFIG["width_shift_range"],
            cfg.AUGMENTATION_CONFIG["width_shift_range"],
        ) * cfg.IMG_SIZE[0])
        params["ty"] = int(np.random.uniform(
            -cfg.AUGMENTATION_CONFIG["height_shift_range"],
            cfg.AUGMENTATION_CONFIG["height_shift_range"],
        ) * cfg.IMG_SIZE[1])

    if "brightness_range" in cfg.AUGMENTATION_CONFIG:
        params["brightness"] = np.random.uniform(*cfg.AUGMENTATION_CONFIG["brightness_range"])

    if "contrast_range" in cfg.AUGMENTATION_CONFIG:
        params["contrast"] = np.random.uniform(*cfg.AUGMENTATION_CONFIG["contrast_range"])

    if "gamma_range" in cfg.AUGMENTATION_CONFIG:
        params["gamma"] = np.random.uniform(*cfg.AUGMENTATION_CONFIG["gamma_range"])

    if cfg.ENABLE_SHADOW_AUG and np.random.rand() < 0.5:
        params["shadow"] = True

    return params


def _apply_aug(img: np.ndarray, params: dict) -> np.ndarray:
    if params.get("hflip"):
        img = cv2.flip(img, 1)

    angle = params.get("angle", 0.0)
    if angle:
        h, w = img.shape[:2]
        M = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1.0)
        img = cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)

    zoom = params.get("zoom", 1.0)
    if zoom != 1.0:
        h, w = img.shape[:2]
        new_h, new_w = int(h * zoom), int(w * zoom)
        img_resized = cv2.resize(img, (new_w, new_h))
        if zoom > 1:
            start_y = (new_h - h) // 2
            start_x = (new_w - w) // 2
            img = img_resized[start_y:start_y + h, start_x:start_x + w]
        else:
            pad_h = (h - new_h) // 2
            pad_w = (w - new_w) // 2
            img = cv2.copyMakeBorder(
                img_resized,
                pad_h, h - new_h - pad_h,
                pad_w, w - new_w - pad_w,
                cv2.BORDER_REFLECT,
            )

    tx, ty = params.get("tx", 0), params.get("ty", 0)
    if tx or ty:
        h, w = img.shape[:2]
        M = np.float32([[1, 0, tx], [0, 1, ty]])
        img = cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)

    shear = params.get("shear", 0.0)
    if shear:
        h, w = img.shape[:2]
        M = np.float32([[1, shear, 0], [0, 1, 0]])
        img = cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)

    brightness = params.get("brightness")
    if brightness is not None:
        img = np.clip(img.astype(np.float32) * brightness, 0, 255).astype(np.uint8)

    contrast = params.get("contrast")
    if contrast is not None:
        img = np.clip(128 + contrast * (img.astype(np.float32) - 128), 0, 255).astype(np.uint8)

    gamma = params.get("gamma")
    if gamma is not None:
        img = np.clip(((img.astype(np.float32) / 255.0) ** gamma) * 255.0, 0, 255).astype(np.uint8)

    if params.get("shadow"):
        h, w = img.shape[:2]
        x1, y1 = np.random.randint(0, w), 0
        x2, y2 = np.random.randint(0, w), h
        mask = np.zeros((h, w), dtype=np.uint8)
        cv2.fillPoly(mask, [np.array([[x1, y1], [x2, y2], [0, h], [w, h]])], 255)
        shadow_intensity = np.random.uniform(0.5, 0.9)
        shadow = np.stack([mask] * 3, axis=-1)
        img = np.where(shadow > 0, (img * shadow_intensity).astype(np.uint8), img)

    return img


In [None]:
import tarfile
import requests
from tqdm import tqdm

KAGGLE_URL = "https://github.com/atiselsts/data/raw/master/kaggle-dataset-6classes.tar"
KAGGLE_DIR = RAW_DIR / "kaggle"
KAGGLE_EXTRACTED = KAGGLE_DIR / "kaggle-dataset-6classes"


def download_with_progress(url, dest):
    dest.parent.mkdir(parents=True, exist_ok=True)
    if dest.exists():
        print("skip", dest)
        return
    with requests.get(url, stream=True, timeout=60) as r:
        r.raise_for_status()
        total = int(r.headers.get("content-length", 0))
        with open(dest, "wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=dest.name) as pbar:
            for chunk in r.iter_content(chunk_size=1024 * 1024):
                if chunk:
                    f.write(chunk)
                    pbar.update(len(chunk))


def extract_tar(tar_path, out_dir):
    out_dir.mkdir(parents=True, exist_ok=True)
    with tarfile.open(tar_path) as tfp:
        tfp.extractall(out_dir)
    try:
        tar_path.unlink()
    except FileNotFoundError:
        pass


def _find_kaggle_input_dataset():
    if not IS_KAGGLE:
        return None
    input_root = Path("/kaggle/input")
    if not input_root.exists():
        return None
    for candidate in input_root.rglob("kaggle-dataset-6classes"):
        if candidate.is_dir():
            return candidate
    return None


def ensure_kaggle_dataset():
    input_candidate = _find_kaggle_input_dataset()
    if input_candidate:
        print("Using Kaggle input dataset:", input_candidate)
        return input_candidate
    if KAGGLE_EXTRACTED.exists() and any(KAGGLE_EXTRACTED.iterdir()):
        print("Kaggle dataset already present:", KAGGLE_EXTRACTED)
        return KAGGLE_EXTRACTED
    KAGGLE_DIR.mkdir(parents=True, exist_ok=True)
    tar_path = KAGGLE_DIR / "kaggle-dataset-6classes.tar"
    download_with_progress(KAGGLE_URL, tar_path)
    print("Extracting kaggle dataset...")
    extract_tar(tar_path, KAGGLE_DIR)
    return KAGGLE_EXTRACTED


kaggle_root = ensure_kaggle_dataset()


In [None]:
VIDEO_EXTS = (".mp4", ".avi", ".mov", ".mkv")

def kaggle_class_id_from_folder(name):
    name_lower = name.lower()
    if name_lower in cfg.KAGGLE_CLASS_MAPPING:
        return int(cfg.KAGGLE_CLASS_MAPPING[name_lower])
    digits = "".join(ch for ch in name_lower if ch.isdigit())
    if digits:
        class_id = int(digits)
        if 0 <= class_id < len(cfg.CLASS_NAMES):
            return class_id
    return 0

def scan_kaggle_videos(root):
    records = []
    for class_dir in sorted(root.iterdir()):
        if not class_dir.is_dir():
            continue
        class_id = kaggle_class_id_from_folder(class_dir.name)
        class_name = cfg.CLASS_NAMES[class_id]
        for video_path in sorted(class_dir.iterdir()):
            if video_path.suffix.lower() in VIDEO_EXTS:
                records.append({
                    "video_path": str(video_path),
                    "class_id": class_id,
                    "class_name": class_name,
                })
    return pd.DataFrame(records)

videos_df = scan_kaggle_videos(kaggle_root)
if videos_df.empty:
    raise RuntimeError(f"No videos found under {kaggle_root}")
display(videos_df.head())


In [None]:
train_df, temp_df = train_test_split(
    videos_df,
    test_size=(cfg.VAL_RATIO + cfg.TEST_RATIO),
    stratify=videos_df["class_id"],
    random_state=cfg.RANDOM_SEED,
)
val_size = cfg.VAL_RATIO / (cfg.VAL_RATIO + cfg.TEST_RATIO)
val_df, test_df = train_test_split(
    temp_df,
    test_size=(1.0 - val_size),
    stratify=temp_df["class_id"],
    random_state=cfg.RANDOM_SEED,
)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

print("Videos:", len(videos_df))
print("Train:", len(train_df), "Val:", len(val_df), "Test:", len(test_df))
display(videos_df.head())


In [None]:
def plot_class_distribution(df, title):
    order = list(range(len(cfg.CLASS_NAMES)))
    counts = df.groupby("class_id").size().reindex(order, fill_value=0)
    labels = [cfg.CLASS_NAMES[i] for i in order]
    plt.figure(figsize=(10, 4))
    plt.bar(labels, counts.values)
    plt.title(title)
    plt.xlabel("Class")
    plt.ylabel("Count")
    plt.xticks(rotation=30, ha="right")
    plt.tight_layout()
    plt.show()

plot_class_distribution(videos_df, "Kaggle WHO6 - All Videos")
plot_class_distribution(train_df, "Kaggle WHO6 - Train Videos")
plot_class_distribution(val_df, "Kaggle WHO6 - Val Videos")
plot_class_distribution(test_df, "Kaggle WHO6 - Test Videos")


In [None]:
def show_videos_by_class(df, max_per_class=1):
    for class_id in sorted(df["class_id"].unique().tolist()):
        subset = df[df["class_id"] == class_id]
        if subset.empty:
            continue
        label = cfg.CLASS_NAMES[int(class_id)]
        samples = subset.sample(min(max_per_class, len(subset)), random_state=cfg.RANDOM_SEED)
        for row in samples.itertuples():
            print(f"Class {class_id} | {label} | {row.video_path}")
            display(Video(row.video_path, embed=True, width=320))

show_videos_by_class(train_df, max_per_class=1)
show_videos_by_class(test_df, max_per_class=1)


In [None]:
def show_sample_videos(df, n=2):
    sample = df.sample(min(n, len(df)), random_state=cfg.RANDOM_SEED)
    for row in sample.itertuples():
        print(f"{row.class_name} | {row.video_path}")
        display(Video(row.video_path, embed=True, width=320))


show_sample_videos(train_df, n=2)


In [None]:
def sample_frames_from_video(video_path, num_frames=8, frame_stride=10):
    cap = cv2.VideoCapture(str(video_path))
    frames = []
    idx = 0
    while len(frames) < num_frames:
        ret, frame = cap.read()
        if not ret:
            break
        if idx % frame_stride == 0:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame_rgb)
        idx += 1
    cap.release()
    return frames


def show_augmented_frames(video_path, num_frames=8, frame_stride=10, consistent=True):
    frames = sample_frames_from_video(video_path, num_frames=num_frames, frame_stride=frame_stride)
    if not frames:
        print("No frames for", video_path)
        return
    params = _sample_aug_params() if consistent else None
    cols = 4
    rows = int(math.ceil(len(frames) / cols))
    plt.figure(figsize=(cols * 3, rows * 3))
    for i, frame in enumerate(frames, 1):
        if not consistent:
            params = _sample_aug_params()
        aug = _apply_aug(frame, params) if params else frame
        plt.subplot(rows, cols, i)
        plt.imshow(aug)
        plt.axis("off")
    plt.suptitle("Augmented frames")
    plt.tight_layout()
    plt.show()


sample_video = train_df.iloc[0]["video_path"]
show_augmented_frames(sample_video, num_frames=8, frame_stride=10, consistent=True)


In [None]:
custom_objects = {}

try:
    from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_v2_preprocess
    custom_objects["preprocess_input"] = mobilenet_v2_preprocess
except Exception:
    pass


def load_model_safe(path, custom_objects):
    kwargs = dict(custom_objects=custom_objects, compile=False)
    try:
        return tf.keras.models.load_model(path, safe_mode=False, **kwargs)
    except TypeError as exc:
        if "safe_mode" in str(exc):
            return tf.keras.models.load_model(path, **kwargs)
        raise


try:
    model = load_model_safe(MODEL_PATH, custom_objects)
except (ValueError, TypeError) as exc:
    if "preprocess_input" in str(exc) and "preprocess_input" not in custom_objects:
        from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_v2_preprocess
        custom_objects["preprocess_input"] = mobilenet_v2_preprocess
        model = load_model_safe(MODEL_PATH, custom_objects)
    else:
        raise

model.summary()


In [None]:
IMG_SIZE = tuple(cfg.IMG_SIZE)
SEQUENCE_LENGTH = int(getattr(cfg, "SEQUENCE_LENGTH", 16))
SEQUENCE_STRIDE = 1
SEQUENCE_BATCH_SIZE = max(1, int(getattr(cfg, "BATCH_SIZE", 32) // 4))
CLASS_NAMES = list(cfg.CLASS_NAMES)


def preprocess_frame(frame_rgb, img_size=IMG_SIZE):
    resized = cv2.resize(frame_rgb, img_size)
    return resized.astype(np.float32) / 255.0


def collect_video_frames(
    video_path,
    frame_stride=1,
    max_frames=None,
    augment=False,
    consistent_aug=True,
    augment_params=None,
):
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        raise RuntimeError(f"Failed to open video: {video_path}")

    fps = cap.get(cv2.CAP_PROP_FPS)
    if fps is None or fps <= 0:
        fps = 30.0

    inputs = []
    frames_bgr = []
    timestamps = []
    idx = 0
    apply_aug = augment or augment_params is not None
    if augment_params is not None:
        aug_params = augment_params
    elif augment and consistent_aug:
        aug_params = _sample_aug_params()
    else:
        aug_params = None

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if idx % frame_stride != 0:
            idx += 1
            continue

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        if apply_aug:
            params = aug_params if (augment_params is not None or consistent_aug) else _sample_aug_params()
            frame_rgb = _apply_aug(frame_rgb, params)

        inputs.append(preprocess_frame(frame_rgb))
        frames_bgr.append(cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR))
        timestamps.append(idx / fps)

        idx += 1
        if max_frames and len(inputs) >= max_frames:
            break

    cap.release()
    if inputs:
        inputs_arr = np.stack(inputs, axis=0)
    else:
        inputs_arr = np.zeros((0, *IMG_SIZE, 3), dtype=np.float32)
    return float(fps), inputs_arr, frames_bgr, np.array(timestamps, dtype=np.float32)


In [None]:
def predict_batches(model, inputs, batch_size=32):
    if len(inputs) == 0:
        return np.zeros((0, len(CLASS_NAMES)), dtype=np.float32)
    preds = []
    for i in range(0, len(inputs), batch_size):
        batch = inputs[i : i + batch_size]
        preds.append(model.predict(batch, verbose=0))
    return np.concatenate(preds, axis=0)



def _get_model_input_shape(model):
    shape = model.input_shape
    # Keras returns a tuple for single-input models and a list for multi-input.
    if isinstance(shape, (list, tuple)) and shape and isinstance(shape[0], (list, tuple)):
        shape = shape[0]
    return shape


def is_sequence_model(model):
    shape = _get_model_input_shape(model)
    return shape is not None and len(shape) == 5


def get_sequence_length(model, fallback=SEQUENCE_LENGTH):
    shape = _get_model_input_shape(model)
    if shape is not None and len(shape) >= 2 and shape[1]:
        return int(shape[1])
    return int(fallback)


def predict_sequence_probs(model, inputs, seq_len=None, stride=1, batch_size=8):
    if len(inputs) == 0:
        return np.zeros((0, len(CLASS_NAMES)), dtype=np.float32)
    seq_len = int(seq_len or get_sequence_length(model))
    stride = max(1, int(stride))
    if len(inputs) < seq_len:
        pad_len = seq_len - len(inputs)
        pad = np.repeat(inputs[-1][None, ...], pad_len, axis=0)
        sequences = np.expand_dims(np.concatenate([inputs, pad], axis=0), axis=0)
        starts = [0]
    else:
        starts = list(range(0, len(inputs) - seq_len + 1, stride))
        sequences = np.stack([inputs[s:s + seq_len] for s in starts], axis=0)
    preds = []
    for i in range(0, len(sequences), batch_size):
        batch = sequences[i:i + batch_size]
        preds.append(model.predict(batch, verbose=0))
    preds = np.concatenate(preds, axis=0) if preds else np.zeros((0, len(CLASS_NAMES)), dtype=np.float32)
    probs = np.zeros((len(inputs), preds.shape[1] if len(preds) else len(CLASS_NAMES)), dtype=np.float32)
    counts = np.zeros((len(inputs), 1), dtype=np.float32)
    for pred, start in zip(preds, starts):
        end = min(start + seq_len, len(inputs))
        probs[start:end] += pred
        counts[start:end] += 1.0
    probs = probs / np.clip(counts, 1e-6, None)
    return probs


def predict_model_probs(model, inputs, batch_size=32, sequence_length=None, sequence_stride=1, sequence_batch_size=None):
    if is_sequence_model(model):
        seq_len = sequence_length or get_sequence_length(model)
        seq_bs = sequence_batch_size or batch_size
        return predict_sequence_probs(model, inputs, seq_len=seq_len, stride=sequence_stride, batch_size=seq_bs)
    return predict_batches(model, inputs, batch_size=batch_size)

def smooth_probs_moving_avg(probs, window=7):
    if window <= 1 or len(probs) == 0:
        return probs
    kernel = np.ones(window, dtype=np.float32) / float(window)
    smoothed = np.zeros_like(probs)
    for c in range(probs.shape[1]):
        smoothed[:, c] = np.convolve(probs[:, c], kernel, mode="same")
    row_sums = smoothed.sum(axis=1, keepdims=True)
    smoothed = smoothed / np.clip(row_sums, 1e-6, None)
    return smoothed


def smooth_preds_majority(pred_ids, window=7, num_classes=None):
    if window <= 1 or len(pred_ids) == 0:
        return pred_ids
    num_classes = num_classes or len(CLASS_NAMES)
    half = window // 2
    out = []
    for i in range(len(pred_ids)):
        start = max(0, i - half)
        end = min(len(pred_ids), i + half + 1)
        window_ids = pred_ids[start:end]
        counts = np.bincount(window_ids, minlength=num_classes)
        out.append(int(np.argmax(counts)))
    return np.array(out, dtype=np.int32)


def apply_smoothing(probs, method="moving_avg", window=7):
    if method == "none":
        pred_ids = np.argmax(probs, axis=1) if len(probs) else np.array([], dtype=np.int32)
        return probs, pred_ids
    if method == "moving_avg":
        smoothed = smooth_probs_moving_avg(probs, window=window)
        pred_ids = np.argmax(smoothed, axis=1) if len(smoothed) else np.array([], dtype=np.int32)
        return smoothed, pred_ids
    if method == "majority":
        pred_ids = np.argmax(probs, axis=1) if len(probs) else np.array([], dtype=np.int32)
        pred_ids = smooth_preds_majority(pred_ids, window=window, num_classes=probs.shape[1] if len(probs) else len(CLASS_NAMES))
        return probs, pred_ids
    raise ValueError("Unknown smoothing method: " + str(method))


In [None]:
from sklearn.metrics import accuracy_score, classification_report


def build_segments(class_ids, timestamps, fps, frame_stride, class_names):
    if len(class_ids) == 0:
        return []
    frame_dt = (frame_stride / fps) if fps else 0.0
    segments = []
    start = 0
    for i in range(1, len(class_ids)):
        if class_ids[i] != class_ids[i - 1]:
            end = i - 1
            segments.append({
                "class_id": int(class_ids[start]),
                "class_name": class_names[int(class_ids[start])],
                "start_frame": int(start),
                "end_frame": int(end),
                "start_s": float(timestamps[start]),
                "end_s": float(timestamps[end] + frame_dt),
                "duration_s": float((end - start + 1) * frame_dt),
            })
            start = i
    end = len(class_ids) - 1
    segments.append({
        "class_id": int(class_ids[start]),
        "class_name": class_names[int(class_ids[start])],
        "start_frame": int(start),
        "end_frame": int(end),
        "start_s": float(timestamps[start]),
        "end_s": float(timestamps[end] + frame_dt),
        "duration_s": float((end - start + 1) * frame_dt),
    })
    return segments


def summarize_durations(segments, class_names):
    totals = {name: 0.0 for name in class_names}
    for seg in segments:
        totals[seg["class_name"]] += seg["duration_s"]
    total_wash = sum(totals[name] for name in class_names if name.lower() != "other")
    return totals, total_wash


def build_timeline_df(timestamps, probs_raw, pred_ids, class_names, gt_ids=None, probs_smoothed=None):
    data = {
        "timestamp_s": timestamps,
        "pred_class_id": pred_ids,
        "pred_class_name": [class_names[i] for i in pred_ids],
    }
    if gt_ids is not None:
        data["gt_class_id"] = gt_ids
        data["gt_class_name"] = [class_names[i] for i in gt_ids]
    if probs_smoothed is None:
        probs_smoothed = probs_raw
    for i, name in enumerate(class_names):
        data[f"prob_raw_{name}"] = probs_raw[:, i] if len(probs_raw) else []
        data[f"prob_smooth_{name}"] = probs_smoothed[:, i] if len(probs_smoothed) else []
    return pd.DataFrame(data)


def plot_time_series(timestamps, gt_ids, pred_ids, class_names):
    fig, ax = plt.subplots(figsize=(12, 3))
    if gt_ids is not None and len(gt_ids):
        ax.step(timestamps, gt_ids, where="post", label="GT", linewidth=2)
    ax.step(timestamps, pred_ids, where="post", label="Pred", linewidth=2)
    ax.set_yticks(range(len(class_names)))
    ax.set_yticklabels([name.replace("Step", "") for name in class_names])
    ax.set_xlabel("Time (s)")
    ax.legend(loc="upper right")
    ax.grid(True, axis="x", alpha=0.3)
    plt.tight_layout()
    plt.show()


def compress_sequence(seq):
    out = []
    prev = None
    for s in seq:
        if prev is None or s != prev:
            out.append(int(s))
            prev = s
    return out


def edit_distance(seq_a, seq_b):
    m, n = len(seq_a), len(seq_b)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            cost = 0 if seq_a[i - 1] == seq_b[j - 1] else 1
            dp[i][j] = min(
                dp[i - 1][j] + 1,
                dp[i][j - 1] + 1,
                dp[i - 1][j - 1] + cost,
            )
    return dp[m][n]


def dtw_distance(seq_a, seq_b):
    m, n = len(seq_a), len(seq_b)
    if m == 0 or n == 0:
        return float("inf")
    dp = np.full((m + 1, n + 1), np.inf, dtype=np.float32)
    dp[0, 0] = 0.0
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            cost = 0.0 if seq_a[i - 1] == seq_b[j - 1] else 1.0
            dp[i, j] = cost + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1])
    return float(dp[m, n])


def temporal_iou(gt_ids, pred_ids, num_classes):
    ious = {}
    gt_ids = np.asarray(gt_ids)
    pred_ids = np.asarray(pred_ids)
    for c in range(num_classes):
        gt_mask = gt_ids == c
        pred_mask = pred_ids == c
        inter = np.sum(gt_mask & pred_mask)
        union = np.sum(gt_mask | pred_mask)
        ious[c] = inter / union if union > 0 else float("nan")
    return ious


def evaluate_alignment(gt_ids, pred_ids, class_names):
    if gt_ids is None or len(gt_ids) == 0:
        return {}
    metrics = {}
    metrics["frame_accuracy"] = float(accuracy_score(gt_ids, pred_ids))
    metrics["edit_distance"] = int(edit_distance(compress_sequence(gt_ids), compress_sequence(pred_ids)))
    metrics["dtw_distance"] = float(dtw_distance(gt_ids, pred_ids))
    ious = temporal_iou(gt_ids, pred_ids, len(class_names))
    metrics["temporal_iou_mean"] = float(np.nanmean(list(ious.values())))
    metrics["temporal_iou_by_class"] = {
        class_names[c]: (None if np.isnan(v) else float(v)) for c, v in ious.items()
    }
    metrics["classification_report"] = classification_report(
        gt_ids,
        pred_ids,
        labels=list(range(len(class_names))),
        target_names=class_names,
        output_dict=True,
        zero_division=0,
    )
    return metrics


In [None]:
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, top_k_accuracy_score


def _base_aug_params():
    return {
        "hflip": False,
        "angle": 0.0,
        "zoom": 1.0,
        "shear": 0.0,
        "tx": 0,
        "ty": 0,
        "brightness": None,
        "contrast": None,
        "gamma": None,
        "shadow": False,
        "reverse_sequence": False,
    }


def build_aug_variants():
    cfg_aug = cfg.AUGMENTATION_CONFIG
    variants = [("none", None)]
    if cfg_aug.get("rotation_range", 0) > 0:
        params = _base_aug_params()
        params["angle"] = float(cfg_aug["rotation_range"])
        variants.append(("rotation", params))
    if cfg_aug.get("zoom_range", 0) > 0:
        params = _base_aug_params()
        params["zoom"] = 1 + float(cfg_aug["zoom_range"])
        variants.append(("zoom", params))
    if cfg_aug.get("shear_range", 0) > 0:
        params = _base_aug_params()
        params["shear"] = float(cfg_aug["shear_range"])
        variants.append(("shear", params))
    if cfg_aug.get("width_shift_range", 0) > 0 or cfg_aug.get("height_shift_range", 0) > 0:
        params = _base_aug_params()
        params["tx"] = int(float(cfg_aug.get("width_shift_range", 0)) * IMG_SIZE[0])
        params["ty"] = int(float(cfg_aug.get("height_shift_range", 0)) * IMG_SIZE[1])
        variants.append(("shift", params))
    if cfg_aug.get("horizontal_flip", False) or cfg_aug.get("hflip", False) or cfg_aug.get("mid_flip", False):
        params = _base_aug_params()
        params["hflip"] = True
        variants.append(("hflip", params))
    if cfg_aug.get("brightness_range"):
        params = _base_aug_params()
        params["brightness"] = float(cfg_aug["brightness_range"][1])
        variants.append(("brightness", params))
    if cfg_aug.get("contrast_range"):
        params = _base_aug_params()
        params["contrast"] = float(cfg_aug["contrast_range"][1])
        variants.append(("contrast", params))
    if cfg_aug.get("gamma_range"):
        params = _base_aug_params()
        params["gamma"] = float(cfg_aug["gamma_range"][1])
        variants.append(("gamma", params))
    if cfg.ENABLE_SHADOW_AUG:
        params = _base_aug_params()
        params["shadow"] = True
        variants.append(("shadow", params))
    return variants


def _safe_name(text):
    return "".join(ch if ch.isalnum() else "_" for ch in text)


def plot_confusion_matrix(cm, class_names, title, out_path=None):
    plt.figure(figsize=(7, 6))
    plt.imshow(cm, interpolation="nearest", cmap="Blues")
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45, ha="right")
    plt.yticks(tick_marks, class_names)
    plt.ylabel("True")
    plt.xlabel("Pred")
    plt.tight_layout()
    if out_path:
        plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.show()


def compute_metrics(y_true, y_pred, y_prob):
    labels = list(range(len(CLASS_NAMES)))
    accuracy = accuracy_score(y_true, y_pred)
    precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(
        y_true, y_pred, average="weighted", zero_division=0, labels=labels
    )
    precision_m, recall_m, f1_m, _ = precision_recall_fscore_support(
        y_true, y_pred, average="macro", zero_division=0, labels=labels
    )
    top2 = top_k_accuracy_score(y_true, y_prob, k=2, labels=labels) if len(y_prob) else 0.0
    report = classification_report(
        y_true, y_pred, labels=labels, target_names=CLASS_NAMES, output_dict=True, zero_division=0
    )
    return {
        "accuracy": float(accuracy),
        "precision_weighted": float(precision_w),
        "recall_weighted": float(recall_w),
        "f1_weighted": float(f1_w),
        "precision_macro": float(precision_m),
        "recall_macro": float(recall_m),
        "f1_macro": float(f1_m),
        "top2_accuracy": float(top2),
        "report": report,
    }


def _save_misclassified(frames_bgr, gt_ids, pred_ids, out_root, split_name, aug_name, video_id):
    for idx, (gt, pred) in enumerate(zip(gt_ids, pred_ids)):
        if int(gt) == int(pred):
            continue
        gt_name = _safe_name(CLASS_NAMES[int(gt)])
        pred_name = _safe_name(CLASS_NAMES[int(pred)])
        folder = out_root / split_name / aug_name / f"gt_{gt}_{gt_name}" / f"pred_{pred}_{pred_name}" / _safe_name(video_id)
        folder.mkdir(parents=True, exist_ok=True)
        frame_path = folder / f"frame_{idx:06d}.jpg"
        cv2.imwrite(str(frame_path), frames_bgr[idx])


def evaluate_split(df, split_name, aug_name, aug_params, output_root, frame_stride=2, batch_size=32,
                   save_misclassified_flag=True, max_videos=None, max_frames=None):
    y_true = []
    y_pred = []
    y_prob = []
    subset = df
    if max_videos is not None:
        subset = df.sample(min(max_videos, len(df)), random_state=cfg.RANDOM_SEED)
    for row in subset.itertuples():
        fps, inputs, frames_bgr, timestamps = collect_video_frames(
            row.video_path,
            frame_stride=frame_stride,
            max_frames=max_frames,
            augment=aug_params is not None,
            consistent_aug=True,
            augment_params=aug_params,
        )
        if len(inputs) == 0:
            continue
        probs = predict_model_probs(
            model,
            inputs,
            batch_size=batch_size,
            sequence_length=SEQUENCE_LENGTH,
            sequence_stride=SEQUENCE_STRIDE,
            sequence_batch_size=SEQUENCE_BATCH_SIZE,
        )
        preds = np.argmax(probs, axis=1)
        gt_ids = np.full(len(preds), int(row.class_id), dtype=np.int32)
        y_true.extend(gt_ids.tolist())
        y_pred.extend(preds.tolist())
        y_prob.append(probs)
        if save_misclassified_flag:
            _save_misclassified(
                frames_bgr,
                gt_ids,
                preds,
                output_root / "misclassified",
                split_name,
                aug_name,
                Path(row.video_path).stem,
            )

    if not y_true:
        return None
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_prob = np.concatenate(y_prob, axis=0) if y_prob else np.zeros((0, len(CLASS_NAMES)))

    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(CLASS_NAMES))))
    metrics = compute_metrics(y_true, y_pred, y_prob)

    out_dir = output_root / "validation"
    out_dir.mkdir(parents=True, exist_ok=True)
    cm_path = out_dir / f"cm_{split_name}_{aug_name}.png"
    plot_confusion_matrix(cm, CLASS_NAMES, f"{split_name} | {aug_name}", out_path=cm_path)

    metrics_path = out_dir / f"metrics_{split_name}_{aug_name}.csv"
    pd.DataFrame([{
        "split": split_name,
        "augmentation": aug_name,
        **{k: v for k, v in metrics.items() if k != "report"},
    }]).to_csv(metrics_path, index=False)

    report_path = out_dir / f"report_{split_name}_{aug_name}.csv"
    pd.DataFrame(metrics["report"]).T.to_csv(report_path, index=True)

    return {
        "confusion_matrix": cm,
        "metrics": metrics,
        "metrics_path": metrics_path,
        "report_path": report_path,
    }


## Validation: Confusion Matrices + Misclassified Frames
Run this after the model is loaded. It will iterate all videos in each split and each augmentation variant.


In [None]:
VALIDATION_SPLITS = {
    "train": train_df,
    "val": val_df,
    "test": test_df,
}
AUG_VARIANTS = build_aug_variants()
SAVE_MISCLASSIFIED = True
FRAME_STRIDE_VALID = 2
BATCH_SIZE_VALID = 32
MAX_VIDEOS_PER_SPLIT = None  # set to int to limit
MAX_FRAMES_PER_VIDEO = None  # set to int to limit

validation_results = {}
for split_name, df in VALIDATION_SPLITS.items():
    if df.empty:
        print(f"Skipping {split_name}: no videos")
        continue
    for aug_name, aug_params in AUG_VARIANTS:
        print(f"Evaluating {split_name} | {aug_name}")
        res = evaluate_split(
            df,
            split_name=split_name,
            aug_name=aug_name,
            aug_params=aug_params,
            output_root=OUTPUT_DIR,
            frame_stride=FRAME_STRIDE_VALID,
            batch_size=BATCH_SIZE_VALID,
            save_misclassified_flag=SAVE_MISCLASSIFIED,
            max_videos=MAX_VIDEOS_PER_SPLIT,
            max_frames=MAX_FRAMES_PER_VIDEO,
        )
        if res is not None:
            validation_results[(split_name, aug_name)] = res


In [None]:
def draw_text_block(
    frame,
    lines,
    origin=(10, 30),
    line_height=22,
    font_scale=0.6,
    thickness=1,
    text_color=(255, 255, 255),
    bg_color=(0, 0, 0),
):
    x, y = origin
    font = cv2.FONT_HERSHEY_SIMPLEX
    for line in lines:
        if not line:
            y += line_height
            continue
        (text_w, text_h), baseline = cv2.getTextSize(line, font, font_scale, thickness)
        cv2.rectangle(
            frame,
            (x - 4, y - text_h - 4),
            (x + text_w + 4, y + baseline + 4),
            bg_color,
            -1,
        )
        cv2.putText(frame, line, (x, y), font, font_scale, text_color, thickness, cv2.LINE_AA)
        y += line_height


def write_annotated_video(
    frames_bgr,
    fps_out,
    timestamps,
    pred_ids,
    probs,
    class_names,
    gt_ids=None,
    summary=None,
    out_path=None,
    summary_seconds=3,
):
    if not frames_bgr:
        return None
    h, w = frames_bgr[0].shape[:2]
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(str(out_path), fourcc, fps_out, (w, h))

    frame_dt = float(np.median(np.diff(timestamps))) if len(timestamps) > 1 else (1.0 / fps_out if fps_out else 0.0)
    running = {i: 0.0 for i in range(len(class_names))}

    for i, frame in enumerate(frames_bgr):
        pred_id = int(pred_ids[i])
        running[pred_id] += frame_dt

        header_lines = [f"t={timestamps[i]:.2f}s", f"Pred: {class_names[pred_id]}"]
        if gt_ids is not None:
            header_lines.append(f"GT: {class_names[int(gt_ids[i])]}")
        wash_total = sum(running[j] for j, name in enumerate(class_names) if name.lower() != "other")
        header_lines.append(f"Wash total: {wash_total:.1f}s")
        draw_text_block(frame, header_lines, origin=(10, 30))

        prob_lines = []
        for j, name in enumerate(class_names):
            marker = ">" if j == pred_id else " "
            prob_lines.append(f"{marker} {name}: {probs[i, j]:.2f}")
        x_right = max(10, w - 320)
        draw_text_block(frame, prob_lines, origin=(x_right, 30))

        writer.write(frame)

    if summary and summary_seconds > 0:
        summary_lines = ["Summary", f"Total wash time: {summary['total_wash_s']:.1f}s"]
        for name in class_names:
            done_tag = "done" if summary['steps_done'].get(name, False) else "miss"
            summary_lines.append(f"{name}: {summary['per_class_s'][name]:.1f}s ({done_tag})")
        base = frames_bgr[-1].copy()
        num_frames = int(max(1, fps_out * summary_seconds))
        for _ in range(num_frames):
            frame = base.copy()
            draw_text_block(frame, summary_lines, origin=(10, 30))
            writer.write(frame)

    writer.release()
    return out_path


In [None]:
def run_video_pipeline(
    model,
    video_path,
    gt_class_id=None,
    out_dir=OUTPUT_DIR,
    augment=False,
    consistent_aug=True,
    frame_stride=1,
    batch_size=32,
    sequence_length=None,
    sequence_stride=SEQUENCE_STRIDE,
    sequence_batch_size=SEQUENCE_BATCH_SIZE,
    smoothing_method="moving_avg",
    smoothing_window=7,
    min_step_seconds=1.0,
    max_frames=None,
):
    fps, inputs, frames_bgr, timestamps = collect_video_frames(
        video_path,
        frame_stride=frame_stride,
        max_frames=max_frames,
        augment=augment,
        consistent_aug=consistent_aug,
    )
    if len(inputs) == 0:
        print("No frames extracted from", video_path)
        return {}

    probs_raw = predict_model_probs(
        model,
        inputs,
        batch_size=batch_size,
        sequence_length=sequence_length,
        sequence_stride=sequence_stride,
        sequence_batch_size=sequence_batch_size,
    )
    probs_smoothed, pred_ids = apply_smoothing(probs_raw, method=smoothing_method, window=smoothing_window)

    gt_ids = None
    if gt_class_id is not None:
        if isinstance(gt_class_id, (list, tuple, np.ndarray)):
            gt_ids = np.asarray(gt_class_id, dtype=np.int32)
            if len(gt_ids) != len(pred_ids):
                raise ValueError("gt_class_id length does not match predictions")
        else:
            gt_ids = np.full(len(pred_ids), int(gt_class_id), dtype=np.int32)

    segments = build_segments(pred_ids, timestamps, fps, frame_stride, CLASS_NAMES)
    totals, total_wash = summarize_durations(segments, CLASS_NAMES)
    summary = {
        "per_class_s": totals,
        "total_wash_s": total_wash,
        "steps_done": {name: (totals[name] >= min_step_seconds) for name in CLASS_NAMES},
    }

    out_dir.mkdir(parents=True, exist_ok=True)
    suffix = "aug" if augment else "raw"
    out_path = out_dir / f"{Path(video_path).stem}_{suffix}_pred.mp4"
    fps_out = fps / frame_stride if frame_stride else fps
    output_video = write_annotated_video(
        frames_bgr,
        fps_out,
        timestamps,
        pred_ids,
        probs_smoothed,
        CLASS_NAMES,
        gt_ids=gt_ids,
        summary=summary,
        out_path=out_path,
    )

    timeline_df = build_timeline_df(timestamps, probs_raw, pred_ids, CLASS_NAMES, gt_ids, probs_smoothed=probs_smoothed)
    metrics = evaluate_alignment(gt_ids, pred_ids, CLASS_NAMES) if gt_ids is not None else {}

    return {
        "video_path": str(video_path),
        "output_path": str(output_video) if output_video else None,
        "timeline": timeline_df,
        "segments": segments,
        "summary": summary,
        "metrics": metrics,
        "fps": fps,
        "frame_stride": frame_stride,
    }


In [None]:
FRAME_STRIDE = 2
BATCH_SIZE = 32
SMOOTHING_METHOD = "moving_avg"  # "none", "moving_avg", or "majority"
SMOOTHING_WINDOW = 7
MAX_VIDEOS_PER_SPLIT = 2

sample_row = train_df.sample(1, random_state=cfg.RANDOM_SEED).iloc[0]
result = run_video_pipeline(
    model,
    sample_row["video_path"],
    gt_class_id=sample_row["class_id"],
    augment=False,
    frame_stride=FRAME_STRIDE,
    batch_size=BATCH_SIZE,
    smoothing_method=SMOOTHING_METHOD,
    smoothing_window=SMOOTHING_WINDOW,
)

if result:
    print("Output video:", result["output_path"])
    display(Video(result["output_path"], embed=True, width=360))
    timeline = result["timeline"]
    plot_time_series(
        timeline["timestamp_s"].values,
        timeline["gt_class_id"].values if "gt_class_id" in timeline else None,
        timeline["pred_class_id"].values,
        CLASS_NAMES,
    )
    print("Summary:", result["summary"])
    print("Steps done:", result["summary"].get("steps_done"))
    if result.get("metrics"):
        print("Frame accuracy:", result["metrics"].get("frame_accuracy"))
        print("Edit distance:", result["metrics"].get("edit_distance"))
        print("DTW distance:", result["metrics"].get("dtw_distance"))


In [None]:
def run_split(df, split_name, augment):
    results = []
    sample = df.sample(min(MAX_VIDEOS_PER_SPLIT, len(df)), random_state=cfg.RANDOM_SEED)
    for row in sample.itertuples():
        print(f"[{split_name}] {row.video_path} ({row.class_name}) augment={augment}")
        res = run_video_pipeline(
            model,
            row.video_path,
            gt_class_id=row.class_id,
            augment=augment,
            frame_stride=FRAME_STRIDE,
            batch_size=BATCH_SIZE,
            smoothing_method=SMOOTHING_METHOD,
            smoothing_window=SMOOTHING_WINDOW,
        )
        if not res:
            continue
        results.append(res)
        if res.get("output_path"):
            display(Video(res["output_path"], embed=True, width=360))
        timeline = res["timeline"]
        plot_time_series(
            timeline["timestamp_s"].values,
            timeline["gt_class_id"].values if "gt_class_id" in timeline else None,
            timeline["pred_class_id"].values,
            CLASS_NAMES,
        )
        if res.get("metrics"):
            print("Frame accuracy:", res["metrics"].get("frame_accuracy"))
            print("Edit distance:", res["metrics"].get("edit_distance"))
            print("DTW distance:", res["metrics"].get("dtw_distance"))
            print("Mean temporal IoU:", res["metrics"].get("temporal_iou_mean"))
        print("Summary:", res["summary"])
        print("Steps done:", res["summary"].get("steps_done"))
    return results


all_results = []
for split_name, df in [("train", train_df), ("val", val_df)]:
    for augment in (False, True):
        all_results.extend(run_split(df, split_name, augment))


In [None]:
if all_results:
    all_gt = []
    all_pred = []
    for res in all_results:
        timeline = res.get("timeline")
        if timeline is None or "gt_class_id" not in timeline:
            continue
        all_gt.extend(timeline["gt_class_id"].tolist())
        all_pred.extend(timeline["pred_class_id"].tolist())

    if all_gt:
        print("Overall frame accuracy:", accuracy_score(all_gt, all_pred))
        report = classification_report(all_gt, all_pred, labels=list(range(len(CLASS_NAMES))), target_names=CLASS_NAMES, zero_division=0)
        print(report)


## Ideas to make single-frame inference more robust on time series
- Sliding-window majority vote over predicted class IDs (temporal smoothing)
- Moving average or exponential smoothing over logits/probabilities before argmax
- Hysteresis thresholds: require a class to persist for N frames before switching
- Enforce minimum duration per step (drop very short segments as noise)
- Use a lightweight temporal model on top of frame features (GRU/LSTM/TCN)
- Apply a hidden Markov model (HMM) with a transition prior for valid step order
- Confidence calibration (temperature scaling) to reduce jitter
- Train a change-point detector to explicitly segment step boundaries
- Use optical-flow-based motion gating to suppress static false positives
- Self-training: refine temporal consistency by re-labeling with smoothing
