# Disentangled Representation Learning

Steps performed:
1. Load Stable Diffusion model
2. Generate or encode images
3. Extract race vector
4. Generate counterfactuals
5. Evaluate results

## Setup

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from src.models.stable_diffusion import StableDiffusionWrapper
from src.latent.vector_discovery import RaceVectorExtractor
from src.latent.manipulator import LatentManipulator
from src.metrics.evaluator import CounterfactualEvaluator
from src.visualization.grid_generator import CounterfactualGridGenerator

%matplotlib inline
%load_ext autoreload
%autoreload 2

## 1. Load Model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = StableDiffusionWrapper(
    device=device,
    dtype=torch.float16 if device == "cuda" else torch.float32,
    enable_xformers=True,
)

## 2. Generate Test Images

We'll generate a few images with different racial attributes to extract the race vector.

In [None]:
# Generate images with light skin tone
light_images = []
light_latents = []

prompts_light = [
    "portrait photo of a person with light skin tone, professional headshot, neutral background",
    "photo of a person with fair complexion, studio lighting",
]

for i, prompt in enumerate(prompts_light):
    print(f"Generating light skin image {i+1}/{len(prompts_light)}...")
    img, lat = model.generate_from_prompt(prompt, seed=42+i, num_inference_steps=30)
    light_images.append(img)
    light_latents.append(lat)

# Generate images with dark skin tone
dark_images = []
dark_latents = []

prompts_dark = [
    "portrait photo of a person with dark skin tone, professional headshot, neutral background",
    "photo of a person with deep complexion, studio lighting",
]

for i, prompt in enumerate(prompts_dark):
    print(f"Generating dark skin image {i+1}/{len(prompts_dark)}...")
    img, lat = model.generate_from_prompt(prompt, seed=1042+i, num_inference_steps=30)
    dark_images.append(img)
    dark_latents.append(lat)

print("Done!")

In [None]:
# Visualize generated images
fig, axes = plt.subplots(2, 2, figsize=(10, 10))

axes[0, 0].imshow(light_images[0])
axes[0, 0].set_title("Light 1")
axes[0, 0].axis('off')

axes[0, 1].imshow(light_images[1])
axes[0, 1].set_title("Light 2")
axes[0, 1].axis('off')

axes[1, 0].imshow(dark_images[0])
axes[1, 0].set_title("Dark 1")
axes[1, 0].axis('off')

axes[1, 1].imshow(dark_images[1])
axes[1, 1].set_title("Dark 2")
axes[1, 1].axis('off')

plt.tight_layout()
plt.show()

## 3. Extract Race Vector

Compute the average difference between light and dark skin latent codes.

In [None]:
extractor = RaceVectorExtractor(device=device)

race_vector = extractor.extract_from_pairs(
    light_latents,
    dark_latents,
    normalize=True,
)

print(f"Race vector shape: {race_vector.shape}")
print(f"Race vector norm: {race_vector.norm().item():.4f}")

In [None]:
# Analyze vector properties
from src.latent.vector_discovery import VectorAnalyzer

analyzer = VectorAnalyzer(device=device)
analysis = analyzer.analyze_spatial_pattern(race_vector)

# Visualize spatial heatmap
plt.figure(figsize=(8, 6))
plt.imshow(analysis['spatial_heatmap'].cpu().numpy(), cmap='hot')
plt.colorbar(label='Magnitude')
plt.title('Race Vector Spatial Activation Pattern')
plt.xlabel('Width')
plt.ylabel('Height')
plt.show()

print(f"Total magnitude: {analysis['total_magnitude']:.4f}")

## 4. Generate Counterfactuals

Apply the race vector to a new image at different magnitudes.

In [None]:
# Generate base image
print("Generating base image...")
base_image, base_latent = model.generate_from_prompt(
    "portrait photo of a person, professional headshot, neutral background, high quality",
    seed=999,
    num_inference_steps=30,
)

plt.figure(figsize=(6, 6))
plt.imshow(base_image)
plt.title("Base Image")
plt.axis('off')
plt.show()

In [None]:
# Generate counterfactuals at different alphas
manipulator = LatentManipulator(device=device)

alphas = [-2.0, -1.0, 0.0, 1.0, 2.0]

print("Generating counterfactuals...")
counterfactual_latents = manipulator.generate_counterfactuals(
    base_latent,
    race_vector,
    alphas,
)

# Decode to images
counterfactual_images = []
for lat in counterfactual_latents:
    img = model.decode_latent(lat)
    counterfactual_images.append(img)

print("Done!")

In [None]:
# Visualize counterfactuals
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for i, (img, alpha) in enumerate(zip(counterfactual_images, alphas)):
    axes[i].imshow(img)
    axes[i].set_title(f"α = {alpha:.1f}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 5. Evaluate Results

Measure identity preservation and disentanglement.

In [None]:
evaluator = CounterfactualEvaluator(device=device)

# Evaluate each counterfactual (skip α=0)
print("Evaluating counterfactuals...\n")

results = []
for i, (cf_image, alpha) in enumerate(zip(counterfactual_images, alphas)):
    if abs(alpha) < 0.01:  # Skip original
        continue
    
    print(f"\nEvaluating α = {alpha:.1f}")
    print("-" * 60)
    
    result = evaluator.evaluate_pair(
        base_image,
        cf_image,
        verbose=True,
    )
    
    results.append(result)

In [None]:
# Visualize metrics
import pandas as pd

df = pd.DataFrame([r.to_dict() for r in results])
df['alpha'] = [a for a in alphas if abs(a) >= 0.01]

# Plot metrics
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Face similarity
axes[0, 0].plot(df['alpha'], df['face_similarity'], 'o-')
axes[0, 0].axhline(y=0.85, color='r', linestyle='--', label='Threshold')
axes[0, 0].set_xlabel('Alpha')
axes[0, 0].set_ylabel('Face Similarity')
axes[0, 0].set_title('Identity Preservation')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Landmark RMSE
if df['landmark_rmse'].notna().any():
    axes[0, 1].plot(df['alpha'], df['landmark_rmse'], 'o-')
    axes[0, 1].axhline(y=5.0, color='r', linestyle='--', label='Threshold')
    axes[0, 1].set_xlabel('Alpha')
    axes[0, 1].set_ylabel('Landmark RMSE (px)')
    axes[0, 1].set_title('Facial Geometry Preservation')
    axes[0, 1].legend()
    axes[0, 1].grid(True)

# Background SSIM
if df['background_ssim'].notna().any():
    axes[1, 0].plot(df['alpha'], df['background_ssim'], 'o-')
    axes[1, 0].axhline(y=0.90, color='r', linestyle='--', label='Threshold')
    axes[1, 0].set_xlabel('Alpha')
    axes[1, 0].set_ylabel('Background SSIM')
    axes[1, 0].set_title('Background Preservation')
    axes[1, 0].legend()
    axes[1, 0].grid(True)

# Overall score
axes[1, 1].plot(df['alpha'], df['overall_score'], 'o-')
axes[1, 1].set_xlabel('Alpha')
axes[1, 1].set_ylabel('Overall Score')
axes[1, 1].set_title('Overall Disentanglement Quality')
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

## 6. Create Visualization Grid

In [None]:
generator = CounterfactualGridGenerator()

# Create grid (excluding α=0)
cf_images_no_orig = [img for img, a in zip(counterfactual_images, alphas) if abs(a) >= 0.01]
labels = [f"α={a:.1f}" for a in alphas if abs(a) >= 0.01]
metrics_list = [r.to_dict() for r in results]

grid = generator.generate_grid(
    base_image,
    cf_images_no_orig,
    labels=labels,
    metrics=metrics_list,
    title="Disentangled Race Vector Demonstration",
)

# Display
plt.figure(figsize=(15, 8))
plt.imshow(grid)
plt.axis('off')
plt.show()

# Save
grid.save('../experiments/results/demo_grid.png')
print("Grid saved to: experiments/results/demo_grid.png")

## 7. Summary

Completed steps:
1. Extracted a race vector from paired examples
2. Applied the vector to generate counterfactuals
3. Evaluated identity preservation metrics
4. Created visualization grids

To do:
- Perform with real images
- Optimize race vector 
- Full experiments with 
- Ablation studies 