# Disentangled Representation Learning

Steps performed:
1. Load Stable Diffusion model
2. Generate or encode images
3. Extract race vector
4. Generate counterfactuals
5. Evaluate results

In [None]:
# ── Colab / Local Setup ──────────────────────────────────────────────────────
# Run this cell first. It clones the repo and installs dependencies on Colab,
# and is a no-op when running locally.

import os, sys

IN_COLAB = "google.colab" in str(get_ipython()) if hasattr(__builtins__, "__import__") else False
try:
    from IPython import get_ipython
    IN_COLAB = "google.colab" in str(get_ipython())
except Exception:
    IN_COLAB = False

if IN_COLAB:
    REPO = "Isolating-Race-Vectors-in-Latent-Space"
    if not os.path.exists(REPO):
        os.system(f"git clone https://github.com/Arnavsharma2/Isolating-Race-Vectors-in-Latent-Space.git")
    os.chdir(REPO)
    os.system("pip install -q -r requirements.txt")
    print("Colab setup complete. GPU available:", os.popen("nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null").read().strip() or "None detected")
else:
    # Running locally — make sure we're in the repo root
    repo_root = os.path.dirname(os.path.abspath("__file__"))
    if "notebooks" in repo_root:
        os.chdir(os.path.dirname(repo_root))
    sys.path.insert(0, "..")
    print("Running locally.")

print("Working directory:", os.getcwd())

import sys
sys.path.insert(0, '..')

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from src.models.stable_diffusion import StableDiffusionWrapper
from src.latent.vector_discovery import RaceVectorExtractor, VectorAnalyzer
from src.metrics.evaluator import CounterfactualEvaluator
from src.visualization.grid_generator import CounterfactualGridGenerator

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from src.models.stable_diffusion import StableDiffusionWrapper
from src.latent.vector_discovery import RaceVectorExtractor
from src.latent.manipulator import LatentManipulator
from src.metrics.evaluator import CounterfactualEvaluator
from src.visualization.grid_generator import CounterfactualGridGenerator

%matplotlib inline
%load_ext autoreload
%autoreload 2

## 1. Load Model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = StableDiffusionWrapper(
    device=device,
    dtype=torch.float16 if device == "cuda" else torch.float32,
    enable_xformers=True,
)

Using device: cpu
Loading Stable Diffusion XL from stabilityai/stable-diffusion-xl-base-1.0...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


✓ Stable Diffusion loaded successfully!


## 2. Generate Test Images

We'll generate a few images with different racial attributes to extract the race vector.

In [None]:
# ── Photo Upload (Colab only) ─────────────────────────────────────────────────
# Skip this cell when running locally — just put your photos directly in
#   data/photos/light_skin/  and  data/photos/dark_skin/
#
# On Colab: run this cell to upload photos via the file picker.
# Need at least 3 photos per group for a reliable race vector.

if IN_COLAB:
    from google.colab import files
    import os

    os.makedirs("data/photos/light_skin", exist_ok=True)
    os.makedirs("data/photos/dark_skin", exist_ok=True)

    print("── Step 1: Upload LIGHT SKIN photos ──")
    uploaded_light = files.upload()
    for name, data in uploaded_light.items():
        with open(f"data/photos/light_skin/{name}", "wb") as f:
            f.write(data)
    print(f"Saved {len(uploaded_light)} light-skin photo(s).\n")

    print("── Step 2: Upload DARK SKIN photos ──")
    uploaded_dark = files.upload()
    for name, data in uploaded_dark.items():
        with open(f"data/photos/dark_skin/{name}", "wb") as f:
            f.write(data)
    print(f"Saved {len(uploaded_dark)} dark-skin photo(s).")
else:
    print("Local run — photos should already be in data/photos/light_skin/ and dark_skin/")
    print("Light skin:", len(list(__import__('pathlib').Path('data/photos/light_skin').glob('*.*'))), "files")
    print("Dark skin: ", len(list(__import__('pathlib').Path('data/photos/dark_skin').glob('*.*'))), "files")

In [None]:
# USING REAL PHOTOS - This cell loads your downloaded photos
from pathlib import Path
from PIL import Image

# Setup paths to your photo directories
LIGHT_PHOTOS_DIR = Path("../data/photos/light_skin")
DARK_PHOTOS_DIR = Path("../data/photos/dark_skin")

# Load light skin photos (supports JPG, JPEG, PNG)
light_images = []
light_latents = []
light_files = (list(LIGHT_PHOTOS_DIR.glob("*.jpg")) + 
               list(LIGHT_PHOTOS_DIR.glob("*.jpeg")) + 
               list(LIGHT_PHOTOS_DIR.glob("*.png")))

print(f"Found {len(light_files)} light skin photos")

for img_path in sorted(light_files)[:10]:  # Limit to 10
    print(f"Encoding {img_path.name}...")
    img = Image.open(img_path).convert("RGB")
    # Resize to 512x512 for SD
    img = img.resize((512, 512), Image.LANCZOS)
    latent = model.encode_image(img)
    light_images.append(img)
    light_latents.append(latent)

# Load dark skin photos  
dark_images = []
dark_latents = []
dark_files = (list(DARK_PHOTOS_DIR.glob("*.jpg")) + 
              list(DARK_PHOTOS_DIR.glob("*.jpeg")) + 
              list(DARK_PHOTOS_DIR.glob("*.png")))

print(f"Found {len(dark_files)} dark skin photos")

for img_path in sorted(dark_files)[:10]:  # Limit to 10
    print(f"Encoding {img_path.name}...")
    img = Image.open(img_path).convert("RGB")
    img = img.resize((512, 512), Image.LANCZOS)
    latent = model.encode_image(img)
    dark_images.append(img)
    dark_latents.append(latent)

print(f"\n✓ Loaded {len(light_images)} light and {len(dark_images)} dark photos")

if len(light_images) == 0 or len(dark_images) == 0:
    print("\n⚠️  WARNING: No photos loaded!")
    print(f"   Light skin directory: {LIGHT_PHOTOS_DIR.absolute()}")
    print(f"   Dark skin directory: {DARK_PHOTOS_DIR.absolute()}")
    print("   Make sure you have photos in both folders!")

Found 3 light skin photos
Encoding image copy 2.png...
Encoding image copy.png...
Encoding image.png...
Found 3 dark skin photos
Encoding image copy 2.png...
Encoding image copy.png...
Encoding image.png...

✓ Loaded 3 light and 3 dark photos


In [5]:
# Generate images with light skin tone
# Using VERY SPECIFIC prompts with extreme contrasts for better vector extraction
light_images = []
light_latents = []

prompts_light = [
    "portrait photo of a caucasian person with extremely pale white fair skin, light complexion, neutral expression, studio lighting, plain white background, professional photography",
    "headshot photograph of a person with very light fair skin tone and pale complexion, neutral face, even lighting, simple backdrop, high detail",
    "professional portrait of a light-skinned person with fair pale complexion, centered, neutral background, sharp focus",
]

for i, prompt in enumerate(prompts_light):
    print(f"Generating light skin image {i+1}/{len(prompts_light)}...")
    img, lat = model.generate_from_prompt(
        prompt, 
        negative_prompt="dark skin, tan skin, brown skin, tanned, multiple people, accessories, jewelry, glasses, shadows",
        seed=42+i, 
        num_inference_steps=50,
        guidance_scale=8.0,  # Higher guidance for stronger adherence to prompt
    )
    light_images.append(img)
    light_latents.append(lat)

# Generate images with dark skin tone
dark_images = []
dark_latents = []

prompts_dark = [
    "portrait photo of an african person with extremely dark deep black skin, very dark complexion, neutral expression, studio lighting, plain white background, professional photography",
    "headshot photograph of a person with very dark deep skin tone and black complexion, neutral face, even lighting, simple backdrop, high detail",
    "professional portrait of a dark-skinned person with deep black complexion, centered, neutral background, sharp focus",
]

for i, prompt in enumerate(prompts_dark):
    print(f"Generating dark skin image {i+1}/{len(prompts_dark)}...")
    img, lat = model.generate_from_prompt(
        prompt, 
        negative_prompt="light skin, pale skin, fair skin, white skin, caucasian, multiple people, accessories, jewelry, glasses, shadows",
        seed=1042+i, 
        num_inference_steps=50,
        guidance_scale=8.0,  # Higher guidance
    )
    dark_images.append(img)
    dark_latents.append(lat)

print("Done!")
print()
print("IMPORTANT: Visually inspect the images below!")
print("- Light images should have VERY pale/white skin")
print("- Dark images should have VERY dark/black skin")
print("- If they look similar, the race vector will not work!")

Generating light skin image 1/3...


  0%|          | 0/50 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Visualize generated images
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for i in range(3):
    axes[0, i].imshow(light_images[i])
    axes[0, i].set_title(f"Light Skin {i+1}")
    axes[0, i].axis('off')

for i in range(3):
    axes[1, i].imshow(dark_images[i])
    axes[1, i].set_title(f"Dark Skin {i+1}")
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

print("These images will be used to extract the race vector.")

In [None]:
# DIAGNOSTIC: Check if latents are actually different
print("=" * 60)
print("DIAGNOSTIC: Checking latent differences")
print("=" * 60)

# Compute average latent for each group
avg_light = torch.stack(light_latents).mean(dim=0)
avg_dark = torch.stack(dark_latents).mean(dim=0)

# Compute raw difference
raw_diff = avg_dark - avg_light

print(f"Average light latent: mean={avg_light.mean().item():.4f}, std={avg_light.std().item():.4f}")
print(f"Average dark latent:  mean={avg_dark.mean().item():.4f}, std={avg_dark.std().item():.4f}")
print(f"Raw difference:       mean={raw_diff.mean().item():.4f}, std={raw_diff.std().item():.4f}")
print(f"Difference magnitude: {raw_diff.norm().item():.4f}")
print()

# Check if the difference is meaningful
if raw_diff.norm().item() < 10.0:
    print("⚠️  WARNING: Latent differences are very small!")
    print("   This means the generated images are too similar.")
    print("   The race vector will NOT work properly.")
    print()
    print("   SOLUTIONS:")
    print("   1. Check the generated images - do they look different?")
    print("   2. Try more extreme prompts (see suggestions below)")
    print("   3. Use real photos instead of generated images")
else:
    print("✓ Latent differences look reasonable")
    
print("=" * 60)

In [None]:
# DIAGNOSTIC: Compute average skin color from images
import numpy as np

def get_center_region_color(img):
    """Extract average color from center region (face area)"""
    img_array = np.array(img)
    h, w = img_array.shape[:2]
    # Center 50% region
    y1, y2 = h//4, 3*h//4
    x1, x2 = w//4, 3*w//4
    center = img_array[y1:y2, x1:x2]
    return center.mean(axis=(0, 1))

print("=" * 60)
print("VISUAL DIAGNOSTIC: Average skin colors")
print("=" * 60)

# Get average colors
light_colors = [get_center_region_color(img) for img in light_images]
dark_colors = [get_center_region_color(img) for img in dark_images]

avg_light_color = np.mean(light_colors, axis=0)
avg_dark_color = np.mean(dark_colors, axis=0)

print(f"Light skin avg RGB: {avg_light_color}")
print(f"Dark skin avg RGB:  {avg_dark_color}")
print(f"Difference:         {avg_light_color - avg_dark_color}")
print()

# Check brightness difference (simple metric)
light_brightness = avg_light_color.mean()
dark_brightness = avg_dark_color.mean()
brightness_diff = light_brightness - dark_brightness

print(f"Light brightness: {light_brightness:.1f}")
print(f"Dark brightness:  {dark_brightness:.1f}")
print(f"Difference:       {brightness_diff:.1f}")
print()

if abs(brightness_diff) < 20:
    print("⚠️  CRITICAL: Images are TOO SIMILAR in brightness!")
    print("   Difference < 20 means race vector will be weak/broken")
    print()
    print("   ACTION REQUIRED:")
    print("   1. Look at the images above - do they actually look different?")
    print("   2. If not different enough, regenerate with more extreme prompts")
    print("   3. OR use real photographs instead of generated images")
elif abs(brightness_diff) < 40:
    print("⚠️  Warning: Images have small brightness difference")
    print("   Race vector may be weak. Ideal difference: 50-100+")
else:
    print(f"✓ Good brightness contrast! ({abs(brightness_diff):.1f})")
    print("  Race vector should work well")

print("=" * 60)

## 3. Extract Race Vector

Compute the average difference between light and dark skin latent codes.

### Key Improvements to Prevent "Black Halo/Fog" Artifacts

To ensure proper disentanglement and avoid background artifacts:

1. **Balanced Spatial Masking**: 
   - `edge_weight=0.0` ensures NO changes to background
   - `radius=0.8` creates a smooth, natural transition (larger = smoother)
   - Gaussian falloff prevents visible circle boundaries
   - **Key tradeoff**: Larger radius = smoother transition but may affect some background
   
2. **Proportionally Smaller Alphas**: 
   - Since the mask concentrates the vector, we need smaller alpha values
   - Using `±0.8` instead of `±2.0` for natural results
   
3. **Better Prompts**: Using specific prompts focused ONLY on skin tone differences

4. **Negative Prompts**: Explicitly excluding unwanted variations

5. **More Training Examples**: Using 3 examples per group for robust extraction

**Understanding the Radius Parameter**:
- **Small (0.3-0.5)**: Very tight focus, may create visible circle, needs tiny alphas
- **Medium (0.6-0.8)**: Balanced - smooth transition, good for most cases ✓
- **Large (0.9-1.2)**: Very smooth but may affect background edges

**If you see artifacts**:
- **Black circle/ring**: Radius too small OR alphas too large
  → Try `radius=0.8` with `alphas = [-0.8, -0.4, 0.0, 0.4, 0.8]`
- **Black fog**: Radius too large OR edge_weight > 0
  → Try `radius=0.6` with `edge_weight=0.0`
- **No visible change**: Alphas too small
  → Increase gradually: try `[-1.0, -0.5, 0.0, 0.5, 1.0]`

In [None]:
# --- DEBUGGING START (Agent Injected) ---
import torch
if 'light_latents' in locals() and len(light_latents) > 0:
    l_stack = torch.stack(light_latents).float()
    d_stack = torch.stack(dark_latents).float()
    print(f'Light Latents: Mean={l_stack.mean():.4f}, Std={l_stack.std():.4f}, Range=[{l_stack.min():.4f}, {l_stack.max():.4f}]')
    print(f'Dark Latents:  Mean={d_stack.mean():.4f}, Std={d_stack.std():.4f}, Range=[{d_stack.min():.4f}, {d_stack.max():.4f}]')
    diff = d_stack - l_stack
    print(f'Raw Difference: Mean={diff.mean():.4f}, Norm={diff.norm():.4f}')
else:
    print('WARNING: Latents not found or empty.')
# --- DEBUGGING END ---

# --- DEBUGGING START (Agent Injected) ---
import torch
if 'light_latents' in locals() and len(light_latents) > 0:
    l_stack = torch.stack(light_latents).float()
    d_stack = torch.stack(dark_latents).float()
    print(f'Light Latents: Mean={l_stack.mean():.4f}, Std={l_stack.std():.4f}, Range=[{l_stack.min():.4f}, {l_stack.max():.4f}]')
    print(f'Dark Latents:  Mean={d_stack.mean():.4f}, Std={d_stack.std():.4f}, Range=[{d_stack.min():.4f}, {d_stack.max():.4f}]')
    diff = d_stack - l_stack
    print(f'Raw Difference: Mean={diff.mean():.4f}, Norm={diff.norm():.4f}')
else:
    print('WARNING: Latents not found or empty.')
# --- DEBUGGING END ---

extractor = RaceVectorExtractor(device=device)

# Create spatial mask to focus on face region (center) and minimize background effects
# Using a LARGER radius with gaussian falloff for natural-looking transitions
latent_shape = light_latents[0].shape  # (4, H, W) or (1, 4, H, W)
if len(latent_shape) == 4:
    h, w = latent_shape[-2], latent_shape[-1]
else:
    h, w = latent_shape[-2], latent_shape[-1]

spatial_mask = extractor.create_center_mask(
    height=h,
    width=w,
    center_weight=1.0,
    edge_weight=0.0,      # Zero weight at edges - NO background changes
    falloff='gaussian',   # Smooth falloff (NOT hard - avoids visible circles)
    radius=0.8,          # LARGER radius for smoother, more natural transition
)

print(f"Created spatial mask with shape: {spatial_mask.shape}")
print(f"Mask: center={spatial_mask[h//2, w//2].item():.4f}, edge={spatial_mask[0, 0].item():.4f}")

race_vector = extractor.extract_from_pairs(
    light_latents,
    dark_latents,
    normalize=False,  # Preserve magnitude for better control
    spatial_mask=spatial_mask,
)

print(f"Race vector shape: {race_vector.shape}")
print(f"Race vector norm: {race_vector.norm().item():.4f}")
print(f"✓ Using larger radius (0.8) with smooth Gaussian falloff")

# --- ANALYSIS & FIX START (Agent Injected) ---
print(f'Race Vector: Norm={race_vector.norm():.4f}, Mean={race_vector.mean():.4f}')
if race_vector.norm() < 1.0 or race_vector.norm() > 1000.0:
    print('⚠️  Vector magnitude suspect. Force-enabling normalization for stability.')
    race_vector = extractor.extract_from_pairs(
        light_latents, dark_latents, normalize=True, spatial_mask=spatial_mask
    )
    print(f'New Normalized Vector Norm: {race_vector.norm():.4f}')
# --- ANALYSIS & FIX END ---

# --- ANALYSIS & FIX START (Agent Injected) ---
print(f'Race Vector: Norm={race_vector.norm():.4f}, Mean={race_vector.mean():.4f}')
if race_vector.norm() < 1.0 or race_vector.norm() > 1000.0:
    print('⚠️  Vector magnitude suspect. Force-enabling normalization for stability.')
    race_vector = extractor.extract_from_pairs(
        light_latents, dark_latents, normalize=True, spatial_mask=spatial_mask
    )
    print(f'New Normalized Vector Norm: {race_vector.norm():.4f}')
# --- ANALYSIS & FIX END ---


In [None]:
# Visualize the spatial mask
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Spatial mask
axes[0].imshow(spatial_mask.cpu().numpy(), cmap='hot')
axes[0].set_title('Spatial Mask\n(High weight = center/face, Low weight = edges/background)')
axes[0].set_xlabel('Width')
axes[0].set_ylabel('Height')
plt.colorbar(axes[0].images[0], ax=axes[0], label='Weight')

# Analyze vector properties
from src.latent.vector_discovery import VectorAnalyzer

analyzer = VectorAnalyzer(device=device)
analysis = analyzer.analyze_spatial_pattern(race_vector)

# Visualize spatial heatmap (averaged across channels)
axes[1].imshow(analysis['spatial_heatmap'].cpu().numpy(), cmap='hot')
axes[1].set_title('Race Vector Spatial Activation Pattern\n(After masking)')
axes[1].set_xlabel('Width')
axes[1].set_ylabel('Height')
plt.colorbar(axes[1].images[0], ax=axes[1], label='Magnitude')

plt.tight_layout()
plt.show()

print(f"Total magnitude: {analysis['total_magnitude']:.4f}")
print(f"Center-focused: The mask ensures changes concentrate on the face region")

# Generate counterfactuals using steered denoising
# Each image is generated from the same seed as the base image, with the race
# vector injected at every denoising step. This avoids the "misty" VAE
# artifacts that come from directly adding a vector to a final latent.

alphas = [-4, -2, 0, 2, 4]
BASE_PROMPT = "portrait photo of a person, professional headshot, neutral background, high quality, detailed face"
BASE_SEED = 999

print("Generating counterfactuals (steered denoising)...")
print("Alpha values:", alphas)
print("Negative = lighter skin, Positive = darker skin\n")

counterfactual_images = []

for alpha in alphas:
    print(f"  α = {alpha:+.1f} ...", end=" ", flush=True)

    if abs(alpha) < 0.01:
        counterfactual_images.append(base_image)
        print("(base image, skipped)")
        continue

    img, _ = model.generate_steered(
        prompt=BASE_PROMPT,
        race_vector=race_vector,
        alpha=alpha,
        seed=BASE_SEED,
        negative_prompt="multiple people, accessories, jewelry, glasses",
        num_inference_steps=50,
        guidance_scale=7.5,
    )
    counterfactual_images.append(img)
    print("done")

print(f"\nGenerated {len(counterfactual_images)} counterfactual images.")
print("Tip — if the effect is too subtle: try alphas = [-6, -3, 0, 3, 6]")
print("Tip — if the face changes too much: try alphas = [-2, -1, 0, 1, 2]")

In [None]:
# Generate base image with similar quality settings
print("Generating base image...")
base_image, base_latent = model.generate_from_prompt(
    "portrait photo of a person, professional headshot, neutral background, high quality, detailed face",
    negative_prompt="multiple people, accessories, jewelry, glasses",
    seed=999,
    num_inference_steps=50,  # Same as training images
    guidance_scale=7.5,
)

plt.figure(figsize=(6, 6))
plt.imshow(base_image)
plt.title("Base Image")
plt.axis('off')
plt.show()

In [None]:
# ── Tuning Cell ───────────────────────────────────────────────────────────────
# Adjust alpha range and mask settings here if results aren't right.
# Run this cell to regenerate with new settings, then re-run the viz cell.

ALPHA_RANGE  = 4    # symmetric: generates [-ALPHA_RANGE, ..., 0, ..., +ALPHA_RANGE]
MASK_RADIUS  = 1.0  # 0.6–1.2 recommended; larger = smoother spatial transition
EDGE_WEIGHT  = 0.3  # 0.0 = pure center mask, 0.3 = soft blend at edges

print("=" * 60)
print("TUNING PARAMETERS")
print(f"  Alpha range : ±{ALPHA_RANGE}")
print(f"  Mask radius : {MASK_RADIUS}")
print(f"  Edge weight : {EDGE_WEIGHT}")
print("=" * 60)

# Re-extract race vector with updated mask
latent_shape = light_latents[0].shape
h, w = latent_shape[-2], latent_shape[-1]

spatial_mask = extractor.create_center_mask(
    height=h, width=w,
    center_weight=1.0,
    edge_weight=EDGE_WEIGHT,
    falloff="gaussian",
    radius=MASK_RADIUS,
)

race_vector = extractor.extract_from_pairs(
    light_latents, dark_latents,
    normalize=False,
    spatial_mask=spatial_mask,
)
print(f"Re-extracted vector (norm: {race_vector.norm().item():.4f})")

# Re-generate counterfactuals with updated alpha range
alphas = [-ALPHA_RANGE, -ALPHA_RANGE // 2, 0, ALPHA_RANGE // 2, ALPHA_RANGE]
counterfactual_images = []

for alpha in alphas:
    print(f"  α = {alpha:+.1f} ...", end=" ", flush=True)
    if abs(alpha) < 0.01:
        counterfactual_images.append(base_image)
        print("(base)")
        continue
    img, _ = model.generate_steered(
        prompt=BASE_PROMPT,
        race_vector=race_vector,
        alpha=alpha,
        seed=BASE_SEED,
        negative_prompt="multiple people, accessories, jewelry, glasses",
        num_inference_steps=50,
        guidance_scale=7.5,
    )
    counterfactual_images.append(img)
    print("done")

print("\nDone. Run the visualization cell below to see results.")

In [None]:
### Quick Fix: If You See Black Fog/Halos

# If counterfactuals show artifacts, uncomment and run this cell to re-extract with stronger masking:


# Re-extract with hard mask (sharp cutoff)
latent_shape = light_latents[0].shape
h, w = (latent_shape[-2], latent_shape[-1]) if len(latent_shape) == 4 else (latent_shape[-2], latent_shape[-1])

spatial_mask_hard = extractor.create_center_mask(
    height=h, width=w,
    center_weight=1.0, edge_weight=0.0,
    falloff='hard',  # Sharp cutoff instead of gaussian
    radius=0.4,      # Even tighter focus (vs 0.5)
)
race_vector = extractor.extract_from_pairs(
    light_latents, dark_latents,
    normalize=False, spatial_mask=spatial_mask_hard,
)
print(f"Re-extracted with hard mask: {race_vector.norm().item():.4f}")

# Then re-run the counterfactual generation with smaller alphas
alphas = [-1.0, -0.5, 0.0, 0.5, 1.0]
# ... (continue with counterfactual generation)


In [None]:
# TUNING CELL: Adjust these parameters if you see artifacts
# This cell is ready to run - just modify the values below

# Current settings (modify as needed):
USE_HARD_MASK = False  # Set to True for sharp cutoff (may create visible circle)
MASK_RADIUS = 0.8      # 0.6-0.8 recommended, larger = smoother
ALPHA_SCALE = 0.8      # Scale factor for alphas (0.5 = half strength, 1.5 = 1.5x strength)

print("=" * 60)
print("TUNING PARAMETERS")
print("=" * 60)

if USE_HARD_MASK or MASK_RADIUS != 0.8 or ALPHA_SCALE != 0.8:
    print(f"Re-extracting with custom settings...")
    print(f"  Mask type: {'HARD' if USE_HARD_MASK else 'Gaussian'}")
    print(f"  Radius: {MASK_RADIUS}")
    print(f"  Alpha scale: {ALPHA_SCALE}")
    print()
    
    latent_shape = light_latents[0].shape
    h, w = (latent_shape[-2], latent_shape[-1]) if len(latent_shape) == 4 else (latent_shape[-2], latent_shape[-1])
    
    spatial_mask = extractor.create_center_mask(
        height=h, width=w,
        center_weight=1.0, 
        edge_weight=0.0,
        falloff='hard' if USE_HARD_MASK else 'gaussian',
        radius=MASK_RADIUS,
    )
    
    # Re-extract vector
    race_vector = extractor.extract_from_pairs(
        light_latents, dark_latents,
        normalize=False, 
        spatial_mask=spatial_mask,
    )
    print(f"✓ Re-extracted vector (norm: {race_vector.norm().item():.4f})")
    
    # Use scaled alphas
    base_alphas = [-0.8, -0.4, 0.0, 0.4, 0.8]
    alphas = [a * ALPHA_SCALE for a in base_alphas]
    print(f"✓ Using alphas: {[f'{a:.2f}' for a in alphas]}")
    
    # Re-generate counterfactuals
    counterfactual_latents = manipulator.generate_counterfactuals(base_latent, race_vector, alphas)
    counterfactual_images = [model.decode_latent(lat) for lat in counterfactual_latents]
    
    print(f"✓ Re-generated counterfactuals")
    print()
    print("Next: Run the visualization cell below to see results")
    print("=" * 60)
else:
    print("Using default settings (no changes)")
    print("To tune, modify the values at the top of this cell")
    print("=" * 60)

In [None]:
# Visualize counterfactuals
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for i, (img, alpha) in enumerate(zip(counterfactual_images, alphas)):
    axes[i].imshow(img)
    axes[i].set_title(f"α = {alpha:.1f}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 5. Evaluate Results

Measure identity preservation and disentanglement.

In [None]:
evaluator = CounterfactualEvaluator(device=device)

# Evaluate each counterfactual (skip α=0)
print("Evaluating counterfactuals...\n")

results = []
for i, (cf_image, alpha) in enumerate(zip(counterfactual_images, alphas)):
    if abs(alpha) < 0.01:  # Skip original
        continue
    
    print(f"\nEvaluating α = {alpha:.1f}")
    print("-" * 60)
    
    result = evaluator.evaluate_pair(
        base_image,
        cf_image,
        verbose=True,
    )
    
    results.append(result)

In [None]:
# Visualize metrics
import pandas as pd

df = pd.DataFrame([r.to_dict() for r in results])
df['alpha'] = [a for a in alphas if abs(a) >= 0.01]

# Plot metrics
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Face similarity
axes[0, 0].plot(df['alpha'], df['face_similarity'], 'o-')
axes[0, 0].axhline(y=0.85, color='r', linestyle='--', label='Threshold')
axes[0, 0].set_xlabel('Alpha')
axes[0, 0].set_ylabel('Face Similarity')
axes[0, 0].set_title('Identity Preservation')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Landmark RMSE
if df['landmark_rmse'].notna().any():
    axes[0, 1].plot(df['alpha'], df['landmark_rmse'], 'o-')
    axes[0, 1].axhline(y=5.0, color='r', linestyle='--', label='Threshold')
    axes[0, 1].set_xlabel('Alpha')
    axes[0, 1].set_ylabel('Landmark RMSE (px)')
    axes[0, 1].set_title('Facial Geometry Preservation')
    axes[0, 1].legend()
    axes[0, 1].grid(True)

# Background SSIM
if df['background_ssim'].notna().any():
    axes[1, 0].plot(df['alpha'], df['background_ssim'], 'o-')
    axes[1, 0].axhline(y=0.90, color='r', linestyle='--', label='Threshold')
    axes[1, 0].set_xlabel('Alpha')
    axes[1, 0].set_ylabel('Background SSIM')
    axes[1, 0].set_title('Background Preservation')
    axes[1, 0].legend()
    axes[1, 0].grid(True)

# Overall score
axes[1, 1].plot(df['alpha'], df['overall_score'], 'o-')
axes[1, 1].set_xlabel('Alpha')
axes[1, 1].set_ylabel('Overall Score')
axes[1, 1].set_title('Overall Disentanglement Quality')
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

## 6. Create Visualization Grid

In [None]:
generator = CounterfactualGridGenerator()

# Create grid (excluding α=0)
cf_images_no_orig = [img for img, a in zip(counterfactual_images, alphas) if abs(a) >= 0.01]
labels = [f"α={a:.1f}" for a in alphas if abs(a) >= 0.01]
metrics_list = [r.to_dict() for r in results]

grid = generator.generate_grid(
    base_image,
    cf_images_no_orig,
    labels=labels,
    metrics=metrics_list,
    title="Disentangled Race Vector Demonstration",
)

# Display
plt.figure(figsize=(15, 8))
plt.imshow(grid)
plt.axis('off')
plt.show()

# Save
grid.save('../experiments/results/demo_grid.png')
print("Grid saved to: experiments/results/demo_grid.png")