In [None]:
# Clone the repository
!git clone https://github.com/alessioborgi/StyleAlignedDiffModels.git

# Change directory to the cloned repository
%cd StyleAlignedDiffModels
%ls

# Set up Git configuration
!git config --global user.name "Alessio Borgi"
!git config --global user.email "alessioborgi3@gmail.com"

# Stage the changes
#!git add .

# Commit the changes
#!git commit -m "Added some content to your-file.txt"

# Push the changes (replace 'your-token' with your actual personal access token)
#!git push origin main

In [None]:
# Install the required packages
!pip install -r requirements.txt > /dev/null

In [None]:
import copy
import torch
import einops
import mediapy
import numpy as np
from PIL import Image
import torch.nn as nn
from tqdm import tqdm
from typing import Any
from typing import Callable
from dataclasses import dataclass
from __future__ import annotations
from diffusers.utils import load_image
from torch.nn import functional as nnf
from diffusers.models import attention_processor
from diffusers.image_processor import PipelineImageInput
from transformers import DPTImageProcessor, DPTForDepthEstimation
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, ControlNetModel, StableDiffusionXLControlNetPipeline

import sa_handler_try as sa_handler
T = torch.tensor # Create Alias for torch.tensor to increase readability.
TN = T

In [None]:
scheduler = DDIMScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False)

pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
    scheduler=scheduler
).to("cuda")

In [None]:
# 1) Medieval Painting
# Set the source style, prompt and path.
reference_style = "medieval painting"
reference_prompt = f'Man laying in a bed, {reference_style}.'
reference_image_path = './imgs/medieval-bed.jpeg'

# 2) Cubism Painting
# reference_style = "cubism painting"
# reference_prompt = f'Two men smoking water pipe, {reference_style}.'
# reference_image_path = './imgs/Picasso_Smoking_Water_Pipe.jpeg'


# Setting the number of inference steps in the Diffusion Inversion Process.
num_inference_steps = 50

# Setting the Guidance Scale for the Diffusion Inversion Process.
guidance_scale = 10.0

# 1) Normal Painting
# These are some parameters you can Adjust to Control StyleAlignment to Reference Image.
style_alignment_score_shift = 2  # higher value induces higher fidelity, set 0 for no shift
style_alignment_score_scale = 1.0  # higher value induces higher, set 1 for no rescale

# 2) Very Famous Paintings
# style_alignment_score_shift = 1
# style_alignment_score_scale = 0.5

In [None]:
# Load the reference image and resize it to 1024x1024 pixels.
ref_image = np.array(load_image(reference_image_path).resize((1024, 1024)))

# Display the output image.
mediapy.show_image(ref_image, title="Reference Image for Style Alignment", height=256)

In [None]:
# Defining a type alias for the Diffusion Inversion Process type of callable.
Diff_Inversion_Process_Callback = Callable[[StableDiffusionXLPipeline, int, T, dict[str, T]], dict[str, T]]

In [None]:
def prompt_tokenizazion_and_embedding(prompt: str, tokenizer, text_encoder, device):

    # 1) Tokenize the Input Prompt: Tokenize the input prompt using the provided tokenizer, with padding and truncation.
    prompt_tokenized = tokenizer(prompt, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt')
    # 2) Extract Token IDs: Extract the input IDs (token indices) from the tokenized inputs.
    prompt_tokenized_ids = prompt_tokenized.input_ids

    # 3) Generate Embeddings: Use torch.no_grad() to disable gradient computation for the following operations.
    with torch.no_grad():
        # Generate embeddings for the tokenized input IDs using the text encoder.
        # The embeddings include output hidden states.
        prompt_embeddings = text_encoder(
            prompt_tokenized_ids.to(device),  # Move input IDs to the specified device (e.g., GPU).
            output_hidden_states=True,  # Request hidden states from the encoder.
        )

    # 4) Extract Pooled Output Embeddings: Extract the pooled output embeddings (first element of the tuple returned by the encoder).
    pooled_prompt_embeddings = prompt_embeddings[0]
    # 5) Extract Hidden State Embeddings: Extract the hidden state embeddings from the second last layer of the encoder.
    prompt_embeddings = prompt_embeddings.hidden_states[-2]

    # 6) Handle Empty Prompt Case: If the prompt is empty, return zero tensors as Negative Embeddings.
    if prompt == '':
        # Create a zero tensor with the same shape as the hidden state embeddings.
        negative_prompt_embeddings = torch.zeros_like(prompt_embeddings)
        # Create a zero tensor with the same shape as the pooled output embeddings.
        negative_pooled_prompt_embeddings = torch.zeros_like(pooled_prompt_embeddings)
        # Return the zero tensors for both negative embeddings and pooled negative embeddings.
        return negative_prompt_embeddings, negative_pooled_prompt_embeddings

    # 7) Returns the generated embeddings: Return the hidden state embeddings and the pooled output embeddings.
    return prompt_embeddings, pooled_prompt_embeddings

In [None]:
def embeddings_ensemble(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]:

    # 1) Get the Device: Get the device (e.g., CPU or GPU) on which the model is being executed.
    device = model._execution_device

    # 2) Generate Text Embeddings:
    # Generate text embeddings using the first set of tokenizer and text encoder.
    prompt_embeddings_1, pooled_prompt_embeddings_1 = prompt_tokenizazion_and_embedding(prompt, model.tokenizer, model.text_encoder, device)

    # Generate text embeddings using the second set of tokenizer and text encoder.
    prompt_embeddings_2, pooled_prompt_embeddings_2 = prompt_tokenizazion_and_embedding(prompt, model.tokenizer_2, model.text_encoder_2, device)

    # 3) Concatenate Prompt Embeddings: Concatenate the embeddings from both sets of encoders along the last dimension.
    prompt_embeddings_concat = torch.cat((prompt_embeddings_1, prompt_embeddings_2), dim=-1)

    # 4) Get Text Encoder Projection Dimension: Retrieve the projection dimension from the configuration of the second text encoder
    prompt_encoder_projection_dim = model.text_encoder_2.config.projection_dim

    # 5) Generate Additional Time IDs: Generate additional time IDs required for conditioning.
    conditioning_time_ids = model._get_add_time_ids((1024, 1024), (0, 0), (1024, 1024), torch.float16, prompt_encoder_projection_dim).to(device)

    # 6) Prepare Additional Condition Keyword Arguments: Prepare additional condition keyword arguments required for the model.
    conditioning_kwargs = {"text_embeds": pooled_prompt_embeddings_2, "time_ids": conditioning_time_ids}

    # 7) Return the Additional Condition Keyword Arguments and Concatenated Embeddings:Return the prepared additional condition keyword arguments and concatenated prompt embeddings
    return conditioning_kwargs, prompt_embeddings_concat

In [None]:
def embeddings_ensemble_with_neg_conditioning(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]:
    # 1) Encode Text with Given Prompt using Text Embedding Ensemble Encode Text with Given Prompt: Generate text embeddings and conditioning keywords for the given prompt.
    conditioning_kwargs, prompt_embeddings_concat = embeddings_ensemble(model, prompt)

    # 2) Encode Text with Empty Prompt: Generate text embeddings and conditioning keywords for an empty prompt (negative conditioning).
    unconditioning_kwargs, prompt_embeddings_uncond = embeddings_ensemble(model, "")

    # 3) Concatenate Positive and Negative Embeddings: Concatenate the embeddings from the negative and positive prompts.
    prompt_embeddings_concat = torch.cat((prompt_embeddings_uncond, prompt_embeddings_concat))

    # 4) Concatenate Positive and Negative Conditioning Keywords: Concatenate the conditioning keywords from the negative and positive prompts.
    conditioning_unconditioning_kwargs = {
        "text_embeds": torch.cat((unconditioning_kwargs["text_embeds"], conditioning_kwargs["text_embeds"])),
        "time_ids": torch.cat((unconditioning_kwargs["time_ids"], conditioning_kwargs["time_ids"]))
    }

    # 5) Return Combined Conditioning Keywords and Embeddings: Return the combined conditioning keywords and concatenated embeddings.
    return conditioning_unconditioning_kwargs, prompt_embeddings_concat

In [None]:
def image_encoding(model: StableDiffusionXLPipeline, image: np.ndarray) -> T:

    # 1) Set VAE to Float32: Ensure the VAE operates in float32 precision for encoding.
    model.vae.to(dtype=torch.float32)

    # 2) Convert Image to PyTorch Tensor: Convert the input image from a numpy array to a PyTorch tensor and normalize pixel values to [0, 1].
    scaled_img = torch.from_numpy(image).float() / 255.

    # 3) Normalize and Prepare Image: Scale pixel values to the range [-1, 1], rearrange dimensions, and add batch dimension.
    permuted_img = (scaled_img * 2 - 1).permute(2, 0, 1).unsqueeze(0)

    # 4) Encode Image Using VAE: Use the VAE to encode the image into the latent space.
    latent_img = model.vae.encode(permuted_img.to(model.vae.device))['latent_dist'].mean * model.vae.config.scaling_factor

    # 5) Reset VAE to Float16: Optionally reset the VAE to float16 precision.
    model.vae.to(dtype=torch.float16)

    # 6) Return Latent Representation: Return the encoded latent representation of the image.
    return latent_img

In [None]:
def Denoising_next_step(model: StableDiffusionXLPipeline, model_output: T, timestep: int, sample: T) -> T:

    # 1) Calculate Current and Next Timesteps: Compute the current and next timesteps for the denoising process.
    current_timestep, next_timestep = min(timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep

    # 2) Calculate Alpha Products: Retrieve the alpha cumulative product for the current and next timesteps.
    alpha_prod_t = model.scheduler.alphas_cumprod[int(current_timestep)] if current_timestep >= 0 else model.scheduler.final_alpha_cumprod
    alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)]

    # 3) Calculate Beta Product: Compute the beta cumulative product for the current timestep.
    beta_prod_t = 1 - alpha_prod_t

    # 4) Compute Next Original Sample: Calculate the next original sample using the current sample and model output.
    next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5

    # 5) Compute Next Sample Direction: Determine the direction for the next sample.
    next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output

    # 6) Compute Next Sample: Combine the next original sample and next sample direction to get the next sample.
    next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction

    # 7) Return Next Sample: Return the computed next sample.
    return next_sample

In [None]:
def Generate_Noise_Prediction(model: StableDiffusionXLPipeline, latent: T, t: T, context: T, guidance_scale: float, added_cond_kwargs: dict[str, T]):
    # 1) Duplicate Latent Input: Create a batch of two identical latent representations.
    double_input_latents = torch.cat([latent] * 2)

    # 2) Generate Noise Predictions: Use the model's UNet to generate noise predictions for the duplicated latents.
    noise_prediction = model.unet(double_input_latents, t, encoder_hidden_states=context, added_cond_kwargs=added_cond_kwargs)["sample"]

    # 3) Split Noise Predictions: Split the noise predictions into unconditional and conditional components.
    noise_prediction_unconditioned, noise_prediction_text = noise_prediction.chunk(2)

    # 4) Apply Guidance: Combine the unconditional and conditional noise predictions using the guidance scale.
    noise_prediction = noise_prediction_unconditioned + guidance_scale * (noise_prediction_text - noise_prediction_unconditioned)

    # 5) Return Noise Prediction: Return the combined noise prediction.
    return noise_prediction

In [None]:
def Denoising_Process(model: StableDiffusionXLPipeline, z0, prompt, guidance_scale) -> T:
    # 1) Initialize Latent List: Start with the initial latent representation.
    latent_list = [z0]

    # 2) Encode Text with Negative Conditioning: Generate text embeddings and conditioning keywords for the prompt, including also negative conditioning.
    conditioning_unconditioning_kwargs, prompt_embedding = embeddings_ensemble_with_neg_conditioning(model, prompt)

    # 3) Prepare Latent for Inference: Clone and detach the initial latent, and convert it to half precision.
    latent = z0.clone().detach().half()

    # 4) Denoising Loop: Perform the denoising process over the specified number of inference steps.
    for i in tqdm(range(model.scheduler.num_inference_steps)):
        # 4.1) Get Current Timestep: Retrieve the current timestep.
        current_timestep = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1]

        # 4.2) Generate Noise Prediction: Use the model to predict noise for the current latent and timestep.
        noise_prediction = Generate_Noise_Prediction(model, latent, current_timestep, prompt_embedding, guidance_scale, conditioning_unconditioning_kwargs)

        # 4.3) Compute Next Latent: Compute the next latent representation using the noise prediction.
        next_latent = Denoising_next_step(model, noise_prediction, current_timestep, latent)

        # 4.4) Append Latent to List: Append the new latent to the list of all latents.
        latent_list.append(next_latent)

    # 5) Return Sequence of Latents: Concatenate all latents and reverse their order.
    final_latent = torch.cat(latent_list).flip(0)

    return final_latent

In [None]:
def extract_latent_and_inversion(ddim_result, offset: int = 0) -> [T, Diff_Inversion_Process_Callback]:

    def callback_on_step_end(pipeline: StableDiffusionXLPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[str, T]:

        latents = callback_kwargs['latents']
        # Update the first latent tensor with the corresponding tensor from ddim_result.
        latents[0] = ddim_result[max(offset + 1, i + 1)].to(latents.device, latents.dtype)
        return {'latents': latents}

    # Return the initial latent tensor and the callback function.
    return  ddim_result[offset], callback_on_step_end

In [None]:
@torch.no_grad()
def DDIM_Inversion_Process(model: StableDiffusionXLPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int, guidance_scale: float) -> T:

    # 1) Encode Image: Encode the input image into a latent representation using the model's VAE.
    encoded_img = image_encoding(model, x0)

    # 2) Set Timesteps: Set the timesteps for the diffusion process.
    model.scheduler.set_timesteps(num_inference_steps, device=encoded_img.device)

    # 3) Perform DDIM Loop: Perform the DDIM denoising loop to generate a sequence of latent representations.
    latent_repr_sequence = Denoising_Process(model, encoded_img, prompt, guidance_scale)

    # 4) Return Sequence of Latents: Return the sequence of latent representations generated by the DDIM loop.
    return latent_repr_sequence

In [None]:
# Set of prompts to generate images for. The first refers to the Reference Image. The other to generate images.
prompts = [
    reference_prompt,
    "A man working on a laptop",
    "A man eats pizza",
    "A woman playig on saxophone",
]

# Append the reference style to each of subsequent prompts for generating images with the same Style.
for i in range(1, len(prompts)):
    prompts[i] = f'{prompts[i]}, {reference_style}.'

# Configure the StyleAligned Handler using the StyleAlignedArgs.
handler = sa_handler.Handler(pipeline)
sa_args = sa_handler.StyleAlignedArgs(
    share_group_norm=True,
    share_layer_norm=True,
    share_attention=True,
    adain_queries=True,
    adain_keys=True,
    adain_values=False,
    style_alignment_score_shift=np.log(style_alignment_score_shift),
    style_alignment_score_scale=style_alignment_score_scale)
handler.register(sa_args)

In [None]:
# Execute the Diffusion Inversion Process to map the reference image to its latent representation.
DDIM_inv_result = DDIM_Inversion_Process(pipeline, ref_image, reference_prompt, num_inference_steps, 2)

# Extract the latent representation from the Diffusion Inversion Result that can be used to guide the generation of new images in the desired style.
latent_vector_ref_img, inversion_callback = extract_latent_and_inversion(DDIM_inv_result, offset=5)

# Create a Random Number Generator on the CPU.
rand_gen = torch.Generator(device='cpu').manual_seed(31)

# Generate the images using the latent representation of the reference image as guidance.
latents = torch.randn(len(prompts), 4, 128, 128,            # Random Latent Vectors shape
                      device='cpu',                         # Latent Vectors on CPU.
                      generator=rand_gen,                   # Random Number Generator.
                      dtype=pipeline.unet.dtype,).to('cuda:0') # Data Type of the Latent Vectors (same as required by the model's UNet).

# Set the first latent vector to the latent representation of the reference image extracted before.
latents[0] = latent_vector_ref_img

# Generate the images using the provided prompts and the latent vectors.
images_a = pipeline(
    prompts,                                 # Prompts to generate images for.
    latents=latents,                         # Latent Vectors to guide the generation of images.
    callback_on_step_end=inversion_callback, # Callback to update the latent vectors during the generation process.
    num_inference_steps=num_inference_steps, # Number of Inference Steps to generate the images.
    guidance_scale=guidance_scale).images              # Guidance Scale to control the influence of the latent vectors on the generated images.

# Display the generated images.
handler.remove()
mediapy.show_images(images_a, titles=[p[:-(len(reference_style) + 3)] for p in prompts])