# Combo 04: Encoder-Based GAN Inversion

**Configuration:**
- **Resolution:** 1024×1024
- **Latent Space:** W+
- **Loss:** LPIPS (Perceptual)
- **Initialization:** e4e Encoder
- **Optimization:** Adam (300 steps, LR=0.01)

This notebook demonstrates encoder-based initialization for GAN inversion using the [Encoder for Editing (e4e)](https://github.com/omertov/encoder4editing) framework.

---

## How It Works

1. **Clone your GAN_Inversion project** (includes external/encoder4editing)
2. **Encoder Init:** Use e4e encoder to predict initial latent code from target image
3. **Optimization:** Refine latent code using Adam + LPIPS loss
4. **Results:** Save to Google Drive for comparison with Combos 1-3

---

## Instructions

1. **Run all cells** (Runtime → Run all)
2. **Wait** for setup (cloning project, installing dependencies, downloading checkpoint)
3. **Results** will be saved to `MyDrive/GAN_Inversion_Results/combo_04/`

The notebook will automatically:
- Clone the project from `https://github.com/assafzimand/GAN_Inversion`
- Clone e4e repository into `external/encoder4editing`
- Load sample images from `data/samples/`
- Download e4e checkpoint (~350 MB)
- Save all results to your Google Drive



## 1. Setup: Mount Google Drive


In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Create output directory
output_base = '/content/drive/MyDrive/GAN_Inversion_Results/combo_04'
os.makedirs(output_base, exist_ok=True)

print(f"Google Drive mounted successfully!")
print(f"Results will be saved to: {output_base}")


## 2. Clone Repositories

Clones your GAN_Inversion project + e4e encoder repository


In [None]:
import sys
from pathlib import Path
import os

# Clone YOUR main project
if not Path('GAN_Inversion').exists():
    print("Cloning GAN_Inversion project...")
    !git clone https://github.com/assafzimand/GAN_Inversion.git GAN_Inversion
    print("✓ Main project cloned!")
else:
    print("✓ Main project already exists")

# Change to project directory
os.chdir('/content/GAN_Inversion')
print(f"✓ Working directory: {os.getcwd()}")

# Clone e4e repository into external/ (if not present)
os.makedirs('external', exist_ok=True)
if not Path('external/encoder4editing').exists():
    print("\nCloning e4e repository into external/...")
    !git clone https://github.com/omertov/encoder4editing.git external/encoder4editing
    print("✓ e4e repository cloned!")
else:
    print("✓ e4e repository already exists")

# Add e4e to path
sys.path.insert(0, '/content/GAN_Inversion/external/encoder4editing')
print("✓ e4e added to Python path")


## 3. Install Dependencies


In [None]:
# Install only the required dependencies (Colab already has torch/torchvision)
print("Installing dependencies...")
%pip install -q lpips gdown Pillow numpy matplotlib PyYAML ninja

# Verify torch is available
import torch
print(f"\n✓ Using PyTorch {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
print("✓ All dependencies ready!")

# Note: e4e will compile CUDA extensions on first import (may take 1-2 minutes)
print("\nNote: e4e CUDA operations will compile on first use (this is normal)")


## 4. Download e4e Checkpoint


In [None]:
import gdown

# Create checkpoints directory in combo_04/
os.makedirs('combo_04/checkpoints', exist_ok=True)

# Download e4e FFHQ encoder checkpoint
checkpoint_path = 'combo_04/checkpoints/e4e_ffhq_encode.pt'

if not os.path.exists(checkpoint_path):
    print("Downloading e4e checkpoint (may take a few minutes)...")
    url = 'https://drive.google.com/uc?id=1cUv_reLE6k3604or78EranS7XzuVMWeO'
    gdown.download(url, checkpoint_path, quiet=False)
    print("✓ Checkpoint downloaded successfully!")
else:
    print("✓ Checkpoint already exists")

print(f"Checkpoint ready at: {checkpoint_path}")


## 5. Load Target Images

Loading sample FFHQ images from `data/samples/` (included in the repo).


In [None]:
from PIL import Image
from pathlib import Path

# Load images from data/samples/
samples_dir = Path('data/samples')
image_paths = list(samples_dir.glob('*.png')) + list(samples_dir.glob('*.jpg'))

print(f"Found {len(image_paths)} images in data/samples/\n")

# Load and resize images
images = []
image_names = []

for img_path in image_paths:
    img = Image.open(img_path).convert('RGB')
    # Resize to 1024x1024 if needed
    if img.size != (1024, 1024):
        img = img.resize((1024, 1024), Image.BICUBIC)
    images.append(img)
    image_names.append(img_path.name)

print(f"Loaded {len(images)} image(s):")
for name in image_names:
    print(f"  - {name}")


## 5.5. Clear CUDA Extension Cache

This ensures e4e CUDA operations compile cleanly.


In [None]:
import shutil
from pathlib import Path

# Clear torch extensions cache
cache_dir = Path.home() / '.cache' / 'torch_extensions'
if cache_dir.exists():
    print("Clearing torch extensions cache...")
    shutil.rmtree(cache_dir)
    print("✓ Cache cleared")
else:
    print("✓ No cache to clear")

print("\ne4e CUDA operations will compile on first import (takes 1-2 minutes)...")


## 6. Load e4e Model

Using their official `setup_model()` function to load the complete pSp framework (encoder + decoder).


In [None]:
import torch
from utils.model_utils import setup_model

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Load e4e model
print("\nLoading e4e model...")
model, opts = setup_model(checkpoint_path, device)
model.eval()

print("Model loaded successfully!")
print(f"  - Encoder type: {opts.encoder_type}")
print(f"  - StyleGAN size: {opts.stylegan_size}")
print(f"  - Model has encoder: {hasattr(model, 'encoder')}")
print(f"  - Model has decoder: {hasattr(model, 'decoder')}")
print(f"  - Model has latent_avg: {hasattr(model, 'latent_avg')}")


## 7. Run Inversion: Encoder Init + Optimization

This is our **custom optimization loop** using:
- **Encoder initialization** (from e4e)
- **LPIPS loss** (perceptual)
- **Adam optimizer** (300 steps, LR=0.01)
- **e4e decoder** for generation


In [None]:
import torch.nn.functional as F
import lpips
import numpy as np
from torch.optim import Adam
from torchvision import transforms
import time

# Configuration (Combo 4: encoder-based initialization)
CONFIG = {
    'steps': 300,
    'learning_rate': 0.01,
    'betas': (0.9, 0.999),
    'log_interval': 50,
    'save_interval': 100,  # Save at 0, 100, 200
}

# Initialize LPIPS loss
lpips_loss_fn = lpips.LPIPS(net='alex').to(device)
lpips_loss_fn.eval()

# Image preprocessing (PIL to tensor)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# Results storage
results = []

print("="*80)
print(f"Starting Combo 4: Processing {len(images)} image(s)")
print("="*80)

# Process all images
for idx, (img, img_name) in enumerate(zip(images, image_names)):
    print(f"\n[{idx+1}/{len(images)}] Processing: {img_name}")
    print("-"*80)
    
    # Convert image to tensor [1, 3, 1024, 1024]
    target_1024 = transform(img).unsqueeze(0).to(device)
    
    # Resize to 256×256 (what e4e encoder expects)
    target_256 = F.interpolate(target_1024, size=(256, 256), mode='bilinear', align_corners=False)
    print(f"  Resized input: {target_1024.shape} → {target_256.shape}")
    
    # Step 1: Encoder initialization (using e4e's full API with resize=True)
    print("Step 1: Encoder initialization...")
    
    with torch.no_grad():
        # Use their full forward API (handles encoder + latent_avg + decoder + face_pool)
        initial_image, initial_latent = model(target_256, resize=True, randomize_noise=False, return_latents=True)
    
    print(f"  Output latent shape: {initial_latent.shape}")
    print(f"  Output image shape: {initial_image.shape}")
    print(f"  Output image stats: min={initial_image.min():.3f}, max={initial_image.max():.3f}, mean={initial_image.mean():.3f}")
    
    # Step 2: Optimization
    print(f"Step 2: Optimizing with Adam (LR={CONFIG['learning_rate']}, Steps={CONFIG['steps']})...")
    
    # Make latent optimizable
    latent = initial_latent.detach().clone().requires_grad_(True)
    
    # Create optimizer
    optimizer = Adam([latent], lr=CONFIG['learning_rate'], betas=CONFIG['betas'])
    
    # Track history
    loss_history = []
    intermediates = {0: initial_image.cpu()}  # Store encoder output as Step 0
    print(f"  Stored Step 0 = encoder output (not optimization step)")
    
    start_time = time.time()
    
    # Optimization loop
    for step in range(CONFIG['steps']):
        optimizer.zero_grad()
        
        # Generate image from latent (decoder outputs 1024×1024, face_pool → 256×256)
        generated, _ = model.decoder([latent], input_is_latent=True, randomize_noise=False, return_latents=True)
        generated = model.face_pool(generated)  # Apply face_pool to get 256×256
        
        # Compute LPIPS loss (comparing 256×256 images)
        loss = lpips_loss_fn(generated, target_256).mean()
        
        # Backward
        loss.backward()
        optimizer.step()
        
        # Track
        loss_history.append(loss.item())
        
        # Save intermediates (skip step 0 since we already have encoder output)
        if step > 0 and step % CONFIG['save_interval'] == 0:
            with torch.no_grad():
                intermediate, _ = model.decoder([latent], input_is_latent=True, randomize_noise=False, return_latents=True)
                intermediate = model.face_pool(intermediate)
                intermediates[step] = intermediate.cpu()
        
        # Log
        if (step + 1) % CONFIG['log_interval'] == 0 or step == 0:
            elapsed = time.time() - start_time
            print(f"  Step [{step+1}/{CONFIG['steps']}] Loss: {loss.item():.6f} Time: {elapsed:.2f}s")
    
    # Final generation
    with torch.no_grad():
        final_output, _ = model.decoder([latent], input_is_latent=True, randomize_noise=False, return_latents=True)
        final_output = model.face_pool(final_output)
        intermediates[CONFIG['steps']-1] = final_output.cpu()
    
    total_time = time.time() - start_time
    
    print(f"\nCompleted in {total_time:.2f}s")
    print(f"Final loss: {loss_history[-1]:.6f}")
    print(f"Initial loss: {loss_history[0]:.6f}")
    print(f"Improvement: {loss_history[0] - loss_history[-1]:.6f}")
    
    # Store results
    results.append({
        'name': img_name,
        'target': target_256.cpu(),  # Store 256×256 target for visualization
        'initial_latent': initial_latent.detach().cpu(),
        'final_latent': latent.detach().cpu(),
        'final_output': final_output.cpu(),
        'loss_history': loss_history,
        'intermediates': intermediates,
        'time': total_time
    })

print("\n" + "="*80)
print(f"✓ Combo 4 complete! Processed {len(results)} image(s)")
print("="*80)


## 8. Visualize Results

Creating evolution panels showing the progression from encoder init to final optimization.


In [None]:
import matplotlib.pyplot as plt
from IPython.display import display, Image as IPImage
import io

def tensor_to_image(tensor):
    """Convert tensor [-1, 1] to numpy image [0, 255]"""
    # Handle any tensor shape, ensure it's [C, H, W]
    img = tensor.cpu().squeeze()  # Remove all size-1 dimensions
    
    # If still has batch dimension, take first item
    while img.dim() > 3:
        img = img[0]
    
    # Should now be [C, H, W], convert to [H, W, C]
    if img.dim() == 3:
        img = img.permute(1, 2, 0)
    
    img = img.numpy()
    img = (img * 0.5 + 0.5) * 255  # Denormalize from [-1, 1] to [0, 255]
    img = np.clip(img, 0, 255).astype(np.uint8)
    return img

# Visualize each result
for result in results:
    img_name = result['name']
    target = result['target']
    intermediates = result['intermediates']
    loss_history = result['loss_history']
    
    # Get intermediate steps (0, 100, 200, final)
    steps_to_show = sorted(intermediates.keys())
    
    # Create evolution panel
    num_images = len(steps_to_show) + 1  # +1 for original
    fig, axes = plt.subplots(1, num_images, figsize=(4*num_images, 4))
    
    # Original
    axes[0].imshow(tensor_to_image(target))
    axes[0].set_title('Original', fontsize=12, fontweight='bold')
    axes[0].axis('off')
    
    # Intermediates
    for idx, step in enumerate(steps_to_show):
        axes[idx+1].imshow(tensor_to_image(intermediates[step]))
        axes[idx+1].set_title(f'Step {step}', fontsize=12)
        axes[idx+1].axis('off')
    
    # Add metrics to title
    final_loss = loss_history[-1]
    initial_loss = loss_history[0]
    improvement = initial_loss - final_loss
    
    fig.suptitle(
        f'{img_name} | LPIPS: {final_loss:.4f} | Improvement: {improvement:.4f}',
        fontsize=14,
        fontweight='bold'
    )
    
    plt.tight_layout()
    display(fig)  # Force display in Colab
    plt.show()
    plt.close(fig)  # Clean up
    
    # Plot loss curve
    fig2 = plt.figure(figsize=(10, 4))
    plt.plot(loss_history, linewidth=2)
    plt.xlabel('Step', fontsize=12)
    plt.ylabel('LPIPS Loss', fontsize=12)
    plt.title(f'Loss Curve - {img_name}', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    display(fig2)  # Force display in Colab
    plt.show()
    plt.close(fig2)  # Clean up
    
    print(f"✓ Visualized: {img_name}\n")


## 9. Save Results to Google Drive


In [None]:
import json
from datetime import datetime

# Create timestamped output directory
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_dir = os.path.join(output_base, f'run_{timestamp}')
os.makedirs(output_dir, exist_ok=True)

print(f"Saving results to: {output_dir}\n")

for result in results:
    img_name = result['name'].rsplit('.', 1)[0]  # Remove extension
    
    # Create subdirectory for this image
    img_dir = os.path.join(output_dir, img_name)
    os.makedirs(img_dir, exist_ok=True)
    
    # Save final reconstruction
    final_img = tensor_to_image(result['final_output'])
    final_img_pil = Image.fromarray(final_img)
    final_img_pil.save(os.path.join(img_dir, f'{img_name}_reconstruction.png'))
    
    # Save evolution panel
    steps_to_show = sorted(result['intermediates'].keys())
    num_images = len(steps_to_show) + 1
    fig, axes = plt.subplots(1, num_images, figsize=(4*num_images, 4))
    
    axes[0].imshow(tensor_to_image(result['target']))
    axes[0].set_title('Original', fontsize=12, fontweight='bold')
    axes[0].axis('off')
    
    for idx, step in enumerate(steps_to_show):
        axes[idx+1].imshow(tensor_to_image(result['intermediates'][step]))
        axes[idx+1].set_title(f'Step {step}', fontsize=12)
        axes[idx+1].axis('off')
    
    final_loss = result['loss_history'][-1]
    initial_loss = result['loss_history'][0]
    improvement = initial_loss - final_loss
    
    fig.suptitle(
        f'{img_name} | LPIPS: {final_loss:.4f} | Improvement: {improvement:.4f}',
        fontsize=14,
        fontweight='bold'
    )
    
    plt.tight_layout()
    plt.savefig(os.path.join(img_dir, f'{img_name}_evolution.png'), dpi=150, bbox_inches='tight')
    plt.close()
    
    # Save loss curve
    plt.figure(figsize=(10, 4))
    plt.plot(result['loss_history'], linewidth=2)
    plt.xlabel('Step', fontsize=12)
    plt.ylabel('LPIPS Loss', fontsize=12)
    plt.title(f'Loss Curve - {img_name}', fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(img_dir, f'{img_name}_loss_curve.png'), dpi=150, bbox_inches='tight')
    plt.close()
    
    # Save metrics
    metrics = {
        'image_name': result['name'],
        'initial_loss': float(initial_loss),
        'final_loss': float(final_loss),
        'improvement': float(improvement),
        'optimization_time': result['time'],
        'num_steps': len(result['loss_history'])
    }
    
    with open(os.path.join(img_dir, f'{img_name}_metrics.json'), 'w') as f:
        json.dump(metrics, f, indent=2)
    
    # Save loss history
    with open(os.path.join(img_dir, f'{img_name}_loss_history.json'), 'w') as f:
        json.dump(result['loss_history'], f)
    
    print(f"✓ Saved: {img_name}")

# Save summary
summary = {
    'timestamp': timestamp,
    'num_images': len(results),
    'config': CONFIG,
    'results': [
        {
            'name': r['name'],
            'initial_loss': float(r['loss_history'][0]),
            'final_loss': float(r['loss_history'][-1]),
            'time': r['time']
        }
        for r in results
    ]
}

with open(os.path.join(output_dir, 'summary.json'), 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\n{'='*80}")
print(f"All results saved to:")
print(f"  {output_dir}")
print(f"{'='*80}")


## ✅ Complete!

### Summary

You've successfully run **Combo 04: Encoder-Based GAN Inversion** with:
- ✅ e4e encoder for initialization
- ✅ LPIPS perceptual loss
- ✅ Adam optimization (300 steps)
- ✅ 1024×1024 resolution

### Results Location

All results have been saved to your Google Drive:
- **Path:** `MyDrive/GAN_Inversion_Results/combo_04/run_YYYYMMDD_HHMMSS/`
- **Contains:** Reconstructions, evolution panels, loss curves, metrics

### Next Steps

1. **Download results** from your Google Drive
2. **Compare with Combos 1-3** (128×128, random/mean init)
3. **Analyze metrics** to see the benefit of encoder initialization

### Key Observations

- **Encoder init** provides a better starting point than random/mean
- **Optimization** further refines the reconstruction
- **LPIPS loss** focuses on perceptual quality over pixel-perfect matching

---

**Thank you for using Combo 04!** 🎉
