# Vanish - Face Anonymization

Detect and blur faces using Florence-2 + SAM2.

- **Florence-2**: Object detection to find human faces
- **SAM2**: Precise segmentation for clean masking
- **Pixelation**: Apply blur effect to masked regions

## Setup

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q flash_attn timm accelerate einops supervision

# Download SAM2 checkpoints
!mkdir -p models/sam2
!wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt -P models/sam2

# Install SAM2
!git clone -q https://github.com/facebookresearch/segment-anything-2.git
%cd segment-anything-2
!pip install -q -e .
%cd ..

In [None]:
# Load Florence-2
from transformers import AutoModelForCausalLM, AutoProcessor
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-large-ft",
    device_map="cuda",
    trust_remote_code=True,
    torch_dtype=torch.float16
)

processor = AutoProcessor.from_pretrained(
    "microsoft/Florence-2-large-ft",
    trust_remote_code=True
)

print("✓ Florence-2 loaded")

In [None]:
# Core functions
import cv2
import numpy as np
from pathlib import Path
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


def find_all_faces(image):
    """Find all human faces using Florence-2 object detection."""
    prompt = "<OD>"
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(DEVICE)
    
    with torch.inference_mode():
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=2048,
            do_sample=False,
        )
    
    text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    results = processor.post_process_generation(text, task="<OD>", image_size=(image.width, image.height))
    
    faces = [bbox for bbox, label in zip(results["<OD>"]["bboxes"], results["<OD>"]["labels"]) if label == "human face"]
    return faces


def find_main_speakers(image):
    """Find main speaker faces to exclude from blurring."""
    prompt = "<CAPTION_TO_PHRASE_GROUNDING> human face (main speaker)"
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(DEVICE)
    
    with torch.inference_mode():
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=2048,
            do_sample=False,
        )
    
    text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    results = processor.post_process_generation(text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=(image.width, image.height))
    
    speakers = [bbox for bbox, label in zip(results["<CAPTION_TO_PHRASE_GROUNDING>"]["bboxes"], results["<CAPTION_TO_PHRASE_GROUNDING>"]["labels"]) if label == "human face"]
    return speakers


def is_overlapping(box1, box2, threshold=0.7):
    """Check if two boxes overlap significantly."""
    x1_min, y1_min, x1_max, y1_max = box1
    x2_min, y2_min, x2_max, y2_max = box2
    
    x_overlap = max(0, min(x1_max, x2_max) - max(x1_min, x2_min))
    y_overlap = max(0, min(y1_max, y2_max) - max(y1_min, y2_min))
    overlap_area = x_overlap * y_overlap
    
    area1 = (x1_max - x1_min) * (y1_max - y1_min)
    area2 = (x2_max - x2_min) * (y2_max - y2_min)
    
    return overlap_area >= threshold * min(area1, area2)


def find_passerby_faces(image, exclude_speakers=True):
    """Find faces to blur (all faces minus main speakers)."""
    all_faces = find_all_faces(image)
    
    if not exclude_speakers:
        return all_faces
    
    speakers = find_main_speakers(image)
    return [face for face in all_faces if not any(is_overlapping(face, s) for s in speakers)]


def pixelate_region(image, masks, pixel_size=10):
    """Apply pixelation to masked regions."""
    masks = masks.astype(bool)
    height, width = image.shape[:2]
    result = image.copy()
    
    for y in range(0, height, pixel_size):
        for x in range(0, width, pixel_size):
            y_end = min(y + pixel_size, height)
            x_end = min(x + pixel_size, width)
            block = image[y:y_end, x:x_end]
            
            combined_mask = np.zeros(block.shape[:2], dtype=bool)
            for mask in masks:
                combined_mask |= mask[y:y_end, x:x_end]
            
            if combined_mask.any():
                avg_color = [int(np.mean(c[combined_mask])) for c in cv2.split(block)]
                for c in range(3):
                    block[:, :, c][combined_mask] = avg_color[c]
                result[y:y_end, x:x_end] = block
    
    return result


print("✓ Functions defined")

In [None]:
# Load SAM2
SAM2_CHECKPOINT = "models/sam2/sam2_hiera_large.pt"
SAM2_CONFIG = "sam2_hiera_l.yaml"

sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=DEVICE, apply_postprocessing=False)
sam2_predictor = SAM2ImagePredictor(sam2_model)

print("✓ SAM2 loaded")

In [None]:
# Main vanish function
import matplotlib.pyplot as plt

def vanish(image_path, pixel_size=10, exclude_speakers=True, show=True):
    """
    Detect and pixelate faces in an image.
    
    Args:
        image_path: Path to input image
        pixel_size: Size of pixelation blocks
        exclude_speakers: If True, don't blur main speakers
        show: Display the result
    
    Returns:
        Pixelated image as numpy array
    """
    image = Image.open(image_path).convert("RGB")
    
    # Find faces to blur
    faces = find_passerby_faces(image, exclude_speakers=exclude_speakers)
    print(f"Found {len(faces)} face(s) to blur")
    
    if not faces:
        return np.array(image)
    
    # Segment with SAM2
    sam2_predictor.set_image(image)
    masks, scores, logits = sam2_predictor.predict(box=faces, multimask_output=False)
    masks = np.squeeze(masks)
    if masks.ndim == 2:
        masks = np.expand_dims(masks, axis=0)
    
    # Pixelate
    image_array = np.array(image)
    result = pixelate_region(image_array, masks, pixel_size)
    
    if show:
        fig, axes = plt.subplots(1, 2, figsize=(14, 7))
        axes[0].imshow(image)
        axes[0].set_title("Original")
        axes[0].axis("off")
        axes[1].imshow(result)
        axes[1].set_title("Vanished")
        axes[1].axis("off")
        plt.tight_layout()
        plt.show()
    
    return result


print("✓ vanish() ready")

## Process an Image

In [None]:
# Process an image
# result = vanish("your_image.jpg", pixel_size=10)

# To blur ALL faces (including main speakers):
# result = vanish("your_image.jpg", pixel_size=10, exclude_speakers=False)

# Save result:
# Image.fromarray(result).save("output.jpg")

## Process a Video

In [None]:
def vanish_video(input_path, output_path, pixel_size=10, exclude_speakers=True):
    """Process a video frame by frame."""
    cap = cv2.VideoCapture(input_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    print(f"Processing {total} frames...")
    
    frame_num = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        frame_num += 1
        print(f"\rFrame {frame_num}/{total}", end="")
        
        # Convert BGR to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_frame = Image.fromarray(frame_rgb)
        
        # Find and segment faces
        faces = find_passerby_faces(pil_frame, exclude_speakers=exclude_speakers)
        
        if faces:
            sam2_predictor.set_image(pil_frame)
            masks, _, _ = sam2_predictor.predict(box=faces, multimask_output=False)
            masks = np.squeeze(masks)
            if masks.ndim == 2:
                masks = np.expand_dims(masks, axis=0)
            frame_rgb = pixelate_region(frame_rgb, masks, pixel_size)
        
        # Convert back to BGR and write
        out.write(cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR))
    
    cap.release()
    out.release()
    print(f"\n✓ Saved to {output_path}")


# Usage:
# vanish_video("input.mp4", "output.mp4", pixel_size=10)