# ELLA: Equip Diffusion Models with LLM for Enhanced Semantic Alignment 

make sure to download the model from: `https://huggingface.co/QQGYLab/ELLA/blob/main/ella-sd1.5-tsc-t5xl.safetensors`  
and save it thus `/ella/models/ella_path/ella-sd1.5-tsc-t5xl.safetensors`

### Install and Import Dependencies and Define Classes 


In [None]:
!pip install safetensors 
!python ella_gen.py 
!pip install torch 
!pip install diffusers 
!pip install torchvision
!pip install transformers
!pip install accelerate 
!pip install sentencepiece

In [2]:
from pathlib import Path
from typing import Any, Optional, Union
import os
import sys
import safetensors.torch
import torch
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
from torchvision.utils import save_image

from model import ELLA, T5TextEmbedder


class ELLAProxyUNet(torch.nn.Module):
    def __init__(self, ella, unet):
        super().__init__()
        self.ella = ella
        self.unet = unet
        self.config = unet.config
        self.dtype = unet.dtype
        self.device = unet.device
        self.flexible_max_length_workaround = None

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[dict[str, Any]] = None,
        added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
        down_block_additional_residuals: Optional[tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        down_intrablock_additional_residuals: Optional[tuple[torch.Tensor]] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ):
        if self.flexible_max_length_workaround is not None:
            time_aware_encoder_hidden_state_list = []
            for i, max_length in enumerate(self.flexible_max_length_workaround):
                time_aware_encoder_hidden_state_list.append(
                    self.ella(encoder_hidden_states[i : i + 1, :max_length], timestep)
                )
            time_aware_encoder_hidden_states = torch.cat(
                time_aware_encoder_hidden_state_list, dim=0
            )
        else:
            time_aware_encoder_hidden_states = self.ella(
                encoder_hidden_states, timestep
            )

        return self.unet(
            sample=sample,
            timestep=timestep,
            encoder_hidden_states=time_aware_encoder_hidden_states,
            class_labels=class_labels,
            timestep_cond=timestep_cond,
            attention_mask=attention_mask,
            cross_attention_kwargs=cross_attention_kwargs,
            added_cond_kwargs=added_cond_kwargs,
            down_block_additional_residuals=down_block_additional_residuals,
            mid_block_additional_residual=mid_block_additional_residual,
            down_intrablock_additional_residuals=down_intrablock_additional_residuals,
            encoder_attention_mask=encoder_attention_mask,
            return_dict=return_dict,
        )


### Define helper functions 

In [3]:

def generate_image_with_flexible_max_length(
    pipe, t5_encoder, prompt, negative_prompt=None, fixed_negative=False, output_type="pt", **pipe_kwargs
):
    device = pipe.device
    dtype = pipe.dtype
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    prompt_embeds = t5_encoder(prompt, max_length=None).to(device, dtype)
    
    # Use the provided negative prompt if any, otherwise use empty strings
    negative_prompt = [negative_prompt] * batch_size if negative_prompt else [""] * batch_size
    negative_prompt_embeds = t5_encoder(
        negative_prompt, max_length=256 if fixed_negative else None
    ).to(device, dtype)

    pipe.unet.flexible_max_length_workaround = [
        negative_prompt_embeds.size(1)
    ] * batch_size + [prompt_embeds.size(1)] * batch_size

    max_length = max([prompt_embeds.size(1), negative_prompt_embeds.size(1)])
    b, _, d = prompt_embeds.shape
    prompt_embeds = torch.cat(
        [
            prompt_embeds,
            torch.zeros(
                (b, max_length - prompt_embeds.size(1), d), device=device, dtype=dtype
            ),
        ],
        dim=1,
    )
    negative_prompt_embeds = torch.cat(
        [
            negative_prompt_embeds,
            torch.zeros(
                (b, max_length - negative_prompt_embeds.size(1), d),
                device=device,
                dtype=dtype,
            ),
        ],
        dim=1,
    )

    images = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        **pipe_kwargs,
        output_type=output_type,
    ).images
    pipe.unet.flexible_max_length_workaround = None
    return images

def load_ella(filename, device, dtype):
    ella = ELLA()
    safetensors.torch.load_model(ella, filename, strict=True)
    ella.to(device, dtype=dtype)
    return ella

def load_ella_for_pipe(pipe, ella):
    pipe.unet = ELLAProxyUNet(ella, pipe.unet)

def offload_ella_for_pipe(pipe):
    pipe.unet = pipe.unet.unet

def generate_image_with_fixed_max_length(
    pipe, t5_encoder, prompt, negative_prompt=None, output_type="pt", **pipe_kwargs
):
    prompt = [prompt] if isinstance(prompt, str) else prompt

    prompt_embeds = t5_encoder(prompt, max_length=256).to(pipe.device, pipe.dtype)
    
    # Use the provided negative prompt if any, otherwise use empty strings
    negative_prompt = [negative_prompt] * len(prompt) if negative_prompt else [""] * len(prompt)
    negative_prompt_embeds = t5_encoder(negative_prompt, max_length=256).to(
        pipe.device, pipe.dtype
    )

    return pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        **pipe_kwargs,
        output_type=output_type,
    ).images


### Load Models and set up the pipeline

In [6]:

# Set paths to the model and prompts file
save_folder = "output_images"
ella_path = "../ella/models/ella_path/ella-sd1.5-tsc-t5xl.safetensors"  # Update this path to the actual ELLA model path
prompt_file = "prompts.txt"  # Ensure this file exists and is correct

# Ensure save folder exists
save_folder = Path(save_folder)
save_folder.mkdir(exist_ok=True)

# Check if ELLA model path is correct
if not Path(ella_path).exists():
    raise FileNotFoundError(f"ELLA model file not found at {ella_path}")

# Check if prompt file exists
if not Path(prompt_file).exists():
    raise FileNotFoundError(f"Prompt file not found at {prompt_file}")

# Load pipeline and models
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    safety_checker=None,
    feature_extractor=None,
    requires_safety_checker=False,
)
pipe = pipe.to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

ella = load_ella(ella_path, pipe.device, pipe.dtype)
t5_encoder = T5TextEmbedder().to(pipe.device, dtype=torch.float16)


Loading pipeline components...: 100%|██████████| 5/5 [00:01<00:00,  4.52it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.00it/s]
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### Define functions to generate and save images

In [7]:

def generate_and_save_images(prompts, negative_prompt=None):
    for i, prompt in enumerate(prompts):
        print(f'Generating images for prompt: {prompt}')
        _batch_size = 1
        size = 512
        seed = 1001
        prompt_list = [prompt] * _batch_size

        load_ella_for_pipe(pipe, ella)
        image_flexible = generate_image_with_flexible_max_length(
            pipe,
            t5_encoder,
            prompt_list,
            negative_prompt=negative_prompt,
            guidance_scale=11,
            num_inference_steps=70,
            height=size,
            width=size,
            generator=[
                torch.Generator(device="cuda").manual_seed(seed + j)
                for j in range(_batch_size)
            ],
        )
        image_fixed = generate_image_with_fixed_max_length(
            pipe,
            t5_encoder,
            prompt_list,
            negative_prompt=negative_prompt,
            guidance_scale=11,
            num_inference_steps=70,
            height=size,
            width=size,
            generator=[
                torch.Generator(device="cuda").manual_seed(seed + j)
                for j in range(_batch_size)
            ],
        )
        offload_ella_for_pipe(pipe)

        image_ori = pipe(
            prompt_list,
            negative_prompt=[negative_prompt] * _batch_size,
            output_type="pt",
            guidance_scale=11,
            num_inference_steps=70,
            height=size,
            width=size,
            generator=[
                torch.Generator(device="cuda").manual_seed(seed + j)
                for j in range(_batch_size)
            ],
        ).images

        local_save_path = save_folder / f"{i:03d}.png"
        save_image(
            torch.cat([image_ori, image_fixed, image_flexible], dim=0),
            local_save_path,
            nrow=3,
        )


### Load prompts and generate images 

In [8]:

# Read prompts from file
with open(prompt_file, 'r') as f:
    prompts = [line.strip() for line in f.readlines()]

# Example of a negative prompt
negative_prompt = """(((deformed))), blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), 
    fused fingers, messy drawing, broken legs censor, censored, censor_bar, multiple breasts, (mutated hands and fingers:1.5), (long body :1.3), 
    (mutation, poorly drawn :1.2), black-white, bad anatomy, liquid body, liquid tongue, disfigured, malformed, mutated, anatomical nonsense, 
    text font ui, error, malformed hands, long neck, blurred, lowers, low res, bad anatomy, bad proportions, bad shadow, uncoordinated body, 
    unnatural body, fused breasts, bad breasts, huge breasts, poorly drawn breasts, extra breasts, liquid breasts, heavy breasts, missing breasts, 
    huge haunch, huge thighs, huge calf, bad hands, fused hand, missing hand, disappearing arms, disappearing thigh, disappearing calf, disappearing legs, 
    fused ears, bad ears, poorly drawn ears, extra ears, liquid ears, heavy ears, missing ears, old photo, low res, black and white, black and white filter, 
    colorless, deformed teeth, bad teeth, poorly drawn teeth, extra teeth, liquid teeth, malformed teeth, missing teeth,"""

# Generate and save images
generate_and_save_images(prompts, negative_prompt=negative_prompt)


Generating images for prompt: Crocodile in a sweater


100%|██████████| 70/70 [00:04<00:00, 14.78it/s]
100%|██████████| 70/70 [00:04<00:00, 15.94it/s]
100%|██████████| 70/70 [00:04<00:00, 16.25it/s]


Generating images for prompt: a large, textured green crocodile lying comfortably on a patch of grass with a cute, knitted orange sweater enveloping its scaly body. Around its neck, the sweater features a whimsical pattern of blue and yellow stripes. In the background, a smooth, grey rock partially obscures the view of a small pond with lily pads floating on the surface.


100%|██████████| 70/70 [00:04<00:00, 15.59it/s]
100%|██████████| 70/70 [00:04<00:00, 15.85it/s]
100%|██████████| 70/70 [00:04<00:00, 16.12it/s]


Generating images for prompt: A red book and a yellow vase.


100%|██████████| 70/70 [00:04<00:00, 15.59it/s]
100%|██████████| 70/70 [00:04<00:00, 15.71it/s]
100%|██████████| 70/70 [00:04<00:00, 15.95it/s]


Generating images for prompt: A vivid red book with a smooth, matte cover lies next to a glossy yellow vase. The vase, with a slightly curved silhouette, stands on a dark wood table with a noticeable grain pattern. The book appears slightly worn at the edges, suggesting frequent use, while the vase holds a fresh array of multicolored wildflowers.


100%|██████████| 70/70 [00:04<00:00, 15.39it/s]
100%|██████████| 70/70 [00:04<00:00, 15.54it/s]
100%|██████████| 70/70 [00:04<00:00, 15.79it/s]


Generating images for prompt: a racoon holding a shiny red apple over its head


100%|██████████| 70/70 [00:04<00:00, 15.29it/s]
100%|██████████| 70/70 [00:04<00:00, 15.47it/s]
100%|██████████| 70/70 [00:04<00:00, 15.69it/s]


Generating images for prompt: a mischievous raccoon standing on its hind legs, holding a bright red apple aloft in its furry paws. the apple shines brightly against the backdrop of a dense forest, with leaves rustling in the gentle breeze. a few scattered rocks can be seen on the ground beneath the raccoon's feet, while a gnarled tree trunk stands nearby.


100%|██████████| 70/70 [00:04<00:00, 15.15it/s]
100%|██████████| 70/70 [00:04<00:00, 15.30it/s]
100%|██████████| 70/70 [00:04<00:00, 15.62it/s]


Generating images for prompt: a chinese man wearing a white shirt and a checkered headscarf, holds a large falcon near his shoulder. the falcon has dark feathers with a distinctive beak. the background consists of a clear sky and a fence, suggesting an outdoor setting, possibly a desert or arid region


100%|██████████| 70/70 [00:04<00:00, 14.84it/s]
100%|██████████| 70/70 [00:04<00:00, 15.40it/s]
100%|██████████| 70/70 [00:04<00:00, 15.70it/s]


Generating images for prompt: A close-up photo of a wombat wearing a red backpack and raising both arms in the air. Mount Rushmore is in the background


100%|██████████| 70/70 [00:04<00:00, 15.24it/s]
100%|██████████| 70/70 [00:04<00:00, 15.49it/s]
100%|██████████| 70/70 [00:04<00:00, 15.77it/s]


Generating images for prompt: An oil painting of a man in a factory looking at a cat wearing a top hat


100%|██████████| 70/70 [00:04<00:00, 15.29it/s]
100%|██████████| 70/70 [00:04<00:00, 15.50it/s]
100%|██████████| 70/70 [00:04<00:00, 15.81it/s]


Generating images for prompt: A young woman with long brown hair, wearing a white t-shirt, holding multiple colorful shopping bags in her hands. She has a big smile on her face and her mouth is open wide, showing her teeth. Her right hand is raised in a fist, as if she is celebrating or cheering. The background is black, making the woman and her bags the focal point of the image. . A young woman with long brown hair, wearing a white t-shirt, holds multiple colorful shopping bags in her hands. She flashes a big smile on her face, her mouth open wide, showing her teeth. With her right hand raised in a fist, she celebrates or cheers. The black background makes the woman and her bags the focal point of the image. high resolution, photorealistic, golden hours


100%|██████████| 70/70 [00:04<00:00, 15.31it/s]
100%|██████████| 70/70 [00:04<00:00, 15.53it/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (161 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['the image.. a young woman with long brown hair, wearing a white t - shirt, holds multiple colorful shopping bags in her hands. she flashes a big smile on her face, her mouth open wide, showing her teeth. with her right hand raised in a fist, she celebrates or cheers. the black background makes the woman and her bags the focal point of the image. high resolution, photorealistic, golden hours']
100%|██████████| 70/70 [00:04<00:00, 15.85it/s]
