In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification
import cv2
import face_recognition
from PIL import Image
import os
import glob
from datetime import datetime
from tqdm import tqdm
import numpy as np
import pandas as pd
import argparse

In [2]:
# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================
ANALYSIS_OUTPUT_ROOT = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/flywheel"
MODEL_PATH = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/sup_training/V29_20250710_082807"
CENTROIDS_PATH = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/flywheel/emotion_centroids.pt" # The file created by the last script

# This is a new hyperparameter. It's the distance threshold for the "gatekeeper".
# A smaller value makes the filter stricter. You can tune this value later.
RELEVANCE_THRESHOLD = 0.6 

os.makedirs(ANALYSIS_OUTPUT_ROOT, exist_ok=True)

In [3]:
# ==============================================================================
# 2. HELPER FUNCTIONS
# ==============================================================================

# Dynamically determines the next version number by scanning a directory.
def get_next_version(base_dir):
    all_entries = glob.glob(os.path.join(base_dir, "V*_*"))
    existing = [os.path.basename(d) for d in all_entries if os.path.isdir(d)]
    versions = [int(d[1:].split("_")[0]) for d in existing if d.startswith("V") and "_" in d and d[1:].split("_")[0].isdigit()]
    next_version = max(versions, default=0) + 1
    return f"V{next_version}"

In [4]:
# ==============================================================================
# 3. CORE PROCESSING FUNCTION (with Two-Stage Analysis)
# ==============================================================================

# Function processes a video with a two-stage approach:
    # 1. Relevance Check: Determines if a face is showing a clear emotion.
    # 2. Emotion Classification: If relevant, classifies the emotion.
def analyze_video_with_relevance_gate(
    video_path,
    save_dir,
    model,
    embedding_model,
    processor,
    device,
    centroids,
    relevance_threshold,
    process_every_n_frames=1
):
    
    if not os.path.exists(video_path):
        print(f"‚ùå Error: Video file not found at {video_path}")
        return []

    video_capture = cv2.VideoCapture(video_path)
    total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = video_capture.get(cv2.CAP_PROP_FPS)
    print(f"‚úÖ Opened video: {os.path.basename(video_path)} ({total_frames} frames at {fps:.2f} fps)")

    all_results_log = []
    pbar = tqdm(total=total_frames, desc="Processing video frames")

    for frame_count in range(total_frames):
        ret, frame = video_capture.read()
        if not ret:
            break
            
        if frame_count % process_every_n_frames == 0:
            rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            face_locations = face_recognition.face_locations(rgb_frame)
            
            if face_locations:
                for i, (top, right, bottom, left) in enumerate(face_locations):
                    face_image_pil = Image.fromarray(rgb_frame[top:bottom, left:right])
                    inputs = processor(images=face_image_pil, return_tensors="pt").to(device)

                    # --- Stage 1: Relevance Detector ---
                    with torch.no_grad():
                        # Get the feature embedding for the current face
                        embedding = embedding_model(**inputs).logits.squeeze()
                    
                    # Calculate the distance to all known emotion centroids
                    distances = {
                        label_id: torch.nn.functional.cosine_similarity(embedding, centroid, dim=0).item()
                        for label_id, centroid in centroids.items()
                    }
                    
                    # Find the highest similarity score (closest distance)
                    max_similarity = max(distances.values())
                    
                    # --- The Gatekeeper ---
                    # Only proceed if the face is similar enough to a known emotion
                    if max_similarity >= relevance_threshold:
                        # --- Stage 2: Emotion Classifier ---
                        with torch.no_grad():
                            logits = model(**inputs).logits
                        
                        probabilities = F.softmax(logits, dim=1).squeeze()
                        top_confidence, top_pred_idx = torch.max(probabilities, dim=0)
                        top_pred_label = model.config.id2label[top_pred_idx.item()]
                        entropy = -torch.sum(probabilities * torch.log(probabilities + 1e-9)).item()

                        log_entry = {
                            "timestamp_seconds": frame_count / fps,
                            "frame_number": frame_count,
                            "face_index": i,
                            "is_relevant": True,
                            "max_similarity": max_similarity,
                            "predicted_label": top_pred_label,
                            "confidence": top_confidence.item(),
                            "entropy": entropy
                        }
                        all_results_log.append(log_entry)
        pbar.update(1)
        
    pbar.close()
    video_capture.release()
    
    print(f"‚úÖ Video processing complete. Logged {len(all_results_log)} relevant emotional events.")
    return all_results_log

In [9]:
# ==============================================================================
# 4. MAIN EXECUTION BLOCK
# ==============================================================================

# --- Setup Dynamic Save Directory ---
VERSION = get_next_version(ANALYSIS_OUTPUT_ROOT)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
VERSION_TAG = f"{VERSION}_{timestamp}"
SAVE_DIR = os.path.join(ANALYSIS_OUTPUT_ROOT, VERSION_TAG)
os.makedirs(SAVE_DIR, exist_ok=True)

print(f"üìÅ Created analysis output directory: {SAVE_DIR}")

# --- Load Models, Processor, and Centroids ---
print(f"\n--- Loading assets ---")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# 1. Load the original model for the final classification
classification_model = AutoModelForImageClassification.from_pretrained(MODEL_PATH).to(device).eval()

# 2. Create a second model instance specifically for generating embeddings
embedding_model = AutoModelForImageClassification.from_pretrained(MODEL_PATH)
embedding_model.classifier = nn.Identity()
embedding_model.to(device).eval()

# 3. Load the processor
processor = AutoImageProcessor.from_pretrained(MODEL_PATH)

# 4. Load the pre-calculated emotion centroids from their path
emotion_centroids = torch.load(CENTROIDS_PATH, map_location=device)
print(f"‚úÖ Models, processor, and {len(emotion_centroids)} centroids loaded onto {device}.")


# --- Run the Analysis ---
# Define the path to your video file directly here.
video_to_process = "/Users/natalyagrokh/AI/ml_expressions/img_expressions/flywheel/sample_vids/StreetQs.mp4" 

# Call the function with all required arguments
analysis_log = analyze_video_with_relevance_gate(
    video_path=video_to_process, 
    save_dir=SAVE_DIR,
    model=classification_model,
    embedding_model=embedding_model,
    processor=processor,
    device=device,
    centroids=emotion_centroids,
    relevance_threshold=RELEVANCE_THRESHOLD,
    process_every_n_frames=1
)

# --- Save Results to CSV ---
if analysis_log:
    log_df = pd.DataFrame(analysis_log)
    csv_path = os.path.join(SAVE_DIR, "filtered_emotion_log.csv")
    log_df.to_csv(csv_path, index=False)
    print(f"\n‚úÖ Successfully saved filtered analysis to: {csv_path}")
else:
    print("\n‚ö†Ô∏è No relevant emotional events were detected, so no log file was created.")

print(f"\n--- Summary ---")
print(f"Total relevant faces analyzed: {len(analysis_log)}")

üìÅ Created analysis output directory: /Users/natalyagrokh/AI/ml_expressions/img_expressions/flywheel/V3_20250714_093845

--- Loading assets ---
‚úÖ Models, processor, and 10 centroids loaded onto mps.
‚úÖ Opened video: StreetQs.mp4 (5657 frames at 30.00 fps)


Processing video frames: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5657/5657 [24:45<00:00,  3.81it/s]


‚úÖ Video processing complete. Logged 5152 relevant emotional events.

‚úÖ Successfully saved filtered analysis to: /Users/natalyagrokh/AI/ml_expressions/img_expressions/flywheel/V3_20250714_093845/filtered_emotion_log.csv

--- Summary ---
Total relevant faces analyzed: 5152
