# StyleAligned: Zero-Shot Style Alignment among a Series of Generated Images via Attention Sharing

Authors: Borgi Alessio, Danese Francesco.

In this notebook we aim to reproduce and enhance [StyleAligned](https://arxiv.org/abs/2312.02133), a novel technique introduced by Google Research, for achieving style consistency in large-scale Text-to-Image (T2I) generative models. While current T2I models excel in creating visually compelling images from textual descriptions, they often struggle to maintain a consistent style across multiple images. Traditional methods to address this require extensive fine-tuning and manual intervention. 

StyleAligned addresses this challenge by introducing minimal Attention Sharing during the diffusion process, ensuring style alignment among generated images without the need for optimization or fine-tuning. The method operates by leveraging a straightforward inversion operation to apply a reference style across various generated images, maintaining high-quality synthesis and fidelity to the provided text prompts.

### 0.1: CLONE REPOSITORY AND GIT SETUP

In the following cell, we setup the code, by cloning the repository, setting up the Git configurations, and providing some other useful commands useful for git.  

In [None]:
# Clone the repository
!git clone https://github.com/alessioborgi/StyleAlignedDiffModels.git

# Change directory to the cloned repository
%cd StyleAlignedDiffModels
%ls

# Set up Git configuration
!git config --global user.name "Alessio Borgi"
!git config --global user.email "alessioborgi3@gmail.com"

# Stage the changes
#!git add .

# Commit the changes
#!git commit -m "Added some content to your-file.txt"

# Push the changes (replace 'your-token' with your actual personal access token)
#!git push origin main

### 0.2: INSTALL AND IMPORT REQUIRED LIBRARIES

We proceed then by installing and importing the required libraries.

In [None]:
# Install the required packages
!pip install -r requirements.txt > /dev/null

In [None]:
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
import torch
import mediapy
import sa_handler

### 4.0 DDIM SCHEDULER

In [None]:
# init models

scheduler = DDIMScheduler(
    beta_start=0.00085, 
    beta_end=0.012, 
    beta_schedule="scaled_linear", 
    clip_sample=False,
    set_alpha_to_one=False)

pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", 
    torch_dtype=torch.float16, 
    variant="fp16", 
    use_safetensors=True,
    scheduler=scheduler
).to("cuda")

handler = sa_handler.Handler(pipeline)
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
                                      share_layer_norm=False,
                                      share_attention=True,
                                      adain_queries=True,
                                      adain_keys=True,
                                      adain_values=False,
                                     )

handler.register(sa_args, )

### 5: RUNNING STYLE-ALIGNED with A SET OF PROMPTS

In [None]:
# run StyleAligned

sets_of_prompts = [
  "a toy train. macro photo. 3d game asset",
  "a toy airplane. macro photo. 3d game asset",
  "a toy bicycle. macro photo. 3d game asset",
  "a toy car. macro photo. 3d game asset",
  "a toy boat. macro photo. 3d game asset",
]
images = pipeline(sets_of_prompts,).images
mediapy.show_images(images)

In [None]:
# run StyleAligned
sets_of_prompts = [
  "a toy train. macro photo. 3d game asset",
  "a toy airplane. macro photo. 3d game asset",
  "a toy bicycle. macro photo. 3d game asset",
  "a toy car. macro photo. 3d game asset",
  "a toy boat. macro photo. 3d game asset",
]
# sets_of_prompts = [
#   "a hot hair balloon, simple wooden statue",
#   "a friendly robot, simple wooden statue",
#   "a bull, simple wooden statue",
# ]
images = []
for prompt in sets_of_prompts:
    # Generate image for each prompt individually
    image = pipeline([prompt]).images[0]
    images.append(image)
    # Clear CUDA cache to free memory
    torch.cuda.empty_cache()
    
    # Print Memory summary
    # print(torch.cuda.memory_summary(device=None, abbreviated=False))
    
mediapy.show_images(images)