# StyleAligned: Zero-Shot Style Alignment among a Series of Generated Images via Attention Sharing

### **Authors**: ***Borgi Alessio***, ***Danese Francesco***

### **Abstract**
In this notebook we aim to reproduce and enhance **[StyleAligned](https://arxiv.org/abs/2312.02133)**, a novel technique introduced by **Google Research**, for achieving **Style Consistency** in large-scale Text-to-Image (T2I) generative models. While current T2I models excel in creating visually compelling images from textual descriptions, they often struggle to maintain a consistent style across multiple images. Traditional methods to address this require extensive fine-tuning and manual intervention. 

**StyleAligned** addresses this challenge by introducing minimal **Attention Sharing** during the **Diffusion Process**, ensuring **Style Alignment among generated images** without the need for optimization or fine-tuning (**Zero-Shoot Inference**). The method operates by leveraging a straightforward inversion operation to apply a reference style across various generated images, maintaining high-quality synthesis and fidelity to the provided text prompts.

### 0: SETTINGS & IMPORTS

#### 0.1: CLONE REPOSITORY AND GIT SETUP

In the following cell, we setup the code, by cloning the repository, setting up the Git configurations, and providing some other useful commands useful for git.  

In [None]:
# Clone the repository
!git clone https://github.com/alessioborgi/StyleAlignedDiffModels.git

# Change directory to the cloned repository
%cd StyleAlignedDiffModels
%ls

# Set up Git configuration
!git config --global user.name "Alessio Borgi"
!git config --global user.email "alessioborgi3@gmail.com"

# Stage the changes
#!git add .

# Commit the changes
#!git commit -m "Added some content to your-file.txt"

# Push the changes (replace 'your-token' with your actual personal access token)
#!git push origin main

#### 0.2: INSTALL AND IMPORT REQUIRED LIBRARIES

We proceed then by installing and importing the required libraries.

In [None]:
# Install the required packages
!pip install -r requirements.txt > /dev/null

In [None]:
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
import torch
import mediapy
import sa_handler
from diffusers.utils import load_image
import inversion
import numpy as np

### 4: DDIM \& PIPELINE DEFINITION
We then proceed to load the **SDXL (Stable Diffusion XL)** Model and configure the **DDIM (Denoising Diffusion Implicit Models) Scheduler**. We then configure the **Pipeline**. 

#### 4.1: DDIM SCHEDULER

The **DDIM Scheduler** is the component used in diffusion models for generating high-quality samples from noise. It controls the denoising process by defining a schedule for adding and removing noise to and from the data. The scheduler is essential in determining how the model transitions from pure noise to a final, coherent image or other data form.

In particular, its parameters are:
- **beta_start (float)**: Starting value of beta, the variance of the noise schedule. 
- **beta_end (float)**: Ending value of beta, the variance of the noise schedule. 
- **beta_schedule (str)**: The type of schedule for beta. (Possible values: "linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"). 
- **clip_sample (bool)**: If True, the samples are clipped to [-1, 1]. 
- **set_alpha_to_one (bool)**: If True, alpha will be set to 1 at the end of the sampling process.
- **num_train_timesteps (int)**: The number of diffusion steps used during training. 
- **timestep_spacing (str)**: The method to space out timesteps.(Possible values: "linspace", "leading"). 
- **prediction_type (str)**: The type of prediction model used in the scheduler. (Possible values: "epsilon", "sample", "v-prediction").
- **trained_betas (torch.Tensor or None)**: Optional tensor of pre-trained betas to use in the scheduler. 

##### 4.1.1: Diffusion Process

The diffusion process involves adding noise to the data over a series of timesteps, which is described by the forward process:

$$ q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{\alpha_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}) $$

where:
- $\alpha_t$ and $\beta_t$ are the scaling and noise variance terms, respectively.

##### 4.1.2: Reverse Process

The reverse process aims to recover the data by denoising it, and is given by:

$$ p_{\theta}(\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mu_{\theta}(\mathbf{x}_t, t), \sigma_t^2 \mathbf{I}) $$

where:
- $\mu_{\theta}(\mathbf{x}_t, t)$ is the predicted mean.
- $\sigma_t$ is the standard deviation of the noise at timestep $t$.

##### 4.1.3: Beta Schedule

The beta values are scheduled over timesteps from `beta_start` to `beta_end`. The schedule can be:
- **Linear**: 

$$ \beta_t = \beta_{\text{start}} + t \frac{\beta_{\text{end}} - \beta_{\text{start}}}{T} $$

- **Scaled Linear**:

$$ \beta_t = \beta_{\text{start}} + t \left(\frac{\beta_{\text{end}} - \beta_{\text{start}}}{T}\right)^2 $$

- **Sigmoid**:

$$ \beta_t = \beta_{\text{start}} + (\beta_{\text{end}} - \beta_{\text{start}}) \cdot \text{sigmoid}(t) $$

- **Squared Cosine (squaredcos\_cap\_v2)**:

$$ \beta_t = \beta_{\text{start}} + 0.5 \left(1 - \cos\left(\frac{t \pi}{T}\right)\right) (\beta_{\text{end}} - \beta_{\text{start}}) $$

##### 4.1.4: Inference with DDIM

During inference, the denoising process can be described as:

$$ \mathbf{x}_{t-1} = \sqrt{\alpha_{t-1}} \left( \frac{\mathbf{x}_t - \sqrt{1 - \alpha_t} \mathbf{\epsilon}_{\theta}(\mathbf{x}_t, t)}{\sqrt{\alpha_t}} \right) + \sqrt{1 - \alpha_{t-1} - \sigma_t^2} \mathbf{\epsilon}_{\theta}(\mathbf{x}_t, t) $$

where:
- $\mathbf{\epsilon}_{\theta}(\mathbf{x}_t, t)$ is the noise predicted by the model.
- $\sigma_t$ is the standard deviation for the timestep $t$.

In [None]:
scheduler_linear = DDIMScheduler(
    beta_start=0.00085,                 # Starting value of beta
    beta_end=0.012,                     # Ending value of beta
    beta_schedule="scaled_linear",      # Type of schedule for beta values
    clip_sample=False,                  # Whether to clip samples to a specified range
    set_alpha_to_one=False,             # Whether to set alpha to one at the end of the process
    
    num_train_timesteps=1000,           # Number of diffusion steps used during training
    timestep_spacing="linspace",        # Method to space out timesteps
    prediction_type="epsilon",          # Type of prediction model used in the scheduler
    trained_betas=None                  # Optional pre-trained beta values
)

scheduler = scheduler_linear

### 4.2: SDXL PIPELINE DEFINITION

We then proceed to **load** the **pre-trained `StableDiffusionXLPipeline` model** with specific configurations to optimize for GPU memory usage and ensure efficient processing. Below is a breakdown of each parameter and its purpose:

- **pretrained_model_name_or_path**: The name or path of the pre-trained model to be loaded. In this example, we use `"stabilityai/stable-diffusion-xl-base-1.0"`, which is a pre-trained model available in the Stability AI repository.
- **torch_dtype**: Specifies the data type for the model's tensors. Here, `torch.float16` is used to enable mixed precision, which helps reduce memory usage and improve computation speed.
- **variant**: Indicates the model variant. `"fp16"` is used to specify 16-bit floating point precision, aligning with the `torch_dtype` parameter.
- **use_safetensors**: Determines whether to use the `safetensors` library for safe tensor loading. Setting this to `True` ensures safer model loading.
- **scheduler**: An instance of the scheduler to be used for the diffusion process. In this example, we use a `DDIMScheduler` instance configured for efficient sampling.
- **revision**: Specifies the model version to use. The default value is `None`, which means the latest version will be used.
- **use_auth_token**: The authentication token used for accessing private models. The default value is `None`, meaning no authentication is required.
- **cache_dir**: The directory where the downloaded model will be cached. The default value is `None`, which uses the default cache directory.
- **force_download**: Forces the model to be downloaded even if it exists locally. The default value is `False`.
- **resume_download**: Resumes a partial download if available. The default value is `False`.
- **proxies**: A dictionary of proxy servers to use. The default value is `None`, meaning no proxies are used.
- **local_files_only**: Uses only local files if set to `True`. The default value is `False`.
- **device_map**: Specifies device placement for model layers. The default value is `None`, which uses the default placement.
- **max_memory**: Specifies the maximum memory allowed for each device. The default value is `None`, meaning no specific memory limit is set.

Finally, the model is moved to the GPU for faster computations using `.to("cuda")`.

The use of mixed precision (`torch_dtype=torch.float16` and `variant="fp16"`) helps in reducing memory usage and improving performance. This configuration is particularly useful when working with large models and limited GPU memory.

In [None]:
SDXL_Pipeline = StableDiffusionXLPipeline.from_pretrained(
    pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0",  # The model name or path
    torch_dtype=torch.float16,            # Data type for the model's tensors
    variant="fp16",                       # Model variant for 16-bit floating point precision (Mixed Precision)
    use_safetensors=True,                 # Use the safetensors library for safe tensor loading
    scheduler=scheduler,                  # Scheduler instance for the diffusion process
    
    revision=None,                        # Model version to use, default is None
    use_auth_token=None,                  # Authentication token, None means no authentication
    cache_dir=None,                       # Directory to cache the downloaded model, None uses default
    force_download=False,                 # Force download even if the model exists locally
    resume_download=False,                # Resume a partial download if available
    proxies=None,                         # Dictionary of proxy servers to use, None means no proxies
    local_files_only=False,               # Use only local files if set to True
    device_map=None,                      # Device placement for model layers, None uses default placement
    max_memory=None                       # Maximum memory allowed for each device, None means no specific limit
).to("cuda")                              # Move the model to the GPU for faster computations

In [None]:
handler = sa_handler.Handler(SDXL_Pipeline)
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
                                      share_layer_norm=False,
                                      share_attention=True,
                                      adain_queries=True,
                                      adain_keys=True,
                                      adain_values=False
                                     )

handler.register(sa_args)

### 5: RUNNING STYLE-ALIGNED with A SET OF PROMPTS WITHOUT REFERENCE IMAGE

TO RUN IF YOU HAVE ENOUGH GPU RAM 

In [None]:
# run StyleAligned

sets_of_prompts = [
  "a toy train. macro photo. 3d game asset",
  "a toy airplane. macro photo. 3d game asset",
  "a toy bicycle. macro photo. 3d game asset",
  "a toy car. macro photo. 3d game asset",
  "a toy boat. macro photo. 3d game asset",
]
images = SDXL_Pipeline(sets_of_prompts,).images
mediapy.show_images(images)

TO RUN IF YOU HAVEN'T ENOUGH GPU RAM 

In [None]:
# run StyleAligned
sets_of_prompts = [
  "a toy train. macro photo. 3d game asset",
  "a toy airplane. macro photo. 3d game asset",
  "a toy bicycle. macro photo. 3d game asset",
  "a toy car. macro photo. 3d game asset",
  "a toy boat. macro photo. 3d game asset",
]
# sets_of_prompts = [
#   "a hot hair balloon, simple wooden statue",
#   "a friendly robot, simple wooden statue",
#   "a bull, simple wooden statue",
# ]
images = []
for prompt in sets_of_prompts:
    # Generate image for each prompt individually
    image = SDXL_Pipeline([prompt]).images[0]
    images.append(image)
    # Clear CUDA cache to free memory
    torch.cuda.empty_cache()
    
    # Print Memory summary
    # print(torch.cuda.memory_summary(device=None, abbreviated=False))
    
mediapy.show_images(images)

### 6: STYLE-ALIGNED WITH REFERENCE IMAGE 

#### 6.1: LOADING AND INVERTING REFERENCE IMAGE  
Load a reference image and perform the inversion process to extract latent representations.

In [None]:
src_style = "medieval painting"
src_prompt = f'Man laying in a bed, {src_style}.'
image_path = './imgs/medieval-bed.jpeg'

num_inference_steps = 50
x0 = np.array(load_image(image_path).resize((1024, 1024)))
zts = inversion.ddim_inversion(SDXL_Pipeline, x0, src_prompt, num_inference_steps, 2)
mediapy.show_image(x0, title="innput reference image", height=256)

In [None]:
prompts = [
    src_prompt,
    "A man working on a laptop",
    "A man eats pizza",
    "A woman playig on saxophone",
]

# some parameters you can adjust to control fidelity to reference
shared_score_shift = np.log(2)  # higher value induces higher fidelity, set 0 for no shift
shared_score_scale = 1.0  # higher value induces higher, set 1 for no rescale

# for very famouse images consider supressing attention to reference, here is a configuration example:
# shared_score_shift = np.log(1)
# shared_score_scale = 0.5

for i in range(1, len(prompts)):
    prompts[i] = f'{prompts[i]}, {src_style}.'

handler = sa_handler.Handler(pipeline)
sa_args = sa_handler.StyleAlignedArgs(
    share_group_norm=True, 
    share_layer_norm=True, 
    share_attention=True,
    adain_queries=True, 
    adain_keys=True, 
    adain_values=False,
    shared_score_shift=shared_score_shift, 
    shared_score_scale=shared_score_scale)
handler.register(sa_args)

zT, inversion_callback = inversion.make_inversion_callback(zts, offset=5)

g_cpu = torch.Generator(device='cpu')
g_cpu.manual_seed(10)

latents = torch.randn(len(prompts), 4, 128, 128, device='cpu', generator=g_cpu,
                      dtype=pipeline.unet.dtype,).to('cuda:0')
latents[0] = zT

images_a = pipeline(prompts, latents=latents,
                    callback_on_step_end=inversion_callback,
                    num_inference_steps=num_inference_steps, guidance_scale=10.0).images

handler.remove()
mediapy.show_images(images_a, titles=[p[:-(len(src_style) + 3)] for p in prompts])