# StyleGAN Practical Lab: Image Editing and Latent Space Manipulation

## Lab Objectives

- Understand the basic functioning of StyleGAN
- Generate realistic images using pre-trained models
- Project real images to latent space
- Find semantic directions in latent space
- Apply directed edits to images

## 1. Initial Setup and Dependencies

### Install required libraries

In [None]:
# Install dependencies
!pip install torch torchvision
!pip install ninja
!pip install requests pillow numpy matplotlib
!pip install opencv-python
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
!pip install click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3

import os
os.chdir('/content/stylegan2-ada-pytorch')

### Required imports

In [None]:
import pickle
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import requests
import io
from google.colab import files
import cv2

# Configure device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# CPU Warning
if device.type == 'cpu':
    print("⚠️  WARNING: Running on CPU. Expect slower performance!")
    print("   - Image generation: ~10-30 seconds per image")
    print("   - Projection: ~20-60 minutes")
    print("   - Consider using Google Colab with GPU for better performance")

## 2. Loading Pre-trained Model

### Download StyleGAN2 model for faces (256x256 version)

In [None]:
# Download pre-trained StyleGAN2-FFHQ model (256x256 for faster processing)
model_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl'
model_path = 'ffhq.pkl'

if not os.path.exists(model_path):
    print("Downloading StyleGAN2-FFHQ model...")
    response = requests.get(model_url)
    with open(model_path, 'wb') as f:
        f.write(response.content)
    print("Model downloaded successfully!")

### Load the model and configure for 256x256


In [None]:
# Load the generator
with open(model_path, 'rb') as f:
    G = pickle.load(f)['G_ema'].to(device)

# Force 256x256 output for faster processing
original_resolution = G.img_resolution
target_resolution = 256

print(f"Original model resolution: {original_resolution}x{original_resolution}")
print(f"Using resolution: {target_resolution}x{target_resolution}")
print(f"Z latent space dimension: {G.z_dim}")
print(f"W latent space dimension: {G.w_dim}")

# Note: We'll resize outputs to 256x256 for faster processing

**Key Concepts**:

- **Z Space**: Original latent space (Gaussian)
- **W Space**: Intermediate latent space (more disentangled)
- **StyleGAN2 vs StyleGAN1**: Better quality, fewer artifacts, improved architecture

## 3. Basic Image Generation

### Helper function to display images

In [None]:
def tensor_to_pil(tensor, target_size=256):
    """Convert tensor to PIL image"""
    tensor = (tensor + 1) * 127.5  # From [-1,1] to [0,255]
    tensor = tensor.clamp(0, 255).to(torch.uint8)
    tensor = tensor.permute(1, 2, 0).cpu().numpy()
    img = Image.fromarray(tensor)
    
    # Resize for consistent display
    if img.size != (target_size, target_size):
        img = img.resize((target_size, target_size), Image.Resampling.LANCZOS)
    
    return img

def show_images(images, titles=None, figsize=(15, 5)):
    """Display multiple images"""
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize=figsize)
    if n == 1:
        axes = [axes]
    
    for i, (img, ax) in enumerate(zip(images, axes)):
        ax.imshow(img)
        ax.axis('off')
        if titles:
            ax.set_title(titles[i])
    plt.tight_layout()
    plt.show()

### Generate random faces

In [None]:
# Generate random vectors in Z space
num_samples = 4
z = torch.randn([num_samples, G.z_dim]).to(device)

print("Generating random faces...")
# Generate images
with torch.no_grad():
    # Map from Z to W
    w = G.mapping(z, None)
    # Generate images
    imgs = G.synthesis(w)

# Convert and display
pil_images = [tensor_to_pil(img) for img in imgs]
show_images(pil_images, titles=[f'Image {i+1}' for i in range(num_samples)])
print("✅ Generation complete!")

**Reflection Question**: Why does StyleGAN use two latent spaces (Z and W) instead of just one?

## 4. Real Image Projection to Latent Space

### Load input image

In [None]:
def load_image(image_path_or_upload=True, target_size=256):
    """Load image from file or upload"""
    if image_path_or_upload == True:
        # Option 1: Upload image
        print("Upload your image (preferably a face, any size):")
        uploaded = files.upload()
        image_path = list(uploaded.keys())[0]
    else:
        # Option 2: Use example image
        # Download example image
        example_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/5/50/Vd-Orig.png/256px-Vd-Orig.png"
        response = requests.get(example_url)
        image_path = "example_face.jpg"
        with open(image_path, 'wb') as f:
            f.write(response.content)
    
    # Load and process image
    image = Image.open(image_path).convert('RGB')
    image = image.resize((target_size, target_size), Image.Resampling.LANCZOS)
    
    # Convert to tensor
    image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
    image_tensor = (image_tensor / 127.5) - 1  # Normalize to [-1, 1]
    image_tensor = image_tensor.unsqueeze(0).to(device)
    
    return image_tensor, image

# Load image (change to False to use example)
target_tensor, target_pil = load_image(True)
plt.figure(figsize=(5, 5))
plt.imshow(target_pil)
plt.title("Target Image")
plt.axis('off')
plt.show()

### Projection function (inversion)

In [None]:
def project_image(G, target, num_steps=300, lr=0.01):
    """Project an image to latent space W"""
    # Initialize w latent
    w_avg = G.mapping.w_avg.unsqueeze(0).unsqueeze(1).repeat([1, G.mapping.num_ws, 1])
    w = w_avg.clone().detach().requires_grad_(True)
  
    # Optimizer
    optimizer = torch.optim.Adam([w], lr=lr)
  
    # Loss function
    loss_fn = torch.nn.MSELoss()
  
    losses = []
  
    print(f"Starting projection for {num_steps} steps...")
    for step in range(num_steps):
        # Forward pass
        synth_img = G.synthesis(w)
      
        # Resize target to match synthesis output if needed
        if synth_img.shape != target.shape:
            target_resized = F.interpolate(target, size=synth_img.shape[2:], mode='bilinear', align_corners=False)
        else:
            target_resized = target
      
        # Calculate loss
        loss = loss_fn(synth_img, target_resized)
      
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
      
        losses.append(loss.item())
      
        if step % 50 == 0:
            print(f"Step {step}: Loss = {loss.item():.4f}")
  
    print("✅ Projection complete!")
    return w.detach(), losses

# Project the image (reduced steps for CPU compatibility)
cpu_steps = 300 if device.type == 'cpu' else 500
print(f"Projecting image to latent space ({cpu_steps} steps)...")
if device.type == 'cpu':
    print("⏰ This may take 20-60 minutes on CPU...")

projected_w, losses = project_image(G, target_tensor, num_steps=cpu_steps)

### Compare projection result

In [None]:

# Generate projected image
with torch.no_grad():
    projected_img = G.synthesis(projected_w)

# Show comparison
original = tensor_to_pil(target_tensor[0])
reconstructed = tensor_to_pil(projected_img[0])

show_images([original, reconstructed], 
           titles=['Original', 'Reconstructed'], 
           figsize=(10, 5))

# Show loss curve
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.title('Loss during projection')
plt.xlabel('Iteration')
plt.ylabel('MSE Loss')
plt.show()

print(f"Final loss: {losses[-1]:.4f}")

## 5. Latent Space Exploration

### Generate semantic directions

In [None]:
# Pre-calculated semantic directions for FFHQ
# These are approximate directions in W space
def get_semantic_directions():
    """Return basic semantic directions"""
    # Note: These are simplified directions for the example
    # In a real project, you'd use directions calculated with methods like SeFa or InterFaceGAN
    
    directions = {}
    
    # Generate random samples to find approximate directions
    num_samples = 50  # Reduced for CPU
    z_samples = torch.randn([num_samples, G.z_dim]).to(device)
    
    with torch.no_grad():
        w_samples = G.mapping(z_samples, None)
    
    # Calculate mean direction (simplified)
    w_mean = w_samples.mean(dim=0, keepdim=True)
    
    # Directions based on simplified PCA
    # In practice, you'd use more precise pre-calculated directions
    directions['age'] = torch.randn_like(w_mean) * 0.1
    directions['smile'] = torch.randn_like(w_mean) * 0.1
    directions['gender'] = torch.randn_like(w_mean) * 0.1
    
    return directions

print("Loading semantic directions...")
semantic_directions = get_semantic_directions()
print("✅ Semantic directions loaded:", list(semantic_directions.keys()))

### Function to apply edits

In [None]:
def apply_edit(w, direction, strength=1.0):
    """Apply a semantic edit to a latent vector"""
    return w + direction * strength

def show_edit_progression(w_base, direction, strengths=[-2, -1, 0, 1, 2], title="Edit"):
    """Show progression of an edit"""
    images = []
    titles = []
  
    print(f"Generating {title} progression...")
    for strength in strengths:
        w_edited = apply_edit(w_base, direction, strength)
      
        with torch.no_grad():
            img = G.synthesis(w_edited)
      
        images.append(tensor_to_pil(img[0]))
        titles.append(f'{title}: {strength}')
  
    show_images(images, titles, figsize=(20, 4))
    print("✅ Edit progression complete!")

## 6. Directed Image Editing
### Apply different edits to your projected image

In [None]:
# Use the previously projected image
base_w = projected_w

# Apply "age" edit
print("=== Edit: Age ===")
show_edit_progression(base_w, semantic_directions['age'], 
                     strengths=[-1.5, -0.75, 0, 0.75, 1.5], 
                     title="Age")

# Apply "smile" edit
print("\n=== Edit: Smile ===")
show_edit_progression(base_w, semantic_directions['smile'], 
                     strengths=[-1.5, -0.75, 0, 0.75, 1.5], 
                     title="Smile")

# Apply "gender" edit
print("\n=== Edit: Gender ===")
show_edit_progression(base_w, semantic_directions['gender'], 
                     strengths=[-1.5, -0.75, 0, 0.75, 1.5], 
                     title="Gender")

### Combination of multiple edits

In [None]:
def combine_edits(w_base, edits_dict):
    """Combine multiple edits"""
    w_edited = w_base.clone()
  
    for direction_name, strength in edits_dict.items():
        if direction_name in semantic_directions:
            w_edited = apply_edit(w_edited, semantic_directions[direction_name], strength)
  
    return w_edited

# Example: Combine multiple edits
edit_combinations = [
    {'age': 0, 'smile': 0, 'gender': 0},  # Original
    {'age': 1, 'smile': 0.5, 'gender': 0},  # Older and smiling
    {'age': -1, 'smile': -0.5, 'gender': 0.5},  # Younger, less smile
    {'age': 0.5, 'smile': 1, 'gender': -0.5},  # Custom combination
]

images = []
titles = []

print("Generating edit combinations...")
for i, edits in enumerate(edit_combinations):
    w_combined = combine_edits(base_w, edits)
  
    with torch.no_grad():
        img = G.synthesis(w_combined)
  
    images.append(tensor_to_pil(img[0]))
    title = "Original" if i == 0 else f"Combo {i}"
    titles.append(title)

show_images(images, titles, figsize=(16, 4))
print("✅ Edit combinations complete!")

## 7. Latent Space Interpolation
### Interpolation between two faces

In [None]:
# Generate two random faces
z1 = torch.randn([1, G.z_dim]).to(device)
z2 = torch.randn([1, G.z_dim]).to(device)

with torch.no_grad():
    w1 = G.mapping(z1, None)
    w2 = G.mapping(z2, None)

# Interpolation
def interpolate_w(w1, w2, num_steps=5):
    """Interpolate between two w vectors"""
    alphas = np.linspace(0, 1, num_steps)
    interpolated = []
  
    for alpha in alphas:
        w_interp = w1 * (1 - alpha) + w2 * alpha
        interpolated.append(w_interp)
  
    return interpolated, alphas

w_interpolated, alphas = interpolate_w(w1, w2, num_steps=7)

# Generate interpolated images
images = []
print("Generating interpolation sequence...")
for w_interp in w_interpolated:
    with torch.no_grad():
        img = G.synthesis(w_interp)
    images.append(tensor_to_pil(img[0]))

titles = [f'α={alpha:.2f}' for alpha in alphas]
show_images(images, titles, figsize=(21, 3))
print("✅ Interpolation complete!")

## 8. Additional Experiments
### Free exploration of latent space

In [None]:
# Interactive function to explore latent space
def random_walk_in_latent_space(start_w, num_steps=5, step_size=0.3):
    """Random walk in latent space"""
    current_w = start_w.clone()
    path = [current_w.clone()]
  
    for _ in range(num_steps):
        # Random step
        noise = torch.randn_like(current_w) * step_size
        current_w = current_w + noise
        path.append(current_w.clone())
  
    return path

# Start from your projected image
walk_path = random_walk_in_latent_space(projected_w, num_steps=6, step_size=0.2)

# Generate images from the walk
walk_images = []
print("Generating random walk sequence...")
for i, w in enumerate(walk_path):
    with torch.no_grad():
        img = G.synthesis(w)
    walk_images.append(tensor_to_pil(img[0]))

show_images(walk_images[:4], titles=[f'Step {i}' for i in range(4)], figsize=(16, 4))
print("✅ Random walk complete!")

### Save results

In [None]:
# Save your favorite edited image
def save_result(image_tensor, filename="stylegan_result.png"):
    """Save result image"""
    pil_img = tensor_to_pil(image_tensor[0])
    pil_img.save(filename)
    print(f"Image saved as: {filename}")
    return pil_img

# Example: Save a specific edit
favorite_edit = combine_edits(base_w, {'age': 0.5, 'smile': 1.0, 'gender': 0})
with torch.no_grad():
    favorite_img = G.synthesis(favorite_edit)

save_result(favorite_img, "my_edited_face.png")

## 9. Important Theoretical Concepts

### How does StyleGAN work?

**Key Architecture:**

- **Mapping Network**: Z → W (8 FC layers)
- **Synthesis Network**: W → Image (AdaIN at each layer)
- **Discriminator**: Judges image realism

**Advantages of W Space:**

- More **disentangled** (independent features)
- Better **linear interpolation**
- More **semantic control**

### Differences between versions:
```python
# StyleGAN1 (2019): First version, artifact problems
# StyleGAN2 (2020): Improved architecture, fewer artifacts
# StyleGAN3 (2021): Rotation and translation invariance

print("We're using StyleGAN2 - the perfect balance between quality and speed")
```

## 10. Challenges and Exercises
### Exercise 1: Experiment with your own directions

In [None]:
# Create your own custom semantic direction
# Hint: Combine existing directions with different weights

custom_direction = (semantic_directions['age'] * 0.5 + 
                   semantic_directions['smile'] * 0.3)

show_edit_progression(base_w, custom_direction, 
                     strengths=[-2, -1, 0, 1, 2], 
                     title="Custom Direction")

### Exercise 2: Find the "average face"

In [None]:
# Generate multiple faces and calculate average in W space
num_faces = 10
z_batch = torch.randn([num_faces, G.z_dim]).to(device)

with torch.no_grad():
    w_batch = G.mapping(z_batch, None)
  
# Calculate average face
w_average = w_batch.mean(dim=0, keepdim=True)

with torch.no_grad():
    avg_face = G.synthesis(w_average)

plt.figure(figsize=(5, 5))
plt.imshow(tensor_to_pil(avg_face[0]))
plt.title("Average Face")
plt.axis('off')
plt.show()

## Conclusions and Reflections

### What we have learned:

1. **StyleGAN** generates high-quality images using structured latent spaces
2. **Image projection** allows us to edit real photos
3. **W space** is more interpretable than Z for semantic edits
4. **Semantic directions** enable controlled and predictable edits

### Real-world applications:

- **Digital art** and creativity
- **Advanced photo editing**
- **Synthetic dataset** generation
- **Research** in visual representation

### Ethical considerations:

- **Deepfakes** and media manipulation
- **Bias** in training data
- **Consent** in image usage