In [5]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt
import time
import copy
import cv2
import mediapipe as mp
import pickle

# ===== SETUP =====

print(f"PyTorch version: {torch.__version__}")
print(f"MPS available: {torch.backends.mps.is_available()}")
print(f"MPS built: {torch.backends.mps.is_built()}")

# Set device (MPS for M1/M2, CUDA for NVIDIA, CPU otherwise)
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print(f"Using MPS device: {device}")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using CUDA device: {device}")
else:
    device = torch.device("cpu")
    print("Using CPU")

# Performance optimization
torch.backends.cudnn.benchmark = True

# Helper function to free memory
def empty_cache():
    """Empty GPU cache to free memory"""
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()
    elif torch.cuda.is_available():
        torch.cuda.empty_cache()

# ===== FACIAL LANDMARK EXTRACTION =====

# Initialize MediaPipe Face Mesh with more forgiving parameters
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(
    static_image_mode=True,
    max_num_faces=1,
    min_detection_confidence=0.3)  # Lowered from 0.5 to improve detection rate

# Add logging function
def log_detection_failure(image_path=None):
    """Log face detection failures"""
    if image_path:
        print(f"Face detection failed for image: {image_path}")
    else:
        print("Face detection failed")

class FacialActionUnitExtractor:
    def __init__(self):
        # Key AU landmarks for PD detection based on the paper
        self.au_landmarks = {
            'AU1': [10, 338],   # Inner brow raiser
            'AU2': [65, 295],   # Outer brow raiser
            'AU4': [9, 337],    # Brow lowerer - key for PD
            'AU6': [117, 346],  # Cheek raiser - key for PD
            'AU7': [159, 386],  # Lid tightener
            'AU9': [129, 358],  # Nose wrinkler
            'AU12': [61, 291],  # Lip corner puller (smile) - key for PD
            'AU15': [61, 291],  # Lip corner depressor
            'AU20': [0, 267]    # Lip stretcher
        }
        
        # PD-specific facial measurements
        self.pd_measurements = [
            'smile_symmetry',       # Asymmetry in smile - PD indicator
            'blink_rate',           # Reduced blink rate - PD indicator
            'facial_mobility',      # Reduced overall mobility - PD indicator
            'expression_transition' # Slow transitions - PD indicator
        ]
    
    def extract_aus(self, image, image_path=None):
        """Extract facial action units with focus on PD-relevant features"""
        # Convert image to RGB for MediaPipe
        rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w = rgb_image.shape[:2]  # Image dimensions for correct scaling
        
        # Process the image to find facial landmarks
        results = face_mesh.process(rgb_image)
        
        # Initialize AU values - 9 core AUs + 4 PD-specific measurements
        aus = np.zeros(13)
        
        # Check if face was detected
        if not results.multi_face_landmarks:
            if image_path:
                log_detection_failure(image_path)
            return aus, False  # Return zeros and False for detection failure
        
        landmarks = results.multi_face_landmarks[0].landmark
        
        try:
            # Convert landmarks to numpy array with correct scaling
            points = np.array([(lm.x * w, lm.y * h, lm.z) for lm in landmarks])
            
            # Extract key PD-related AUs
            # 1. AU6 (Cheek Raiser) - key for PD detection
            cheek_raise_left = np.linalg.norm(points[117] - points[123])
            cheek_raise_right = np.linalg.norm(points[346] - points[352])
            aus[0] = (cheek_raise_left + cheek_raise_right) / 2
            
            # 2. AU12 (Lip Corner Puller) - key for PD detection
            mouth_left = points[61]
            mouth_right = points[291]
            mouth_top = points[13]
            mouth_bottom = points[14]
            mouth_center = (mouth_top + mouth_bottom) / 2
            smile_measure = mouth_center[1] - (mouth_left[1] + mouth_right[1])/2
            aus[1] = smile_measure
            
            # 3. AU4 (Brow Lowerer) - key for PD detection
            brow_lower_left = np.linalg.norm(points[9] - points[107])
            brow_lower_right = np.linalg.norm(points[337] - points[336])
            aus[2] = (brow_lower_left + brow_lower_right) / 2
            
            # 4. AU1 (Inner Brow Raiser)
            inner_brow_raise = np.linalg.norm(points[10] - points[338])
            aus[3] = inner_brow_raise
            
            # 5. AU2 (Outer Brow Raiser)
            outer_brow_raise = np.linalg.norm(points[65] - points[295])
            aus[4] = outer_brow_raise
            
            # 6. AU7 (Lid Tightener)
            lid_tighten_left = np.linalg.norm(points[159] - points[145])
            lid_tighten_right = np.linalg.norm(points[386] - points[374])
            aus[5] = (lid_tighten_left + lid_tighten_right) / 2
            
            # 7. AU9 (Nose Wrinkler)
            nose_wrinkle = np.linalg.norm(points[129] - points[358])
            aus[6] = nose_wrinkle
            
            # 8. AU15 (Lip Corner Depressor)
            lip_corner_depress = np.linalg.norm(points[61] - points[291])
            aus[7] = lip_corner_depress
            
            # 9. AU20 (Lip Stretcher)
            lip_stretch = np.linalg.norm(points[0] - points[267])
            aus[8] = lip_stretch
            
            # PD-specific measurements
            # 10. Smile symmetry - PD often has asymmetrical facial expressions
            left_smile = np.linalg.norm(mouth_left - mouth_top)
            right_smile = np.linalg.norm(mouth_right - mouth_top)
            # Add epsilon to prevent division by zero
            epsilon = 1e-6
            smile_asymmetry = abs(left_smile - right_smile) / (max(left_smile, right_smile) + epsilon)
            aus[9] = smile_asymmetry
            
            # 11. Blink rate approximation (eye openness)
            left_eye_open = np.linalg.norm(points[159] - points[145])
            right_eye_open = np.linalg.norm(points[386] - points[374])
            eye_openness = (left_eye_open + right_eye_open) / 2
            aus[10] = eye_openness
            
            # 12. Overall facial mobility (average movement potential)
            # Use max to prevent negative values from smile_measure which might be negative
            facial_mobility = (max(0, aus[0]) + max(0, abs(aus[1])) + max(0, aus[2])) / 3
            aus[11] = facial_mobility
            
            # 13. Mouth corner resting position (hypomimia indicator)
            mouth_corner_rest = (mouth_left[1] + mouth_right[1]) / 2
            aus[12] = mouth_corner_rest
            
            return aus, True  # Return AUs and True for successful detection
            
        except Exception as e:
            print(f"Error extracting AUs: {e}")
            if image_path:
                log_detection_failure(image_path)
            return aus, False  # Return zeros and False for extraction failure

# ===== DATASET PREPARATION =====

class PDDataset(Dataset):
    def __init__(self, root_dir, annotations_file, transform=None, seq_length=30, target_size=(224, 224),
                inspect_features=False, balance_classes=True):
        """
        Dataset for Parkinson's Disease detection
        
        Args:
            root_dir: Directory with all the images/videos
            annotations_file: Path to CSV file with annotations
            transform: Optional transform to be applied
            seq_length: Number of frames to extract
            target_size: Target size for frames
            inspect_features: Whether to save feature statistics for inspection
            balance_classes: Whether to balance classes in the dataset
        """
        self.root_dir = root_dir
        self.transform = transform
        self.seq_length = seq_length
        self.target_size = target_size
        self.au_extractor = FacialActionUnitExtractor()
        self.inspect_features = inspect_features
        self.face_detection_failures = 0
        self.total_processed = 0
        
        # Read annotations
        self.annotations = pd.read_csv(annotations_file)
        
        # Balance classes if requested
        if balance_classes:
            pd_samples = self.annotations[self.annotations['has_pd'] == 1]
            non_pd_samples = self.annotations[self.annotations['has_pd'] == 0]
            
            # Downsample majority class or upsample minority class
            if len(pd_samples) < len(non_pd_samples):
                # Downsample non-PD
                non_pd_samples = non_pd_samples.sample(n=len(pd_samples), random_state=42)
                self.annotations = pd.concat([pd_samples, non_pd_samples])
            elif len(pd_samples) > len(non_pd_samples):
                # Upsample PD (with replacement)
                pd_samples = pd_samples.sample(n=len(non_pd_samples), random_state=42, replace=True)
                self.annotations = pd.concat([pd_samples, non_pd_samples])
        
        # Shuffle
        self.annotations = self.annotations.sample(frac=1, random_state=42).reset_index(drop=True)
        
        # Report class distribution
        pd_count = sum(self.annotations['has_pd'])
        total = len(self.annotations)
        print(f"Dataset loaded: {total} samples")
        print(f"Class distribution: PD={pd_count} ({pd_count/total*100:.2f}%), Non-PD={total-pd_count} ({(total-pd_count)/total*100:.2f}%)")
        
        # For feature inspection
        if inspect_features:
            self.pd_features = []
            self.non_pd_features = []
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        # Get label (PD or non-PD)
        has_pd = self.annotations.iloc[idx].get('has_pd', 0)
        
        # Get filename - handle both image and video datasets
        if 'video_file' in self.annotations.columns:
            file_path = os.path.join(self.root_dir, self.annotations.iloc[idx]['video_file'])
            is_video = True
        else:
            file_path = os.path.join(self.root_dir, self.annotations.iloc[idx]['image_file'])
            is_video = False
        
        # Extract features
        if is_video:
            # Process video file
            sequence, variances, detection_succeeded = self._process_video(file_path)
        else:
            # Process image file (with simulated sequence)
            sequence, variances, detection_succeeded = self._process_image(file_path)
        
        # Track detection failures
        self.total_processed += 1
        if not detection_succeeded:
            self.face_detection_failures += 1
            if self.total_processed % 100 == 0:
                print(f"Face detection failures: {self.face_detection_failures}/{self.total_processed} ({self.face_detection_failures/self.total_processed*100:.2f}%)")
        
        # Store features for inspection if needed and if detection succeeded
        if self.inspect_features and detection_succeeded:
            if has_pd == 1:
                self.pd_features.append(variances.numpy())
            else:
                self.non_pd_features.append(variances.numpy())
        
        # Return tensors
        label_tensor = torch.FloatTensor([has_pd])
        
        return sequence, variances, label_tensor
    
    def _process_video(self, video_path):
        """Process a video file to extract AU sequence and variances"""
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Could not open video file: {video_path}")
            # Return default values for failure
            default_sequence = torch.zeros((self.seq_length, 13))
            default_variances = torch.zeros(13)
            return default_sequence, default_variances, False
        
        # Get frame count
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Select frames
        if frame_count <= self.seq_length:
            indices = np.linspace(0, frame_count-1, self.seq_length, dtype=int)
        else:
            indices = np.linspace(0, frame_count-1, self.seq_length, dtype=int)
        
        # Extract AUs
        au_sequence = []
        detection_count = 0
        
        for frame_idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if ret:
                frame = cv2.resize(frame, self.target_size)
                aus, detected = self.au_extractor.extract_aus(frame, video_path)
                au_sequence.append(aus)
                if detected:
                    detection_count += 1
            else:
                if au_sequence:
                    au_sequence.append(au_sequence[-1])
                else:
                    au_sequence.append(np.zeros(13))
        
        cap.release()
        
        # Calculate detection success rate
        detection_success_rate = detection_count / self.seq_length
        detection_succeeded = detection_success_rate > 0.5  # Require at least 50% frames with detected faces
        
        # Calculate variances (key feature according to the paper)
        # Add small epsilon to prevent NaN values
        epsilon = 1e-6
        au_variances = np.var(au_sequence, axis=0) + epsilon
        
        # Convert to tensors
        sequence_np = np.array(au_sequence)
        au_sequence_tensor = torch.FloatTensor(sequence_np)
        au_variances_tensor = torch.FloatTensor(au_variances)
        
        return au_sequence_tensor, au_variances_tensor, detection_succeeded
    
    def _process_image(self, image_path):
        """Process an image file with simulated sequence for temporal features"""
        img = cv2.imread(image_path)
        if img is None:
            print(f"Could not find image: {image_path}")
            # Return default values for failure
            default_sequence = torch.zeros((self.seq_length, 13))
            default_variances = torch.zeros(13)
            return default_sequence, default_variances, False
        
        img = cv2.resize(img, self.target_size)
        aus, detected = self.au_extractor.extract_aus(img, image_path)
        
        if not detected:
            # Return default values for failure
            default_sequence = torch.zeros((self.seq_length, 13))
            default_variances = torch.zeros(13)
            return default_sequence, default_variances, False
        
        # Generate a sequence with subtle variations to simulate a video
        au_sequence = []
        for i in range(self.seq_length):
            # Create variations that decrease as frames progress (simulating expression change)
            decay_factor = (self.seq_length - i) / self.seq_length
            variation = np.random.normal(0, 0.02 * decay_factor, size=aus.shape)
            au_sequence.append(aus + variation)
        
        # Calculate variances with small epsilon to prevent NaN
        epsilon = 1e-6
        au_variances = np.var(au_sequence, axis=0) + epsilon
        
        # Convert to tensors
        sequence_np = np.array(au_sequence)
        au_sequence_tensor = torch.FloatTensor(sequence_np)
        au_variances_tensor = torch.FloatTensor(au_variances)
        
        return au_sequence_tensor, au_variances_tensor, True
    
    def get_feature_statistics(self):
        """Get statistics of features for PD and non-PD samples"""
        if not self.inspect_features:
            raise ValueError("Feature inspection was not enabled for this dataset")
        
        # Handle case with no valid features
        if not self.pd_features or not self.non_pd_features:
            print("Warning: No valid features collected. Face detection likely failed for all images.")
            
            # Return default values
            feature_names = [
                "AU6 (Cheek Raiser)", 
                "AU12 (Lip Corner Puller)",
                "AU4 (Brow Lowerer)",
                "AU1 (Inner Brow Raiser)",
                "AU2 (Outer Brow Raiser)",
                "AU7 (Lid Tightener)",
                "AU9 (Nose Wrinkler)",
                "AU15 (Lip Corner Depressor)",
                "AU20 (Lip Stretcher)",
                "Smile Asymmetry",
                "Eye Openness",
                "Facial Mobility",
                "Mouth Corner Rest"
            ]
            
            # Create placeholder statistics
            pd_means = np.zeros(13)
            non_pd_means = np.zeros(13)
            pd_std = np.zeros(13)
            non_pd_std = np.zeros(13)
            
            stats = {
                "feature_names": feature_names,
                "pd_means": pd_means,
                "non_pd_means": non_pd_means,
                "pd_std": pd_std,
                "non_pd_std": non_pd_std,
                "differences": pd_means - non_pd_means,
                "t_statistics": np.zeros(13)
            }
            
            return stats
        
        pd_features = np.array(self.pd_features)
        non_pd_features = np.array(self.non_pd_features)
        
        pd_means = np.mean(pd_features, axis=0)
        non_pd_means = np.mean(non_pd_features, axis=0)
        
        pd_std = np.std(pd_features, axis=0)
        non_pd_std = np.std(non_pd_features, axis=0)
        
        # Add epsilon to prevent division by zero
        epsilon = 1e-6
        t_statistics = (pd_means - non_pd_means) / np.sqrt(
            pd_std**2/(len(pd_features) + epsilon) + 
            non_pd_std**2/(len(non_pd_features) + epsilon) + 
            epsilon
        )
        
        feature_names = [
            "AU6 (Cheek Raiser)", 
            "AU12 (Lip Corner Puller)",
            "AU4 (Brow Lowerer)",
            "AU1 (Inner Brow Raiser)",
            "AU2 (Outer Brow Raiser)",
            "AU7 (Lid Tightener)",
            "AU9 (Nose Wrinkler)",
            "AU15 (Lip Corner Depressor)",
            "AU20 (Lip Stretcher)",
            "Smile Asymmetry",
            "Eye Openness",
            "Facial Mobility",
            "Mouth Corner Rest"
        ]
        
        stats = {
            "feature_names": feature_names,
            "pd_means": pd_means,
            "non_pd_means": non_pd_means,
            "pd_std": pd_std,
            "non_pd_std": non_pd_std,
            "differences": pd_means - non_pd_means,
            "t_statistics": t_statistics
        }
        
        return stats

# ===== NEW FILTERING FUNCTION =====

def filter_dataset(root_dir, annotations_file, output_file):
    """Create a new dataset with only images where face detection succeeds"""
    print(f"Filtering dataset: {annotations_file}")
    
    annotations = pd.read_csv(annotations_file)
    valid_rows = []
    au_extractor = FacialActionUnitExtractor()
    
    for i, row in annotations.iterrows():
        # Get file path
        if 'video_file' in annotations.columns:
            file_path = os.path.join(root_dir, row['video_file'])
            is_video = True
        else:
            file_path = os.path.join(root_dir, row['image_file'])
            is_video = False
        
        if is_video:
            # For videos, check a few frames
            cap = cv2.VideoCapture(file_path)
            if not cap.isOpened():
                print(f"Could not open video file: {file_path}")
                continue
            
            # Check middle frame
            frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count // 2)
            ret, frame = cap.read()
            
            if ret:
                frame = cv2.resize(frame, (224, 224))
                _, detected = au_extractor.extract_aus(frame, file_path)
                
                if detected:
                    valid_rows.append(row)
            
            cap.release()
        else:
            # For images, check the image
            img = cv2.imread(file_path)
            if img is None:
                print(f"Could not find image: {file_path}")
                continue
            
            img = cv2.resize(img, (224, 224))
            _, detected = au_extractor.extract_aus(img, file_path)
            
            if detected:
                valid_rows.append(row)
        
        # Report progress
        if (i+1) % 100 == 0:
            print(f"Processed {i+1}/{len(annotations)}, valid: {len(valid_rows)}")
    
    # Create filtered dataset
    valid_df = pd.DataFrame(valid_rows)
    valid_df.to_csv(output_file, index=False)
    
    # Report results
    print(f"Filtered dataset created: {output_file}")
    print(f"Original dataset: {len(annotations)} samples")
    print(f"Filtered dataset: {len(valid_df)} samples ({len(valid_df)/len(annotations)*100:.2f}%)")
    
    return valid_df

# ===== EMOTION RECOGNITION INTEGRATION =====

def pd_from_emotions(emotion_probs):
    """Estimate PD likelihood from emotion probabilities
    
    Args:
        emotion_probs: Array of emotion probabilities [surprise, fear, disgust, happy, sad, angry, neutral]
        
    Returns:
        Estimated PD probability (0-1)
    """
    # PD is characterized by reduced emotional expressivity and increased neutral expression
    # This is a simplified heuristic - you should tune these weights based on clinical data
    
    # Weight for each emotion (higher weight = more important for PD detection)
    weights = np.array([0.1, 0.1, 0.1, 0.3, 0.1, 0.1, 0.2])
    
    # PD indicators (for each emotion: low value = PD indication, high value = healthy indication)
    # Flip for neutral emotion where high values indicate PD
    indicators = np.array([
        1.0,  # surprise - reduced in PD
        1.0,  # fear - reduced in PD
        1.0,  # disgust - reduced in PD
        1.0,  # happy - significantly reduced in PD
        0.8,  # sad - slightly reduced in PD
        0.8,  # angry - slightly reduced in PD
        -1.0  # neutral - increased in PD (note the negative sign)
    ])
    
    # Calculate weighted indicators
    weighted_indicators = weights * indicators * emotion_probs
    
    # Sum and normalize to 0-1 range
    pd_indicator = 1.0 - (sum(weighted_indicators) + 0.3) / 0.6
    
    # Clip to valid probability range
    return max(0.0, min(1.0, pd_indicator))

# ===== COMBINED FEATURE EXTRACTION =====

class CombinedPDDetector:
    """Class that combines emotion recognition with facial AUs for PD detection"""
    
    def __init__(self, emotion_model_path, pd_model_paths=None):
        """
        Initialize the detector with pre-trained models
        
        Args:
            emotion_model_path: Path to your pre-trained emotion recognition model
            pd_model_paths: Dictionary of paths to PD-specific models
                (e.g., {'rf': 'rf_model.pkl', 'simple': 'simple_pd_model.pth'})
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else 
                                  "mps" if torch.backends.mps.is_available() else 
                                  "cpu")
        
        # Load your emotion recognition model
        try:
            # Placeholder - replace with your actual emotion model loading code
            # self.emotion_model = torch.load(emotion_model_path, map_location=self.device)
            # self.emotion_model.eval()
            self.emotion_model = None  # Placeholder
            print("Emotion model loaded successfully")
        except Exception as e:
            print(f"Error loading emotion model: {e}")
            self.emotion_model = None
        
        # Load PD-specific models if provided
        self.pd_models = {}
        if pd_model_paths:
            try:
                # Load Random Forest if available
                if 'rf' in pd_model_paths:
                    with open(pd_model_paths['rf'], 'rb') as f:
                        self.pd_models['rf'] = pickle.load(f)
                
                # Load Simple Model if available
                if 'simple' in pd_model_paths:
                    self.pd_models['simple'] = SimplePDModel(input_size=13)
                    self.pd_models['simple'].load_state_dict(
                        torch.load(pd_model_paths['simple'], map_location=self.device)
                    )
                    self.pd_models['simple'].to(self.device)
                    self.pd_models['simple'].eval()
                
                # Load Temporal Model if available
                if 'temporal' in pd_model_paths:
                    self.pd_models['temporal'] = TemporalPDModel(input_size=13, hidden_size=64, num_layers=2)
                    self.pd_models['temporal'].load_state_dict(
                        torch.load(pd_model_paths['temporal'], map_location=self.device)
                    )
                    self.pd_models['temporal'].to(self.device)
                    self.pd_models['temporal'].eval()
                
                print(f"Loaded {len(self.pd_models)} PD models")
            except Exception as e:
                print(f"Error loading PD models: {e}")
        
        # Initialize feature extractors
        self.au_extractor = FacialActionUnitExtractor()
    
    def detect_from_image(self, image_path, use_emotion=True, use_aus=True):
        """
        Perform PD detection on a single image
        
        Args:
            image_path: Path to input image
            use_emotion: Whether to use emotion recognition features
            use_aus: Whether to use facial action unit features
            
        Returns:
            Dictionary with detection results
        """
        # Load image
        img = cv2.imread(image_path)
        if img is None:
            return {"error": f"Could not load image: {image_path}"}
        
        img = cv2.resize(img, (224, 224))
        
        # Create results dictionary
        results = {
            "image_path": image_path,
            "pd_probability": 0.0,
            "emotion_features": None,
            "au_features": None,
            "model_predictions": {}
        }
        
        # Extract emotion features if requested
        if use_emotion and self.emotion_model:
            try:
                # Convert image for emotion model
                rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                # This is a placeholder - replace with your actual emotion model preprocessing
                # emotion_input = preprocess_for_emotion_model(rgb_img)
                # emotion_input = emotion_input.to(self.device)
                
                # Get emotion probabilities
                # with torch.no_grad():
                #     emotion_probs = self.emotion_model(emotion_input)
                #     emotion_probs = emotion_probs.cpu().numpy()
                
                # Placeholder emotion probabilities
                emotion_probs = np.array([0.1, 0.1, 0.1, 0.3, 0.1, 0.1, 0.2])
                
                # Calculate PD probability from emotions
                pd_from_emotion = pd_from_emotions(emotion_probs)
                
                results["emotion_features"] = emotion_probs
                results["model_predictions"]["emotion_based"] = pd_from_emotion
            except Exception as e:
                print(f"Error processing emotions: {e}")
        
        # Extract AU features if requested
        if use_aus:
            try:
                # Extract AUs
                aus, detected = self.au_extractor.extract_aus(img, image_path)
                
                if not detected:
                    results["error"] = "Face not detected in image"
                    return results
                
                # Generate sequence with variations
                au_sequence = []
                for i in range(30):  # 30 frames like in training
                    decay_factor = (30 - i) / 30
                    variation = np.random.normal(0, 0.02 * decay_factor, size=aus.shape)
                    au_sequence.append(aus + variation)
                
                # Calculate variances
                au_variances = np.var(au_sequence, axis=0)
                
                results["au_features"] = au_variances
                
                # Make predictions with PD models if available
                if self.pd_models:
                    # Random Forest prediction
                    if 'rf' in self.pd_models:
                        rf_pred = self.pd_models['rf'].predict_proba([au_variances])[0, 1]
                        results["model_predictions"]["rf"] = rf_pred
                    
                    # Simple Model prediction
                    if 'simple' in self.pd_models:
                        with torch.no_grad():
                            simple_input = torch.FloatTensor(au_variances).to(self.device)
                            simple_pred = self.pd_models['simple'](simple_input).item()
                            results["model_predictions"]["simple"] = simple_pred
                    
                    # Temporal Model prediction
                    if 'temporal' in self.pd_models:
                        with torch.no_grad():
                            sequence_np = np.array(au_sequence)
                            temporal_input = torch.FloatTensor(sequence_np).unsqueeze(0).to(self.device)
                            temporal_pred = self.pd_models['temporal'](temporal_input).item()
                            results["model_predictions"]["temporal"] = temporal_pred
            except Exception as e:
                print(f"Error processing AUs: {e}")
        
        # Calculate final probability (ensemble)
        if results["model_predictions"]:
            # Average all available predictions
            pd_probability = sum(results["model_predictions"].values()) / len(results["model_predictions"])
            results["pd_probability"] = pd_probability
            
            # Add likelihood category
            if pd_probability > 0.7:
                results["pd_likelihood"] = "High"
            elif pd_probability > 0.3:
                results["pd_likelihood"] = "Medium"
            else:
                results["pd_likelihood"] = "Low"
        
        return results
    
    def detect_from_video(self, video_path, use_emotion=True, use_aus=True):
        """
        Perform PD detection on a video
        
        Args:
            video_path: Path to input video
            use_emotion: Whether to use emotion recognition features
            use_aus: Whether to use facial action unit features
            
        Returns:
            Dictionary with detection results
        """
        # Check if video exists
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            return {"error": f"Could not open video: {video_path}"}
        
        # Get video properties
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        
        # Select frames to process
        if frame_count > 30:
            indices = np.linspace(0, frame_count-1, 30, dtype=int)
        else:
            indices = range(frame_count)
        
        # Create results dictionary
        results = {
            "video_path": video_path,
            "pd_probability": 0.0,
            "frames_processed": 0,
            "face_detection_rate": 0.0,
            "emotion_features": None,
            "au_features": None,
            "model_predictions": {}
        }
        
        # Process frames
        emotion_probs_list = []
        au_sequence = []
        face_detection_count = 0
        
        for i in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = cap.read()
            
            if not ret:
                continue
            
            # Resize frame
            frame = cv2.resize(frame, (224, 224))
            
            # Extract emotions if requested
            if use_emotion and self.emotion_model:
                try:
                    # Convert frame for emotion model
                    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    # This is a placeholder - replace with your actual emotion model preprocessing
                    # emotion_input = preprocess_for_emotion_model(rgb_frame)
                    # emotion_input = emotion_input.to(self.device)
                    
                    # Get emotion probabilities
                    # with torch.no_grad():
                    #     emotion_probs = self.emotion_model(emotion_input)
                    #     emotion_probs = emotion_probs.cpu().numpy()
                    
                    # Placeholder emotion probabilities
                    emotion_probs = np.array([0.1, 0.1, 0.1, 0.3, 0.1, 0.1, 0.2])
                    emotion_probs_list.append(emotion_probs)
                except Exception as e:
                    print(f"Error processing emotions for frame {i}: {e}")
            
            # Extract AUs if requested
            if use_aus:
                try:
                    # Extract AUs
                    aus, detected = self.au_extractor.extract_aus(frame)
                    
                    if detected:
                        face_detection_count += 1
                    
                    # Add to sequence
                    au_sequence.append(aus)
                except Exception as e:
                    print(f"Error processing AUs for frame {i}: {e}")
                    au_sequence.append(np.zeros(13))
        
        cap.release()
        
        # Update results
        results["frames_processed"] = len(indices)
        results["face_detection_rate"] = face_detection_count / len(indices) if indices else 0
        
        # Process emotion features
        if emotion_probs_list:
            # Average emotion probabilities across frames
            avg_emotion_probs = np.mean(emotion_probs_list, axis=0)
            results["emotion_features"] = avg_emotion_probs
            
            # Calculate PD probability from emotions
            pd_from_emotion = pd_from_emotions(avg_emotion_probs)
            results["model_predictions"]["emotion_based"] = pd_from_emotion
        
        # Process AU features
        if au_sequence:
            # Calculate variances
            au_variances = np.var(au_sequence, axis=0)
            results["au_features"] = au_variances
            
            # Make predictions with PD models if available
            if self.pd_models:
                # Random Forest prediction
                if 'rf' in self.pd_models:
                    rf_pred = self.pd_models['rf'].predict_proba([au_variances])[0, 1]
                    results["model_predictions"]["rf"] = rf_pred
                
                # Simple Model prediction
                if 'simple' in self.pd_models:
                    with torch.no_grad():
                        simple_input = torch.FloatTensor(au_variances).to(self.device)
                        simple_pred = self.pd_models['simple'](simple_input).item()
                        results["model_predictions"]["simple"] = simple_pred
                
                # Temporal Model prediction
                if 'temporal' in self.pd_models:
                    with torch.no_grad():
                        sequence_np = np.array(au_sequence)
                        temporal_input = torch.FloatTensor(sequence_np).unsqueeze(0).to(self.device)
                        temporal_pred = self.pd_models['temporal'](temporal_input).item()
                        results["model_predictions"]["temporal"] = temporal_pred
        
        # Calculate final probability (ensemble)
        if results["model_predictions"]:
            # Average all available predictions
            pd_probability = sum(results["model_predictions"].values()) / len(results["model_predictions"])
            results["pd_probability"] = pd_probability
            
            # Add likelihood category
            if pd_probability > 0.7:
                results["pd_likelihood"] = "High"
            elif pd_probability > 0.3:
                results["pd_likelihood"] = "Medium"
            else:
                results["pd_likelihood"] = "Low"
        
        return results

# ===== MAIN FUNCTION WITH DETECTION =====

def main():
    # Set paths
    data_dir = '/path/to/your/data'
    emotion_model_path = '/path/to/your/emotion_model.pth'
    
    # Filter original dataset to only include images with detectable faces
    # This is crucial for preventing errors later
    filter_dataset(
        root_dir=data_dir,
        annotations_file=os.path.join(data_dir, 'pd_annotations_balanced.csv'),
        output_file=os.path.join(data_dir, 'pd_annotations_filtered.csv')
    )
    
    # Split into train/val sets (80/20 split)
    annotations = pd.read_csv(os.path.join(data_dir, 'pd_annotations_filtered.csv'))
    from sklearn.model_selection import train_test_split
    train_df, val_df = train_test_split(annotations, test_size=0.2, stratify=annotations['has_pd'], random_state=42)
    
    # Save train/val splits
    train_df.to_csv(os.path.join(data_dir, 'pd_train_filtered.csv'), index=False)
    val_df.to_csv(os.path.join(data_dir, 'pd_val_filtered.csv'), index=False)
    
    # Create datasets with feature inspection enabled
    train_dataset = PDDataset(
        root_dir=data_dir,
        annotations_file=os.path.join(data_dir, 'pd_train_filtered.csv'),
        inspect_features=True,
        balance_classes=True
    )
    
    val_dataset = PDDataset(
        root_dir=data_dir,
        annotations_file=os.path.join(data_dir, 'pd_val_filtered.csv'),
        inspect_features=True,
        balance_classes=True
    )
    
    # Validate features - see if they can distinguish PD from non-PD
    feature_stats = train_dataset.get_feature_statistics()
    
    # Print feature differences
    print("\nFeature Differences (PD vs non-PD):")
    print("-" * 50)
    for i, name in enumerate(feature_stats["feature_names"]):
        diff = feature_stats["differences"][i]
        t_stat = feature_stats["t_statistics"][i]
        pd_val = feature_stats["pd_means"][i]
        non_pd_val = feature_stats["non_pd_means"][i]
        
        significance = ""
        if abs(t_stat) > 2.58:
            significance = "*** (p<0.01)"
        elif abs(t_stat) > 1.96:
            significance = "** (p<0.05)"
        elif abs(t_stat) > 1.65:
            significance = "* (p<0.1)"
        
        print(f"{name}: PD={pd_val:.5f}, Non-PD={non_pd_val:.5f}, Diff={diff:.5f}, t={t_stat:.2f} {significance}")
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)
    
    # Define models
    simple_model = SimplePDModel(input_size=13).to(device)
    temporal_model = TemporalPDModel(input_size=13, hidden_size=64, num_layers=2).to(device)
    
    # Define loss function and optimizer
    criterion = F.binary_cross_entropy
    
    # Train simple model
    simple_optimizer = optim.AdamW(simple_model.parameters(), lr=0.001, weight_decay=0.01)
    simple_scheduler = optim.lr_scheduler.ReduceLROnPlateau(simple_optimizer, 'min', patience=3, factor=0.5)
    
    # Train temporal model
    temporal_optimizer = optim.AdamW(temporal_model.parameters(), lr=0.001, weight_decay=0.01)
    temporal_scheduler = optim.lr_scheduler.ReduceLROnPlateau(temporal_optimizer, 'min', patience=3, factor=0.5)
    
    # Train models (simplified training loop)
    num_epochs = 10
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        
        # Train simple model
        simple_model.train()
        for _, variances, labels in train_loader:
            variances = variances.to(device)
            labels = labels.to(device)
            
            simple_optimizer.zero_grad()
            outputs = simple_model(variances)
            loss = criterion(outputs, labels)
            loss.backward()
            simple_optimizer.step()
        
        # Train temporal model
        temporal_model.train()
        for sequences, _, labels in train_loader:
            sequences = sequences.to(device)
            labels = labels.to(device)
            
            temporal_optimizer.zero_grad()
            outputs = temporal_model(sequences)
            loss = criterion(outputs, labels)
            loss.backward()
            temporal_optimizer.step()
    
    # Save models
    torch.save(simple_model.state_dict(), 'simple_pd_model.pth')
    torch.save(temporal_model.state_dict(), 'temporal_pd_model.pth')
    
    print("Training complete. Models saved.")
    
    # Create combined detector
    detector = CombinedPDDetector(
        emotion_model_path=emotion_model_path,
        pd_model_paths={
            'simple': 'simple_pd_model.pth',
            'temporal': 'temporal_pd_model.pth'
        }
    )
    
    # Test on an image
    test_image = os.path.join(data_dir, 'test_image.jpg')
    if os.path.exists(test_image):
        results = detector.detect_from_image(test_image)
        print("\nPD Detection Results:")
        print(f"PD Probability: {results['pd_probability']:.4f}")
        if 'pd_likelihood' in results:
            print(f"PD Likelihood: {results['pd_likelihood']}")
        print("Model Predictions:")
        for model_name, pred in results.get("model_predictions", {}).items():
            print(f"  {model_name}: {pred:.4f}")
    
    # Create live video detection function
    def run_live_detection():
        """Run live PD detection using webcam"""
        # Initialize combined detector
        detector = CombinedPDDetector(
            emotion_model_path=emotion_model_path,
            pd_model_paths={
                'simple': 'simple_pd_model.pth',
                'temporal': 'temporal_pd_model.pth'
            }
        )
        
        # Open webcam
        cap = cv2.VideoCapture(0)
        if not cap.isOpened():
            print("Could not open webcam")
            return
        
        # For temporal features
        frame_buffer = []
        buffer_size = 30  # Same as in training
        
        print("PD Detection started. Press 'q' to quit.")
        
        start_time = time.time()
        frame_count = 0
        
        while True:
            # Capture frame
            ret, frame = cap.read()
            if not ret:
                print("Failed to capture frame")
                break
                
            # Resize for processing
            process_frame = cv2.resize(frame, (224, 224))
            
            # Extract AUs
            aus, detected = detector.au_extractor.extract_aus(process_frame)
            
            # Add to buffer
            frame_buffer.append(aus)
            if len(frame_buffer) > buffer_size:
                frame_buffer.pop(0)
                
            # If buffer is full and face detected, make prediction
            if len(frame_buffer) == buffer_size and detected:
                # Calculate variances
                au_variances = np.var(frame_buffer, axis=0)
                
                # Make predictions
                predictions = {}
                
                # 1. Random Forest (if available)
                if 'rf' in detector.pd_models:
                    rf_pred = detector.pd_models['rf'].predict_proba([au_variances])[0, 1]
                    predictions['rf'] = rf_pred
                
                # 2. Simple Model
                if 'simple' in detector.pd_models:
                    with torch.no_grad():
                        simple_input = torch.FloatTensor(au_variances).to(device)
                        simple_pred = detector.pd_models['simple'](simple_input).item()
                        predictions['simple'] = simple_pred
                
                # 3. Temporal Model
                if 'temporal' in detector.pd_models:
                    with torch.no_grad():
                        sequence_np = np.array(frame_buffer)
                        temporal_input = torch.FloatTensor(sequence_np).unsqueeze(0).to(device)
                        temporal_pred = detector.pd_models['temporal'](temporal_input).item()
                        predictions['temporal'] = temporal_pred
                
                # Ensemble prediction
                if predictions:
                    ensemble_pred = sum(predictions.values()) / len(predictions)
                    
                    # Determine likelihood category
                    if ensemble_pred > 0.7:
                        likelihood = "High"
                        color = (0, 0, 255)  # Red
                    elif ensemble_pred > 0.3:
                        likelihood = "Medium"
                        color = (0, 165, 255)  # Orange
                    else:
                        likelihood = "Low"
                        color = (0, 255, 0)  # Green
                    
                    # Display prediction on frame
                    cv2.putText(frame, f"PD Probability: {ensemble_pred:.2f} ({likelihood})", 
                                (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
                    
                    # Display key features
                    features = {
                        'Cheek Raiser': au_variances[0],
                        'Lip Corner Puller': au_variances[1],
                        'Brow Lowerer': au_variances[2],
                        'Smile Asymmetry': au_variances[9],
                        'Facial Mobility': au_variances[11]
                    }
                    
                    y_pos = 60
                    for name, value in features.items():
                        # Normalize feature value for display
                        norm_value = min(1.0, max(0.0, value / 0.05))  # Adjust denominator as needed
                        feature_text = f"{name}: {value:.4f}"
                        cv2.putText(frame, feature_text, (10, y_pos), 
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
                        
                        # Draw bar representation
                        bar_length = int(150 * norm_value)
                        cv2.rectangle(frame, (200, y_pos-15), (200+bar_length, y_pos-5), color, -1)
                        
                        y_pos += 25
                    
                    # Calculate FPS
                    frame_count += 1
                    elapsed_time = time.time() - start_time
                    if elapsed_time >= 1.0:
                        fps = frame_count / elapsed_time
                        frame_count = 0
                        start_time = time.time()
                        cv2.putText(frame, f"FPS: {fps:.1f}", (frame.shape[1]-120, 30), 
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
            else:
                # Display status when face not detected or buffer not full
                if not detected:
                    cv2.putText(frame, "Face not detected", (10, 30), 
                                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
                elif len(frame_buffer) < buffer_size:
                    cv2.putText(frame, f"Buffering: {len(frame_buffer)}/{buffer_size}", (10, 30), 
                                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
            
            # Show frame
            cv2.imshow('Parkinson\'s Disease Detection', frame)
            
            # Check for exit
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        
        # Release resources
        cap.release()
        cv2.destroyAllWindows()
    
    # Add the function to be callable
    print("\nTo run live detection, call run_live_detection()")
    
    # Return the live detection function
    return run_live_detection

if __name__ == "__main__":
    run_live_detection = main()
    # Uncomment to run live detection immediately
    # run_live_detection()

PyTorch version: 2.6.0
MPS available: True
MPS built: True
Using MPS device: mps
Filtering dataset: /path/to/your/data/pd_annotations_balanced.csv


I0000 00:00:1744855914.330521 1984146 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 88.1), renderer: Apple M2
W0000 00:00:1744855914.331939 1996382 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1744855914.335970 1996385 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


FileNotFoundError: [Errno 2] No such file or directory: '/path/to/your/data/pd_annotations_balanced.csv'

In [None]:
# Second cell - Run the live detector
%run pd_live_detector.py