In [None]:
# Complete Autism Gesture Detection System
# Goal: Detect autism-related gestures (ArmFlapping, HeadBanging, Spinning) from any length video
# Approach: 3 binary classifiers + ensemble for final decision
# Cell 1: Environment Setup and Library Imports

import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import glob
import random
import json
from PIL import Image
import torch.nn.functional as F
from collections import Counter
import warnings

warnings.filterwarnings('ignore')

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Using device: {device}")

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Configuration
CONFIG = {
    'DATA_FOLDER': '/home/samir/projects/clipping_ssbd_videos/ssbd_clip_segment/',
    'SEQUENCE_LENGTH': 16,  # Reduced for better performance
    'IMG_SIZE': (224, 224),
    'BATCH_SIZE': 8,
    'LEARNING_RATE': 0.0001,
    'EPOCHS': 40,
    'PATIENCE': 10,  # Early stopping
    'GESTURE_NAMES': ['ArmFlapping', 'HeadBanging', 'Spinning']
}

print("✅ Environment setup complete!")

🚀 Using device: cpu
✅ Environment setup complete!


In [46]:
# Cell 2: Advanced Video Dataset with Negative Sample Generation
class AutismGestureDataset(Dataset):
    def __init__(self, video_paths, labels, gesture_type=None, transform=None, 
                 sequence_length=16, img_size=(224, 224), mode='train'):
        """
        Dataset for binary classification of specific autism gesture
        Args:
            video_paths: List of video paths
            labels: Binary labels (1: has gesture, 0: no gesture)
            gesture_type: Specific gesture to detect (0: ArmFlapping, 1: HeadBanging, 2: Spinning)
            transform: Image transformations
            sequence_length: Number of frames to extract
            img_size: Target image size
            mode: 'train', 'val', or 'test'
        """
        self.video_paths = video_paths
        self.labels = labels
        self.gesture_type = gesture_type
        self.transform = transform
        self.sequence_length = sequence_length
        self.img_size = img_size
        self.mode = mode
        
    def __len__(self):
        return len(self.video_paths)
    
    def extract_keyframes(self, video_path):
        """Extract key frames using optical flow and motion analysis"""
        cap = cv2.VideoCapture(video_path)
        frames = []
        motion_scores = []
        
        # Read all frames first
        all_frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame_resized = cv2.resize(frame_rgb, self.img_size)
            all_frames.append(frame_resized)
        cap.release()
        
        if len(all_frames) == 0:
            return None
        
        # Calculate motion between consecutive frames
        for i in range(1, len(all_frames)):
            frame1_gray = cv2.cvtColor(all_frames[i-1], cv2.COLOR_RGB2GRAY)
            frame2_gray = cv2.cvtColor(all_frames[i], cv2.COLOR_RGB2GRAY)
            
            # Calculate optical flow
            try:
                flow = cv2.calcOpticalFlowPyrLK(
                    frame1_gray, frame2_gray, 
                    np.array([[x, y] for x in range(0, frame1_gray.shape[1], 20) 
                             for y in range(0, frame1_gray.shape[0], 20)], dtype=np.float32).reshape(-1, 1, 2),
                    None
                )[0]
                # Calculate motion magnitude
                motion = np.mean(np.sqrt(np.sum(flow**2, axis=2))) if flow is not None else 0
            except:
                motion = 0
            motion_scores.append(motion)
        
        # Select frames based on motion and uniform sampling
        if len(all_frames) <= self.sequence_length:
            # For short videos: repeat frames
            selected_frames = all_frames.copy()
            while len(selected_frames) < self.sequence_length:
                selected_frames.extend(all_frames[:min(len(all_frames), 
                                                     self.sequence_length - len(selected_frames))])
            frames = selected_frames[:self.sequence_length]
        else:
            # For longer videos: smart sampling
            if self.mode == 'train' and random.random() < 0.3:
                # Random sampling for augmentation
                indices = sorted(random.sample(range(len(all_frames)), self.sequence_length))
            else:
                # Motion-based + uniform sampling
                if len(motion_scores) > 0:
                    high_motion_indices = np.argsort(motion_scores)[-self.sequence_length//2:]
                    uniform_indices = np.linspace(0, len(all_frames)-1, self.sequence_length//2, dtype=int)
                    combined_indices = sorted(set(list(high_motion_indices) + list(uniform_indices)))
                    if len(combined_indices) >= self.sequence_length:
                        indices = combined_indices[:self.sequence_length]
                    else:
                        # Fill remaining with uniform sampling
                        remaining = self.sequence_length - len(combined_indices)
                        extra_indices = np.linspace(0, len(all_frames)-1, remaining, dtype=int)
                        indices = sorted(set(list(combined_indices) + list(extra_indices)))[:self.sequence_length]
                else:
                    # Fallback to uniform sampling
                    indices = np.linspace(0, len(all_frames)-1, self.sequence_length, dtype=int)
            frames = [all_frames[i] for i in indices]
        
        return frames
    
    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        frames = self.extract_keyframes(video_path)
        
        if frames is None:
            # Return dummy data if video loading fails
            frames = [np.zeros((*self.img_size, 3), dtype=np.uint8) for _ in range(self.sequence_length)]
        
        # Apply transforms
        # Apply transforms
        transformed_frames = []
        for frame in frames:
            frame_pil = Image.fromarray(frame.astype('uint8'))
            transformed_frame = self.transform(frame_pil)
            transformed_frames.append(transformed_frame)

        # Ensure fixed-length sequence
        while len(transformed_frames) < self.sequence_length:
            transformed_frames.append(transformed_frames[-1])  # repeat last frame

        transformed_frames = transformed_frames[:self.sequence_length]  # truncate if too long

        video_tensor = torch.stack(transformed_frames)
        
        return video_tensor, torch.tensor(label, dtype=torch.long)

In [47]:
# Cell 3: Binary Gesture Classifier with Transfer Learning
class BinaryGestureClassifier(nn.Module):
    def __init__(self, sequence_length=16, dropout=0.3):
        """
        Binary classifier for specific autism gesture using transfer learning
        """
        super(BinaryGestureClassifier, self).__init__()
        
        # Pre-trained ResNet18 as feature extractor
        self.backbone = models.resnet18(pretrained=True)
        self.backbone.fc = nn.Identity()  # Remove final layer
        
        # Freeze early layers, fine-tune later layers
        for param in list(self.backbone.parameters())[:-10]:
            param.requires_grad = False
            
        self.feature_dim = 512
        self.sequence_length = sequence_length
        
        # Temporal modeling - simplified LSTM
        self.temporal = nn.LSTM(
            input_size=self.feature_dim,
            hidden_size=128,
            num_layers=1,
            batch_first=True,
            dropout=dropout,
            bidirectional=True
        )
        
        # Attention mechanism for important frame selection
        self.attention = nn.Sequential(
            nn.Linear(256, 64),  # 256 from bidirectional LSTM
            nn.Tanh(),
            nn.Linear(64, 1)
        )
        
        # Final classifier
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 2)  # Binary classification
        )
        
    def forward(self, x):
        batch_size, seq_len, C, H, W = x.shape
        
        # Extract features for each frame
        x = x.view(batch_size * seq_len, C, H, W)
        features = self.backbone(x)  # (batch_size * seq_len, 512)
        features = features.view(batch_size, seq_len, self.feature_dim)
        
        # Temporal modeling
        lstm_out, _ = self.temporal(features)  # (batch_size, seq_len, 256)
        
        # Attention mechanism
        attention_weights = self.attention(lstm_out)  # (batch_size, seq_len, 1)
        attention_weights = torch.softmax(attention_weights, dim=1)
        
        # Weighted average of LSTM outputs
        attended_features = torch.sum(lstm_out * attention_weights, dim=1)  # (batch_size, 256)
        
        # Classification
        output = self.classifier(attended_features)
        return output

In [48]:
# Cell 4: Data Loading and Preprocessing Functions
def load_and_prepare_data(data_folder):
    """
    Load dataset and prepare for binary classification training
    """
    gesture_names = CONFIG['GESTURE_NAMES']
    
    # Load original data
    video_paths = []
    original_labels = []
    
    for idx, gesture_name in enumerate(gesture_names):
        gesture_folder = os.path.join(data_folder, gesture_name)
        if not os.path.exists(gesture_folder):
            print(f"❌ Warning: {gesture_folder} not found!")
            continue
        videos = glob.glob(os.path.join(gesture_folder, '*.avi'))
        print(f"📁 {gesture_name}: {len(videos)} videos")
        video_paths.extend(videos)
        original_labels.extend([idx] * len(videos))
        
    print(f"📊 Total videos: {len(video_paths)}")
    return video_paths, original_labels, gesture_names

def create_binary_datasets(video_paths, original_labels, target_gesture_idx):
    """
    Create binary dataset for specific gesture detection
    Args:
        video_paths: List of video paths
        original_labels: Original multi-class labels
        target_gesture_idx: Index of target gesture (0, 1, or 2)
    Returns:
        Binary labeled dataset
    """
    binary_labels = [1 if label == target_gesture_idx else 0 for label in original_labels]
    
    # Count positive and negative samples
    pos_count = sum(binary_labels)
    neg_count = len(binary_labels) - pos_count
    
    print(f"🎯 Gesture {CONFIG['GESTURE_NAMES'][target_gesture_idx]}:")
    print(f"   ✅ Positive samples: {pos_count}")
    print(f"   ❌ Negative samples: {neg_count}")
    
    return video_paths, binary_labels

def create_data_loaders(video_paths, labels, target_gesture_idx, test_size=0.2, val_size=0.15):
    """
    Create data loaders with proper stratification and augmentation
    """
    # Data augmentation transforms
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomRotation(degrees=5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Stratified split
    X_temp, X_test, y_temp, y_test = train_test_split(
        video_paths, labels, test_size=test_size, stratify=labels, random_state=42
    )
    
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_size, stratify=y_temp, random_state=42
    )
    
    print(f"📈 Split sizes - Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")
    
    # Create datasets
    train_dataset = AutismGestureDataset(
        X_train, y_train, target_gesture_idx, train_transform, 
        CONFIG['SEQUENCE_LENGTH'], CONFIG['IMG_SIZE'], 'train'
    )
    
    val_dataset = AutismGestureDataset(
        X_val, y_val, target_gesture_idx, val_test_transform,
        CONFIG['SEQUENCE_LENGTH'], CONFIG['IMG_SIZE'], 'val'
    )
    
    test_dataset = AutismGestureDataset(
        X_test, y_test, target_gesture_idx, val_test_transform,
        CONFIG['SEQUENCE_LENGTH'], CONFIG['IMG_SIZE'], 'test'
    )
    
    # Weighted sampling for imbalanced classes
    class_counts = Counter(y_train)
    class_weights = {cls: 1.0/count for cls, count in class_counts.items()}
    sample_weights = [class_weights[label] for label in y_train]
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], 
                             sampler=sampler, num_workers=0, pin_memory=True)
    
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], 
                           shuffle=False, num_workers=0, pin_memory=True)
    
    test_loader = DataLoader(test_dataset, batch_size=CONFIG['BATCH_SIZE'], 
                            shuffle=False, num_workers=0, pin_memory=True)
    
    return train_loader, val_loader, test_loader

In [49]:
# Cell 5: Training Function with Early Stopping and Advanced Metrics
def train_binary_classifier(model, train_loader, val_loader, gesture_name, epochs=40):
    """
    Train binary classifier with advanced monitoring
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=CONFIG['LEARNING_RATE'], weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, 
                                                    patience=5, )
    
    # Tracking variables
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    train_aucs, val_aucs = [], []
    best_val_auc = 0.0
    patience_counter = 0
    best_model_state = None
    
    print(f"🚀 Training {gesture_name} classifier...")
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_preds, train_targets = [], []
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
        
        for videos, labels in train_pbar:
            videos, labels = videos.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(videos)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            train_preds.extend(torch.softmax(outputs, dim=1)[:, 1].cpu().detach().numpy())
            train_targets.extend(labels.cpu().numpy())
            
            # Update progress bar
            current_acc = accuracy_score(train_targets, np.array(train_preds) > 0.5)
            train_pbar.set_postfix({'Loss': f'{loss.item():.4f}', 'Acc': f'{current_acc:.3f}'})
            
        # Calculate training metrics
        train_acc = accuracy_score(train_targets, np.array(train_preds) > 0.5)
        train_auc = roc_auc_score(train_targets, train_preds)
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_preds, val_targets = [], []
        
        with torch.no_grad():
            for videos, labels in val_loader:
                videos, labels = videos.to(device), labels.to(device)
                outputs = model(videos)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                val_preds.extend(torch.softmax(outputs, dim=1)[:, 1].cpu().numpy())
                val_targets.extend(labels.cpu().numpy())
                
        # Calculate validation metrics
        val_acc = accuracy_score(val_targets, np.array(val_preds) > 0.5)
        val_auc = roc_auc_score(val_targets, val_preds)
        avg_val_loss = val_loss / len(val_loader)
        
        # Store metrics
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        train_aucs.append(train_auc)
        val_aucs.append(val_auc)
        
        # Print epoch results
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'  Train - Loss: {avg_train_loss:.4f}, Acc: {train_acc:.3f}, AUC: {train_auc:.3f}')
        print(f'  Val   - Loss: {avg_val_loss:.4f}, Acc: {val_acc:.3f}, AUC: {val_auc:.3f}')
        
        # Learning rate scheduling
        scheduler.step(val_auc)
        
        # Early stopping and best model saving
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_model_state = model.state_dict().copy()
            patience_counter = 0
            print(f'  🎉 New best AUC: {best_val_auc:.3f}')
        else:
            patience_counter += 1
            
        if patience_counter >= CONFIG['PATIENCE']:
            print(f'⏹️ Early stopping at epoch {epoch+1}')
            break
            
        print('-' * 60)
        
    # Load best model
    if best_model_state:
        model.load_state_dict(best_model_state)
        print(f'✅ Best model loaded with AUC: {best_val_auc:.3f}')
        
    return model, {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs,
        'train_aucs': train_aucs,
        'val_aucs': val_aucs,
        'best_val_auc': best_val_auc
    }

In [50]:
# Cell 6: Model Evaluation and Visualization Functions
def evaluate_binary_classifier(model, test_loader, gesture_name):
    """
    Comprehensive evaluation of binary classifier
    """
    model.eval()
    predictions = []
    probabilities = []
    true_labels = []
    
    with torch.no_grad():
        for videos, labels in tqdm(test_loader, desc=f'Testing {gesture_name}'):
            videos, labels = videos.to(device), labels.to(device)
            outputs = model(videos)
            probs = torch.softmax(outputs, dim=1)
            predictions.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            probabilities.extend(probs[:, 1].cpu().numpy())  # Probability of positive class
            true_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, predictions)
    auc_score = roc_auc_score(true_labels, probabilities)
    
    print(f"\n📊 {gesture_name} Classifier Results:")
    print(f"  🎯 Accuracy: {accuracy:.3f}")
    print(f"  📈 AUC Score: {auc_score:.3f}")
    
    # Classification report
    print(f"\n📋 Classification Report:")
    print(classification_report(true_labels, predictions, 
                              target_names=['No Gesture', gesture_name]))
    
    return {
        'accuracy': accuracy,
        'auc_score': auc_score,
        'predictions': predictions,
        'probabilities': probabilities,
        'true_labels': true_labels
    }

def plot_training_history(history, gesture_name):
    """
    Plot comprehensive training history
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'{gesture_name} Training History', fontsize=16)
    
    # Loss plot
    axes[0, 0].plot(history['train_losses'], label='Train Loss', color='blue')
    axes[0, 0].plot(history['val_losses'], label='Val Loss', color='red')
    axes[0, 0].set_title('Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Accuracy plot
    axes[0, 1].plot(history['train_accs'], label='Train Acc', color='blue')
    axes[0, 1].plot(history['val_accs'], label='Val Acc', color='red')
    axes[0, 1].set_title('Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # AUC plot
    axes[1, 0].plot(history['train_aucs'], label='Train AUC', color='blue')
    axes[1, 0].plot(history['val_aucs'], label='Val AUC', color='red')
    axes[1, 0].set_title('AUC Score')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('AUC')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # ROC Curve (placeholder for test data)
    axes[1, 1].text(0.5, 0.5, 'ROC Curve\n(Run evaluation first)', 
                   ha='center', va='center', transform=axes[1, 1].transAxes)
    axes[1, 1].set_title('ROC Curve')
    
    plt.tight_layout()
    plt.show()

def plot_roc_curve(results, gesture_name):
    """
    Plot ROC curve for binary classifier
    """
    fpr, tpr, _ = roc_curve(results['true_labels'], results['probabilities'])
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, linewidth=2, label=f'{gesture_name} (AUC = {results["auc_score"]:.3f})')
    plt.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curve - {gesture_name}')
    plt.legend()
    plt.grid(True)
    plt.show()

In [51]:
# Cell 7: Main Training Pipeline
def main_training_pipeline():
    """
    Main pipeline to train all three binary classifiers
    """
    print("🚀 Starting Autism Gesture Detection Training Pipeline")
    print("=" * 60)
    
    # Load data
    print("📂 Loading dataset...")
    video_paths, original_labels, gesture_names = load_and_prepare_data(CONFIG['DATA_FOLDER'])
    
    # Store trained models and results
    trained_models = {}
    training_histories = {}
    evaluation_results = {}
    
    # Train binary classifier for each gesture
    for gesture_idx, gesture_name in enumerate(gesture_names):
        print(f"\n🎯 Training Binary Classifier for: {gesture_name}")
        print("=" * 50)
        
        # Prepare binary dataset
        binary_video_paths, binary_labels = create_binary_datasets(
            video_paths, original_labels, gesture_idx
        )
        
        # Create data loaders
        train_loader, val_loader, test_loader = create_data_loaders(
            binary_video_paths, binary_labels, gesture_idx
        )
        
        # Initialize model
        model = BinaryGestureClassifier(
            sequence_length=CONFIG['SEQUENCE_LENGTH']
        ).to(device)
        
        print(f"🧠 Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        
        # Train model
        trained_model, history = train_binary_classifier(
            model, train_loader, val_loader, gesture_name, CONFIG['EPOCHS']
        )
        
        # Evaluate model
        results = evaluate_binary_classifier(trained_model, test_loader, gesture_name)
        
        # Store results
        trained_models[gesture_name] = trained_model
        training_histories[gesture_name] = history
        evaluation_results[gesture_name] = results
        
        # Plot training history
        plot_training_history(history, gesture_name)
        plot_roc_curve(results, gesture_name)
        
        # Save individual model
        model_path = f'binary_classifier_{gesture_name.lower()}.pth'
        torch.save({
            'model_state_dict': trained_model.state_dict(),
            'model_config': {
                'sequence_length': CONFIG['SEQUENCE_LENGTH'],
                'gesture_name': gesture_name,
                'gesture_idx': gesture_idx
            },
            'training_history': history,
            'evaluation_results': results
        }, model_path)
        
        print(f"💾 Model saved: {model_path}")
        print(f"✅ {gesture_name} classifier training complete!")
        print("-" * 50)
        
    return trained_models, training_histories, evaluation_results

In [None]:
# Cell 8: Ensemble Prediction System for Any Length Video
class AutismGestureEnsemble:
    def __init__(self, models_dict, gesture_names, threshold=0.5):
        """
        Ensemble system combining all three binary classifiers
        Args:
            models_dict: Dictionary of trained models {gesture_name: model}
            gesture_names: List of gesture names
            threshold: Confidence threshold for positive detection
        """
        self.models = models_dict
        self.gesture_names = gesture_names
        self.threshold = threshold
        
        # Set all models to evaluation mode
        for model in self.models.values():
            model.eval()
            
    def predict_single_video(self, video_path, return_details=False):
        """
        Predict autism gestures for a single video of any length
        Args:
            video_path: Path to video file
            return_details: If True, return detailed probabilities
        Returns:
            prediction: Dict with detected gestures and confidences
        """
        # Preprocessing transform
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Create temporary dataset for single video
        temp_dataset = AutismGestureDataset(
            [video_path], [0], transform=transform,
            sequence_length=CONFIG['SEQUENCE_LENGTH'],
            img_size=CONFIG['IMG_SIZE'], mode='test'
        )
        
        # Get video tensor
        video_tensor, _ = temp_dataset[0]
        video_tensor = video_tensor.unsqueeze(0).to(device)  # Add batch dimension
        
        results = {}
        detected_gestures = []
        
        # Test each binary classifier
        with torch.no_grad():
            for gesture_name, model in self.models.items():
                outputs = model(video_tensor)
                probabilities = torch.softmax(outputs, dim=1)
                confidence = probabilities[0, 1].item()  # Probability of positive class
                
                results[gesture_name] = {
                    'confidence': confidence,
                    'detected': confidence > self.threshold
                }
                
                if confidence > self.threshold:
                    detected_gestures.append({
                        'gesture': gesture_name,
                        'confidence': confidence
                    })
                    
        # Sort by confidence
        detected_gestures.sort(key=lambda x: x['confidence'], reverse=True)
        
        # Overall prediction
        prediction = {
            'has_autism_gesture': len(detected_gestures) > 0,
            'detected_gestures': detected_gestures,
            'primary_gesture': detected_gestures[0]['gesture'] if detected_gestures else None,
            'max_confidence': detected_gestures[0]['confidence'] if detected_gestures else 0.0,
            'all_confidences': results
        }
        
        if return_details:
            prediction['video_path'] = video_path
            prediction['threshold'] = self.threshold
            
        return prediction
    
    def predict_batch_videos(self, video_paths, batch_size=4):
        """
        Predict autism gestures for multiple videos
        Args:
            video_paths: List of video paths
            batch_size: Processing batch size
        Returns:
            List of predictions for each video
        """
        predictions = []
        print(f"🔄 Processing {len(video_paths)} videos...")
        
        for i in tqdm(range(0, len(video_paths), batch_size)):
            batch_paths = video_paths[i:i+batch_size]
            
            for video_path in batch_paths:
                try:
                    prediction = self.predict_single_video(video_path)
                    predictions.append(prediction)
                except Exception as e:
                    print(f"❌ Error processing {video_path}: {str(e)}")
                    predictions.append({
                        'has_autism_gesture': False,
                        'detected_gestures': [],
                        'primary_gesture': None,
                        'max_confidence': 0.0,
                        'error': str(e)
                    })
                    
        return predictions
    
    def set_threshold(self, new_threshold):
        """Update confidence threshold"""
        self.threshold = new_threshold
        print(f"🎯 Threshold updated to: {new_threshold}")
aws

In [53]:
# Cell 10: Model Loading and Saving Functions
def save_ensemble_model(trained_models, save_path='autism_gesture_ensemble.pth'):
    """
    Save the complete ensemble model
    Args:
        trained_models: Dictionary of trained models
        save_path: Path to save the ensemble
    """
    ensemble_data = {
        'model_states': {},
        'gesture_names': CONFIG['GESTURE_NAMES'],
        'config': CONFIG,
        'timestamp': torch.tensor(cv2.getTickCount())
    }
    
    for gesture_name, model in trained_models.items():
        ensemble_data['model_states'][gesture_name] = model.state_dict()
        
    torch.save(ensemble_data, save_path)
    print(f"💾 Ensemble model saved: {save_path}")

def load_ensemble_model(model_path='autism_gesture_ensemble.pth', threshold=0.5):
    """
    Load complete ensemble model
    Args:
        model_path: Path to saved ensemble model
        threshold: Detection threshold
    Returns:
        AutismGestureEnsemble instance
    """
    if not os.path.exists(model_path):
        print(f"❌ Model file not found: {model_path}")
        return None
        
    # Load saved data
    ensemble_data = torch.load(model_path, map_location=device)
    gesture_names = ensemble_data['gesture_names']
    
    # Reconstruct models
    models_dict = {}
    
    for gesture_name in gesture_names:
        model = BinaryGestureClassifier(
            sequence_length=ensemble_data['config']['SEQUENCE_LENGTH']
        ).to(device)
        
        model.load_state_dict(ensemble_data['model_states'][gesture_name])
        model.eval()
        models_dict[gesture_name] = model
        print(f"✅ Loaded {gesture_name} classifier")
        
    # Create ensemble
    ensemble = AutismGestureEnsemble(models_dict, gesture_names, threshold)
    print(f"✅ Ensemble model loaded: {model_path}")
    
    return ensemble

In [None]:
# Cell 12: Export Models for Deployment
def export_models_for_deployment(trained_models, export_dir='exported_models'):
    """
    Export trained models in various formats for deployment
    Args:
        trained_models: Dictionary of trained models
        export_dir: Directory to save exported models
    """
    os.makedirs(export_dir, exist_ok=True)
    print("📦 Exporting models for deployment...")
    
    # Create dummy input for tracing
    dummy_input = torch.randn(1, CONFIG['SEQUENCE_LENGTH'], 3, 224, 224).to(device)
    
    for gesture_name, model in trained_models.items():
        model.eval()
        
        # TorchScript export
        try:
            traced_model = torch.jit.trace(model, dummy_input)
            torchscript_path = os.path.join(export_dir, f'{gesture_name.lower()}_torchscript.pt')
            traced_model.save(torchscript_path)
            print(f"✅ TorchScript saved: {torchscript_path}")
        except Exception as e:
            print(f"❌ TorchScript export failed for {gesture_name}: {str(e)}")
            
        # ONNX export
        try:
            onnx_path = os.path.join(export_dir, f'{gesture_name.lower()}_model.onnx')
            torch.onnx.export(
                model, dummy_input, onnx_path,
                export_params=True,
                opset_version=11,
                do_constant_folding=True,
                input_names=['input'],
                output_names=['output'],
                dynamic_axes={
                    'input': {0: 'batch_size'},
                    'output': {0: 'batch_size'}
                }
            )
            print(f"✅ ONNX saved: {onnx_path}")
        except Exception as e:
            print(f"❌ ONNX export failed for {gesture_name}: {str(e)}")
            
    # Save ensemble configuration
    config_path = os.path.join(export_dir, 'ensemble_config.json')
    config_data = {
        'gesture_names': CONFIG['GESTURE_NAMES'],
        'sequence_length': CONFIG['SEQUENCE_LENGTH'],
        'img_size': CONFIG['IMG_SIZE'],
        'model_architecture': 'BinaryGestureClassifier',
        'input_shape': [CONFIG['SEQUENCE_LENGTH'], 3, 224, 224],
        'output_shape': [2],
        'preprocessing': {
            'mean': [0.485, 0.456, 0.406],
            'std': [0.229, 0.224, 0.225],
            'resize': [224, 224]
        }
    }
    
    with open(config_path, 'w') as f:
        json.dump(config_data, f, indent=2)
        
    print(f"✅ Configuration saved: {config_path}")

In [None]:
# Cell 13: Complete Demo and Testing Functions

def run_complete_demo():
    """
    Complete demonstration of the autism gesture detection system
    """
    print("🎬 Starting Complete Autism Gesture Detection Demo")
    print("=" * 60)
    
    # Step 1: Train models (if not already trained)
    print("🚀 Step 1: Training Models")
    trained_models, training_histories, evaluation_results = main_training_pipeline()
    
    # Step 2: Save ensemble model
    print("\n💾 Step 2: Saving Ensemble Model")
    save_ensemble_model(trained_models)
    
    # Step 3: Export for deployment
    print("\n📦 Step 3: Exporting for Deployment")
    export_models_for_deployment(trained_models)
    
    # Step 4: Load ensemble and test
    print("\n🔄 Step 4: Loading and Testing Ensemble")
    ensemble = AutismGestureEnsemble(trained_models, CONFIG['GESTURE_NAMES'], threshold=0.5)
    
    # Test on sample videos (if available)
    test_videos = glob.glob(os.path.join(CONFIG['DATA_FOLDER'], '*/*.avi'))[:10]  # First 10 videos
    
    if test_videos:
        print(f"🧪 Testing on {len(test_videos)} sample videos...")
        sample_predictions = ensemble.predict_single_video(test_videos[0])  # Only test first video
        
        print(f"\n📹 Video 1: {os.path.basename(test_videos[0])}")
        if 'error' not in sample_predictions:
            print(f"   🎯 Detected: {sample_predictions['primary_gesture'] or 'None'}")
            print(f"   📊 Confidence: {float(sample_predictions['max_confidence']):.3f}")
            print(f"   📋 All scores: {sample_predictions['all_confidences']}")
        else:
            print(f"   ❌ Error: {sample_predictions['error']}")
                
    print("\n✅ Complete demo finished!")
    return ensemble, trained_models, training_histories, evaluation_results


def test_single_video(video_path, ensemble=None, model_path='autism_gesture_ensemble.pth'):
    """
    Test a single video with the trained ensemble model
    Args:
        video_path: Path to video file
        ensemble: Pre-loaded ensemble model (optional)
        model_path: Path to saved ensemble model
    Returns:
        Prediction result dictionary
    """
    if ensemble is None:
        print("🔄 Loading ensemble model...")
        ensemble = load_ensemble_model(model_path)
        if ensemble is None:
            return None

    print(f"🔍 Analyzing video: {video_path}")
    prediction = ensemble.predict_single_video(video_path)

    print("\n📊 Analysis Results:")
    if 'error' in prediction:
        print(f"❌ Error: {prediction['error']}")
    else:
        # Get primary gesture and confidences
        primary_gesture = prediction['primary_gesture']
        max_confidence = prediction['max_confidence']
        all_confidences = prediction['all_confidences']

        print(f"🎯 Primary Gesture: {primary_gesture or 'None'}")
        print(f"📈 Confidence: {float(max_confidence):.3f}")
        print("\n📋 Detailed Confidences:")

        # Show individual gesture confidences
        for gesture_name, confidence_dict in all_confidences.items():
            confidence = confidence_dict.get('confidence', 0.0)
            print(f"  {gesture_name}: {float(confidence):.3f}")

        # New logic: Autism indication
        autism_detected = any(
            float(confidence_dict.get('confidence', 0.0)) > 0.5
            for confidence_dict in all_confidences.values()
        )
        if autism_detected:
            print("\n✅ Autism-related gesture detected!")
        else:
            print("\n🚫 No autism-related gesture detected.")

    return prediction

def interactive_cli_interface(ensemble=None, model_path='autism_gesture_ensemble.pth'):
    """
    Interactive command-line interface for using the autism gesture detection system
    """
    print("🤖 Starting Autism Gesture Detection CLI Interface")
    print("=" * 60)
    
    if ensemble is None:
        print("🔄 Loading ensemble model...")
        ensemble = load_ensemble_model(model_path)
        if ensemble is None:
            return
    
    while True:
        print("\nMENU:")
        print("1. Test single video")
        print("2. Exit")
        
        choice = input("\nEnter your choice (1-2): ")
        
        if choice == '1':
            video_path = input("Enter video path: ")
            if os.path.exists(video_path):
                test_single_video(video_path, ensemble)
            else:
                print("❌ File not found!")
                
        elif choice == '2':
            print("👋 Exiting CLI interface. Goodbye!")
            break
            
        else:
            print("❌ Invalid choice! Please enter a number between 1 and 2.")



# Example usage:
if __name__ == "__main__":
    print("🧠 Welcome to Autism Gesture Detection System")
    print("=" * 60)
    
    # Option 1: Run complete demo
    # ensemble, trained_models, histories, results = run_complete_demo()
    
    # Option 2: Load existing model and start CLI
    ensemble = load_ensemble_model()
    
    if ensemble:
        interactive_cli_interface(ensemble)
    else:
        print("❌ Failed to load ensemble model. Please check that model files exist.")

🧠 Welcome to Autism Gesture Detection System
✅ Loaded ArmFlapping classifier
✅ Loaded HeadBanging classifier
✅ Loaded Spinning classifier
✅ Ensemble model loaded: autism_gesture_ensemble.pth
🤖 Starting Autism Gesture Detection CLI Interface

MENU:
1. Test single video
2. Exit
🔍 Analyzing video: \\wsl.localhost\Ubuntu\home\samir\projects\clipping_ssbd_videos\ssbd_raw\v_Spinning_07.mp4

📊 Analysis Results:
🎯 Primary Gesture: Spinning
📈 Confidence: 0.958

📋 Detailed Confidences:
  ArmFlapping: 0.037
  HeadBanging: 0.013
  Spinning: 0.958

✅ Autism-related gesture detected!

MENU:
1. Test single video
2. Exit
🔍 Analyzing video: C:\Users\Lenovo\Downloads\CRkgy266nVA.mp4

📊 Analysis Results:
🎯 Primary Gesture: None
📈 Confidence: 0.000

📋 Detailed Confidences:
  ArmFlapping: 0.319
  HeadBanging: 0.064
  Spinning: 0.219

🚫 No autism-related gesture detected.

MENU:
1. Test single video
2. Exit
👋 Exiting CLI interface. Goodbye!
