# Segmented Facial Anonymization Ablation Study

This notebook demonstrates the segmented facial anonymization with 3 masks and 3 operators.

In [None]:
import torch
from transformers import CLIPImageProcessor, CLIPVisionModel
from diffusers import AutoencoderKL, DDPMScheduler
from diffusers.utils import load_image, make_image_grid
import face_alignment
from PIL import Image
import numpy as np
import cv2

from src.diffusers.models.referencenet.referencenet_unet_2d_condition import ReferenceNetModel
from src.diffusers.models.referencenet.unet_2d_condition import UNet2DConditionModel
from src.diffusers.pipelines.referencenet.pipeline_referencenet import StableDiffusionReferenceNetPipeline

from utils.segmented_anonymization import anonymize_faces_segmented
from utils.segmentation import get_mask_from_landmarks, visualize_mask, get_segmented_regions
from utils.extractor import get_transform_mat, FaceType

## Load Models

In [None]:
face_model_id = "hkung/face-anon-simple"
clip_model_id = "openai/clip-vit-large-patch14"
sd_model_id = "stabilityai/stable-diffusion-2-1"

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

unet = UNet2DConditionModel.from_pretrained(face_model_id, subfolder="unet", use_safetensors=True)
referencenet = ReferenceNetModel.from_pretrained(face_model_id, subfolder="referencenet", use_safetensors=True)
conditioning_referencenet = ReferenceNetModel.from_pretrained(face_model_id, subfolder="conditioning_referencenet", use_safetensors=True)
vae = AutoencoderKL.from_pretrained(sd_model_id, subfolder="vae", use_safetensors=True)
scheduler = DDPMScheduler.from_pretrained(sd_model_id, subfolder="scheduler", use_safetensors=True)
feature_extractor = CLIPImageProcessor.from_pretrained(clip_model_id, use_safetensors=True)
image_encoder = CLIPVisionModel.from_pretrained(clip_model_id, use_safetensors=True)

pipe = StableDiffusionReferenceNetPipeline(
    unet=unet,
    referencenet=referencenet,
    conditioning_referencenet=conditioning_referencenet,
    vae=vae,
    feature_extractor=feature_extractor,
    image_encoder=image_encoder,
    scheduler=scheduler,
)
pipe = pipe.to(device, dtype=dtype)

# Initialize Face Alignment
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, face_detector="sfd", device=device)

## Define Masks and Operators

## Visualizar Máscaras de Segmentação

Esta seção demonstra como visualizar as máscaras de segmentação facial antes de aplicar a anonimização.


In [None]:
# Carregar imagem de teste
image_path = "my_dataset/test/14795.png"
original_image = load_image(image_path)
image_np = np.array(original_image)

# Garantir RGB
if image_np.shape[2] == 4:
    image_np = image_np[:, :, :3]

# Detectar landmarks
preds = fa.get_landmarks(image_np)

if preds is not None:
    # Visualizar máscaras para cada rosto detectado
    face_image_size = 512
    
    for idx, landmarks in enumerate(preds):
        # Obter matriz de transformação
        mat = get_transform_mat(landmarks, face_image_size, FaceType.WHOLE_FACE)
        
        # Extrair face alinhada
        face_aligned = cv2.warpAffine(
            image_np,
            mat,
            (face_image_size, face_image_size),
            cv2.INTER_LANCZOS4,
            borderValue=(255, 255, 255),
        )
        
        # Transformar landmarks para o espaço alinhado
        pts = np.array([landmarks], dtype=np.float32)
        aligned_landmarks = cv2.transform(pts, mat)[0]
        
        # Visualizar máscaras definidas
        print(f"Visualizando máscaras para rosto {idx + 1}:")
        visualization_images = []
        
        for mask_name, features in masks.items():
            # Gerar máscara
            mask = get_mask_from_landmarks(
                aligned_landmarks,
                (face_image_size, face_image_size),
                features,
                dilate_radius=3,
                smooth_edges=True
            )
            
            # Visualizar máscara sobreposta na face alinhada
            vis_image = visualize_mask(face_aligned, mask, alpha=0.5)
            visualization_images.append(Image.fromarray(vis_image))
            print(f"  - {mask_name}: {features}")
        
        # Criar grid de visualizações
        if visualization_images:
            grid = make_image_grid(visualization_images, rows=1, cols=len(masks))
            display(grid)
            print(f"\nMáscaras visualizadas para rosto {idx + 1}\n")
else:
    print("Nenhum rosto detectado na imagem.")


In [None]:
# Masks definition
# (i) eyes + mouth + nostrils
mask1 = ['eyes', 'mouth', 'nostrils']

# (ii) + eyebrows
mask2 = mask1 + ['eyebrows']

# (iii) + lip contour + teeth
# Note: 'mouth' in mask1 typically covers lips and teeth if it refers to the outer lip boundary.
# If we want to be explicit or if 'mouth' meant something else, we add them.
# Here we add 'lips' and 'teeth' explicitly, though they might be redundant if 'mouth' is full.
# To ensure we cover everything requested:
mask3 = mask2 + ['lips', 'teeth']

masks = {
    "Mask 1 (Eyes+Mouth+Nostrils)": mask1,
    "Mask 2 (+Eyebrows)": mask2,
    "Mask 3 (+Lips+Teeth)": mask3
}

operators = ['blur', 'mosaic', 'diffusion']

## Run Ablation

In [None]:
# Load test image
image_path = "my_dataset/test/14795.png" # Using an example from the repo
original_image = load_image(image_path)

results = []
labels = []

generator = torch.manual_seed(42)

for mask_name, features in masks.items():
    row_images = []
    for op in operators:
        print(f"Processing {mask_name} with {op}...")
        anon_image = anonymize_faces_segmented(
            image=original_image,
            face_alignment_model=fa,
            mask_features=features,
            operator_type=op,
            pipe=pipe,
            generator=generator,
            # Operator specific params
            kernel_size=(31, 31), # For blur
            block_size=15,        # For mosaic
            num_inference_steps=30, # For diffusion (faster for demo)
            guidance_scale=4.0,
            anonymization_degree=1.25
        )
        row_images.append(anon_image)
        labels.append(f"{mask_name}\n{op}")
    results.extend(row_images)

# Display results
grid = make_image_grid(results, rows=3, cols=3)
grid.save("ablation_results.png")
grid