# YouTube Frame Extractor - Advanced Analysis

This notebook demonstrates advanced analysis techniques using the YouTube Frame Extractor package. Building on the basics covered in the quickstart guide, we'll explore more sophisticated methods for extracting insights from video frames.

## Advanced Topics Covered

1. **Advanced VLM Analysis**: Fine-grained content matching and scoring
2. **Object Detection**: Identifying and tracking objects across frames
3. **Temporal Analysis**: Analyzing changes over time
4. **Advanced Batch Processing**: Working with multiple videos efficiently
5. **Data Visualization**: Creating insightful visualizations of your analysis
6. **Custom Model Integration**: Using your own models for frame analysis

## 1. Setup and Configuration

First, let's set up our environment and import the required modules:

In [None]:
# Add the parent directory to the path for importing the package
import sys
import os
from pathlib import Path

# Move up two directories from the current notebook location
project_root = Path().absolute().parent.parent
sys.path.insert(0, str(project_root))

# Verify we can import the package
try:
    from src.youtube_frame_extractor.extractors.browser import BrowserExtractor
    from src.youtube_frame_extractor.extractors.download import DownloadExtractor
    from src.youtube_frame_extractor.analysis.vlm import VLMAnalyzer
    print("✅ Successfully imported YouTube Frame Extractor package")
except ImportError as e:
    print(f"❌ Error importing package: {str(e)}")
    print("Please make sure you're running this notebook from the examples/notebooks directory")
    raise

In [None]:
# Import additional libraries for advanced analysis
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image, ImageDraw, ImageFont
import cv2
import time
import json
import logging
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.notebook import tqdm
from IPython.display import display, HTML, clear_output
import torch
import torchvision

# Set up plotting
plt.style.use('ggplot')
sns.set(style="whitegrid")

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('advanced_analysis')

# Suppress unnecessary warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

# Create output directory for extracted frames
output_dir = Path("./advanced_output")
output_dir.mkdir(exist_ok=True)

print(f"Output will be saved to: {output_dir.absolute()}")

## 2. Utility Functions

Let's define some advanced utility functions for display and analysis:

In [None]:
def display_frames(frames, max_frames=6, figsize=(15, 10), title="Extracted Frames", annotations=None):
    """Display a grid of extracted frames with optional annotations.
    
    Args:
        frames: List of frame dictionaries
        max_frames: Maximum number of frames to display
        figsize: Figure size as (width, height)
        title: Title for the overall figure
        annotations: Optional dict mapping frame indices to annotation text
    """
    num_frames = min(max_frames, len(frames))
    if num_frames == 0:
        print("No frames to display")
        return
    
    # Calculate grid dimensions
    cols = min(3, num_frames)
    rows = (num_frames + cols - 1) // cols
    
    plt.figure(figsize=figsize)
    plt.suptitle(title, fontsize=16)
    
    for i in range(num_frames):
        plt.subplot(rows, cols, i + 1)
        
        # Get the frame image
        if 'frame' in frames[i] and frames[i]['frame'] is not None:
            img = frames[i]['frame']
        elif 'path' in frames[i] and os.path.exists(frames[i]['path']):
            img = Image.open(frames[i]['path'])
        else:
            plt.text(0.5, 0.5, "Image not available", ha='center', va='center')
            plt.axis('off')
            continue
        
        if isinstance(img, Image.Image):
            img = np.array(img)
        
        plt.imshow(img)
        
        subtitle = f"Frame {i+1}"
        if 'time' in frames[i]:
            subtitle += f" | Time: {frames[i]['time']:.2f}s"
        if 'similarity' in frames[i]:
            subtitle += f" | Score: {frames[i]['similarity']:.2f}"
        
        if annotations and i in annotations:
            subtitle += f"\n{annotations[i]}"
            
        plt.title(subtitle)
        plt.axis('off')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.show()

def display_video_info(video_id):
    """Display YouTube video embed and basic info."""
    embed_html = f"""
    <div style=\"width:560px;\">\n",
        <h3>YouTube Video: {video_id}</h3>\n",
        <iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/{video_id}\" \n",
                frameborder=\"0\" allow=\"accelerometer; autoplay; clipboard-write; encrypted-media; \n",
                gyroscope; picture-in-picture\" allowfullscreen>\n",
        </iframe>\n",
    </div>\n",
    """
    from IPython.display import display, HTML
    display(HTML(embed_html))

def plot_similarity_timeline(frames, title="Similarity Scores Over Time", figsize=(12, 6)):
    """Plot similarity scores over time from frame data.
    
    Args:
        frames: List of frame dictionaries with 'time' and 'similarity' keys
        title: Plot title
        figsize: Figure size as (width, height)
    """
    times = [frame.get('time', i) for i, frame in enumerate(frames)]
    scores = [frame.get('similarity', 0) for frame in frames]
    
    if not times or not scores:
        print("No data available for timeline plot")
        return
    
    plt.figure(figsize=figsize)
    plt.plot(times, scores, '-o', linewidth=2, markersize=8)
    plt.grid(True, alpha=0.3)
    plt.title(title, fontsize=14)
    plt.xlabel("Time (seconds)", fontsize=12)
    plt.ylabel("Similarity Score", fontsize=12)
    
    thresholds = {frame.get('threshold', None) for frame in frames if 'threshold' in frame}
    if len(thresholds) == 1 and None not in thresholds:
        threshold = thresholds.pop()
        plt.axhline(y=threshold, color='r', linestyle='--', alpha=0.7, label=f"Threshold ({threshold})")
        plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    plt.figure(figsize=(figsize[0], figsize[1]//2))
    heatmap_data = df = None
    try:
        heatmap_data = pd.DataFrame({
            query: [frame.get('similarity', 0) for frame in frames] for query in ["similarity"]
        })
        sns.heatmap(heatmap_data, cmap="YlGnBu", cbar_kws={'label': 'Similarity Score'})
        plt.xlabel("Frame Index")
        plt.ylabel("Query")
        plt.title("Similarity Scores Across All Frames and Queries")
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"Error plotting heatmap: {str(e)}")

def apply_text_overlay(image, text, position=(10, 10), font_size=20, color=(255, 255, 255), bg_color=(0, 0, 0, 128)):
    """Apply text overlay to an image.
    
    Args:
        image: PIL Image or numpy array
        text: Text to overlay
        position: (x, y) position for text
        font_size: Text font size
        color: Text color as RGB tuple
        bg_color: Background color as RGBA tuple
    Returns:
        PIL Image with text overlay
    """
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image.astype('uint8'))
    annotated_image = image.copy()
    try:
        from PIL import ImageFont
        font = ImageFont.truetype("arial.ttf", font_size)
    except IOError:
        font = ImageFont.load_default()
    draw = ImageDraw.Draw(annotated_image, 'RGBA')
    text_width, text_height = draw.textsize(text, font=font)
    draw.rectangle([
        position[0], position[1], position[0] + text_width + 10, position[1] + text_height + 10
    ], fill=bg_color)
    draw.text((position[0] + 5, position[1] + 5), text, font=font, fill=color)
    return annotated_image

## 3. Advanced VLM Analysis

Now let's explore more sophisticated VLM-based frame analysis, including:
- Multi-query analysis
- Concept comparison
- Fine-grained visual attribute detection

In [None]:
# Define a function for multi-query VLM analysis
def analyze_frames_with_multiple_queries(frames, queries, vlm_analyzer=None):
    """Analyze frames with multiple queries and return similarity scores for each.
    
    Args:
        frames: List of frame dictionaries
        queries: List of query strings
        vlm_analyzer: VLMAnalyzer instance (will be created if None)
    Returns:
        DataFrame with similarity scores for each query and frame
    """
    if not frames or not queries:
        print("No frames or queries provided")
        return None
    if vlm_analyzer is None:
        try:
            vlm_analyzer = VLMAnalyzer(model_name="openai/clip-vit-base-patch16")
        except Exception as e:
            print(f"Error initializing VLM analyzer: {str(e)}")
            return None
    results = { 'frame_index': [], 'time': [] }
    for query in queries:
        results[f"score_{query.replace(' ', '_')}"] = []
    for i, frame in enumerate(frames):
        results['frame_index'].append(i)
        results['time'].append(frame.get('time', i))
        if 'frame' in frame and frame['frame'] is not None:
            image = frame['frame']
        elif 'path' in frame and os.path.exists(frame['path']):
            image = Image.open(frame['path'])
        else:
            for query in queries:
                results[f"score_{query.replace(' ', '_')}"].append(0.0)
            continue
        for query in queries:
            try:
                similarity = vlm_analyzer.calculate_similarity(image, query)
                results[f"score_{query.replace(' ', '_')}"].append(float(similarity))
            except Exception as e:
                print(f"Error calculating similarity for query '{query}': {str(e)}")
                results[f"score_{query.replace(' ', '_')}"].append(0.0)
    return pd.DataFrame(results)

def visualize_multi_query_results(df, queries, figsize=(12, 8)):
    """Visualize multi-query analysis results.
    Args:
        df: DataFrame with similarity scores
        queries: List of original query strings
        figsize: Figure size as (width, height)
    """
    if df is None or df.empty:
        print("No data to visualize")
        return
    plt.figure(figsize=figsize)
    for query in queries:
        col_name = f"score_{query.replace(' ', '_')}"
        if col_name in df.columns:
            plt.plot(df['time'], df[col_name], '-o', label=query, linewidth=2, markersize=6)
    plt.grid(True, alpha=0.3)
    plt.title("Multi-Query Analysis Results", fontsize=14)
    plt.xlabel("Time (seconds)", fontsize=12)
    plt.ylabel("Similarity Score", fontsize=12)
    plt.legend()
    plt.tight_layout()
    plt.show()
    plt.figure(figsize=(figsize[0], figsize[1]//2))
    heatmap_data = df[[f"score_{query.replace(' ', '_')}" for query in queries]].copy()
    heatmap_data.columns = queries
    sns.heatmap(heatmap_data.T, cmap="YlGnBu", cbar_kws={'label': 'Similarity Score'})
    plt.xlabel("Frame Index")
    plt.ylabel("Query")
    plt.title("Similarity Scores Across All Frames and Queries")
    plt.tight_layout()
    plt.show()
    print("Summary Statistics:")
    for query in queries:
        col_name = f"score_{query.replace(' ', '_')}"
        if col_name in df.columns:
            max_score = df[col_name].max()
            max_frame = df.loc[df[col_name].idxmax(), 'frame_index']
            max_time = df.loc[df[col_name].idxmax(), 'time']
            print(f"- '{query}': Max score {max_score:.3f} at frame {int(max_frame)} (time: {max_time:.2f}s)")

In [None]:
# Set up our video for analysis
# Using Nature documentary as an example (generic example)
video_id = "nLrrOcXX2kw"  # Example nature documentary

# Display the video for reference
display_video_info(video_id)

In [None]:
# Extract frames for analysis
try:
    download_extractor = DownloadExtractor(output_dir=str(output_dir / "multi_query"))
    frames = download_extractor.extract_frames(
        video_id=video_id,
        frame_rate=0.1,  # One frame every 10 seconds
        max_frames=20    # Up to 20 frames
    )
    print(f"Successfully extracted {len(frames)} frames for multi-query analysis")
    display_frames(frames[:6], title="Sample Frames for Analysis")
except Exception as e:
    print(f"Error extracting frames: {str(e)}")
    frames = []

In [None]:
# Define multiple queries for analysis
nature_queries = [
    "forest landscape", 
    "wild animals", 
    "underwater scene",
    "mountain vista",
    "birds flying"
]

# Initialize VLM analyzer and perform multi-query analysis
try:
    vlm_analyzer = VLMAnalyzer(model_name="openai/clip-vit-base-patch16")
    print("✅ VLM analyzer initialized successfully")
    print("\nAnalyzing frames with multiple queries...")
    multi_query_results = analyze_frames_with_multiple_queries(
        frames=frames,
        queries=nature_queries,
        vlm_analyzer=vlm_analyzer
    )
    print("\nMulti-query analysis complete!")
    visualize_multi_query_results(multi_query_results, nature_queries)
except Exception as e:
    print(f"❌ Error in multi-query analysis: {str(e)}")

### 3.1 Creating Frame Montages Based on Content

Let's demonstrate how to create content-specific montages based on VLM analysis:

In [None]:
def create_content_montage(frames, query, vlm_analyzer, threshold=0.3, max_frames=6, cols=3, frame_size=(320, 180)):
    """Create a montage of frames that match a specific content query.
    Args:
        frames: List of frame dictionaries
        query: Content query to match
        vlm_analyzer: VLM analyzer instance
        threshold: Minimum similarity score to include frame
        max_frames: Maximum number of frames to include
        cols: Number of columns in the montage
        frame_size: Size to resize each frame to (width, height)
    Returns:
        PIL Image containing the montage
    """
    if not frames:
        print("No frames provided for montage creation")
        return None
    scored_frames = []
    for frame in frames:
        if 'frame' in frame and frame['frame'] is not None:
            image = frame['frame']
        elif 'path' in frame and os.path.exists(frame['path']):
            image = Image.open(frame['path'])
        else:
            continue
        try:
            similarity = vlm_analyzer.calculate_similarity(image, query)
            scored_frames.append({
                'frame': image,
                'similarity': float(similarity),
                'time': frame.get('time', 0)
            })
        except Exception as e:
            print(f"Error calculating similarity: {str(e)}")
    scored_frames.sort(key=lambda x: x['similarity'], reverse=True)
    matching_frames = [f for f in scored_frames if f['similarity'] >= threshold]
    if not matching_frames:
        print(f"No frames matched the query '{query}' with threshold {threshold}")
        return None
    matching_frames = matching_frames[:max_frames]
    num_frames = len(matching_frames)
    rows = (num_frames + cols - 1) // cols
    montage_width = cols * frame_size[0]
    montage_height = rows * frame_size[1]
    montage = Image.new('RGB', (montage_width, montage_height))
    for i, frame_data in enumerate(matching_frames):
        img = frame_data['frame'].copy()
        img = img.resize(frame_size, Image.LANCZOS)
        text = f"Score: {frame_data['similarity']:.2f} | Time: {frame_data['time']:.1f}s"
        img = apply_text_overlay(img, text)
        row = i // cols
        col = i % cols
        x = col * frame_size[0]
        y = row * frame_size[1]
        montage.paste(img, (x, y))
    return montage

# Create and display montages for each query
if 'vlm_analyzer' in locals() and frames:
    for query in nature_queries:
        try:
            print(f"Creating montage for query: '{query}'")
            montage = create_content_montage(
                frames=frames,
                query=query,
                vlm_analyzer=vlm_analyzer,
                threshold=0.25,
                max_frames=6
            )
            if montage is not None:
                plt.figure(figsize=(15, 8))
                plt.imshow(montage)
                plt.title(f"Content Montage: '{query}'")
                plt.axis('off')
                plt.show()
                montage_path = output_dir / f"montage_{query.replace(' ', '_')}.jpg"
                montage.save(montage_path)
                print(f"Montage saved to {montage_path}\n")
        except Exception as e:
            print(f"Error creating montage for '{query}': {str(e)}\n")

## 4. Object Detection and Tracking

Let's implement advanced object detection and tracking across frames:

In [None]:
def detect_objects(image, confidence_threshold=0.5):
    """Detect objects in an image using a pre-trained model.
    Args:
        image: PIL Image or numpy array
        confidence_threshold: Minimum confidence score for detections
    Returns:
        List of detected objects with class, confidence, and bounding box
    """
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image.astype('uint8'))
    try:
        model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
        model.eval()
    except Exception as e:
        print(f"Error loading object detection model: {str(e)}")
        return []
    image_tensor = torchvision.transforms.functional.to_tensor(image)
    with torch.no_grad():
        predictions = model([image_tensor])
    detections = []
    boxes = predictions[0]['boxes'].cpu().numpy()
    scores = predictions[0]['scores'].cpu().numpy()
    labels = predictions[0]['labels'].cpu().numpy()
    COCO_CLASSES = [
        '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
        'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
        'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
        'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
        'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
        'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
        'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
        'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
        'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
        'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
    ]
    image_tensor = torchvision.transforms.functional.to_tensor(image)
    with torch.no_grad():
        predictions = model([image_tensor])
    boxes = predictions[0]['boxes'].cpu().numpy()
    scores = predictions[0]['scores'].cpu().numpy()
    labels = predictions[0]['labels'].cpu().numpy()
    for box, score, label in zip(boxes, scores, labels):
        if score >= confidence_threshold:
            class_name = COCO_CLASSES[label]
            x1, y1, x2, y2 = box.astype(int)
            detections.append({
                'class': class_name,
                'confidence': float(score),
                'box': box.astype(int).tolist()
            })
    return detections

def visualize_detections(image, detections):
    """Draw bounding boxes and labels for detected objects.
    Args:
        image: PIL Image or numpy array
        detections: List of detection dictionaries
    Returns:
        PIL Image with detection visualizations
    """
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image.astype('uint8'))
    vis_image = image.copy()
    draw = ImageDraw.Draw(vis_image)
    colors = [
        (255, 0, 0),    
        (0, 255, 0),    
        (0, 0, 255),    
        (255, 255, 0),  
        (255, 0, 255),  
        (0, 255, 255)   
    ]
    for i, det in enumerate(detections):
        box = det['box']
        class_name = det['class']
        confidence = det['confidence']
        color_idx = hash(class_name) % len(colors)
        color = colors[color_idx]
        draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline=color, width=3)
        label_text = f"{class_name} {confidence:.2f}"
        text_width, text_height = draw.textsize(label_text)
        draw.rectangle([
            (box[0], box[1] - text_height - 4), (box[0] + text_width + 4, box[1])
        ], fill=color)
        draw.text((box[0] + 2, box[1] - text_height - 2), label_text, fill=(255, 255, 255))
    return vis_image

def detect_and_track_objects_across_frames(frames, confidence_threshold=0.5):
    """Detect objects across multiple frames and track their occurrences.
    Args:
        frames: List of frame dictionaries
        confidence_threshold: Minimum confidence score for detections
    Returns:
        Tuple of (processed frames, object occurrence data)
    """
    if not frames:
        return [], {}
    processed_frames = []
    object_occurrences = {}
    for i, frame in enumerate(frames):
        if 'frame' in frame and frame['frame'] is not None:
            image = frame['frame']
        elif 'path' in frame and os.path.exists(frame['path']):
            image = Image.open(frame['path'])
        else:
            continue
        try:
            detections = detect_objects(image, confidence_threshold)
            for det in detections:
                class_name = det['class']
                if class_name not in object_occurrences:
                    object_occurrences[class_name] = {'count': 0, 'frames': [], 'confidences': []}
                object_occurrences[class_name]['count'] += 1
                object_occurrences[class_name]['frames'].append(i)
                object_occurrences[class_name]['confidences'].append(det['confidence'])
            vis_image = visualize_detections(image, detections)
            processed_frame = frame.copy()
            processed_frame['frame'] = vis_image
            processed_frame['detections'] = detections
            processed_frame['object_count'] = len(detections)
            processed_frames.append(processed_frame)
        except Exception as e:
            print(f"Error processing frame {i}: {str(e)}")
            processed_frames.append(frame)
    return processed_frames, object_occurrences

def visualize_object_occurrences(object_occurrences, min_count=2, figsize=(12, 8)):
    """Visualize object occurrences across frames.
    Args:
        object_occurrences: Dictionary of object occurrence data
        min_count: Minimum count to include in visualization
        figsize: Figure size as (width, height)
    """
    if not object_occurrences:
        print("No object occurrence data to visualize")
        return
    filtered_objects = {k: v for k, v in object_occurrences.items() if v['count'] >= min_count}
    if not filtered_objects:
        print(f"No objects detected in at least {min_count} frames")
        return
    sorted_objects = sorted(filtered_objects.items(), key=lambda x: x[1]['count'], reverse=True)
    objects = [x[0] for x in sorted_objects]
    counts = [x[1]['count'] for x in sorted_objects]
    plt.figure(figsize=figsize)
    bars = plt.bar(objects, counts, color='steelblue')
    plt.xticks(rotation=45, ha='right')
    plt.title("Object Detection Frequency", fontsize=14)
    plt.xlabel("Object Class", fontsize=12)
    plt.ylabel("Number of Occurrences", fontsize=12)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.1, f'{int(height)}', ha='center', va='bottom')
    plt.show()

In [None]:
# Run object detection on our frames
if 'frames' in locals() and frames:
    try:
        print("Detecting objects across frames...")
        processed_frames, object_occurrences = detect_and_track_objects_across_frames(
            frames=frames[:8],  # Limit to 8 frames for demo
            confidence_threshold=0.4
        )
        print(f"Detected objects in {len(processed_frames)} frames")
        display_frames(
            processed_frames, 
            title="Frames with Object Detection",
            annotations={i: f"Objects: {frame.get('object_count', 0)}" for i, frame in enumerate(processed_frames)}
        )
        visualize_object_occurrences(object_occurrences)
        print("\nDetailed Object Occurrence Data:")
        for obj, data in sorted(object_occurrences.items(), key=lambda x: x[1]['count'], reverse=True):
            if data['count'] >= 2:
                avg_conf = sum(data['confidences']) / len(data['confidences'])
                print(f"- {obj}: {data['count']} occurrences, avg confidence: {avg_conf:.3f}")
                print(f"  Appears in frames: {', '.join(str(f) for f in data['frames'])}")
    except Exception as e:
        print(f"Error in object detection: {str(e)}")

## 5. Temporal Analysis

Let's analyze how content changes over time in the video:

In [None]:
def calculate_frame_differences(frames, method='mse'):
    """Calculate differences between consecutive frames.
    Args:
        frames: List of frame dictionaries
        method: Difference calculation method ('mse', 'ssim', or 'histogram')
    Returns:
        List of dictionaries with frame differences
    """
    if not frames or len(frames) < 2:
        print("Not enough frames for difference calculation")
        return []
    differences = []
    for i in range(1, len(frames)):
        prev_frame = None
        if 'frame' in frames[i-1] and frames[i-1]['frame'] is not None:
            prev_frame = frames[i-1]['frame']
        elif 'path' in frames[i-1] and os.path.exists(frames[i-1]['path']):
            prev_frame = Image.open(frames[i-1]['path'])
        curr_frame = None
        if 'frame' in frames[i] and frames[i]['frame'] is not None:
            curr_frame = frames[i]['frame']
        elif 'path' in frames[i] and os.path.exists(frames[i]['path']):
            curr_frame = Image.open(frames[i]['path'])
        if prev_frame is None or curr_frame is None:
            continue
        if isinstance(prev_frame, Image.Image):
            prev_frame = np.array(prev_frame)
        if isinstance(curr_frame, Image.Image):
            curr_frame = np.array(curr_frame)
        prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_RGB2GRAY)
        curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_RGB2GRAY)
        diff_value = 0
        diff_image = None
        if method == 'mse':
            diff = np.square(np.subtract(prev_gray, curr_gray)).mean()
            diff_value = float(diff)
            diff_image = cv2.absdiff(prev_gray, curr_gray)
            diff_image = cv2.cvtColor(diff_image, cv2.COLOR_GRAY2RGB)
        elif method == 'ssim':
            from skimage.metrics import structural_similarity as ssim
            score, diff_image = ssim(prev_gray, curr_gray, full=True)
            diff_value = 1.0 - score
            diff_image = (diff_image * 255).astype("uint8")
            diff_image = cv2.cvtColor(diff_image, cv2.COLOR_GRAY2RGB)
        elif method == 'histogram':
            prev_hist = cv2.calcHist([prev_gray], [0], None, [256], [0, 256])
            curr_hist = cv2.calcHist([curr_gray], [0], None, [256], [0, 256])
            cv2.normalize(prev_hist, prev_hist, 0, 1, cv2.NORM_MINMAX)
            cv2.normalize(curr_hist, curr_hist, 0, 1, cv2.NORM_MINMAX)
            diff_value = cv2.compareHist(prev_hist, curr_hist, cv2.HISTCMP_BHATTACHARYYA)
            diff_image = cv2.absdiff(prev_frame, curr_frame)
        difference = {
            'prev_index': i-1,
            'curr_index': i,
            'prev_time': frames[i-1].get('time', i-1),
            'curr_time': frames[i].get('time', i),
            'difference': diff_value,
            'diff_image': Image.fromarray(diff_image) if diff_image is not None else None
        }
        differences.append(difference)
    return differences

def detect_scene_changes(differences, threshold=0.5):
    """Detect potential scene changes based on frame differences.
    Args:
        differences: List of frame difference dictionaries
        threshold: Threshold for considering a difference a scene change
    Returns:
        List of indices where scene changes occur
    """
    if not differences:
        return []
    diff_values = [d['difference'] for d in differences]
    if threshold is None:
        mean_diff = np.mean(diff_values)
        std_diff = np.std(diff_values)
        threshold = mean_diff + 2 * std_diff
    scene_changes = []
    for i, diff in enumerate(differences):
        if diff['difference'] > threshold:
            scene_changes.append(i)
    return scene_changes

def visualize_frame_differences(differences, scene_changes=None, figsize=(12, 6)):
    """Visualize frame differences and scene changes.
    Args:
        differences: List of frame difference dictionaries
        scene_changes: List of indices where scene changes occur
        figsize: Figure size as (width, height)
    """
    if not differences:
        print("No differences to visualize")
        return
    times = [(d['prev_time'] + d['curr_time']) / 2 for d in differences]
    diff_values = [d['difference'] for d in differences]
    plt.figure(figsize=figsize)
    plt.plot(times, diff_values, '-o', linewidth=2, markersize=6)
    if scene_changes:
        for sc in scene_changes:
            if 0 <= sc < len(differences):
                plt.axvline(x=times[sc], color='r', linestyle='--', alpha=0.7)
    plt.grid(True, alpha=0.3)
    plt.title("Frame Differences Over Time", fontsize=14)
    plt.xlabel("Time (seconds)", fontsize=12)
    plt.ylabel("Difference Value", fontsize=12)
    if scene_changes:
        plt.legend(["Frame Difference", "Scene Change"])
    plt.tight_layout()
    plt.show()
    if scene_changes and any(0 <= sc < len(differences) for sc in scene_changes):
        valid_changes = [sc for sc in scene_changes if 0 <= sc < len(differences)][:3]
        fig, axes = plt.subplots(len(valid_changes), 3, figsize=(15, 5*len(valid_changes)))
        if len(valid_changes) == 1:
            axes = [axes]
        for i, sc in enumerate(valid_changes):
            diff = differences[sc]
            prev_frame = None
            if 'frame' in frames[diff['prev_index']] and frames[diff['prev_index']]['frame'] is not None:
                prev_frame = frames[diff['prev_index']]['frame']
            elif 'path' in frames[diff['prev_index']] and os.path.exists(frames[diff['prev_index']]['path']):
                prev_frame = Image.open(frames[diff['prev_index']]['path'])
            curr_frame = None
            if 'frame' in frames[diff['curr_index']] and frames[diff['curr_index']]['frame'] is not None:
                curr_frame = frames[diff['curr_index']]['frame']
            elif 'path' in frames[diff['curr_index']] and os.path.exists(frames[diff['curr_index']]['path']):
                curr_frame = Image.open(frames[diff['curr_index']]['path'])
            if prev_frame is not None:
                axes[i][0].imshow(np.array(prev_frame))
                axes[i][0].set_title(f"Before (time: {diff['prev_time']:.2f}s)")
                axes[i][0].axis('off')
            if curr_frame is not None:
                axes[i][1].imshow(np.array(curr_frame))
                axes[i][1].set_title(f"After (time: {diff['curr_time']:.2f}s)")
                axes[i][1].axis('off')
            if diff['diff_image'] is not None:
                axes[i][2].imshow(np.array(diff['diff_image']))
                axes[i][2].set_title(f"Difference (value: {diff['difference']:.4f})")
                axes[i][2].axis('off')
        plt.tight_layout()
        plt.suptitle("Scene Change Visualization", fontsize=16, y=1.02)
        plt.show()

In [None]:
# Analyze frame differences and detect scene changes
if 'frames' in locals() and len(frames) >= 2:
    try:
        print("Calculating frame differences...")
        differences = calculate_frame_differences(frames, method='mse')
        print(f"Calculated {len(differences)} frame differences")
        diff_values = [d['difference'] for d in differences]
        mean_diff = np.mean(diff_values)
        std_diff = np.std(diff_values)
        threshold = mean_diff + 1.5 * std_diff
        print(f"Using threshold {threshold:.4f} for scene change detection")
        scene_changes = detect_scene_changes(differences, threshold=threshold)
        print(f"Detected {len(scene_changes)} potential scene changes")
        visualize_frame_differences(differences, scene_changes)
    except Exception as e:
        print(f"Error analyzing frame differences: {str(e)}")

## 6. Advanced Batch Processing

Now let's demonstrate advanced batch processing with progress tracking and parallel execution:

In [None]:
def process_video(video_id, query, output_dir, frame_rate=0.1, max_frames=10):
    """Process a single video with VLM-based frame extraction and analysis.
    Args:
        video_id: YouTube video ID
        query: Natural language query for content matching
        output_dir: Directory to save output
        frame_rate: Frames per second to extract
        max_frames: Maximum number of frames to extract
    Returns:
        Dictionary with results
    """
    try:
        video_dir = os.path.join(output_dir, video_id)
        os.makedirs(video_dir, exist_ok=True)
        extractor = DownloadExtractor(output_dir=video_dir)
        frames = extractor.extract_frames(
            video_id=video_id,
            frame_rate=frame_rate,
            max_frames=max_frames
        )
        vlm_analyzer = VLMAnalyzer(model_name="openai/clip-vit-base-patch16")
        scored_frames = []
        for frame in frames:
            if 'frame' in frame and frame['frame'] is not None:
                image = frame['frame']
            elif 'path' in frame and os.path.exists(frame['path']):
                image = Image.open(frame['path'])
            else:
                continue
            similarity = vlm_analyzer.calculate_similarity(image, query)
            scored_frame = frame.copy()
            scored_frame['similarity'] = float(similarity)
            scored_frame['query'] = query
            scored_frames.append(scored_frame)
        scored_frames.sort(key=lambda x: x.get('similarity', 0), reverse=True)
        top_frames = scored_frames[:min(6, len(scored_frames))]
        montage = create_content_montage(
            frames=frames,
            query=query,
            vlm_analyzer=vlm_analyzer,
            threshold=0.2,
            max_frames=6
        )
        montage_path = os.path.join(video_dir, f"{video_id}_{query.replace(' ', '_')}_montage.jpg")
        if montage is not None:
            montage.save(montage_path)
        metadata = {
            'video_id': video_id,
            'query': query,
            'frame_count': len(frames),
            'top_score': scored_frames[0]['similarity'] if scored_frames else 0,
            'average_score': sum(f['similarity'] for f in scored_frames) / len(scored_frames) if scored_frames else 0,
            'montage_path': montage_path if montage is not None else None
        }
        metadata_path = os.path.join(video_dir, f"{video_id}_{query.replace(' ', '_')}_metadata.json")
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        return {
            'video_id': video_id,
            'success': True,
            'frames': scored_frames,
            'metadata': metadata,
            'montage_path': montage_path if montage is not None else None
        }
    except Exception as e:
        error_message = str(e)
        print(f"Error processing video {video_id}: {error_message}")
        return {
            'video_id': video_id,
            'success': False,
            'error': error_message,
            'frames': [],
            'metadata': None,
            'montage_path': None
        }

def batch_process_videos(video_ids, query, output_dir, frame_rate=0.1, max_frames=10, max_workers=3):
    """Process multiple videos in parallel.
    Args:
        video_ids: List of YouTube video IDs
        query: Natural language query for content matching
        output_dir: Directory to save output
        frame_rate: Frames per second to extract
        max_frames: Maximum number of frames to extract
        max_workers: Maximum number of concurrent workers
    Returns:
        Dictionary with results for each video
    """
    if not video_ids:
        print("No video IDs provided")
        return {}
    os.makedirs(output_dir, exist_ok=True)
    results = {}
    from concurrent.futures import ThreadPoolExecutor, as_completed
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_video = {executor.submit(process_video, video_id, query, output_dir, frame_rate, max_frames): video_id for video_id in video_ids}
        for future in tqdm(as_completed(future_to_video), total=len(video_ids), desc="Processing videos"):
            vid = future_to_video[future]
            try:
                result = future.result()
                results[vid] = result
            except Exception as e:
                print(f"Error processing video {vid}: {str(e)}")
                results[vid] = {
                    'video_id': vid,
                    'success': False,
                    'error': str(e),
                    'frames': [],
                    'metadata': None,
                    'montage_path': None
                }
    return results

def generate_batch_report(results, output_dir):
    """Generate a comprehensive report from batch processing results.
    Args:
        results: Dictionary with results for each video
        output_dir: Directory to save the report
    Returns:
        Path to the generated report
    """
    if not results:
        print("No results to include in report")
        return None
    report_dir = os.path.join(output_dir, "report")
    os.makedirs(report_dir, exist_ok=True)
    query = next(iter(results.values()))['metadata']['query'] if next(iter(results.values()))['metadata'] else "Unknown"
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    report_path = os.path.join(report_dir, f"batch_report_{query.replace(' ', '_')}_{timestamp}.html")
    with open(report_path, 'w') as f:
        f.write(f"""
        <!DOCTYPE html>
        <html>
        <head>
            <title>Batch Processing Report</title>
            <style>
                body {{ font-family: Arial, sans-serif; margin: 20px; }}
                h1, h2, h3 {{ color: #2c3e50; }}
                .video-card {{ border: 1px solid #ddd; margin: 20px 0; padding: 15px; border-radius: 5px; }}
                .video-header {{ display: flex; justify-content: space-between; align-items: center; }}
                .success {{ color: green; }}
                .failure {{ color: red; }}
                table {{ border-collapse: collapse; width: 100%; }}
                th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
                th {{ background-color: #f2f2f2; }}
                .montage {{ max-width: 100%; height: auto; margin: 10px 0; border: 1px solid #ddd; }}
                .summary {{ background-color: #f8f9fa; padding: 15px; border-radius: 5px; margin-bottom: 20px; }}
            </style>
        </head>
        <body>
            <h1>YouTube Frame Extractor Batch Processing Report</h1>
            <div class="summary">
                <h2>Summary</h2>
                <p>Query: <strong>{query}</strong></p>
                <p>Date: {time.strftime("%Y-%m-%d %H:%M:%S")}</p>
                <p>Total videos processed: {len(results)}</p>
                <p>Successful: {sum(1 for r in results.values() if r['success'])}</p>
                <p>Failed: {sum(1 for r in results.values() if not r['success'])}</p>
            </div>
        """)
        f.write("""
            <h2>Results Overview</h2>
            <table>
                <tr>
                    <th>Video ID</th>
                    <th>Status</th>
                    <th>Frames</th>
                    <th>Top Score</th>
                    <th>Avg Score</th>
                </tr>
        """)
        for video_id, result in results.items():
            status = "Success" if result['success'] else f"Failed: {result.get('error', 'Unknown error')}"
            frame_count = len(result.get('frames', []))
            top_score = result.get('metadata', {}).get('top_score', 'N/A')
            avg_score = result.get('metadata', {}).get('average_score', 'N/A')
            f.write(f"""
                <tr>
                    <td>{video_id}</td>
                    <td class="{'success' if result['success'] else 'failure'}">{status}</td>
                    <td>{frame_count}</td>
                    <td>{top_score if top_score != 'N/A' else 'N/A'}</td>
                    <td>{avg_score if avg_score != 'N/A' else 'N/A'}</td>
                </tr>
            """)
        f.write("</table>")
        f.write("<h2>Detailed Results</h2>")
        for video_id, result in results.items():
            f.write(f"""
                <div class="video-card">
                    <div class="video-header">
                        <h3>Video: {video_id}</h3>
                        <span class="{'success' if result['success'] else 'failure'}">
                            {"Success" if result['success'] else "Failed"}
                        </span>
                    </div>
            """)
            if result['success']:
                f.write(f"""
                    <iframe width="560" height="315" src="https://www.youtube.com/embed/{video_id}" 
                            frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; 
                            gyroscope; picture-in-picture" allowfullscreen></iframe>
                """)
                if result.get('montage_path') and os.path.exists(result['montage_path']):
                    montage_filename = os.path.basename(result['montage_path'])
                    report_montage_path = os.path.join(report_dir, montage_filename)
                    import shutil
                    shutil.copy2(result['montage_path'], report_montage_path)
                    f.write(f"""
                        <h4>Top Matching Frames</h4>
                        <img src="{montage_filename}" alt="Frame montage" class="montage">
                    """)
                if result.get('metadata'):
                    f.write(f"""
                        <h4>Metadata</h4>
                        <table>
                            <tr><th>Frame Count</th><td>{result['metadata'].get('frame_count', 'N/A')}</td></tr>
                            <tr><th>Top Score</th><td>{result['metadata'].get('top_score', 'N/A')}</td></tr>
                            <tr><th>Average Score</th><td>{result['metadata'].get('average_score', 'N/A')}</td></tr>
                        </table>
                    """)
            else:
                f.write(f"<p>Error: {result.get('error', 'Unknown error')}</p>")
            f.write("</div>")
        f.write("""
        </body>
        </html>
        """)
    return report_path


In [None]:
# Define a list of YouTube video IDs for batch processing
batch_video_ids = [
    "nLrrOcXX2kw",  # Nature documentary
    "eDiSYp_o_8Q",  # Another nature video
    "dQw4w9WgXcQ"   # Rick Astley (as a control)
]

# Define the query
batch_query = "animals in the wild"

# Create output directory for batch processing
batch_output_dir = output_dir / "batch_processing"
batch_output_dir.mkdir(exist_ok=True)

# Batch processing code (currently commented out for safety)
'''
try:
    print(f"Starting batch processing of {len(batch_video_ids)} videos with query: '{batch_query}'")
    batch_results = batch_process_videos(
        video_ids=batch_video_ids,
        query=batch_query,
        output_dir=str(batch_output_dir),
        frame_rate=0.1,
        max_frames=10,
        max_workers=2
    )
    print(f"Batch processing complete for {len(batch_results)} videos")
    report_path = generate_batch_report(batch_results, str(batch_output_dir))
    if report_path:
        print(f"Generated batch report: {report_path}")
    print("\nResults Summary:")
    for video_id, result in batch_results.items():
        status = "Success" if result['success'] else "Failed"
        frame_count = len(result.get('frames', []))
        print(f"- {video_id}: {status}, {frame_count} frames")
        if result['success'] and result.get('metadata'):
            top_score = result['metadata'].get('top_score', 'N/A')
            avg_score = result['metadata'].get('average_score', 'N/A')
            print(f"  Top score: {top_score:.3f}, Avg score: {avg_score:.3f}")
            if result.get('montage_path') and os.path.exists(result['montage_path']):
                montage = Image.open(result['montage_path'])
                plt.figure(figsize=(12, 6))
                plt.imshow(montage)
                plt.title(f"Video {video_id} - Content Montage for '{batch_query}'")
                plt.axis('off')
                plt.show()
except Exception as e:
    print(f"Error in batch processing: {str(e)}")
'''

## 7. Custom Model Integration

Here's how to integrate a custom analysis model with the YouTube Frame Extractor:

In [None]:
class CustomFrameAnalyzer:
    """A custom frame analyzer for demonstration purposes.
    This example shows how to integrate a custom analysis model with the YouTube Frame Extractor framework.
    """
    def __init__(self):
        # Load any required models or resources here
        self.model_loaded = True
        print("Custom analyzer initialized")
    def analyze_frame(self, image):
        """Analyze a single frame.
        Args:
            image: PIL Image or numpy array
        Returns:
            Dictionary with analysis results
        """
        if isinstance(image, Image.Image):
            image = np.array(image)
        channels = cv2.split(image) if len(image.shape) > 2 else [image]
        channel_stats = []
        for i, channel in enumerate(channels):
            mean = np.mean(channel)
            std = np.std(channel)
            min_val = np.min(channel)
            max_val = np.max(channel)
            channel_stats.append({
                'channel': i,
                'mean': float(mean),
                'std': float(std),
                'min': int(min_val),
                'max': int(max_val),
                'contrast': float(max_val - min_val)
            })
        if len(channels) >= 3:
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            brightness = float(np.mean(gray))
            contrast = float(np.std(gray))
        else:
            brightness = float(np.mean(image))
            contrast = float(np.std(image))
        edges = cv2.Canny(gray, 100, 200)
        edge_density = float(np.count_nonzero(edges) / (edges.shape[0] * edges.shape[1]))
        return {
            'channel_stats': channel_stats,
            'brightness': brightness,
            'contrast': contrast,
            'edge_density': edge_density,
            'complexity_score': contrast * edge_density
        }

def analyze_frames_with_custom_model(frames):
    """Apply custom analysis to a list of frames.
    Args:
        frames: List of frame dictionaries
    Returns:
        List of frames with added analysis results
    """
    if not frames:
        return []
    analyzer = CustomFrameAnalyzer()
    analyzed_frames = []
    for frame in frames:
        if 'frame' in frame and frame['frame'] is not None:
            image = frame['frame']
        elif 'path' in frame and os.path.exists(frame['path']):
            image = Image.open(frame['path'])
        else:
            analyzed_frames.append(frame)
            continue
        analysis_results = analyzer.analyze_frame(image)
        analyzed_frame = frame.copy()
        analyzed_frame['custom_analysis'] = analysis_results
        analyzed_frames.append(analyzed_frame)
    return analyzed_frames

def visualize_custom_analysis(frames):
    """Visualize the results of custom frame analysis.
    Args:
        frames: List of frames with custom analysis results
    """
    if not frames or 'custom_analysis' not in frames[0]:
        print("No custom analysis results to visualize")
        return
    times = [frame.get('time', i) for i, frame in enumerate(frames)]
    brightness = [frame['custom_analysis']['brightness'] for frame in frames]
    contrast = [frame['custom_analysis']['contrast'] for frame in frames]

    # Plot brightness over time
    plt.figure(figsize=(12,6))
    plt.plot(times, brightness, '-o', label='Brightness')
    plt.title("Brightness Over Time")
    plt.xlabel("Time (s)")
    plt.ylabel("Brightness")
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Plot contrast over time
    plt.figure(figsize=(12,6))
    plt.plot(times, contrast, '-o', label='Contrast', color='orange')
    plt.title("Contrast Over Time")
    plt.xlabel("Time (s)")
    plt.ylabel("Contrast")
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Scatter plot of brightness vs contrast
    plt.figure(figsize=(8,6))
    plt.scatter(brightness, contrast, c=times, cmap='viridis')
    plt.colorbar(label='Time (s)')
    plt.title("Brightness vs Contrast")
    plt.xlabel("Brightness")
    plt.ylabel("Contrast")
    plt.tight_layout()
    plt.show()

In [None]:
# Run custom frame analysis on the extracted frames
try:
    print("Performing custom analysis on extracted frames...")
    custom_analyzed_frames = analyze_frames_with_custom_model(frames)
    print("Custom analysis complete!")
    display_frames(custom_analyzed_frames, title="Frames with Custom Analysis")
    visualize_custom_analysis(custom_analyzed_frames)
except Exception as e:
    print(f"Error during custom frame analysis: {str(e)}")

## 8. Cleanup

Finally, clean up any resources and summarize what we've learned in this advanced analysis session.

In [None]:
# Cleanup resources
try:
    if 'browser_extractor' in locals() and browser_extractor._driver is not None:
        browser_extractor._driver.quit()
        print("Browser extractor cleaned up")
    for var in ['frames', 'multi_query_results', 'custom_analyzed_frames']:
        if var in locals():
            locals()[var] = None
    print("Cleanup complete")
except Exception as e:
    print(f"Error during cleanup: {str(e)}")

## Summary

In this notebook, we covered advanced analysis techniques using the YouTube Frame Extractor package:

- **Advanced VLM Analysis:** Using multiple queries to compute similarity scores and create montages.
- **Object Detection & Tracking:** Detecting objects in frames and tracking their occurrences across the video.
- **Temporal Analysis:** Calculating differences between frames to detect scene changes and visualize content variation over time.
- **Custom Model Integration:** Integrating a custom analyzer to compute image statistics such as brightness, contrast, and edge density, and visualizing these metrics.
- **Advanced Batch Processing:** Demonstrating parallel processing of multiple videos (code provided but commented out).

This advanced analysis notebook extends the basic capabilities by incorporating richer visualization and detailed analytics, making it a powerful tool for extracting actionable insights from video content.