# StyleAligned: Zero-Shot Style Alignment among a Series of Generated Images via Attention Sharing

### **Authors**: ***Borgi Alessio***, ***Danese Francesco***

### **Abstract**
In this notebook we aim to reproduce and enhance **[StyleAligned](https://arxiv.org/abs/2312.02133)**, a novel technique introduced by **Google Research**, for achieving **Style Consistency** in large-scale Text-to-Image (T2I) generative models. While current T2I models excel in creating visually compelling images from textual descriptions, they often struggle to maintain a consistent style across multiple images. Traditional methods to address this require extensive fine-tuning and manual intervention.

**StyleAligned** addresses this challenge by introducing minimal **Attention Sharing** during the **Diffusion Process**, ensuring **Style Alignment among generated images** without the need for optimization or fine-tuning (**Zero-Shoot Inference**). The method operates by leveraging a straightforward inversion operation to apply a reference style across various generated images, maintaining high-quality synthesis and fidelity to the provided text prompts.

### 0: SETTINGS & IMPORTS

#### 0.1: CLONE REPOSITORY AND GIT SETUP

In the following cell, we setup the code, by cloning the repository, setting up the Git configurations, and providing some other useful commands useful for git.  

In [None]:
# Clone the repository
!git clone https://github.com/alessioborgi/StyleAlignedDiffModels.git

# Change directory to the cloned repository
%cd StyleAlignedDiffModels
%ls

# Set up Git configuration
!git config --global user.name "Alessio Borgi"
!git config --global user.email "alessioborgi3@gmail.com"

# Stage the changes
#!git add .

# Commit the changes
#!git commit -m "Added some content to your-file.txt"

# Push the changes (replace 'your-token' with your actual personal access token)
#!git push origin main

#### 0.2: INSTALL AND IMPORT REQUIRED LIBRARIES

We proceed then by installing and importing the required libraries.

In [None]:
# Install the required packages
!pip install -r requirements.txt > /dev/null

In [None]:
import torch
import einops
import mediapy
import inversion
import sa_handler
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from typing import Callable
from dataclasses import dataclass
from __future__ import annotations
from diffusers.utils import load_image
from torch.nn import functional as nnf
from diffusers.models import attention_processor
from diffusers import StableDiffusionXLPipeline, DDIMScheduler

### 1: UTILS IMPLEMENTATION

#### 1.1: ADAIN MODULE

The **[Adaptive Instance Normalization (AdaIN)](https://arxiv.org/abs/1703.06868)** module is essential for **StyleAligned**. This works by first computing the mean and standard deviation of the input feature tensor $x$, independently for each feature map. The mean $\mu_x$ and standard deviation $\sigma_x$ are calculated as: $$\mu_x = \frac{1}{HW} \sum_{i=1}^{H} \sum_{j=1}^{W} x_{ij}$$ and $$\sigma_x = \sqrt{\frac{1}{HW} \sum_{i=1}^{H} \sum_{j=1}^{W} (x_{ij} - \mu_x)^2 + \epsilon}$$, where $H$ and $W$ are the height and width of the feature map, respectively, and $\epsilon$ is a small constant for numerical stability. These statistics are then matched to those of the style features $y$ by normalizing the input features $x$ and then scaling and shifting them using the style's mean $\mu_y$ and standard deviation $\sigma_y$.

The transformed feature tensor is given by: $$\text{AdaIN}(x, y) = \sigma_y \left( \frac{x - \mu_x}{\sigma_x} \right) + \mu_y$$

AdaIN receives a content input $x$ and a style input $y$, and simply aligns the channel wise mean and variance of $x$ to match those of $y$.
This process enables the content to adopt the style's statistical properties, facilitating effective style transfer, adding almost no computational cost.

In the StyleAligned project, instead of applying this normalization on convolutional-extracted feature maps, we embed it in the self attention layer: the AdaIN module is utilized to normalize the Queries $Q_t$ and Keys $K_t$ of the target image using the Queries $Q_r$ and Keys $K_r$ of the reference image:

$$\hat Q_t = \text{AdaIN}(Q_t, Q_r) \;\;\;\;\;\;\;\;\;\;\;\;\;\; \hat K_t = \text{AdaIN}(K_t, K_r)$$


In [None]:
#@title Ignore
'''
feat = torch.randn(16, 10, 1, 32)
b = feat.shape[0]
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
print(feat_style.shape)
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
print(feat_style.shape)
feat_style.reshape(*feat.shape).shape
'''

torch.Size([2, 1, 10, 1, 32])
torch.Size([2, 8, 10, 1, 32])


torch.Size([16, 10, 1, 32])

In [None]:
T = torch.tensor # Create Alias for torch.tensor to increase readability

def concat_first(feat: T, dim=2) -> T:
    feat_style = expand_first(feat)
    return torch.cat((feat, feat_style), dim=dim)

def expand_first(feat: T) -> T: # this takes mean and std -> input shape: (batch, heads, 1, channels), see below
    b = feat.shape[0] # Extract batch size
    feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) # shape: (2, 1, heads, 1, channels), stack the mean (or std) of first and middle images in the batch
    feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) # repeat the mean or std batch/2 times (since we are considering 2 stats)
    return feat_style.reshape(*feat.shape) # reshape so that first half of batch has assigned the mean/std of the first img, second half of the middle image

def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]:  # computes mean and std along number of tokens dimension
    feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
    feat_mean = feat.mean(dim=-2, keepdims=True)
    return feat_mean, feat_std # output shape: (batch, heads, 1, channels)

def adain(feat: T) -> T: # Input shape: (Batch, Heads, #Tokens, Channels), #Tokens is number of "pixels" in the feature map, channels = token_dim
    feat_mean, feat_std = calc_mean_std(feat)
    feat_style_mean = expand_first(feat_mean)
    feat_style_std = expand_first(feat_std)
    feat = (feat - feat_mean) / feat_std  # normalize the feature map
    feat = feat * feat_style_std + feat_style_mean  # scale and shift the feature map (reparameterization with reference stats)
    return feat


In [None]:
import torch
# create an example feature tensor of shape (Batch, h, t, c)
q_test = torch.randn(16, 10, 256, 32)
query_after_adain = adain(q_test)
query_after_adain.shape

torch.Size([16, 10, 256, 32])

In [None]:
from diffusers.models import attention_processor

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [None]:
attn = attention_processor.Attention(query_dim=32, dim_head = 64, heads = 8)
# both Q and K and V are projected from the input hidden states to a dimension that matches heads * dim_head.
# example: Projected from (Batch, tokens, 32) to (Batch, tokens, 512), and then reshaped to (Batch, tokens, 8 = heads, 64).

In [None]:
attn

Attention(
  (to_q): Linear(in_features=32, out_features=512, bias=False)
  (to_k): Linear(in_features=32, out_features=512, bias=False)
  (to_v): Linear(in_features=32, out_features=512, bias=False)
  (to_out): ModuleList(
    (0): Linear(in_features=512, out_features=32, bias=True)
    (1): Dropout(p=0.0, inplace=False)
  )
)

In [None]:
# if h_s comes form previous conv layer, it has the shape (batch, channels, height, width)
# therefore we reshape it to have shape (batch, height * width, channels) creating the "tokens"
# ready to be turned into queries, keys and values and be fed to the self-attention mechanism
hidden_states = torch.randn(16, 256, 32) # shape (batch, #tokens = w*h, channels)
query = attn.to_q(hidden_states) # to_q, to_k and to_v linearly projects from "channels" to "dim_head * heads" ...
key = attn.to_k(hidden_states) # ... we'll need to reshape them back after to "divide" the heads
value = attn.to_v(hidden_states)

In [None]:
class DefaultAttentionProcessor(nn.Module):

    def __init__(self):
        super().__init__()
        self.processor = attention_processor.AttnProcessor2_0() # from diffusers.models import attention_processor

    def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None,
                 attention_mask=None, **kwargs):
        return self.processor(attn, hidden_states, encoder_hidden_states, attention_mask)

class SharedAttentionProcessor(DefaultAttentionProcessor):

    def shifted_scaled_dot_product_attention(self, attn: attention_processor.Attention, query: T, key: T, value: T) -> T:
        logits = torch.einsum('bhqd,bhkd->bhqk', query, key) * attn.scale
        logits[:, :, :, query.shape[2]:] += self.shared_score_shift
        probs = logits.softmax(-1)
        return torch.einsum('bhqk,bhkd->bhqd', probs, value)

    def shared_call(
            self,
            attn: attention_processor.Attention,
            hidden_states,
            encoder_hidden_states=None,
            attention_mask=None,
            **kwargs
    ):

        residual = hidden_states
        input_ndim = hidden_states.ndim
        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states) # linear layer "channels" -> "heads * dim_heads"
        key = attn.to_k(hidden_states) # same as above
        value = attn.to_v(hidden_states) # same as above
        inner_dim = key.shape[-1] # get "heads * dim_heads" value
        head_dim = inner_dim // attn.heads # infer "dim_head" by dividing for the number of heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # shape all back to (batch, heads, tokens, dim_head)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # same as above
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # same as above
        # if self.step >= self.start_inject:
        # Adapative Normalization of Q and K (and possibly V)
        if self.adain_queries:
            query = adain(query)
        if self.adain_keys:
            key = adain(key)
        if self.adain_values: # usually false
            value = adain(value)
        # shared attention layer
        if self.share_attention:
            key = concat_first(key, -2, scale=self.shared_score_scale)
            value = concat_first(value, -2)
            hidden_states = nnf.scaled_dot_product_attention(
                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
            )
        else:
            hidden_states = nnf.scaled_dot_product_attention(
                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
            )
        # hidden_states = adain(hidden_states)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor
        return hidden_states

    def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None,
                 attention_mask=None, **kwargs):

        hidden_states = self.shared_call(attn, hidden_states, hidden_states, attention_mask, **kwargs)

        return hidden_states

    def __init__(self, style_aligned_args: StyleAlignedArgs):
        super().__init__()
        self.share_attention = style_aligned_args.share_attention
        self.adain_queries = style_aligned_args.adain_queries
        self.adain_keys = style_aligned_args.adain_keys
        self.adain_values = style_aligned_args.adain_values
        self.full_attention_share = style_aligned_args.full_attention_share
        self.shared_score_scale = style_aligned_args.shared_score_scale
        self.shared_score_shift = style_aligned_args.shared_score_shift

In [None]:
from dataclasses import dataclass

@dataclass(frozen=True)
class StyleAlignedArgs:
    share_group_norm: bool = True
    """
    Indicates whether to share group normalization across the model.
    """

    share_layer_norm: bool = True
    """
    Indicates whether to share layer normalization across the model.
    """

    share_attention: bool = True
    """
    Indicates whether to share attention mechanisms across the model.
    """

    adain_queries: bool = True
    """
    Indicates whether to apply AdaIN (Adaptive Instance Normalization) to the queries.
    """

    adain_keys: bool = True
    """
    Indicates whether to apply AdaIN to the keys.
    """

    adain_values: bool = False
    """
    Indicates whether to apply AdaIN to the values.
    """

    only_self_level: float = 0.0
    """
    Level of self-attention to be applied exclusively.
    """

### 4: DDIM \& PIPELINE DEFINITION
We then proceed to load the **SDXL (Stable Diffusion XL)** Model and configure the **DDIM (Denoising Diffusion Implicit Models) Scheduler**. We then configure the **Pipeline**.

#### 4.1: DDIM SCHEDULER

The **DDIM Scheduler** is the component used in diffusion models for generating high-quality samples from noise. It controls the denoising process by defining a schedule for adding and removing noise to and from the data. The scheduler is essential in determining how the model transitions from pure noise to a final, coherent image or other data form.

In particular, its parameters are:
- **beta_start (float)**: Starting value of beta, the variance of the noise schedule.
- **beta_end (float)**: Ending value of beta, the variance of the noise schedule.
- **beta_schedule (str)**: The type of schedule for beta. (Possible values: "linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid").
- **clip_sample (bool)**: If True, the samples are clipped to [-1, 1].
- **set_alpha_to_one (bool)**: If True, alpha will be set to 1 at the end of the sampling process.
- **num_train_timesteps (int)**: The number of diffusion steps used during training.
- **timestep_spacing (str)**: The method to space out timesteps.(Possible values: "linspace", "leading").
- **prediction_type (str)**: The type of prediction model used in the scheduler. (Possible values: "epsilon", "sample", "v-prediction").
- **trained_betas (torch.Tensor or None)**: Optional tensor of pre-trained betas to use in the scheduler.

##### 4.1.1: DIFFUSION PROCESS

The diffusion process involves adding noise to the data over a series of timesteps, which is described by the forward process:

$$ q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{\alpha_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}) $$

where:
- $\alpha_t$ and $\beta_t$ are the scaling and noise variance terms, respectively.

##### 4.1.2: REVERSE PROCESS

The reverse process aims to recover the data by denoising it, and is given by:

$$ p_{\theta}(\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mu_{\theta}(\mathbf{x}_t, t), \sigma_t^2 \mathbf{I}) $$

where:
- $\mu_{\theta}(\mathbf{x}_t, t)$ is the predicted mean.
- $\sigma_t$ is the standard deviation of the noise at timestep $t$.

##### 4.1.3: BETA SCHEDULE

The beta values are scheduled over timesteps from `beta_start` to `beta_end`. The schedule can be:
- **Linear**:

$$ \beta_t = \beta_{\text{start}} + t \frac{\beta_{\text{end}} - \beta_{\text{start}}}{T} $$

- **Scaled Linear**:

$$ \beta_t = \beta_{\text{start}} + t \left(\frac{\beta_{\text{end}} - \beta_{\text{start}}}{T}\right)^2 $$

- **Sigmoid**:

$$ \beta_t = \beta_{\text{start}} + (\beta_{\text{end}} - \beta_{\text{start}}) \cdot \text{sigmoid}(t) $$

- **Squared Cosine (squaredcos\_cap\_v2)**:

$$ \beta_t = \beta_{\text{start}} + 0.5 \left(1 - \cos\left(\frac{t \pi}{T}\right)\right) (\beta_{\text{end}} - \beta_{\text{start}}) $$

##### 4.1.4: INFERENCE WITH DDIM

During inference, the denoising process can be described as:

$$ \mathbf{x}_{t-1} = \sqrt{\alpha_{t-1}} \left( \frac{\mathbf{x}_t - \sqrt{1 - \alpha_t} \mathbf{\epsilon}_{\theta}(\mathbf{x}_t, t)}{\sqrt{\alpha_t}} \right) + \sqrt{1 - \alpha_{t-1} - \sigma_t^2} \mathbf{\epsilon}_{\theta}(\mathbf{x}_t, t) $$

where:
- $\mathbf{\epsilon}_{\theta}(\mathbf{x}_t, t)$ is the noise predicted by the model.
- $\sigma_t$ is the standard deviation for the timestep $t$.

In [None]:
scheduler_linear = DDIMScheduler(
    beta_start=0.00085,                 # Starting value of beta
    beta_end=0.012,                     # Ending value of beta
    beta_schedule="scaled_linear",      # Type of schedule for beta values
    clip_sample=False,                  # Whether to clip samples to a specified range
    set_alpha_to_one=False,             # Whether to set alpha to one at the end of the process

    num_train_timesteps=1000,           # Number of diffusion steps used during training
    timestep_spacing="linspace",        # Method to space out timesteps
    prediction_type="epsilon",          # Type of prediction model used in the scheduler
    trained_betas=None                  # Optional pre-trained beta values
)

scheduler = scheduler_linear

### 4.2: SDXL PIPELINE DEFINITION

We then proceed to **load** the **pre-trained `StableDiffusionXLPipeline` model** with specific configurations to optimize for GPU memory usage and ensure efficient processing. Below is a breakdown of each parameter and its purpose:

- **pretrained_model_name_or_path**: The name or path of the pre-trained model to be loaded. In this example, we use `"stabilityai/stable-diffusion-xl-base-1.0"`, which is a pre-trained model available in the Stability AI repository.
- **torch_dtype**: Specifies the data type for the model's tensors. Here, `torch.float16` is used to enable mixed precision, which helps reduce memory usage and improve computation speed.
- **variant**: Indicates the model variant. `"fp16"` is used to specify 16-bit floating point precision, aligning with the `torch_dtype` parameter.
- **use_safetensors**: Determines whether to use the `safetensors` library for safe tensor loading. Setting this to `True` ensures safer model loading.
- **scheduler**: An instance of the scheduler to be used for the diffusion process. In this example, we use a `DDIMScheduler` instance configured for efficient sampling.
- **revision**: Specifies the model version to use. The default value is `None`, which means the latest version will be used.
- **use_auth_token**: The authentication token used for accessing private models. The default value is `None`, meaning no authentication is required.
- **cache_dir**: The directory where the downloaded model will be cached. The default value is `None`, which uses the default cache directory.
- **force_download**: Forces the model to be downloaded even if it exists locally. The default value is `False`.
- **resume_download**: Resumes a partial download if available. The default value is `False`.
- **proxies**: A dictionary of proxy servers to use. The default value is `None`, meaning no proxies are used.
- **local_files_only**: Uses only local files if set to `True`. The default value is `False`.
- **device_map**: Specifies device placement for model layers. The default value is `None`, which uses the default placement.
- **max_memory**: Specifies the maximum memory allowed for each device. The default value is `None`, meaning no specific memory limit is set.

Finally, the model is moved to the GPU for faster computations using `.to("cuda")`.

The use of mixed precision (`torch_dtype=torch.float16` and `variant="fp16"`) helps in reducing memory usage and improving performance. This configuration is particularly useful when working with large models and limited GPU memory.

In [None]:
SDXL_Pipeline = StableDiffusionXLPipeline.from_pretrained(
    pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0",  # The model name or path
    torch_dtype=torch.float16,            # Data type for the model's tensors
    variant="fp16",                       # Model variant for 16-bit floating point precision (Mixed Precision)
    use_safetensors=True,                 # Use the safetensors library for safe tensor loading
    scheduler=scheduler,                  # Scheduler instance for the diffusion process

    revision=None,                        # Model version to use, default is None
    use_auth_token=None,                  # Authentication token, None means no authentication
    cache_dir=None,                       # Directory to cache the downloaded model, None uses default
    force_download=False,                 # Force download even if the model exists locally
    resume_download=False,                # Resume a partial download if available
    proxies=None,                         # Dictionary of proxy servers to use, None means no proxies
    local_files_only=False,               # Use only local files if set to True
    device_map=None,                      # Device placement for model layers, None uses default placement
    max_memory=None                       # Maximum memory allowed for each device, None means no specific limit
).to("cuda")                              # Move the model to the GPU for faster computations

In [None]:
handler = sa_handler.Handler(SDXL_Pipeline)
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
                                      share_layer_norm=False,
                                      share_attention=True,
                                      adain_queries=True,
                                      adain_keys=True,
                                      adain_values=False
                                     )

handler.register(sa_args)

### 5: RUNNING STYLE-ALIGNED with A SET OF PROMPTS WITHOUT REFERENCE IMAGE

TO RUN IF YOU HAVE ENOUGH GPU RAM

In [None]:
# run StyleAligned

sets_of_prompts = [
  "a toy train. macro photo. 3d game asset",
  "a toy airplane. macro photo. 3d game asset",
  "a toy bicycle. macro photo. 3d game asset",
  "a toy car. macro photo. 3d game asset",
  "a toy boat. macro photo. 3d game asset",
]
images = SDXL_Pipeline(sets_of_prompts,).images
mediapy.show_images(images)

TO RUN IF YOU HAVEN'T ENOUGH GPU RAM

In [None]:
# run StyleAligned
sets_of_prompts = [
  "a toy train. macro photo. 3d game asset",
  "a toy airplane. macro photo. 3d game asset",
  "a toy bicycle. macro photo. 3d game asset",
  "a toy car. macro photo. 3d game asset",
  "a toy boat. macro photo. 3d game asset",
]
# sets_of_prompts = [
#   "a hot hair balloon, simple wooden statue",
#   "a friendly robot, simple wooden statue",
#   "a bull, simple wooden statue",
# ]
images = []
for prompt in sets_of_prompts:
    # Generate image for each prompt individually
    image = SDXL_Pipeline([prompt]).images[0]
    images.append(image)
    # Clear CUDA cache to free memory
    torch.cuda.empty_cache()

    # Print Memory summary
    # print(torch.cuda.memory_summary(device=None, abbreviated=False))

mediapy.show_images(images)

### 6: STYLE-ALIGNED WITH REFERENCE IMAGE

Load a reference image and perform the inversion process to extract latent representations.

#### 6.1: LOADING REFERENCE IMAGE & SETTING PARAMETERS

As first thing, we **load a Reference Image** from which you will "copy the style". Indeed, we will ask for the newly generated images to stick with the same style of the Reference image, i.e., to be **Style-Aligned**.
In particular, we will define 3 **Image Parameters** here:
- **reference_style**: This is the reference style describing the reference image.
- **reference_prompt**: This is the reference prompt describing the reference image.
- **reference_image_path**: This is the path to the reference image.

As second step, you will set the parameters relative to the **Diffusion Inversion Process**. This process in a.k.a as **Temperature Scaling**, that aims to inject confidence/randomness in a classification/generation model. In this case, we aim at injecting more styleAlignment or randomness to the image generated.

In particular, you will set the parameters relative to:
- **num_inference_steps**: The number of inference steps to be performed during the Diffusion Inversion Process.
- **guidance_scale**: Here we set the parameter to have **Guidance Scale**, a parameter used in **guided diffusion models** to control the influence of the conditioning signal (e.g., a text prompt) during the image generation process, with the purpose to balance the model’s adherence to the conditioning signal and its natural generative tendencies. In particular:
	- **High Guidance Scale (>1)**: Increases the influence of the conditioning signal, making the generated images more aligned with the prompt. The model is more likely to produce images with features that are explicitly described in the prompt, leading to more detailed and specific outputs. Very high values might cause the model to overfit to the prompt, potentially losing some naturalness or introducing artifacts.
	- **Default Guidance Scale (=1)**: A guidance scale of 1.0 means that the model’s predictions are equally balanced between the conditional and unconditional signals, providing a baseline level of adherence to the prompt.
	- **Low Guidance Scale (<1)**: A lower guidance scale reduces the effect of the guiding input, making the generated images less constrained by the prompt. The model has more freedom to generate diverse and potentially more creative outputs that are not strictly bound to the prompt. Very low values might cause the model to generate images that are too generic and not sufficiently aligned with the prompt.
- **style_alignment_score_shift**: This parameter is used to **shift** the **scores logarithmically**. In particular, we will have:
	- **High Values (>1)**: This will **increase** the **StyleAlignment**, making therefore the output image to be closer to the reference.
	- **Low Values (<=1)**: This will **decrease** the **StyleAlignment**, making therefore the output image to be farther to the reference.
- **style_alignment_score_scale**: This parameter is used to **scale** the scores or weights within the model. More in detail:
	- **High Values (>1)**: This increases the magnitude of the scores, making the model more confident and therefore not varying so much the generation process.
	- **Standard Value (=1)**: This translates in having no rescaling.
	- **Low Values (<1)**: This coincides to having more generalization (injecting randomness) into the generation process.

***Special Configuration for Famous Images***

For very famous images, it might be beneficial to suppress the **Shared Attention** to the reference image to avoid overfitting or excessive influence from the reference.

In [None]:
# Set the source style, prompt and path.
reference_style = "medieval painting"
reference_prompt = f'Man laying in a bed, {reference_style}.'
reference_image_path = './imgs/medieval-bed.jpeg'

# Setting the number of inference steps in the Diffusion Inversion Process.
num_inference_steps = 50

# Setting the Guidance Scale for the Diffusion Inversion Process.
guidance_scale = 10.0

# These are some parameters you can Adjust to Control StyleAlignment to Reference Image.
style_alignment_score_shift = 2  # higher value induces higher fidelity, set 0 for no shift
style_alignment_score_scale = 1.0  # higher value induces higher, set 1 for no rescale

# Famous Paintings
# style_alignment_score_shift = 1
# style_alignment_score_scale = 0.5

In [None]:
# Load the reference image and resize it to 1024x1024 pixels.
ref_image = np.array(load_image(reference_image_path).resize((1024, 1024)))

# Display the output image.
mediapy.show_image(ref_image, title="Reference Image for Style Alignment", height=256)

#### 6.2: DIFFUSION INVERSION PROCESS


In this section, we will undergo the **Diffusion Inversion Process** in order to **align the style of generated images with a reference style**. This process involves several key steps:


1. **Set Prompts**:
   - We define a set of prompts for generating images. The first prompt refers to the reference image, while the subsequent prompts are used to generate new images.
   - The reference style is appended to each of the subsequent prompts to ensure the generated images adhere to the desired style.

2. **Configure Style Alignment Handler**:
   - We initialize a handler for the Style Aligned (SA) model and configure it using the `StyleAlignedArgs` parameters. These parameters control various aspects of the style alignment process, such as normalization and attention mechanisms.

3. **Run Diffusion Inversion**:
   - We execute the DDIM inversion process to map the reference image to its latent representation. This inversion allows us to extract latent features that can be used to guide the generation of new images in the desired style.

4. **Generate Images**:
   - Using the latent representation obtained from the inversion process, we generate new images based on the defined prompts. We will generate random latent vectors of shape (number_of_prompts, 4, 128, 128) from a normal distribution. We will make use of generator to ensure that the random values are reproducible. In this step, I will have also to ensure that the latent vectors have the same data type as required by the model’s UNet.
   After this, we will set the first latent vector to the one extracted from the reference image, ensuring that the first generated image closely adheres to the reference style.
   The latent features of the reference image are combined with the prompts to produce images that are stylistically aligned with the reference image.

5. **Display Results**:
   - Finally, we display the generated images to visualize the effect of the style alignment.

This process demonstrates how to leverage the power of diffusion models and inversion techniques to generate images with consistent and coherent styles, guided by a reference image.

In [None]:
# Set of prompts to generate images for. The first refers to the Reference Image. The other to generate images.
prompts = [
    reference_prompt,
    "A man working on a laptop",
    "A man eats pizza",
    "A woman playig on saxophone",
]

# Append the reference style to each of subsequent prompts for generating images with the same Style.
for i in range(1, len(prompts)):
    prompts[i] = f'{prompts[i]}, {reference_style}.'

# Configure the StyleAligned Handler using the StyleAlignedArgs.
handler = sa_handler.Handler(SDXL_Pipeline)
sa_args = sa_handler.StyleAlignedArgs(
    share_group_norm=True,
    share_layer_norm=True,
    share_attention=True,
    adain_queries=True,
    adain_keys=True,
    adain_values=False,
    style_alignment_score_shift=np.log(style_alignment_score_shift),
    style_alignment_score_scale=style_alignment_score_scale)
handler.register(sa_args)

In [None]:
# Execute the Diffusion Inversion Process to map the reference image to its latent representation.
DDIM_inv_result = inversion.ddim_inversion(SDXL_Pipeline, ref_image, reference_prompt, num_inference_steps, 2)

# Extract the latent representation from the Diffusion Inversion Result that can be used to guide the generation of new images in the desired style.
latent_vector_ref_img, inversion_callback = inversion.make_inversion_callback(DDIM_inv_result, offset=5)

# Create a Random Number Generator on the CPU.
rand_gen = torch.Generator(device='cpu').manual_seed(10)

# Generate the images using the latent representation of the reference image as guidance.
latents = torch.randn(len(prompts), 4, 128, 128,            # Random Latent Vectors shape
                      device='cpu',                         # Latent Vectors on CPU.
                      generator=rand_gen,                   # Random Number Generator.
                      dtype=SDXL_Pipeline.unet.dtype,).to('cuda:0') # Data Type of the Latent Vectors (same as required by the model's UNet).

# Set the first latent vector to the latent representation of the reference image extracted before.
latents[0] = latent_vector_ref_img

# Generate the images using the provided prompts and the latent vectors.
images_a = SDXL_Pipeline(
    prompts,                                 # Prompts to generate images for.
    latents=latents,                         # Latent Vectors to guide the generation of images.
    callback_on_step_end=inversion_callback, # Callback to update the latent vectors during the generation process.
    num_inference_steps=num_inference_steps, # Number of Inference Steps to generate the images.
    guidance_scale=10.0).images              # Guidance Scale to control the influence of the latent vectors on the generated images.

# Display the generated images.
handler.remove()
mediapy.show_images(images_a, titles=[p[:-(len(reference_style) + 3)] for p in prompts])