In [None]:
# Dataset & collection plan for early gesture prediction (Rock/Paper/Scissors)
# ----------------------------------------------------------
# GOAL: Predict gesture class ASAP before completion.
# KEY IDEA: Train on progressive prefixes of gesture videos so model learns
# to classify from partial motion context, not only full final pose.
#
# 1. RECORDING PROTOCOL
#    - Camera: laptop/webcam, frontal view, consistent background if possible.
#    - Resolution: 640x480 or 720p (higher adds cost; downsample later).
#    - FPS: 15–30 (choose 20 for balance).
#    - Clip length: capture from neutral hand start -> final pose hold (~1.0–1.5s).
#      Total frames per clip ≈ 20–30 at 20 FPS.
#    - Classes: rock, paper, scissors (folder/class-based or filename-based labels).
#    - Variations: different users, lighting, hand orientations (side, frontal),
#      distances (near/far), left/right hand.
#
# 2. DATA VOLUME (minimum viable):
#    - Per class: 300–500 clips for baseline; scaling to 1000+ improves robustness.
#    - Total frames processed after prefix expansion rises (see below).
#
# 3. STORAGE / VM TRANSFER CONCERNS
#    - Raw video size grows quickly; prefer:
#        a. Compress: H.264 MP4, ~1–2 Mbps.
#        b. Downscale resolution early (e.g. ffmpeg scale=320:240 if quality sufficient).
#        c. Extract frames on HOST, transfer frames (.jpg) instead of video if VM IO limited.
#        d. Optionally store only first N frames (e.g. first 24) for each clip.
#    - Approx size estimate:
#        500 clips/class * 3 classes = 1500 clips.
#        Each clip 1.5s at 20 FPS -> 30 frames -> 1500*30 = 45k frames.
#        JPEG 20–40KB each ≈ 0.9–1.8GB worst-case (optimize with lower quality).
#
# 4. TRAINING STRATEGY FOR EARLY PREDICTION
#    - From each full clip produce multiple PREFIX sequences length T where
#      T in {4,6,8,10,...,MAX_FRAMES}. (Curriculum / multi-prefix sampling.)
#    - Weight earlier timesteps heavier (already implemented time_weights).
#    - OPTION: Binary “decision made” flag—stop predicting earlier once confidence threshold crossed.
#
# 5. FEED FULL VIDEOS OR CUT?
#    - For early classification we do NOT need entire finished gesture every time.
#    - We SHOULD keep the full clip to generate prefixes dynamically.
#    - During training sample a random prefix length L (e.g. uniform 4..SEQ_LEN_TARGET),
#      pad if needed, pass to model; model outputs per timestep logits.
#
# 6. LABELS FOR PREFIXES
#    - Entire prefix labeled with final gesture class (static label).
#    - Optional: add a "none" class for frames before motion begins (requires start frame annotation).
#    - If start-of-motion annotation unavailable, assume clip trimmed to begin near motion onset.
#
# 7. IMPLEMENTATION OUTLINE:
#    - VideoEarlyGestureDataset: loads video, extracts frames, caches optionally.
#    - Returns a sequence tensor (T,C,H,W) with (T <= MAX_SEQ_LEN).
#    - On __getitem__, randomly selects prefix length L >= MIN_PREFIX.
#
# 8. OPTIONAL OPTIMIZATION:
#    - Pre-extract frames to disk: dataset_root/class_name/clip_id/frame_%03d.jpg
#    - Speeds up I/O vs decoding every epoch.
#
# 9. REAL-TIME INFERENCE:
#    - Maintain a sliding buffer of last K frames.
#    - After each new frame, run model, check confidence threshold.
#    - Reset buffer when prediction stabilized or after fixed timeout.

# ----------------------------------------------------------
# CODE: Utilities for frame extraction (offline), dataset class, and loader.

import os
import glob
import random
import math
import cv2
from typing import List, Tuple, Optional
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# Reuse existing constants or redefine if running standalone
NUM_CLASSES = 3
CLASS_MAP = {'rock':0, 'paper':1, 'scissors':2}
INV_CLASS_MAP = {v:k for k,v in CLASS_MAP.items()}

FRAME_SIZE = 128
MAX_SEQ_LEN = 16          # Upper bound of frames fed to model (increase if gesture longer)
MIN_PREFIX_LEN = 4        # Smallest prefix used in training
AUGMENT = True            # Enable per-frame augmentation for robustness
CACHE_FRAMES = True       # If True, pre-extract frames to folders
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Transforms
base_video_transform = transforms.Compose([
    transforms.Resize((FRAME_SIZE, FRAME_SIZE)),
    transforms.ColorJitter(0.25,0.25,0.25,0.05),
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
eval_video_transform = transforms.Compose([
    transforms.Resize((FRAME_SIZE, FRAME_SIZE)),
    transforms.ToTensor()
])

def infer_label_from_filename(name: str) -> Optional[int]:
    low = name.lower()
    for k in CLASS_MAP.keys():
        if k in low:
            return CLASS_MAP[k]
    return None

def extract_frames_from_video(path: str, max_frames: int = MAX_SEQ_LEN, every_n: int = 1) -> List[Image.Image]:
    """Decode video and return up to max_frames PIL RGB frames subsampled by every_n."""
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        return []
    frames = []
    frame_idx = 0
    while len(frames) < max_frames:
        ret, frame = cap.read()
        if not ret:
            break
        if frame_idx % every_n == 0:
            rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(Image.fromarray(rgb))
        frame_idx += 1
    cap.release()
    return frames

def preextract_video_frames(video_root: str, output_root: str, pattern: str = "*.mp4"):
    """Pre-extract frames of each video to output_root/class/clip_basename/frame_###.jpg."""
    os.makedirs(output_root, exist_ok=True)
    videos = glob.glob(os.path.join(video_root, pattern))
    for vp in videos:
        label = infer_label_from_filename(os.path.basename(vp))
        if label is None:
            continue
        cls_name = INV_CLASS_MAP[label]
        out_dir = os.path.join(output_root, cls_name, os.path.splitext(os.path.basename(vp))[0])
        os.makedirs(out_dir, exist_ok=True)
        if len(glob.glob(os.path.join(out_dir, "frame_*.jpg"))) > 0:
            continue  # already extracted
        frames = extract_frames_from_video(vp, max_frames=MAX_SEQ_LEN*2)  # extract extra for prefix variation
        if not frames:
            continue
        for i, img in enumerate(frames):
            img.save(os.path.join(out_dir, f"frame_{i:03d}.jpg"))

class VideoEarlyGestureDataset(Dataset):
    """
    Loads video clips for early gesture classification.
    Supports on-the-fly decoding OR cached frames directory structure:
        root/
          rock/
            clip1.mp4
            clip2.mp4
          paper/
            ...
    OR if CACHE_FRAMES:
        root_frames/
          rock/clip1/frame_000.jpg ...
    Returns (seq_tensor, label).
    Sequence is a random prefix length between MIN_PREFIX_LEN and MAX_SEQ_LEN (training mode).
    """
    def __init__(self,
                 video_root: str,
                 max_seq_len: int = MAX_SEQ_LEN,
                 min_prefix_len: int = MIN_PREFIX_LEN,
                 cache_frames_root: Optional[str] = None,
                 training: bool = True,
                 augment: bool = AUGMENT):
        self.video_root = video_root
        self.cache_root = cache_frames_root
        self.max_seq_len = max_seq_len
        self.min_prefix_len = min_prefix_len
        self.training = training
        self.transform = base_video_transform if (training and augment) else eval_video_transform
        self.samples: List[Tuple[str,int,bool]] = []  # (path,label,is_cached)
        if cache_frames_root and os.path.isdir(cache_frames_root):
            # cached mode: iterate class directories
            for cls_name, cls_idx in CLASS_MAP.items():
                cls_dir = os.path.join(cache_frames_root, cls_name)
                if not os.path.isdir(cls_dir):
                    continue
                clip_dirs = [d for d in glob.glob(os.path.join(cls_dir, "*")) if os.path.isdir(d)]
                for cd in clip_dirs:
                    self.samples.append((cd, cls_idx, True))
        # raw video mode
        for ext in ("*.mp4","*.avi","*.mov","*.mkv"):
            for vp in glob.glob(os.path.join(video_root, ext)):
                label = infer_label_from_filename(os.path.basename(vp))
                if label is None:
                    continue
                self.samples.append((vp, label, False))
        if len(self.samples) == 0:
            raise RuntimeError("No video samples found.")
    def __len__(self):
        return len(self.samples)
    def _load_cached_frames(self, folder: str) -> List[Image.Image]:
        frame_files = sorted(glob.glob(os.path.join(folder, "frame_*.jpg")))
        imgs = []
        for fp in frame_files[:self.max_seq_len*2]:  # allow extra
            try:
                imgs.append(Image.open(fp).convert("RGB"))
            except:
                pass
        return imgs
    def __getitem__(self, idx: int):
        path, label, is_cached = self.samples[idx]
        if is_cached:
            frames = self._load_cached_frames(path)
        else:
            frames = extract_frames_from_video(path, max_frames=self.max_seq_len*2)
        if len(frames) == 0:
            raise RuntimeError(f"Failed to decode frames for {path}")
        # Choose prefix length
        if self.training:
            L = random.randint(self.min_prefix_len, min(self.max_seq_len, len(frames)))
        else:
            L = min(self.max_seq_len, len(frames))
        selected = frames[:L]
        # Pad if shorter than max_seq_len (consistent tensor length optional)
        pad_needed = self.max_seq_len - L
        if pad_needed > 0:
            # repeat last frame
            selected.extend([selected[-1]] * pad_needed)
        tensor_frames = []
        for img in selected[:self.max_seq_len]:
            tensor_frames.append(self.transform(img))
        seq = torch.stack(tensor_frames, dim=0)  # (T,C,H,W)
        return seq, label

# Example: building training / validation loaders (adjust paths)
# video_train_root = "videos/train"
# video_val_root   = "videos/val"
# Optionally pre-extract frames:
# preextract_video_frames(video_train_root, "frames_cache/train")
# preextract_video_frames(video_val_root,   "frames_cache/val")

def build_video_loaders(video_train_root: str,
                        video_val_root: str,
                        batch_size: int = 16,
                        cache_train: Optional[str] = None,
                        cache_val: Optional[str] = None):
    train_ds = VideoEarlyGestureDataset(video_train_root, cache_frames_root=cache_train, training=True)
    val_ds   = VideoEarlyGestureDataset(video_val_root,   cache_frames_root=cache_val, training=False)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=0)
    return train_loader, val_loader

# EARLY METRICS EVALUATION FOR VIDEO PREFIXES
def evaluate_video_prefixes(model, loader, threshold: float = 0.6, device: torch.device = DEVICE):
    model.eval()
    correct_final = 0
    correct_early = 0
    confident_total = 0
    correct_confident = 0
    first_correct_ts = []
    first_confident_correct_ts = []
    total = 0
    with torch.no_grad():
        for seq, label in loader:
            seq = seq.to(device)
            label = label.to(device)
            logits_time = model(seq)  # (B,T,C)
            probs_time = torch.softmax(logits_time, dim=-1)
            final_preds = probs_time[:, -1].argmax(dim=-1)
            correct_final += (final_preds == label).sum().item()
            B,T,_ = probs_time.shape
            for b in range(B):
                gt = label[b].item()
                earliest_correct = None
                earliest_confident_correct = None
                confident_any = False
                for t in range(T):
                    p = probs_time[b,t]
                    pred_t = int(p.argmax().item())
                    conf_t = float(p.max().item())
                    if pred_t == gt and earliest_correct is None:
                        earliest_correct = t
                    if conf_t >= threshold:
                        confident_any = True
                        if pred_t == gt and earliest_confident_correct is None:
                            earliest_confident_correct = t
                    if earliest_correct is not None and earliest_confident_correct is not None:
                        break
                if earliest_correct is not None:
                    correct_early += 1
                    first_correct_ts.append(earliest_correct+1)
                if confident_any:
                    confident_total += 1
                if earliest_confident_correct is not None:
                    correct_confident += 1
                    first_confident_correct_ts.append(earliest_confident_correct+1)
            total += B
    final_acc = correct_final / max(1,total)
    early_cover = correct_early / max(1,total)
    confident_cover = confident_total / max(1,total)
    confident_correct_rate = correct_confident / max(1,total)
    avg_first_correct = sum(first_correct_ts)/len(first_correct_ts) if first_correct_ts else math.nan
    avg_first_conf_correct = (sum(first_confident_correct_ts)/len(first_confident_correct_ts)
                              if first_confident_correct_ts else math.nan)
    return {
        'final_acc': final_acc,
        'early_cover': early_cover,
        'confident_cover': confident_cover,
        'confident_correct_rate': confident_correct_rate,
        'avg_first_correct_t': avg_first_correct,
        'avg_first_confident_correct_t': avg_first_conf_correct
    }

# REAL-TIME BUFFERED VIDEO INFERENCE (simulate with a video file)
def stream_video_early(model, video_path: str, threshold: float = 0.6,
                       max_buffer: int = MAX_SEQ_LEN, device: torch.device = DEVICE):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Failed to open:", video_path)
        return
    buffer_frames: List[torch.Tensor] = []
    model.eval()
    frame_count = 0
    with torch.no_grad():
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil = Image.fromarray(rgb)
            tensor = eval_video_transform(pil).unsqueeze(0)  # (1,C,H,W)
            buffer_frames.append(tensor)
            if len(buffer_frames) > max_buffer:
                buffer_frames.pop(0)
            frame_count += 1
            if len(buffer_frames) >= MIN_PREFIX_LEN:
                seq = torch.cat(buffer_frames, dim=0).unsqueeze(0).to(device)  # (1,T,C,H,W)
                logits_time = model(seq)  # (1,T,C)
                probs_time = torch.softmax(logits_time, dim=-1)[0]
                # Check earliest confident
                earliest_confident = None
                for t in range(probs_time.size(0)):
                    p = probs_time[t]
                    conf = float(p.max().item())
                    if conf >= threshold:
                        earliest_confident = (t+1, int(p.argmax().item()), conf)
                        break
                if earliest_confident:
                    print(f"[frame {frame_count}] EARLY PRED timestep={earliest_confident[0]} "
                          f"class={INV_CLASS_MAP[earliest_confident[1]]} conf={earliest_confident[2]:.2f}")
    cap.release()

# EXAMPLE USAGE (uncomment and set paths):
# train_loader, val_loader = build_video_loaders("videos/train", "videos/val",
#                                                batch_size=16,
#                                                cache_train="frames_cache/train",
#                                                cache_val="frames_cache/val")
# metrics = evaluate_video_prefixes(model_early, val_loader)
# print(metrics)
# stream_video_early(model_early, "videos/val/rock_clip23.mp4")

# ----------------------------------------------------------
# NOTES SUMMARY (in-code):
# - Keep full clips; generate prefixes dynamically for training early predictions.
# - More variation (users/backgrounds) improves generalization.
# - Data transfer to VM: compress, pre-extract frames, keep only needed frames.
# - Extend MAX_SEQ_LEN if gestures need more temporal context; adjust weighting strategy.
# - Consider adding a "neutral" class if initial frames truly have no gesture (requires annotation).
# ----------------------------------------------------------