# EdgeWash Kaggle Notebook 

This notebook downloads the Kaggle hand-wash dataset subset, preprocesses frames, optionally computes optical flow, and trains the EdgeWash CNN classifiers. Each step is heavily commented so you can adapt hyperparameters or swap architectures quickly in a Kaggle environment.

## 1. Environment setup

We install Python dependencies directly (no external files needed) and make sure `ffmpeg` is available for video processing. Kaggle images already ship with CUDA-enabled TensorFlow, so the install is fast.

*Tip:* Re-run this cell if you change dependencies.


In [None]:
%%bash
pip install -q --no-input numpy tensorflow opencv-python keras matplotlib streamlit tqdm
command -v ffmpeg || (apt-get update -y && apt-get install -y ffmpeg)
which ffmpeg


## 2. Define paths and hyperparameters

We collect every important configurable value in one place. Setting environment variables keeps the training code aligned with the repository scripts (e.g., `classify_dataset.py`).

* `USE_OPTICAL_FLOW`: toggle the two-stream model (RGB + optical flow).
* `NUM_EPOCHS`, `NUM_LAYERS`, etc.: mirror the `HANDWASH_*` variables used by the training helpers.
* `DATA_ROOT`: where we download and preprocess the dataset (defaults to `/kaggle/working`).

You can edit values in the dictionary before running the cell.

In [None]:
import os, json, pathlib

CONFIG = {
    "DATA_ROOT": "/kaggle/working/edgewash_data",
    "USE_OPTICAL_FLOW": False,  # set True to train the merged two-stream network
    "MODEL_VARIANT": "single_frame",  # options: single_frame, merged
    "HANDWASH_NN": "MobileNetV2",  # options: MobileNetV2, InceptionV3, Xception
    "HANDWASH_NUM_LAYERS": 0,
    "HANDWASH_NUM_EPOCHS": 10,
    "HANDWASH_NUM_FRAMES": 5,  # kept for compatibility with time-distributed models
    "HANDWASH_EXTRA_LAYERS": 0,
    "BATCH_SIZE": 32,
    "IMG_HEIGHT": 240,
    "IMG_WIDTH": 320,
}

# Export environment variables so downstream scripts pick them up
for key, value in CONFIG.items():
    if key.startswith("HANDWASH_"):
        os.environ[key] = str(value)

root = pathlib.Path(CONFIG["DATA_ROOT"])
root.mkdir(parents=True, exist_ok=True)
print(json.dumps(CONFIG, indent=2))
print("Environment variables applied.")


In [None]:
import gc, random
from pathlib import Path
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.layers import Layer

# GPU safety: enable memory growth if available
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
    try:
        tf.config.experimental.set_memory_growth(device, True)
    except Exception:
        pass

# Pull hyperparameters from CONFIG
batch_size = CONFIG["BATCH_SIZE"]
IMG_SIZE = (CONFIG["IMG_HEIGHT"], CONFIG["IMG_WIDTH"])
IMG_SHAPE = IMG_SIZE + (3,)
N_CLASSES = 7
model_name = CONFIG["HANDWASH_NN"]
num_trainable_layers = int(CONFIG["HANDWASH_NUM_LAYERS"])
num_epochs = int(CONFIG["HANDWASH_NUM_EPOCHS"])
num_frames = int(CONFIG["HANDWASH_NUM_FRAMES"])
num_extra_layers = int(CONFIG["HANDWASH_EXTRA_LAYERS"])
log_dir = os.getenv("HANDWASH_TENSORBOARD_LOGDIR", "")

# Data augmentation block reused across models
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

def freeze_model(model):
    if num_trainable_layers == 0:
        for layer in model.layers:
            layer.trainable = False
        return False
    elif num_trainable_layers > 0:
        for layer in model.layers[:-num_trainable_layers]:
            layer.trainable = False
        for layer in model.layers[-num_trainable_layers:]:
            layer.trainable = True
        return True
    else:
        for layer in model.layers:
            layer.trainable = True
        return True

def get_preprocessing_function():
    if model_name == "MobileNetV2":
        return tf.keras.applications.mobilenet_v2.preprocess_input
    elif model_name == "InceptionV3":
        return tf.keras.applications.inception_v3.preprocess_input
    elif model_name == "Xception":
        return tf.keras.applications.xception.preprocess_input
    return None

class MobileNetPreprocessingLayer(Layer):
    def call(self, x):
        return (x / 127.5) - 1.0
    def compute_output_shape(self, input_shape):
        return input_shape

def get_default_model():
    if model_name == "MobileNetV2":
        base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
    elif model_name == "InceptionV3":
        base_model = tf.keras.applications.InceptionV3(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
    elif model_name == "Xception":
        base_model = tf.keras.applications.Xception(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
    else:
        raise ValueError(f"Unknown model name {model_name}")

    training = freeze_model(base_model)
    inputs = tf.keras.Input(shape=IMG_SHAPE)
    x = data_augmentation(inputs)
    x = get_preprocessing_function()(x)
    x = base_model(x, training=training)
    x = tf.keras.layers.Flatten()(x)
    if num_extra_layers:
        x = tf.keras.layers.GlobalAveragePooling2D()(x)
    for _ in range(num_extra_layers):
        x = tf.keras.layers.Dense(128, activation='relu')(x)
        x = tf.keras.layers.Dropout(0.2)(x)
    outputs = tf.keras.layers.Dense(N_CLASSES, activation='softmax')(x)
    model = tf.keras.Model(inputs, outputs)
    print(model.summary())
    return model

def get_merged_model():
    if model_name == "MobileNetV2":
        rgb_base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
        of_base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
    elif model_name == "InceptionV3":
        rgb_base_model = tf.keras.applications.InceptionV3(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
        of_base_model = tf.keras.applications.InceptionV3(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
    elif model_name == "Xception":
        rgb_base_model = tf.keras.applications.Xception(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
        of_base_model = tf.keras.applications.Xception(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
    else:
        raise ValueError(f"Unknown model name {model_name}")

    training = freeze_model(rgb_base_model)
    freeze_model(of_base_model)

    rgb_input = tf.keras.Input(shape=IMG_SHAPE)
    rgb_branch = data_augmentation(rgb_input)
    rgb_branch = get_preprocessing_function()(rgb_branch)
    rgb_branch = rgb_base_model(rgb_branch, training=training)
    rgb_branch = tf.keras.layers.Flatten()(rgb_branch)

    of_input = tf.keras.Input(shape=IMG_SHAPE)
    of_branch = data_augmentation(of_input)
    of_branch = get_preprocessing_function()(of_branch)
    of_branch = of_base_model(of_branch, training=training)
    of_branch = tf.keras.layers.Flatten()(of_branch)

    merged = tf.keras.layers.concatenate([rgb_branch, of_branch], axis=1)
    for _ in range(num_extra_layers):
        merged = tf.keras.layers.Dense(128, activation='relu')(merged)
        merged = tf.keras.layers.Dropout(0.2)(merged)
    merged = tf.keras.layers.Dense(N_CLASSES, activation='softmax')(merged)
    model = tf.keras.Model([rgb_input, of_input], merged)
    print(model.summary())
    return model

def compute_class_weights(data_dir, class_names):
    counts = []
    for cls in class_names:
        cls_dir = Path(data_dir)/cls
        counts.append(len([p for p in cls_dir.glob('*.jpg')]))
    counts = np.array(counts)
    avg = np.average(counts) if counts.size else 1.0
    weights = avg / counts
    return {i: float(w) for i, w in enumerate(weights)}

def build_single_stream_datasets(trainval_dir, test_dir, batch_size=batch_size, seed=123):
    AUTOTUNE = tf.data.AUTOTUNE
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        trainval_dir, validation_split=0.2, subset="training", seed=seed,
        image_size=IMG_SIZE, label_mode='categorical', crop_to_aspect_ratio=False, batch_size=batch_size)
    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        trainval_dir, validation_split=0.2, subset="validation", seed=seed,
        image_size=IMG_SIZE, label_mode='categorical', crop_to_aspect_ratio=False, batch_size=batch_size)
    test_ds = tf.keras.preprocessing.image_dataset_from_directory(
        test_dir, seed=seed, image_size=IMG_SIZE, label_mode='categorical',
        crop_to_aspect_ratio=False, batch_size=batch_size)
    weights_dict = compute_class_weights(Path(trainval_dir), train_ds.class_names)
    return (train_ds.prefetch(AUTOTUNE), val_ds.prefetch(AUTOTUNE), test_ds.prefetch(AUTOTUNE), weights_dict, train_ds.class_names)

def _split_indices(n, seed=123, val_frac=0.2):
    idxs = np.arange(n)
    rng = np.random.default_rng(seed)
    rng.shuffle(idxs)
    split = int(n * (1 - val_frac))
    return idxs[:split], idxs[split:]

def _load_image(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE)
    return img

def build_two_stream_datasets(rgb_root, of_root, batch_size=batch_size, seed=123):
    rgb_root = Path(rgb_root)
    of_root = Path(of_root)
    class_names = sorted([p.name for p in (of_root/"trainval").iterdir() if p.is_dir()])

    def gather_paths(split_root):
        rgb_paths, of_paths, labels = [], [], []
        for idx, cls in enumerate(class_names):
            of_files = sorted((split_root/cls).glob('*.jpg'))
            for of_path in of_files:
                of_paths.append(str(of_path))
                rgb_paths.append(str(Path(str(of_path)).as_posix().replace(str(of_root), str(rgb_root))))
                labels.append(idx)
        return np.array(rgb_paths), np.array(of_paths), np.array(labels, dtype=np.int32)

    rgb_trainval, of_trainval, labels_trainval = gather_paths(of_root/"trainval")
    rgb_test, of_test, labels_test = gather_paths(of_root/"test")

    train_idx, val_idx = _split_indices(len(labels_trainval), seed=seed)

    def make_dataset(rgb_list, of_list, label_list, shuffle=True):
        ds = tf.data.Dataset.from_tensor_slices((rgb_list, of_list, label_list))
        if shuffle:
            ds = ds.shuffle(buffer_size=len(label_list), seed=seed)
        def load_pair(rgb_path, of_path, label):
            rgb_img = _load_image(rgb_path)
            of_img = _load_image(of_path)
            return (rgb_img, of_img), tf.one_hot(label, depth=len(class_names))
        return ds.map(load_pair, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    train_ds = make_dataset(rgb_trainval[train_idx], of_trainval[train_idx], labels_trainval[train_idx])
    val_ds = make_dataset(rgb_trainval[val_idx], of_trainval[val_idx], labels_trainval[val_idx], shuffle=False)
    test_ds = make_dataset(rgb_test, of_test, labels_test, shuffle=False)

    weights_dict = compute_class_weights(of_root/"trainval", class_names)
    return train_ds, val_ds, test_ds, weights_dict, class_names

def fit_model(name, model, train_ds, val_ds, test_ds, weights_dict):
    es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)
    callbacks = [es]
    if len(log_dir):
        callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch='5,10'))

    history = model.fit(train_ds, epochs=num_epochs, validation_data=val_ds, class_weight=weights_dict, callbacks=callbacks)
    model.save(name + "final-model")

    train_acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']

    plt.figure(figsize=(8, 8))
    plt.grid(True, axis="y")
    plt.subplot(2, 1, 1)
    plt.plot(train_acc, label='Training Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.ylabel('Accuracy')
    plt.ylim([min(plt.ylim()), 1])
    plt.title('Training and Validation Accuracy')
    plt.savefig(f"accuracy-{name}.pdf", format="pdf")

    measure_performance("validation", name, model, val_ds)
    test_loss, test_accuracy = model.evaluate(test_ds)
    result_str = f'Test loss: {test_loss} accuracy: {test_accuracy}
'
    print(result_str)
    with open(f"results-{name}.txt", "a+") as f:
        f.write(result_str)
    measure_performance("test", name, model, test_ds)


def measure_performance(ds_name, name, model, ds, num_classes=N_CLASSES):
    matrix = [[0] * num_classes for _ in range(num_classes)]
    y_predicted, y_true = [], []
    for images, labels in ds:
        preds = model.predict(images)
        for y_p, y_t in zip(preds, labels):
            y_predicted.append(int(np.argmax(y_p)))
            y_true.append(int(np.argmax(y_t)))
        gc.collect()
    for y_p, y_t in zip(y_predicted, y_true):
        matrix[y_t][y_p] += 1
    print("Confusion matrix:")
    for row in matrix:
        print(row)
    f1_scores = []
    for i in range(num_classes):
        total = sum(matrix[i])
        true_predictions = matrix[i][i]
        total_predictions = sum([matrix[j][i] for j in range(num_classes)])
        precision = true_predictions / total if total else 0
        recall = true_predictions / total_predictions if total_predictions else 0
        f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
        print(f"{i} precision={100*precision:.2f}% recall={100*recall:.2f}% f1={f1:.2f}")
        f1_scores.append(f1)
    summary = f"Average {ds_name} F1 score: {np.mean(f1_scores):.2f}
"
    print(summary)
    with open(f"results-{name}.txt", "a+") as f:
        f.write(summary)


def evaluate(name, train_ds, val_ds, test_ds, weights_dict, model=None):
    if model is None:
        model = get_default_model()
    model.compile(optimizer='Adam', loss=tf.keras.losses.CategoricalCrossentropy(), metrics=['accuracy'])
    with open(f"results-{name}.txt", "a+") as f:
        pass
    fit_model(name, model, train_ds, val_ds, test_ds, weights_dict)
    return model


## 3. Download the Kaggle hand-wash subset

The repository ships a helper script (`dataset-kaggle/get-and-preprocess-dataset.sh`) that fetches a reorganized 7-class subset of the public Kaggle hand-wash dataset. The cell below mirrors that logic with inline Python so the notebook stays self-contained.

Artifacts created:
* `kaggle-dataset-6classes.tar` — downloaded archive
* `kaggle-dataset-6classes/` — raw videos sorted into 7 class folders

Run the cell once; it skips work if files already exist.

In [None]:
import pathlib, tarfile, urllib.request

data_root = pathlib.Path(CONFIG["DATA_ROOT"]).resolve()
raw_tar = data_root / "kaggle-dataset-6classes.tar"
raw_dir = data_root / "kaggle-dataset-6classes"

url = "https://github.com/atiselsts/data/raw/master/kaggle-dataset-6classes.tar"
if not raw_tar.exists():
    print("Downloading dataset archive...")
    urllib.request.urlretrieve(url, raw_tar)
else:
    print("Archive already present:", raw_tar)

if not raw_dir.exists():
    print("Extracting archive...")
    with tarfile.open(raw_tar, "r") as tar:
        tar.extractall(data_root)
else:
    print("Extracted directory already exists:", raw_dir)

print("Contents:", list(raw_dir.iterdir())[:3])


## 4. Frame extraction and train/validation/test split

The repository's `dataset-kaggle/separate-frames.py` script splits each class into `trainval` and `test` partitions (70/30) and extracts every video frame. We reuse the same logic here, saving both full videos and per-frame JPEGs.

Outputs (under `kaggle-dataset-6classes-preprocessed/`):
* `videos/trainval` and `videos/test` — original clips split by partition
* `frames/trainval` and `frames/test` — every decoded frame with a class label directory

If you already preprocessed once, the cell will skip the heavy work.

In [None]:
import os, random, cv2, shutil, pathlib

random.seed(123)
input_dir = raw_dir
out_root = data_root / "kaggle-dataset-6classes-preprocessed"
videos_dir = out_root / "videos"
frames_dir = out_root / "frames"

if out_root.exists():
    print("Preprocessed data already exists at", out_root)
else:
    print("Creating frame and video splits...")
    for subset in ["trainval", "test"]:
        for base in [videos_dir, frames_dir]:
            for cls in range(7):
                (base / subset / str(cls)).mkdir(parents=True, exist_ok=True)

    for class_dir in sorted(os.listdir(input_dir)):
        src_cls_path = input_dir / class_dir
        if not src_cls_path.is_dir():
            continue
        for filename in os.listdir(src_cls_path):
            if not filename.endswith(".mp4"):
                continue
            subset = "test" if random.random() < 0.3 else "trainval"
            src = src_cls_path / filename
            video_target = videos_dir / subset / class_dir / filename
            shutil.copy2(src, video_target)

            cap = cv2.VideoCapture(str(src))
            success, frame = cap.read()
            frame_num = 0
            while success:
                frame_name = f"frame_{frame_num}_{os.path.splitext(filename)[0]}.jpg"
                frame_path = frames_dir / subset / class_dir / frame_name
                cv2.imwrite(str(frame_path), frame)
                success, frame = cap.read()
                frame_num += 1
            cap.release()
    print("Finished preprocessing!")

print("Trainval frame examples:", len(list((frames_dir/"trainval").glob("*/*.jpg"))))
print("Test frame examples:", len(list((frames_dir/"test").glob("*/*.jpg"))))


In [None]:
import tensorflow as tf
from tensorflow.data import AUTOTUNE

frames_trainval = frames_dir / "trainval"
frames_test = frames_dir / "test"

preprocess_fn = get_preprocessing_function()

if CONFIG["MODEL_VARIANT"] == "merged" or CONFIG["USE_OPTICAL_FLOW"]:
    of_root = frames_dir.parent / "of"
    if not of_root.exists():
        raise FileNotFoundError("Optical flow not found. Run the optical flow cell with USE_OPTICAL_FLOW=True.")
    train_ds, val_ds, test_ds, weights, class_names = build_two_stream_datasets(frames_trainval, of_root, batch_size=CONFIG["BATCH_SIZE"])
    def normalize_two_stream(rgb_img, of_img, labels):
        return (preprocess_fn(rgb_img), preprocess_fn(of_img)), labels
    train_ds_norm = train_ds.map(lambda inputs, labels: normalize_two_stream(inputs[0], inputs[1], labels), num_parallel_calls=AUTOTUNE)
    val_ds_norm = val_ds.map(lambda inputs, labels: normalize_two_stream(inputs[0], inputs[1], labels), num_parallel_calls=AUTOTUNE)
    test_ds_norm = test_ds.map(lambda inputs, labels: normalize_two_stream(inputs[0], inputs[1], labels), num_parallel_calls=AUTOTUNE)
else:
    train_ds, val_ds, test_ds, weights, class_names = build_single_stream_datasets(str(frames_trainval), str(frames_test), batch_size=CONFIG["BATCH_SIZE"])
    def normalize_batch(images, labels):
        return preprocess_fn(images), labels
    train_ds_norm = train_ds.map(normalize_batch, num_parallel_calls=AUTOTUNE)
    val_ds_norm = val_ds.map(normalize_batch, num_parallel_calls=AUTOTUNE)
    test_ds_norm = test_ds.map(normalize_batch, num_parallel_calls=AUTOTUNE)

print("Class names:", class_names)
print("Class weights:", weights)


## 6. (Optional) Compute optical flow for the merged two-stream network

Set `CONFIG['USE_OPTICAL_FLOW'] = True` if you want to train the RGB + optical flow model. This step computes Farnebäck dense optical flow between frames spaced by ~1/3 second (matching `calculate-optical-flow.py`).

Outputs live in `kaggle-dataset-6classes-preprocessed/of/trainval` and `/test` alongside the RGB frames.

In [None]:
import numpy as np

if CONFIG["USE_OPTICAL_FLOW"]:
    print("Optical flow enabled; proceed with extraction below.")
else:
    print("Optical flow disabled; skip to training.")


### Optical flow helper (embedded)

The repository script references a global `input_dir`; the helper below keeps everything scoped inside the notebook and mirrors the same Farnebäck settings.

In [None]:
import cv2 as cv
from tqdm import tqdm

def compute_optical_flow(dataset_root, frame_step=10):
    import pathlib, numpy as np
    videos_path = pathlib.Path(dataset_root)/"videos"
    output_root = pathlib.Path(dataset_root)/"of"
    output_root.mkdir(exist_ok=True)
    for subset in ["trainval", "test"]:
        subset_in = videos_path/subset
        subset_out = output_root/subset
        for cls_dir in subset_in.iterdir():
            if not cls_dir.is_dir():
                continue
            (subset_out/cls_dir.name).mkdir(parents=True, exist_ok=True)
            for video in tqdm(list(cls_dir.glob("*.mp4")), desc=f"{subset}-{cls_dir.name}"):
                cap = cv.VideoCapture(str(video))
                frames = []
                ret, frame = cap.read()
                while ret:
                    frames.append(cv.cvtColor(frame, cv.COLOR_BGR2GRAY))
                    ret, frame = cap.read()
                cap.release()
                for idx in range(len(frames)-frame_step):
                    flow = cv.calcOpticalFlowFarneback(frames[idx], frames[idx+frame_step], None, 0.5, 3, 15, 3, 5, 1.2, 0)
                    mag, ang = cv.cartToPolar(flow[...,0], flow[...,1])
                    mask = np.zeros((frames[idx].shape[0], frames[idx].shape[1],3), dtype=np.float32)
                    mask[...,0] = ang*180/np.pi/2
                    mask[...,1] = 255
                    mask[...,2] = cv.normalize(mag, None, 0, 255, cv.NORM_MINMAX)
                    rgb = cv.cvtColor(mask, cv.COLOR_HSV2BGR)
                    out_name = subset_out/cls_dir.name/f"frame_{idx}_{video.stem}.jpg"
                    cv.imwrite(str(out_name), rgb)
    print("Optical flow extraction complete at", output_root)

if CONFIG["USE_OPTICAL_FLOW"]:
    compute_optical_flow(out_root)


## 7. Train the classifier

We reuse `classify_dataset.evaluate` to stay faithful to the repository logic. Choose one of the three architectures by importing the right training script: 

* Single-frame baseline (used below)
* Time-distributed GRU (`kaggle-classify-videos.py`)
* Two-stream merged network (`kaggle-classify-merged-network.py`, requires optical flow)

Training artifacts saved:
* `kaggle-single-framefinal-model/` — SavedModel directory
* `results-<name>.txt` — metrics and F1 scores
* `accuracy-<name>.pdf` — accuracy plot

In [None]:
variant = "merged" if CONFIG["USE_OPTICAL_FLOW"] else CONFIG.get("MODEL_VARIANT", "single_frame")
model_name = "kaggle-" + variant

if variant == "merged":
    model = get_merged_model()
else:
    model = get_default_model()

trained_model = evaluate(model_name, train_ds_norm, val_ds_norm, test_ds_norm, weights_dict=weights, model=model)


## 8. Evaluate and visualize

After training, we can inspect the saved accuracy plot and load the metrics log. We also print a few sample predictions to verify the label ordering.

In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import pathlib

results_file = pathlib.Path(f"results-{model_name}.txt")
print(results_file.read_text())

acc_plot = pathlib.Path(f"accuracy-{model_name}.pdf")
print("Accuracy plot saved to", acc_plot)

sample_images, sample_labels = next(iter(test_ds_norm.take(1)))
loaded_model = tf.keras.models.load_model(f"{model_name}final-model", custom_objects={"MobileNetPreprocessingLayer": MobileNetPreprocessingLayer})
preds = loaded_model.predict(sample_images)
print("Sample prediction distribution:", preds[0])
print("True label one-hot:", sample_labels[0].numpy())


## 9. Export for inference (SavedModel + TFLite)

Kaggle notebooks often deploy to mobile or lightweight environments. The following cell saves a TensorFlow Lite version compatible with the MobileNet preprocessing layer defined in `classify_dataset.py`.

In [None]:
import tensorflow as tf
from pathlib import Path

saved_dir = f"{model_name}final-model"
model = tf.keras.models.load_model(saved_dir, custom_objects={"MobileNetPreprocessingLayer": MobileNetPreprocessingLayer})
converter = tf.lite.TFLiteConverter.from_saved_model(saved_dir)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
Path(f"{model_name}.tflite").write_bytes(tflite_model)
print("TFLite model written to", f"{model_name}.tflite")


## 10. TensorBoard (optional)

You can monitor training live by setting `HANDWASH_TENSORBOARD_LOGDIR` before training and launching TensorBoard inside the notebook. Uncomment the block below to enable.

In [None]:
# import os
# os.environ['HANDWASH_TENSORBOARD_LOGDIR'] = str(pathlib.Path(CONFIG['DATA_ROOT']) / 'logs')
# %load_ext tensorboard
# %tensorboard --logdir $HANDWASH_TENSORBOARD_LOGDIR
