# Exploiting the Signal-Leak Bias in Diffusion Models

[![arXiv](https://img.shields.io/badge/arXiv-2309.15842-red)](https://arxiv.org/abs/2309.15842)
[![Project Page](https://img.shields.io/badge/Project%20Page-IVRL-blue)](https://ivrl.github.io/signal-leak-bias/)
[![Proceedings](https://img.shields.io/badge/WACV%20Proceedings-CVF-blue)](https://openaccess.thecvf.com/content/WACV2024/html/Everaert_Exploiting_the_Signal-Leak_Bias_in_Diffusion_Models_WACV_2024_paper.html)


## Overview

This is the Colab version of the official implementation for our paper titled "[Exploiting the Signal-Leak Bias in Diffusion Models](https://ivrl.github.io/signal-leak-bias/)", presented at [WACV 2024](https://openaccess.thecvf.com/content/WACV2024/html/Everaert_Exploiting_the_Signal-Leak_Bias_in_Diffusion_Models_WACV_2024_paper.html) 🔥

Code is available at [https://github.com/IVRL/signal-leak-bias](https://github.com/IVRL/signal-leak-bias).

### 🔎 Research Highlights
- In the training of most diffusion models, data are never completely noised, creating a signal leakage and leading to discrepancies between training and inference processes.
- As a consequence of this signal leakage, the low-frequency / large-scale content of the generated images is mostly unchanged from the initial latents we start the generation process from, generating greyish images or images that do not match the desired style.
- Our research proposed to exploit this signal-leak bias at inference time to gain more control over generated images.
- We model the distribution of the signal leak present during training, to include a signal leak at inference time in the initial latents.
- ✨✨ No training required! ✨✨

### 📃 [Exploiting the Signal-Leak Bias in Diffusion Models](https://ivrl.github.io/signal-leak-bias/)

[Martin Nicolas Everaert](https://martin-ev.github.io/) <sup>1</sup>, [Athanasios Fitsios](https://www.linkedin.com/in/athanasiosfitsios/) <sup>1,2</sup>, [Marco Bocchio](https://scholar.google.com/citations?user=KDiTxBQAAAAJ) <sup>2</sup>, [Sami Arpa](https://scholar.google.com/citations?user=84FopNgAAAAJ) <sup>2</sup>, [Sabine Süsstrunk](https://scholar.google.com/citations?user=EX3OYP4AAAAJ) <sup>1</sup>, [Radhakrishna Achanta](https://scholar.google.com/citations?user=lc2HaZwAAAAJ) <sup>1</sup>

<sup>1</sup>[School of Computer and Communication Sciences, EPFL, Switzerland](https://www.epfl.ch/labs/ivrl/) ; <sup>2</sup>[Largo.ai, Lausanne, Switzerland](https://home.largo.ai/)

**Abstract**: There is a bias in the inference pipeline of most diffusion models. This bias arises from a signal leak whose distribution deviates from the noise distribution, creating a discrepancy between training and inference processes. We demonstrate that this signal-leak bias is particularly significant when models are tuned to a specific style, causing sub-optimal style matching. Recent research tries to avoid the signal leakage during training. We instead show how we can exploit this signal-leak bias in existing diffusion models to allow more control over the generated images. This enables us to generate images with more varied brightness, and images that better match a desired style or color. By modeling the distribution of the signal leak in the spatial frequency and pixel domains, and including a signal leak in the initial latent, we generate images that better match expected results without any additional training.


# License

The implementation here is provided solely as part of the research publication "[Exploiting the Signal-Leak Bias in Diffusion Models](https://ivrl.github.io/signal-leak-bias/)", only for academic non-commercial usage. Details can be found in the [LICENSE file](https://github.com/ivrl/signal-leak-bias/blob/main/LICENSE). If the License is not suitable for your business or project, please contact Largo.ai (info@largo.ai) and EPFL-TTO (info.tto@epfl.ch) for a full commercial license.


# Citation

Please cite the paper as follows:

```
@InProceedings{Everaert_2024_WACV,
      author   = {Everaert, Martin Nicolas and Fitsios, Athanasios and Bocchio, Marco and Arpa, Sami and Süsstrunk, Sabine and Achanta, Radhakrishna},
      title    = {{E}xploiting the {S}ignal-{L}eak {B}ias in {D}iffusion {M}odels},
      booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
      month     = {January},
      year      = {2024},
      pages     = {4025-4034}
}
```


## Getting Started

### Code and development environment

Our code mainly builds on top of the code of the [🤗 Diffusers](https://huggingface.co/docs/diffusers/index) library.

Clone this repository:

In [None]:
!git clone https://github.com/IVRL/signal-leak-bias
%cd signal-leak-bias/src

In [None]:
!rm -r examples #We will regenerate all the examples in this notebook :-)

Run the following command to install our dependencies:

In [None]:
!pip install diffusers==0.25.1
!pip install accelerate==0.26.1

Run the following command to download some images for the examples:

In [None]:
!GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/sd-dreambooth-library/nasa-space-v2-768
!git clone https://huggingface.co/sd-concepts-library/line-art
!wget https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
!unzip -q coco128

### Computing statistics of the signal leak

The provided Python file for computing statistics of the signal leak can be used, for example, as follows:

In [None]:
!python signal_leak.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1"  \
  --data_dir="coco128/images/train2017" \
  --output_dir="examples/C" \
  --resolution=768 \
  --n_components=3 \
  --statistic_type="dct+pixel" \
  --center_crop

### Inference

Once the statistics have been computed, you can use them to sample a signal-leak at inference time too, for instance as follows:

In [None]:
from signal_leak import sample_from_stats

signal_leak = sample_from_stats(path="examples/C")

Images can be generated with the sampled signal-leak in the initial latents, for instance as follows:

In [None]:
from diffusers import StableDiffusionPipeline
import torch

pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1").to("cuda")
num_inference_steps = 50

# Get the timestep T of the first reverse diffusion iteration
pipeline.scheduler.set_timesteps(num_inference_steps, device="cuda")
first_inference_timestep = pipeline.scheduler.timesteps[0].item()

# Get the values of sqrt(alpha_prod_T)
sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5
sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5

# Generate the initial latents, without signal leak
latents = torch.randn([1, 4, 96, 96])

# Add a signal leak in the initial latents
latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * latents

# Generate image
image = pipeline(
    prompt = "An astronaut riding a horse",
    num_inference_steps = num_inference_steps,
    latents = latents,
).images[0]
display(image)

In [None]:
del image, pipeline, latents,  signal_leak, first_inference_timestep, num_inference_steps, sqrt_alpha_prod, sqrt_one_minus_alpha_prod
import torch
torch.cuda.empty_cache()

## Examples



### Improving style-tuned models

Models tuned on specific styles often produce results that do not match the styles well (see the second column of the next two tables). We argue that this is because of a discrepancy between training (contains a signal leak whose distribution differs from unit/standard multivariate Gaussian) and inference (no signal leak). We fix this discrepancy by modelling the signal leak present during training and including a signal leak (see third column) at inference time too. We use a "pixel" model, that is we estimate the mean and variance of each pixel (spatial elements of the latent encodings).

In the 2 following examples, we show how to fix two models:
- [sd-dreambooth-library/nasa-space-v2-768](https://huggingface.co/sd-dreambooth-library/nasa-space-v2-768) is a model tuned with [DreamBooth](https://huggingface.co/docs/diffusers/en/training/dreambooth) (Ruiz et al., 2022) on 24 images of the sky.
- [sd-concepts-library/line-art](https://huggingface.co/sd-concepts-library/line-art) is an embedding for [Stable Diffusion v1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4) trained with [Textual Inversion](https://huggingface.co/docs/diffusers/training/text_inversion) (Gal et al, 2022) on 7 images with line-art style.



#### Example 1

In [None]:
!python signal_leak.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1" \
  --data_dir="nasa-space-v2-768/concept_images" \
  --output_dir="examples/A1/" \
  --resolution=768 \
  --statistic_type="pixel" \
  --center_crop

In [None]:
import os
import torch
from diffusers import StableDiffusionPipeline
from signal_leak import sample_from_stats

folder = "examples/A1/imgs"
path_stats = "examples/A1"

os.makedirs(folder, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pipeline
pipeline = StableDiffusionPipeline.from_pretrained(
    "sd-dreambooth-library/nasa-space-v2-768",
).to(device)
num_inference_steps = 50

# Get the timestep T of the first reverse diffusion iteration
pipeline.scheduler.set_timesteps(num_inference_steps, device="cuda")
first_inference_timestep = pipeline.scheduler.timesteps[0].item()

# Get the values of sqrt(alpha_prod_T)
sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5
sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5

# Dimensions of the latent space, with batch_size=1
shape_latents = [
    1,
    pipeline.unet.config.in_channels,
    pipeline.unet.config.sample_size,
    pipeline.unet.config.sample_size,
]

# Utility function to visualize initial latents / signal leak
def latents_to_pil(pipeline, latents, generator):
    decoded = pipeline.vae.decode(
        latents / pipeline.vae.config.scaling_factor,
        return_dict=False,
        generator=generator,
    )[0]
    image = pipeline.image_processor.postprocess(
        decoded,
        output_type="pil",
        do_denormalize=[True],
    )[0]
    return image

# Random number generator
generator = torch.Generator(device=device)
generator = generator.manual_seed(12345)

with torch.no_grad():
    for n in range(5):

        # Generate the initial latents
        initial_latents = torch.randn(
            shape_latents, generator=generator, device=device, dtype=torch.float32
        )
        latents_to_pil(pipeline, initial_latents, generator).save(f"{folder}/latents{n}.png")


        # Generate an image WITHOUT signal leak in the initial latents
        image = pipeline(
            prompt="A very dark picture of the sky, Nasa style",
            guidance_scale=1,
            num_inference_steps=num_inference_steps,
            latents=initial_latents,
        ).images[0]
        image.save(f"{folder}/original{n}.png")

        # Generate a signal leak from computed statistics
        signal_leak = sample_from_stats(
            path=path_stats,
            dims=shape_latents,
            generator_pt=generator,
            generator_np=None,
            device=device
        )
        latents_to_pil(pipeline, signal_leak, generator).save(f"{folder}/signal_leak{n}.png")

        # Add a signal leak in the initial latents
        initial_latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * initial_latents

        # Generate an image WITH signal leak in the initial latents
        image_with_signalleak = pipeline(
            prompt="A very dark picture of the sky, Nasa style",
            guidance_scale=1,
            num_inference_steps=num_inference_steps,
            latents=initial_latents,
        ).images[0]
        image_with_signalleak.save(f"{folder}/ours{n}.png")

In [None]:
del device, first_inference_timestep, generator, num_inference_steps, image, image_with_signalleak, initial_latents, pipeline, signal_leak, sqrt_alpha_prod, sqrt_one_minus_alpha_prod
import torch
torch.cuda.empty_cache()

In [None]:
from IPython.display import display, HTML
from PIL import Image
from io import BytesIO
import base64

def load_images():
    images = []
    for i in range(5):
        latents_path = f'examples/A1/imgs/latents{i}.png'
        original_path = f'examples/A1/imgs/original{i}.png'
        signal_leak_path = f'examples/A1/imgs/signal_leak{i}.png'
        ours_path = f'examples/A1/imgs/ours{i}.png'

        latents_image = Image.open(latents_path)
        original_image = Image.open(original_path)
        signal_leak_image = Image.open(signal_leak_path)
        ours_image = Image.open(ours_path)

        images.append([
            latents_image, original_image, signal_leak_image, ours_image
        ])

    return images

# Load images
loaded_images = load_images()

# Function to generate HTML code for the table
def generate_table(data, headers):
    table_code = "<table><tr>"

    # Add headers
    for header in headers:
        table_code += "<th>" + header + "</th>"
    table_code += "</tr>"

    # Add data rows
    for row in data:
        table_code += "<tr>"
        for cell in row:
            img_data = BytesIO()
            cell.save(img_data, format="PNG")
            img_data.seek(0)
            img_data_uri = f"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}"
            table_code += f'<td><img src="{img_data_uri}" width="200"/></td>'
        table_code += "</tr>"
    table_code += "</table>"

    return table_code

# Data for the table
table_headers = ["Initial latents", "Generated image (original)", "+ Signal Leak", "Generated image (ours)"]

table_data = []
for i in range(5):
    table_data.append([
        loaded_images[i][0], loaded_images[i][1], loaded_images[i][2], loaded_images[i][3]
    ])

# Display the table
display(HTML(generate_table(table_data, table_headers)))

#### Example 2

In [None]:
!python signal_leak.py \
  --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \
  --data_dir="line-art/concept_images" \
  --output_dir="examples/A2/" \
  --resolution=512 \
  --statistic_type="pixel" \
  --center_crop

In [None]:
import os
import torch
from diffusers import StableDiffusionPipeline
from signal_leak import sample_from_stats

folder = "examples/A2/imgs"
path_stats = "examples/A2"

os.makedirs(folder, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pipeline
pipeline = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
).to(device)
pipeline.load_textual_inversion(
    "sd-concepts-library/line-art",
)
num_inference_steps = 50

# Get the timestep T of the first reverse diffusion iteration
pipeline.scheduler.set_timesteps(num_inference_steps, device="cuda")
first_inference_timestep = pipeline.scheduler.timesteps[0].item()

# Get the values of sqrt(alpha_prod_T)
sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5
sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5

# Dimensions of the latent space, with batch_size=1
shape_latents = [
    1,
    pipeline.unet.config.in_channels,
    pipeline.unet.config.sample_size,
    pipeline.unet.config.sample_size,
]

# Utility function to visualize initial latents / signal leak
def latents_to_pil(pipeline, latents, generator):
    decoded = pipeline.vae.decode(
        latents / pipeline.vae.config.scaling_factor,
        return_dict=False,
        generator=generator,
    )[0]
    image = pipeline.image_processor.postprocess(
        decoded,
        output_type="pil",
        do_denormalize=[True],
    )[0]
    return image

# Random number generator
generator = torch.Generator(device=device)
generator = generator.manual_seed(12345)

with torch.no_grad():
    for n in range(5):

        # Generate the initial latents
        initial_latents = torch.randn(
            shape_latents, generator=generator, device=device, dtype=torch.float32
        )
        latents_to_pil(pipeline, initial_latents, generator).save(f"{folder}/latents{n}.png")


        # Generate an image WITHOUT signal leak in the initial latents
        image = pipeline(
            prompt="An astronaut riding a horse in the style of <line-art>",
            num_inference_steps=num_inference_steps,
            latents=initial_latents,
        ).images[0]
        image.save(f"{folder}/original{n}.png")

        # Generate a signal leak from computed statistics
        signal_leak = sample_from_stats(
            path=path_stats,
            dims=shape_latents,
            generator_pt=generator,
            generator_np=None,
            device=device
        )
        latents_to_pil(pipeline, signal_leak, generator).save(f"{folder}/signal_leak{n}.png")

        # Add a signal leak in the initial latents
        initial_latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * initial_latents

        # Generate an image WITH signal leak in the initial latents
        image_with_signalleak = pipeline(
            prompt="An astronaut riding a horse in the style of <line-art>",
            num_inference_steps=num_inference_steps,
            latents=initial_latents,
        ).images[0]
        image_with_signalleak.save(f"{folder}/ours{n}.png")

In [None]:
del device, first_inference_timestep, generator, num_inference_steps, image, image_with_signalleak, initial_latents, pipeline, signal_leak, sqrt_alpha_prod, sqrt_one_minus_alpha_prod
import torch
torch.cuda.empty_cache()

In [None]:
from IPython.display import display, HTML
from PIL import Image
from io import BytesIO
import base64

def load_images():
    images = []
    for i in range(5):
        latents_path = f'examples/A2/imgs/latents{i}.png'
        original_path = f'examples/A2/imgs/original{i}.png'
        signal_leak_path = f'examples/A2/imgs/signal_leak{i}.png'
        ours_path = f'examples/A2/imgs/ours{i}.png'

        latents_image = Image.open(latents_path)
        original_image = Image.open(original_path)
        signal_leak_image = Image.open(signal_leak_path)
        ours_image = Image.open(ours_path)

        images.append([
            latents_image, original_image, signal_leak_image, ours_image
        ])

    return images

# Load images
loaded_images = load_images()

# Function to generate HTML code for the table
def generate_table(data, headers):
    table_code = "<table><tr>"

    # Add headers
    for header in headers:
        table_code += "<th>" + header + "</th>"
    table_code += "</tr>"

    # Add data rows
    for row in data:
        table_code += "<tr>"
        for cell in row:
            img_data = BytesIO()
            cell.save(img_data, format="PNG")
            img_data.seek(0)
            img_data_uri = f"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}"
            table_code += f'<td><img src="{img_data_uri}" width="200"/></td>'
        table_code += "</tr>"
    table_code += "</table>"

    return table_code

# Data for the table
table_headers = ["Initial latents", "Generated image (original)", "+ Signal Leak", "Generated image (ours)"]

table_data = []
for i in range(5):
    table_data.append([
        loaded_images[i][0], loaded_images[i][1], loaded_images[i][2], loaded_images[i][3]
    ])

# Display the table
display(HTML(generate_table(table_data, table_headers)))


### Training-free style adaptation of Stable Diffusion

The same approach as the previous example can be used directly in the base diffusion model, instead of the model finetuned on a style. That is, we include a signal leak at inference time to bias the image generation towards the desired style.

Without our approach (see second column of the next two tables), the prompt alone is not sufficient enough to generate picture of the desired style. Complementing it with a signal leak of the style (third column) generates images (last column) that better match the desired output.



#### Example 1

In [None]:
import os
import torch
from diffusers import StableDiffusionPipeline
from signal_leak import sample_from_stats

folder = "examples/B1/imgs"
path_stats = "examples/A1"

os.makedirs(folder, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pipeline
pipeline = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1",
).to(device)
num_inference_steps = 50

# Get the timestep T of the first reverse diffusion iteration
pipeline.scheduler.set_timesteps(num_inference_steps, device="cuda")
first_inference_timestep = pipeline.scheduler.timesteps[0].item()

# Get the values of sqrt(alpha_prod_T)
sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5
sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5

# Dimensions of the latent space, with batch_size=1
shape_latents = [
    1,
    pipeline.unet.config.in_channels,
    pipeline.unet.config.sample_size,
    pipeline.unet.config.sample_size,
]

# Utility function to visualize initial latents / signal leak
def latents_to_pil(pipeline, latents, generator):
    decoded = pipeline.vae.decode(
        latents / pipeline.vae.config.scaling_factor,
        return_dict=False,
        generator=generator,
    )[0]
    image = pipeline.image_processor.postprocess(
        decoded,
        output_type="pil",
        do_denormalize=[True],
    )[0]
    return image

# Random number generator
generator = torch.Generator(device=device)
generator = generator.manual_seed(12345)

with torch.no_grad():
    for n in range(5):

        # Generate the initial latents
        initial_latents = torch.randn(
            shape_latents, generator=generator, device=device, dtype=torch.float32
        )
        latents_to_pil(pipeline, initial_latents, generator).save(f"{folder}/latents{n}.png")


        # Generate an image WITHOUT signal leak in the initial latents
        image = pipeline(
            prompt="A very dark picture of the sky, taken by the Nasa.",
            guidance_scale=1,
            num_inference_steps=num_inference_steps,
            latents=initial_latents,
        ).images[0]
        image.save(f"{folder}/original{n}.png")

        # Generate a signal leak from computed statistics
        signal_leak = sample_from_stats(
            path=path_stats,
            dims=shape_latents,
            generator_pt=generator,
            generator_np=None,
            device=device
        )
        latents_to_pil(pipeline, signal_leak, generator).save(f"{folder}/signal_leak{n}.png")

        # Add a signal leak in the initial latents
        initial_latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * initial_latents

        # Generate an image WITH signal leak in the initial latents
        image_with_signalleak = pipeline(
            prompt="A very dark picture of the sky, taken by the Nasa.",
            guidance_scale=1,
            num_inference_steps=num_inference_steps,
            latents=initial_latents,
        ).images[0]
        image_with_signalleak.save(f"{folder}/ours{n}.png")

In [None]:
del device, first_inference_timestep, generator, num_inference_steps, image, image_with_signalleak, initial_latents, pipeline, signal_leak, sqrt_alpha_prod, sqrt_one_minus_alpha_prod
import torch
torch.cuda.empty_cache()

In [None]:
from IPython.display import display, HTML
from PIL import Image
from io import BytesIO
import base64

def load_images():
    images = []
    for i in range(5):
        latents_path = f'examples/B1/imgs/latents{i}.png'
        original_path = f'examples/B1/imgs/original{i}.png'
        signal_leak_path = f'examples/B1/imgs/signal_leak{i}.png'
        ours_path = f'examples/B1/imgs/ours{i}.png'

        latents_image = Image.open(latents_path)
        original_image = Image.open(original_path)
        signal_leak_image = Image.open(signal_leak_path)
        ours_image = Image.open(ours_path)

        images.append([
            latents_image, original_image, signal_leak_image, ours_image
        ])

    return images

# Load images
loaded_images = load_images()

# Function to generate HTML code for the table
def generate_table(data, headers):
    table_code = "<table><tr>"

    # Add headers
    for header in headers:
        table_code += "<th>" + header + "</th>"
    table_code += "</tr>"

    # Add data rows
    for row in data:
        table_code += "<tr>"
        for cell in row:
            img_data = BytesIO()
            cell.save(img_data, format="PNG")
            img_data.seek(0)
            img_data_uri = f"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}"
            table_code += f'<td><img src="{img_data_uri}" width="200"/></td>'
        table_code += "</tr>"
    table_code += "</table>"

    return table_code

# Data for the table
table_headers = ["Initial latents", "Generated image (original)", "+ Signal Leak", "Generated image (ours)"]

table_data = []
for i in range(5):
    table_data.append([
        loaded_images[i][0], loaded_images[i][1], loaded_images[i][2], loaded_images[i][3]
    ])

# Display the table
display(HTML(generate_table(table_data, table_headers)))


#### Example 2

In [None]:
import os
import torch
from diffusers import StableDiffusionPipeline
from signal_leak import sample_from_stats

folder = "examples/B2/imgs"
path_stats = "examples/A2"

os.makedirs(folder, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pipeline
pipeline = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
).to(device)
num_inference_steps = 50

# Get the timestep T of the first reverse diffusion iteration
pipeline.scheduler.set_timesteps(num_inference_steps, device="cuda")
first_inference_timestep = pipeline.scheduler.timesteps[0].item()

# Get the values of sqrt(alpha_prod_T)
sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5
sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5

# Dimensions of the latent space, with batch_size=1
shape_latents = [
    1,
    pipeline.unet.config.in_channels,
    pipeline.unet.config.sample_size,
    pipeline.unet.config.sample_size,
]

# Utility function fo visualize initial latents / signal leak
def latents_to_pil(pipeline, latents, generator):
    decoded = pipeline.vae.decode(
        latents / pipeline.vae.config.scaling_factor,
        return_dict=False,
        generator=generator,
    )[0]
    image = pipeline.image_processor.postprocess(
        decoded,
        output_type="pil",
        do_denormalize=[True],
    )[0]
    return image

# Random number generator
generator = torch.Generator(device=device)
generator = generator.manual_seed(12345)

with torch.no_grad():
    for n in range(5):

        # Generate the initial latents
        initial_latents = torch.randn(
            shape_latents, generator=generator, device=device, dtype=torch.float32
        )
        latents_to_pil(pipeline, initial_latents, generator).save(f"{folder}/latents{n}.png")


        # Generate an image WITHOUT signal leak in the initial latents
        image = pipeline(
            prompt="An astronaut riding a horse, in the style of line art, pastel colors.",
            num_inference_steps=num_inference_steps,
            latents=initial_latents,
        ).images[0]
        image.save(f"{folder}/original{n}.png")

        # Generate a signal leak from computed statistics
        signal_leak = sample_from_stats(
            path=path_stats,
            dims=shape_latents,
            generator_pt=generator,
            generator_np=None,
            device=device
        )
        latents_to_pil(pipeline, signal_leak, generator).save(f"{folder}/signal_leak{n}.png")

        # Add a signal leak in the initial latents
        initial_latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * initial_latents

        # Generate an image WITH signal leak in the initial latents
        image_with_signalleak = pipeline(
            prompt="An astronaut riding a horse, in the style of line art, pastel colors.",
            num_inference_steps=num_inference_steps,
            latents=initial_latents,
        ).images[0]
        image_with_signalleak.save(f"{folder}/ours{n}.png")

del pipeline

In [None]:
from IPython.display import display, HTML
from PIL import Image
from io import BytesIO
import base64

def load_images():
    images = []
    for i in range(5):
        latents_path = f'examples/B2/imgs/latents{i}.png'
        original_path = f'examples/B2/imgs/original{i}.png'
        signal_leak_path = f'examples/B2/imgs/signal_leak{i}.png'
        ours_path = f'examples/B2/imgs/ours{i}.png'

        latents_image = Image.open(latents_path)
        original_image = Image.open(original_path)
        signal_leak_image = Image.open(signal_leak_path)
        ours_image = Image.open(ours_path)

        images.append([
            latents_image, original_image, signal_leak_image, ours_image
        ])

    return images

# Load images
loaded_images = load_images()

# Function to generate HTML code for the table
def generate_table(data, headers):
    table_code = "<table><tr>"

    # Add headers
    for header in headers:
        table_code += "<th>" + header + "</th>"
    table_code += "</tr>"

    # Add data rows
    for row in data:
        table_code += "<tr>"
        for cell in row:
            img_data = BytesIO()
            cell.save(img_data, format="PNG")
            img_data.seek(0)
            img_data_uri = f"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}"
            table_code += f'<td><img src="{img_data_uri}" width="200"/></td>'
        table_code += "</tr>"
    table_code += "</table>"

    return table_code

# Data for the table
table_headers = ["Initial latents", "Generated image (original)", "+ Signal Leak", "Generated image (ours)"]

table_data = []
for i in range(5):
    table_data.append([
        loaded_images[i][0], loaded_images[i][1], loaded_images[i][2], loaded_images[i][3]
    ])

# Display the table
display(HTML(generate_table(table_data, table_headers)))

  
### More diverse generated images

In the previous examples, the signal leak is modelled with a "pixel" model, realigning the training and inference distributions for stylized images. For *natural* images, the disrepency between training and inference distribution mostly lies in the frequency components: noised images during training still retain the low-frequency contents (large-scale patterns, main colors) of the original images, while the initial latents during inference always have medium low-frequency contents (e.g. *greyish* average color). Compared to the examples above, we then additionnaly model the low-frequency content of the signal leak, using a small set of natural images.

In the next examples, we will use [this set of 128 images](https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip) from [COCO](https://cocodataset.org/)


In [None]:
!rm -r examples/C
!python signal_leak.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1" \
  --data_dir="coco128/images/train2017" \
  --output_dir="examples/C/" \
  --resolution=768 \
  --n_components=3 \
  --statistic_type="dct+pixel" \
  --center_crop

In [None]:
import os
import torch
import numpy as np
from diffusers import StableDiffusionPipeline
from signal_leak import sample_from_stats

folder = "examples/C/imgs"
path_stats = "examples/C"

os.makedirs(folder, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pipeline
pipeline = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1",
).to(device)
num_inference_steps = 50

# Get the timestep T of the first reverse diffusion iteration
pipeline.scheduler.set_timesteps(num_inference_steps, device="cuda")
first_inference_timestep = pipeline.scheduler.timesteps[0].item()

# Get the values of sqrt(alpha_prod_T)
sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5
sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5

# Dimensions of the latent space, with batch_size=1
shape_latents = [
    1,
    pipeline.unet.config.in_channels,
    pipeline.unet.config.sample_size,
    pipeline.unet.config.sample_size,
]

# Utility function to visualize initial latents / signal leak
def latents_to_pil(pipeline, latents, generator):
    decoded = pipeline.vae.decode(
        latents / pipeline.vae.config.scaling_factor,
        return_dict=False,
        generator=generator,
    )[0]
    image = pipeline.image_processor.postprocess(
        decoded,
        output_type="pil",
        do_denormalize=[True],
    )[0]
    return image

# Random number generator
generator = torch.Generator(device=device)
generator = generator.manual_seed(12345)
generator_np = np.random.default_rng(seed=654321)

with torch.no_grad():
    for n in range(5):

        # Generate the initial latents
        initial_latents = torch.randn(
            shape_latents, generator=generator, device=device, dtype=torch.float32
        )
        latents_to_pil(pipeline, initial_latents, generator).save(f"{folder}/latents{n}.png")


        # Generate an image WITHOUT signal leak in the initial latents
        image = pipeline(
            prompt="An astronaut riding a horse",
            num_inference_steps=num_inference_steps,
            latents=initial_latents,
        ).images[0]
        image.save(f"{folder}/original{n}.png")

        # Generate a signal leak from computed statistics
        signal_leak = sample_from_stats(
            path=path_stats,
            dims=shape_latents,
            generator_pt=generator,
            generator_np=generator_np,
            device=device
        )
        latents_to_pil(pipeline, signal_leak, generator).save(f"{folder}/signal_leak{n}.png")

        # Add a signal leak in the initial latents
        initial_latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * initial_latents

        # Generate an image WITH signal leak in the initial latents
        image_with_signalleak = pipeline(
            prompt="An astronaut riding a horse",
            num_inference_steps=num_inference_steps,
            latents=initial_latents,
        ).images[0]
        image_with_signalleak.save(f"{folder}/ours{n}.png")


del pipeline

In [None]:
from IPython.display import display, HTML
from PIL import Image
from io import BytesIO
import base64

def load_images():
    images = []
    for i in range(5):
        latents_path = f'examples/C/imgs/latents{i}.png'
        original_path = f'examples/C/imgs/original{i}.png'
        signal_leak_path = f'examples/C/imgs/signal_leak{i}.png'
        ours_path = f'examples/C/imgs/ours{i}.png'

        latents_image = Image.open(latents_path)
        original_image = Image.open(original_path)
        signal_leak_image = Image.open(signal_leak_path)
        ours_image = Image.open(ours_path)

        images.append([
            latents_image, original_image, signal_leak_image, ours_image
        ])

    return images

# Load images
loaded_images = load_images()

# Function to generate HTML code for the table
def generate_table(data, headers):
    table_code = "<table><tr>"

    # Add headers
    for header in headers:
        table_code += "<th>" + header + "</th>"
    table_code += "</tr>"

    # Add data rows
    for row in data:
        table_code += "<tr>"
        for cell in row:
            img_data = BytesIO()
            cell.save(img_data, format="PNG")
            img_data.seek(0)
            img_data_uri = f"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}"
            table_code += f'<td><img src="{img_data_uri}" width="200"/></td>'
        table_code += "</tr>"
    table_code += "</table>"

    return table_code

# Data for the table
table_headers = ["Initial latents", "Generated image (original)", "+ Signal Leak", "Generated image (ours)"]

table_data = []
for i in range(5):
    table_data.append([
        loaded_images[i][0], loaded_images[i][1], loaded_images[i][2], loaded_images[i][3]
    ])

# Display the table
display(HTML(generate_table(table_data, table_headers)))

  
### Control on the average color

In the previous example, the signal leak given at inference time is sampled randomly from the statistics of the signal leak  present at training time. Instead, it is also possible to *manually* set its low-frequency components, providing control on the low-frequency content of the generated image, as we show in the following example.

In [None]:
!python signal_leak.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1" \
  --data_dir="coco128/images/train2017" \
  --output_dir="examples/D/" \
  --resolution=768 \
  --n_components=1 \
  --statistic_type="dct+pixel" \
  --center_crop

In [None]:
import os
import torch
import numpy as np
from diffusers import StableDiffusionPipeline
from signal_leak import sample_from_stats

folder = "examples/D/imgs"
path_stats = "examples/D"

os.makedirs(folder, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pipeline
pipeline = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1",
).to(device)
num_inference_steps = 50

# Get the timestep T of the first reverse diffusion iteration
pipeline.scheduler.set_timesteps(num_inference_steps, device="cuda")
first_inference_timestep = pipeline.scheduler.timesteps[0].item()

# Get the values of sqrt(alpha_prod_T)
sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5
sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5

# Dimensions of the latent space, with batch_size=1
shape_latents = [
    1,
    pipeline.unet.config.in_channels,
    pipeline.unet.config.sample_size,
    pipeline.unet.config.sample_size,
]

# Utility function to visualize initial latents / signal leak
def latents_to_pil(pipeline, latents, generator):
    decoded = pipeline.vae.decode(
        latents / pipeline.vae.config.scaling_factor,
        return_dict=False,
        generator=generator,
    )[0]
    image = pipeline.image_processor.postprocess(
        decoded,
        output_type="pil",
        do_denormalize=[True],
    )[0]
    return image

# Random number generator
generator = torch.Generator(device=device)
generator = generator.manual_seed(12345)


# Generate the initial latents WITHOUT signal-leak
shape_latents = [
    1,
    pipeline.unet.config.in_channels,
    pipeline.unet.config.sample_size,
    pipeline.unet.config.sample_size,
]
initial_latents_without_signalleak = torch.randn(
    shape_latents, generator=generator, device=device, dtype=torch.float32
)

with torch.no_grad():
    for channel in range(4):
        for value in (-2, -1, 0, 1, 2):

            # Reset the seed, so that the only difference between the different initial latents is the LF components
            generator = generator.manual_seed(123456)
            generator_np = np.random.default_rng(seed=654321)

            # Generate the initial latents with signal leak
            signal_leak = sample_from_stats(
                path=path_stats,
                dims=shape_latents,
                generator_pt=generator,
                generator_np=generator_np,
                device=device,
                only_hf=True
            )
            signal_leak[:, channel, :, :] += value

            initial_latents = (
                sqrt_alpha_prod * signal_leak
                + sqrt_one_minus_alpha_prod * initial_latents_without_signalleak
            )
            # Generate an image
            image_with_signalleak = pipeline(
                prompt="An astronaut riding a horse",
                num_inference_steps=num_inference_steps,
                latents=initial_latents,
            ).images[0]
            image_with_signalleak.save(f"{folder}/{channel}_{value}.png")


In [None]:
from IPython.display import display, HTML
from PIL import Image
from io import BytesIO
import base64

def load_images():
    images = []
    for channel in range(4):

        images.append([
            Image.open(f'examples/D/imgs/{channel}_{value}.png') for value in (-2, -1, 0, 1, 2)
        ])

    return images

# Load images
loaded_images = load_images()

# Function to generate HTML code for the table
def generate_table(data, headers):
    table_code = "<table><tr>"

    # Add headers
    for header in headers:
        table_code += "<th>" + header + "</th>"
    table_code += "</tr>"

    # Add data rows
    for row in data:
        table_code += "<tr>"
        for cell in row:
            img_data = BytesIO()
            cell.save(img_data, format="PNG")
            img_data.seek(0)
            img_data_uri = f"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}"
            table_code += f'<td><img src="{img_data_uri}" width="200"/></td>'
        table_code += "</tr>"
    table_code += "</table>"

    return table_code

# Data for the table
table_headers = ["-2", "-1", "0", "1", "2"]

# Display the table
display(HTML(generate_table(loaded_images, table_headers)))