In [None]:
import os
import glob
import cv2
import numpy as np
import tensorflow as tf
from collections import deque
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Input, TimeDistributed, LSTM
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, LambdaCallback
import matplotlib.pyplot as plt

# 3. Parameters & Paths
base_path       = "/content/drive/MyDrive/drone_datasets"
clip_length     = 8
clip_stride     = 8
batch_size      = 4
shuffle_buffer  = 3500
validation_size = 3000

# 4. Clip generator (streaming)
def clip_generator(directory_path, label, clip_length=8, clip_stride=8):
    patterns = [os.path.join(directory_path, '**', ext) for ext in ['*.avi', '*.mp4', '*.mov']]
    for pattern in patterns:
        for video_path in glob.glob(pattern, recursive=True):
            print(f"Processing: {video_path}")
            cap = cv2.VideoCapture(video_path)
            buffer = deque(maxlen=clip_length)
            frame_idx = 0
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, (224, 224))
                buffer.append(frame)
                frame_idx += 1
                if frame_idx >= clip_length and (frame_idx - clip_length) % clip_stride == 0:
                    clip = np.array(buffer, dtype='uint8')
                    clip = preprocess_input(clip.astype('float32'))
                    yield clip, label
            cap.release()

# 5. Weighted clip generator (streams with sample weight)
def weighted_clip_generator(directory_path, label, weight, clip_length=8, clip_stride=8):
    for clip, lbl in clip_generator(directory_path, label, clip_length, clip_stride):
        yield clip, lbl, weight

# 6. Build weighted tf.data.Datasets for each source
dirs_info = [
    (os.path.join(base_path, "KTH",       "wave"),      1, 1.0),
    (os.path.join(base_path, "KTH",       "not_wave"),  0, 1.0),
    (os.path.join(base_path, "HMDB51",    "wave"),      1, 1.0),
    (os.path.join(base_path, "HMDB51",    "not_wave"),  0, 1.0),
    (os.path.join(base_path, "UAV-Gesture","wave"),     1, 5.0),
    (os.path.join(base_path, "UAV-Gesture","not_wave"), 0, 5.0),
]

datasets = []
for path, label, weight in dirs_info:
    ds = tf.data.Dataset.from_generator(
        lambda p=path, l=label, w=weight: weighted_clip_generator(p, l, w, clip_length, clip_stride),
        output_signature=(
            tf.TensorSpec((clip_length, 224, 224, 3), tf.float32),
            tf.TensorSpec((), tf.int32),
            tf.TensorSpec((), tf.float32),
        )
    )
    datasets.append(ds)

# 7. Interleave and initial shuffle
mixed_ds = tf.data.Dataset.sample_from_datasets(datasets, seed=42)
dshuffled = mixed_ds.shuffle(shuffle_buffer)

# 8. Split into validation and training
val_ds   = dshuffled.take(validation_size)
train_raw = dshuffled.skip(validation_size)

# 9. Per-epoch shuffle + batch + prefetch
train_ds = (
    train_raw
    .shuffle(shuffle_buffer, reshuffle_each_iteration=True)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

# 10. Model definition
cnn_base = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224,224,3))
cnn_base.trainable = False

inputs = Input((clip_length,224,224,3))
x = TimeDistributed(cnn_base)(inputs)
x = TimeDistributed(GlobalAveragePooling2D())(x)
x = LSTM(64, dropout=0.3)(x)
outputs = Dense(1, activation='sigmoid')(x)

model_seq = Model(inputs, outputs)
model_seq.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=[
        tf.keras.metrics.BinaryAccuracy(name='accuracy'),
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.AUC(name='auc')
    ],
    run_eagerly=True
)

model_seq.summary()

# 11. Training with callbacks
def print_callback(epoch, logs):
    print(f"Epoch {epoch+1}: loss={logs['loss']:.4f}, acc={logs['accuracy']:.4f}, "
          f"val_loss={logs.get('val_loss',0):.4f}, val_acc={logs.get('val_accuracy',0):.4f}")

print_cb  = LambdaCallback(on_epoch_end=print_callback)
early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

history = model_seq.fit(
    train_ds,
    validation_data=val_ds,
    epochs=1,
    # steps_per_epoch=15,
    verbose=1,
    callbacks=[early_stop, print_cb]
)

# 12. Plotting results
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.title('Loss'); plt.xlabel('Epoch'); plt.legend()

plt.subplot(1,2,2)
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='val_acc')
plt.title('Accuracy'); plt.xlabel('Epoch'); plt.legend()
plt.show()

# 13. Final evaluation & save
results = model_seq.evaluate(val_ds)
for name, val in zip(model_seq.metrics_names, results):
    print(f"{name}: {val:.4f}")

model_seq.save('wave_sequence_model_one_epoch.h5')
print("Model saved as wave_sequence_model_one_epoch.h5")
#