Skip to content

alexmoed/diffusion-enhancer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Single-Step Diffusion Detail Restoration

A proof of concept for restoring degraded face renders from 3D Gaussian splatting avatars. Built by fine-tuning pix2pix-turbo (Parmar et al.), which adapts Stable Diffusion Turbo for paired image-to-image translation in a single forward pass.

The Problem

Gaussian splatting is becoming a leading approach for photorealistic avatar generation. But current splat-based avatars produce renders with visible artifacts: blurred facial features, loss of fine detail around eyes and mouth, hair strand merging, and inconsistent skin texture. These artifacts vary by facial region depending on splat density and geometric complexity.

The Approach

This project fine-tunes SD-Turbo with LoRA adapters to learn the mapping from degraded splat renders to clean reference images. A face-parsing driven degradation pipeline generates training pairs by simulating region-specific splat artifacts on CelebA-HQ photos. The model restores detail in a single inference step at interactive speeds.

The degradation pipeline is synthetic, designed as a proxy for real splat render artifacts. In production, training would use actual degraded/clean render pairs from the target avatar pipeline (similar to the approach described in the ELITE paper for avatar enhancement).

Live Demo on Hugging Face Spaces


Results

Metric Value
LPIPS (mean) 0.205
LPIPS (median) 0.198
Test images 600 (unseen during training)
Inference time ~80ms per image (T4 GPU)


Degradation Pipeline

The synthetic degradation simulates artifacts from Gaussian splat rendering in 14 steps. Rather than applying uniform blur, degradation is driven by a Segformer face parser that segments each image into regions.

Region-Specific Degradation

Each facial region receives degradation calibrated to how splats typically fail in that area. Anisotropic variation masks make the degradation patchy rather than uniform, simulating how individual 3D Gaussians project differently depending on position, orientation, and scale.

Per-region degradation strengths:

  • Mouth (0.9): depth discontinuities between teeth, lips, and tongue
  • Eyes (0.7): fine detail loss from competing splats in small areas
  • Hair (0.6): strand clumping from splat merging
  • Nose (0.5): moderate geometric simplification
  • Skin (0.4): texture loss while retaining broad shape

Architecture

The model adapts Stability AI's SD-Turbo for image-to-image restoration rather than text-to-image generation.

  1. The VAE encoder (frozen) compresses the degraded 512x512 input into a 64x64 latent
  2. The UNet processes the latent at timestep t=999, treating it as heavily corrupted and applying full restoration
  3. The VAE decoder reconstructs the output with skip connections bridging encoder to decoder at four resolution levels

Only LoRA adapters and skip convolutions are trainable. The pretrained SD-Turbo weights stay frozen, preserving the model's learned image priors while learning the specific degraded-to-clean mapping.

Component Details
Base model SD-Turbo
LoRA rank (UNet) 16
LoRA rank (VAE decoder) 8
Skip connections 4x 1x1 conv, initialized near-zero
Skip dropout 0.3 during training
Trainable parameters ~18.5M of ~860M total

Loss Function

Loss Weight Purpose
L1 0.5 Pixel-level accuracy
LPIPS 1.0 Perceptual similarity (VGG features)
Gram matrix 0.07 Texture and surface detail matching

LPIPS leads the loss balance to prioritize perceptual quality over pixel accuracy. This pushes the model toward sharper, more visually correct outputs rather than blurry averages.

Training Augmentation

Applied on-the-fly to the degraded input only (ground truth unchanged):

Augmentation Probability Details
Catastrophic blur blobs 10% 1-4 large blur patches simulating total geometry failure
Noise boost 15% Additional noise (strength 5-10) on top of existing degradation
Brightness shift 10% Random shift of +/- 10 values
Extra desaturation 10% Additional 5-15% colour washout

Training Configuration

Parameter Value
Dataset 12,000 CelebA-HQ images (512x512)
Test set 600 images (separate, unseen)
Optimizer AdamW
Learning rate 5e-5 with 500-step warmup, linear decay to 0
Weight decay 1e-4
Batch size 2
Precision FP16 mixed precision
Hardware NVIDIA L40S (48GB)

Scaling Experiments

Multiple training runs were tracked in Weights and Biases, systematically testing the effect of dataset size, LoRA rank, learning rate, and regularization.

Run Dataset LoRA Rank Best Val LPIPS
Baseline (original submission) 500 8/4 0.191
Scaled data 2,500 8/4 0.234
Higher rank 2,500 16/8 0.234
Lower weight decay 2,500 16/8 0.229
5k images 5,000 16/8 0.204
12k images 12,000 16/8 0.205

The original 500-image submission achieved strong per-image LPIPS through memorization but did not generalize. Scaling the dataset from 500 to 5,000 images improved generalization and reduced the LPIPS floor from 0.234 to 0.204, confirming that data diversity was the primary bottleneck.


Limitations

This is a proof of concept with known constraints.

Synthetic degradation only. The face-parsing pipeline approximates Gaussian splat artifacts based on observed patterns, but it is not calibrated to any specific rendering engine. Production use would require training on real degraded/clean render pairs (as described in the ELITE paper).

CelebA-HQ as proxy. The dataset is a stand-in for actual avatar renders. It skews toward certain demographics and lighting conditions. A production system would need training data that matches the target pipeline's output distribution.

Single frame. No temporal consistency across frames. Video enhancement would require additional constraints (optical flow, temporal losses) to prevent flickering.

Resolution. Fixed at 512x512. Higher resolutions would need either a tiled approach or architectural changes.


Next Steps

If this were developed further into a production system:

  • Real render pairs: Train on actual Gaussian splat output paired with high-quality reference captures
  • Temporal consistency: Add optical flow warping and temporal losses for video sequences
  • TensorRT optimization: Export to ONNX and optimize with TensorRT for real-time inference (~30ms per frame)
  • Discriminator: Reintroduce adversarial loss from pix2pix-turbo to push sharpness with a larger dataset

Project Structure

.
├── app.py                  # Gradio demo application
├── model/
│   ├── architecture.py     # SplatEnhancer model class
│   └── losses.py           # L1 + LPIPS + Gram matrix loss
├── train.py                # Training script
├── inference.py            # Evaluation and comparison generation
├── data/
│   └── prepare_data.py     # Face-parsing degradation pipeline
├── output/
│   └── checkpoints/        # Trained model weights
└── requirements.txt

Usage

Inference

python inference.py \
    --checkpoint output/checkpoints/model_17400.pkl \
    --dataset_folder data/ \
    --output_dir results/ \
    --mixed_precision

Training

python train.py \
    --dataset_folder data/ \
    --mixed_precision \
    --max_train_steps 20000 \
    --batch_size 2 \
    --lambda_l1 0.5 \
    --lambda_lpips 1.0 \
    --lambda_gram 0.07 \
    --learning_rate 5e-5 \
    --lora_rank_unet 16 \
    --lora_rank_vae 8 \
    --adam_weight_decay 1e-4

Gradio Demo

python app.py

Acknowledgements

Architecture adapted from pix2pix-turbo. Built on Stable Diffusion Turbo by Stability AI. Face parsing via Segformer. Dataset from CelebA-HQ.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages