# Handwash Training (Colab-friendly)

End-to-end pipeline in notebook form. Mirrors `scripts/run_colab_pipeline.sh` but lets you inspect and tweak each stage, preview augmentations, and monitor TensorBoard.

In [None]:
# 0) Dependencies (TensorFlow is preinstalled on Colab)
!pip install -q --no-cache-dir scikit-learn pandas numpy opencv-python-headless matplotlib seaborn tqdm requests nbformat

In [None]:
# 1) Paths, Drive, and imports
import os, sys, pathlib, json, shutil, time
from typing import List

from google.colab import drive, output

# Adjust these if you cloned elsewhere
PROJECT_ROOT = pathlib.Path('/content/edgeWash').resolve()
DATA_DIR = pathlib.Path(os.environ.get('DATA_DIR', '/content/handwash_data')).resolve()
RAW_DIR = DATA_DIR / 'raw'
PROCESSED_DIR = DATA_DIR / 'processed'
MODELS_DIR = PROJECT_ROOT / 'models'
CHECKPOINTS_DIR = PROJECT_ROOT / 'checkpoints'
LOGS_DIR = PROJECT_ROOT / 'logs'

# Mount Drive to save checkpoints/logs there (optional)
drive.mount('/content/drive')

# Make training modules importable
sys.path.insert(0, str(PROJECT_ROOT / 'training'))
os.environ['PYTHONPATH'] = str(PROJECT_ROOT / 'training') + ':' + os.environ.get('PYTHONPATH', '')

import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import download_datasets
import preprocess_data
import train
import evaluate
from config import AUGMENT_MULTIPLIER, SEQUENCE_LENGTH, IMG_SIZE
from data_generators import create_frame_generators, create_sequence_generators

PROJECT_ROOT, DATA_DIR

## 2) Run configuration
Tweak batches to better load the GPU; increase until you approach GPU RAM.

In [None]:
DATASETS = ['kaggle', 'pskus', 'metc', 'synthetic_blender_rozakar']
MODELS = ['mobilenetv2', 'resnet50', 'efficientnetb0', 'lstm', 'gru', '3d_cnn']
EPOCHS = 20
BATCH_MOBILENET = 128
BATCH_SEQUENCE = 64
LEARNING_RATE = 1e-4
AUGMENT_MULT = 5  # override if needed
USE_EXISTING_PROCESSED = False

# Ensure dirs
for p in [DATA_DIR, RAW_DIR, PROCESSED_DIR, MODELS_DIR, CHECKPOINTS_DIR, LOGS_DIR]:
    p.mkdir(parents=True, exist_ok=True)

config = {
    'datasets': DATASETS,
    'models': MODELS,
    'epochs': EPOCHS,
    'batch_mobilenet': BATCH_MOBILENET,
    'batch_sequence': BATCH_SEQUENCE,
    'augment_multiplier': AUGMENT_MULT,
    'lr': LEARNING_RATE,
    'data_dir': str(DATA_DIR)
}
json.dumps(config, indent=2)

## 3) TensorBoard (start first)
Run the cell, then open the window below.

In [None]:
import subprocess, os, signal
tb_proc = subprocess.Popen([
    'tensorboard', '--logdir', str(LOGS_DIR), '--host', '0.0.0.0', '--port', '6006', '--load_fast=false'
], stdout=open(LOGS_DIR / 'tensorboard.out', 'w'), stderr=subprocess.STDOUT)
output.serve_kernel_port_as_window(6006)
print('TensorBoard PID', tb_proc.pid)

## 4) Download datasets (skips if already present)
Progress bars and warnings appear in the cell output.

In [None]:
downloaders = {
    'kaggle': download_datasets.download_kaggle_dataset,
    'pskus': download_datasets.download_pskus_dataset,
    'metc': download_datasets.download_metc_dataset,
    'synthetic_blender_rozakar': download_datasets.download_synthetic_blender_rozakar,
}

for name in DATASETS:
    print(f"\n=== {name}: download ===")
    ok = downloaders[name]()
    if not ok:
        print(f"WARNING: download failed for {name}; continue or fix before training.")

print('\nVerification:')
print(json.dumps(download_datasets.verify_datasets(), indent=2))

## 5) Preprocess (per dataset)
This extracts frames/sequences and writes train/val/test CSVs. Set `USE_EXISTING_PROCESSED=True` to skip.

In [None]:
if not USE_EXISTING_PROCESSED:
    for name in DATASETS:
        print(f"\n=== {name}: preprocess ===")
        use_kaggle = name == 'kaggle'
        use_pskus = name == 'pskus'
        use_metc = name == 'metc'
        use_synth = name == 'synthetic_blender_rozakar'
        res = preprocess_data.preprocess_all_datasets(
            use_kaggle=use_kaggle,
            use_pskus=use_pskus,
            use_metc=use_metc,
            use_synthetic_blender_rozakar=use_synth
        )
        print(json.dumps({k: str(v) for k, v in res.items()}, indent=2))
else:
    print('Skipping preprocess: USE_EXISTING_PROCESSED=True')

## 6) Augmentation preview
Shows original vs augmented frames using the on-the-fly augmentations (flip/rotate/zoom/shift/brightness/shadow).

In [None]:
train_csv = PROCESSED_DIR / 'train.csv'
if not train_csv.exists():
    raise FileNotFoundError('train.csv missing; run preprocessing first')

train_df = pd.read_csv(train_csv).sample(16, replace=True, random_state=42)
gen, _, _ = create_frame_generators(train_df, train_df, train_df, batch_size=8, augment_multiplier=2)
batch_imgs, _ = gen[0]

plt.figure(figsize=(12, 6))
for i in range(8):
    plt.subplot(2, 4, i + 1)
    plt.imshow(batch_imgs[i])
    plt.axis('off')
plt.suptitle('Augmented samples (frame models)')
plt.tight_layout()
plt.show()

## 7) Train per dataset and model
Batches are chosen by model type; checkpoints/logs go to `models/` and `checkpoints/`.

In [None]:
def batch_for(model):
    m = model.lower()
    if m in ['mobilenetv2', 'resnet50', 'efficientnetb0']:
        return BATCH_MOBILENET
    if m in ['lstm', 'gru']:
        return BATCH_SEQUENCE
    if m == '3d_cnn':
        return 12
    raise ValueError(m)

for name in DATASETS:
    print(f"\n=== {name}: training ===")
    for model in MODELS:
        try:
            res = train.train_model(
                model_type=model,
                train_csv=PROCESSED_DIR / 'train.csv',
                val_csv=PROCESSED_DIR / 'val.csv',
                batch_size=batch_for(model),
                epochs=EPOCHS,
                learning_rate=LEARNING_RATE
            )
            print(json.dumps({
                'model': model,
                'best_epoch': int(res['best_epoch']) + 1,
                'best_val_acc': float(res['history']['val_accuracy'][res['best_epoch']]),
                'final_model': str(res['final_model_path'])
            }, indent=2))
        except Exception as exc:
            print(f"{model}: FAILED -> {exc}")
    # Optional cleanup to save space after each dataset
    if name != DATASETS[-1]:
        shutil.rmtree(RAW_DIR / name, ignore_errors=True)
        for f in ['train.csv', 'val.csv', 'frames.csv']:
            try:
                os.remove(PROCESSED_DIR / f)
            except FileNotFoundError:
                pass

## 8) Evaluate on test set (if still present)
Runs evaluation for any final models that were saved.

In [None]:
test_csv = PROCESSED_DIR / 'test.csv'
if test_csv.exists():
    for model in MODELS:
        model_path = PROJECT_ROOT / 'models' / f'{model.lower()}_final.keras'
        if not model_path.exists():
            continue
        try:
            res = evaluate.evaluate_model(
                model_path=model_path,
                test_csv=test_csv,
                model_type=model,
                batch_size=batch_for(model),
                save_results=True
            )
            print(model, json.dumps({k: v for k, v in res.items() if isinstance(v, (float, int, str))}, indent=2))
        except Exception as exc:
            print(f"eval {model}: FAILED -> {exc}")
else:
    print('test.csv not found; skipping evaluation')

## 9) Stop TensorBoard when done

In [None]:
try:
    os.kill(tb_proc.pid, signal.SIGTERM)
    print('TensorBoard stopped')
except Exception as exc:
    print('TensorBoard stop error', exc)