In [None]:
#!/usr/bin/env python3
import os
import glob
import cv2
import numpy as np
import tensorflow as tf
from collections import deque
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow.keras.layers import TimeDistributed, GlobalAveragePooling2D, LSTM, Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard
import matplotlib.pyplot as plt
import datetime

# 1. PARAMETERS
BASE_PATH       = "/content/drive/MyDrive/drone_datasets"
CLIP_LENGTH     = 8
CLIP_STRIDE     = 8
BATCH_SIZE      = 8
SHUFFLE_BUFFER  = 5000
VAL_SIZE        = 3000
EPOCHS          = 300
STEPS_PER_EPOCH = 200   # adjust to roughly (total_train_clips // BATCH_SIZE)
VAL_STEPS       = 50    # adjust to (VAL_SIZE // BATCH_SIZE)

# 2. CLIP GENERATOR (RGB + preprocess)
def clip_generator(directory_path, label, clip_length=CLIP_LENGTH, clip_stride=CLIP_STRIDE):
    patterns = [os.path.join(directory_path, '**', ext) 
                for ext in ('*.avi','*.mp4','*.mov')]
    for pattern in patterns:
        for video_path in glob.glob(pattern, recursive=True):
            cap = cv2.VideoCapture(video_path)
            buffer = deque(maxlen=clip_length)
            frame_idx = 0
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                # convert to RGB & resize
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, (224,224))
                buffer.append(frame)
                frame_idx += 1
                if frame_idx >= clip_length and (frame_idx - clip_length) % clip_stride == 0:
                    clip = np.array(buffer, dtype='float32')
                    clip = preprocess_input(clip)
                    yield clip, label
            cap.release()

# 3. DATASET CREATION
dirs_info = [
    (os.path.join(BASE_PATH, "KTH",       "wave"),      1),
    (os.path.join(BASE_PATH, "KTH",       "not_wave"),  0),
    (os.path.join(BASE_PATH, "HMDB51",    "wave"),      1),
    (os.path.join(BASE_PATH, "HMDB51",    "not_wave"),  0),
    (os.path.join(BASE_PATH, "UAV-Gesture","wave"),     1),
    (os.path.join(BASE_PATH, "UAV-Gesture","not_wave"), 0),
]

datasets = []
for path, label in dirs_info:
    ds = tf.data.Dataset.from_generator(
        lambda p=path, l=label: clip_generator(p, l),
        output_signature=(
            tf.TensorSpec((CLIP_LENGTH,224,224,3), tf.float32),
            tf.TensorSpec((), tf.int32),
        )
    )
    datasets.append(ds)

# interleave, shuffle, split
mixed = tf.data.Dataset.sample_from_datasets(datasets, seed=42)
shuffled = mixed.shuffle(SHUFFLE_BUFFER, reshuffle_each_iteration=True)

val_ds   = shuffled.take(VAL_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
train_ds = shuffled.skip(VAL_SIZE)\
                    .batch(BATCH_SIZE)\
                    .prefetch(tf.data.AUTOTUNE)

# 4. MODEL DEFINITION
cnn_base = MobileNetV2(
    weights='imagenet', include_top=False, 
    input_shape=(224,224,3)
)
# unfreeze last 20 layers for fine-tuning
for layer in cnn_base.layers[:-20]:
    layer.trainable = False
for layer in cnn_base.layers[-20:]:
    layer.trainable = True

inputs = Input((CLIP_LENGTH,224,224,3))
x = TimeDistributed(cnn_base)(inputs)
x = TimeDistributed(GlobalAveragePooling2D())(x)
x = LSTM(64, dropout=0.3)(x)
outputs = Dense(1, activation='sigmoid')(x)

model = Model(inputs, outputs)
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss='binary_crossentropy',
    metrics=[
        'accuracy',
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.AUC(name='auc')
    ]
)

model.summary()

# 5. CALLBACKS & TENSORBOARD LOGGING
log_dir = os.path.join("logs/fit", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tb_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-7),
    ModelCheckpoint('best_wave_model.h5', save_best_only=True, monitor='val_loss'),
    tb_callback
]

# 6. TRAINING
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VAL_STEPS,
    callbacks=callbacks,
    verbose=1
)

# 7. EVALUATION
print("\nEvaluating on validation set...")
results = model.evaluate(val_ds, steps=VAL_STEPS, verbose=1)
print("Validation metrics:")
for name, value in zip(model.metrics_names, results):
    print(f"  {name}: {value:.4f}")

# 8. SAVE MODEL
model.save('wave_sequence_model_final.h5')
print("Final model saved as wave_sequence_model_final.h5")
