# VideoSHAP

### Understanding What Vision Models See in Videos

---

**VideoSHAP** uses game-theoretic Shapley values to reveal which objects in a video are most important for a model's response.

Given a video and a question, VideoSHAP:
1. Segments and tracks objects across frames
2. Tests how removing each object affects the model's answer
3. Computes fair importance scores using Shapley values
4. Visualizes results as an attention heatmap

> **Analogy:** If your video is a movie scene, VideoSHAP tells you which actors are crucial for understanding the plot.

## Setup

In [None]:
import sys
from pathlib import Path

# Add parent directory to path
parent_dir = Path().resolve().parent
if str(parent_dir) not in sys.path:
    sys.path.insert(0, str(parent_dir))

from IPython.display import Image as IPyImage, display

from token_shap.video_shap import (
    VideoSHAP,
    SAM3VideoSegmentationModel,
    VideoBlackoutManipulator,
    GeminiVideoModel,
)
from token_shap.base import OpenAIEmbeddings

In [None]:
# API Keys
OPENAI_API_KEY = "..."
GOOGLE_API_KEY = "..."

## Initialize Components

VideoSHAP requires four components:

| Component | Purpose | Model Used |
|-----------|---------|------------|
| **Segmentation** | Track objects across frames | SAM 3 |
| **VLM** | Answer questions about video | Gemini 2.5 Flash |
| **Manipulator** | Hide objects to test importance | Blackout (bbox) |
| **Vectorizer** | Measure response similarity | OpenAI Embeddings |

In [None]:
# Segmentation: SAM3 for video object tracking
sam3_model = SAM3VideoSegmentationModel(
    model_name="facebook/sam3",
    device="cuda",
)

# VLM: Gemini for video understanding
vlm_model = GeminiVideoModel(
    model_name="gemini-2.5-pro",
    api_key=GOOGLE_API_KEY,
    temperature=0.1,
)

# Manipulator: Blackout objects by bounding box
manipulator = VideoBlackoutManipulator(
    mask_type="bbox",
    preserve_overlapping=True,
)

# Vectorizer: OpenAI embeddings for similarity
vectorizer = OpenAIEmbeddings(
    api_key=OPENAI_API_KEY,
    model="text-embedding-3-large",
)

In [None]:
# Create VideoSHAP analyzer
video_shap = VideoSHAP(
    model=vlm_model,
    segmentation_model=sam3_model,
    manipulator=manipulator,
    vectorizer=vectorizer,
)

---

## Example 1: Cats Watching a Flying Object

A simple scene with cats tracking something flying across the frame.

**Question:** *"Describe to me the object flying in the video."*

In [None]:
results_df, shapley_values = video_shap.analyze(
    video_path="../videos/cats.mp4",
    prompt="Describe to me the object flying in the video.",
    text_prompts=["animal"],
    target_fps=8,
    max_combinations=20,
)

In [None]:
# Create side-by-side visualization GIF (perfectly synchronized)
gif_path = video_shap.create_side_by_side_gif(
    output_path="../images/cats.gif",
    heatmap_opacity=0.5,
    background_opacity=0.3,
)

display(IPyImage(filename=gif_path))

In [None]:
plot = video_shap.plot_importance_ranking()

---

## Example 2: Birthday Party

A more complex scene with multiple people, a cake, and various objects.

**Question:** *"Describe the birthday boy to me."*

In [None]:
results_df, shapley_values = video_shap.analyze(
    video_path="../videos/birthday.mp4",
    prompt="Describe the birthday boy to me",
    text_prompts=["person", "cake", "gift", "balloon"],
    target_fps=8,
    max_combinations=20,
)

In [None]:
# Create side-by-side visualization GIF (perfectly synchronized)
gif_path = video_shap.create_side_by_side_gif(
    output_path="../images/birthday.gif",
    heatmap_opacity=0.5,
    background_opacity=0.3,
)

display(IPyImage(filename=gif_path))

In [None]:
plot = video_shap.plot_importance_ranking()

---

## Exploring Results

After analysis, you can inspect the raw data:

In [None]:
# View responses for different object combinations
video_shap.results_df.head()

In [None]:
# Print Shapley values (importance scores)
print("Object Importance Scores:\n")
for obj, value in sorted(video_shap.shapley_values.items(), key=lambda x: x[1], reverse=True):
    print(f"  {obj}: {value:.3f}")

---

## Reference

### Key Parameters

**`analyze()`**
| Parameter | Description | Default |
|-----------|-------------|---------|
| `video_path` | Path to input video | *required* |
| `prompt` | Question to ask about the video | *required* |
| `text_prompts` | Object categories to detect | *required* |
| `target_fps` | Frames per second to process | 8 |
| `max_combinations` | Max object subsets to test | 20 |

**`create_side_by_side_gif()`**
| Parameter | Description | Default |
|-----------|-------------|---------|
| `output_path` | Output GIF path | *required* |
| `heatmap_opacity` | Blend of heatmap colors | 0.5 |
| `background_opacity` | Visibility of non-objects | 0.3 |

### Tips

- **Better detection:** Add more relevant categories to `text_prompts`
- **Faster analysis:** Reduce `max_combinations` or `target_fps`
- **More detail:** Increase `target_fps` for fast-moving scenes

In [None]:
from IPython.display import Image, display

questions = [
    "Describe the character who pours the drink.",
    "What is the man on the far left wearing?",
    "In one word, who is the weirdest in the video?",
]

text_prompts = ["person", "alien"]

for i, q in enumerate(questions):
    print(f"\n=== Question {i+1}: {q} ===")

    results_df, shapley_values = video_shap.analyze(
        video_path="../videos/alien.mp4",
        prompt=q,
        text_prompts=text_prompts,
        target_fps=8,
        max_combinations=20,
    )

    gif_path = video_shap.create_side_by_side_gif(
        output_path=f"../images/alien_q{i+1}.gif",
        heatmap_opacity=0.5,
        background_opacity=0.3,
    )

    display(Image(filename=gif_path))