## Generate embeddings
red __ blue (__ is one of 8 relation)\
blue __ red (__ is one of 8 relation)\
need a language env to run this.

In [1]:
import os
import torch
import numpy as np
from os.path import join
import torch.nn as nn
import torch as th
from tqdm.notebook import tqdm, trange
from transformers import T5Tokenizer, T5EncoderModel

dataset_root = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/objectRel_pilot_rndembposemb"
caption_dir = join(dataset_root, "captions")
image_dir = join(dataset_root, "images")
img_feat_dir = join(dataset_root, "img_vae_features_128resolution")
text_feat_dir = join(dataset_root, "caption_feature_wmask")

@torch.no_grad()
def save_prompt_embeddings_randemb(tokenizer, text_encoder, validation_prompts, prompt_cache_dir="output/tmp/prompt_cache", 
                           device="cuda", max_length=20, t5_path=None, recompute=False):
    """Save T5 text embeddings for a list of prompts to cache directory.
    
    Args:
        validation_prompts (list): List of text prompts to encode
        prompt_cache_dir (str): Directory to save embeddings
        device (str): Device to run encoding on
        max_length (int): Max sequence length for tokenization
        t5_path (str): Path to T5 model. If None, uses default path
    """
    if t5_path is None:
        t5_path = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/pretrained_models/t5_ckpts/t5-v1_1-xxl"
    
    result_col = []
    os.makedirs(prompt_cache_dir, exist_ok=True)
    # Load models
    print(f"Loading text encoder and tokenizer from {t5_path} ...")
    # tokenizer = T5Tokenizer.from_pretrained(t5_path)
    # text_encoder = T5EncoderModel.from_pretrained(t5_path).to(device)
    text_encoder = text_encoder.to(device)
    # Save unconditioned embedding
    uncond = tokenizer("", max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
    uncond_prompt_embeds = text_encoder(uncond.input_ids, attention_mask=uncond.attention_mask)[0]
    torch.save({'caption_embeds': uncond_prompt_embeds, 'emb_mask': uncond.attention_mask, 'prompt': ''}, 
               join(prompt_cache_dir,f'uncond_{max_length}token.pth'))
    result_col.append({'prompt': '', 'caption_embeds': uncond_prompt_embeds, 'emb_mask': uncond.attention_mask})
    print("Preparing Visualization prompt embeddings...")
    print(f"Saving visualizate prompt text embedding at {prompt_cache_dir}")
    for prompt in validation_prompts:
        if os.path.exists(join(prompt_cache_dir,f'{prompt}_{max_length}token.pth')) and not recompute:
            result_col.append(torch.load(join(prompt_cache_dir,f'{prompt}_{max_length}token.pth')))
            continue
        print(f"Mapping {prompt}...")
        caption_token = tokenizer(prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
        caption_emb = text_encoder(caption_token.input_ids, attention_mask=caption_token.attention_mask)[0]
        torch.save({'caption_embeds': caption_emb, 'emb_mask': caption_token.attention_mask, 'prompt': prompt}, 
                    join(prompt_cache_dir,f'{prompt}_{max_length}token.pth'))
        result_col.append({'prompt': prompt, 'caption_embeds': caption_emb, 'emb_mask': caption_token.attention_mask})
    print("Done!")
    # garbage collection
    del tokenizer, text_encoder
    torch.cuda.empty_cache()
    return result_col


def get_positional_encodings(seq_len, d_model, device='cpu'):
    """
    Generate positional encodings for a sequence.

    Args:
        seq_len (int): Length of the sequence.
        d_model (int): Dimension of the model (embedding size).
        device (str): Device to place the tensor on ('cpu' or 'cuda').

    Returns:
        torch.Tensor: Positional encodings of shape (seq_len, d_model).
    """
    position = th.arange(seq_len, dtype=th.float, device=device).unsqueeze(1)
    div_term = th.exp(th.arange(0, d_model, 2, dtype=th.float, device=device) *
                         -(th.log(th.tensor(10000.0)) / d_model))
    wpe = th.zeros(seq_len, d_model, device=device)
    wpe[:, 0::2] = th.sin(position * div_term)
    wpe[:, 1::2] = th.cos(position * div_term)
    return wpe

# Create text encoder class
class RandomEmbeddingEncoder_wPosEmb(nn.Module):
    def __init__(self, embedding_dict=None, input_ids2dict_ids=None, dict_ids2input_ids=None, max_seq_len=20, embed_dim=4096, wpe_scale=1):
        super().__init__()
        if embedding_dict is None:
            self.embedding_dict = th.load(join(text_feat_dir, "word_embedding_dict.pt"))["embedding_dict"]
            self.input_ids2dict_ids = th.load(join(text_feat_dir, "word_embedding_dict.pt"))["input_ids2dict_ids"]
            self.dict_ids2input_ids = th.load(join(text_feat_dir, "word_embedding_dict.pt"))["dict_ids2input_ids"]
        else:
            self.embedding_dict = embedding_dict
            self.input_ids2dict_ids = input_ids2dict_ids
            self.dict_ids2input_ids = dict_ids2input_ids
        self.max_seq_len = max_seq_len
        self.embed_dim = embed_dim
        self.wpe = get_positional_encodings(self.max_seq_len, self.embed_dim, device="cuda") * wpe_scale
        assert self.wpe.shape == (self.max_seq_len, self.embed_dim)
        assert self.embed_dim == self.embedding_dict.shape[1]
        
    def __call__(self, input_ids, attention_mask=None):
        return self.encode(input_ids, attention_mask)
    
    def encode(self, input_ids, attention_mask=None):
        """Convert input ids to embeddings"""
        if isinstance(input_ids, list):
            input_ids = th.tensor(input_ids)
        # map the input_ids to dict ids 
        indices = th.tensor([self.input_ids2dict_ids[id.item()] for id in input_ids.reshape(-1)]).reshape(input_ids.shape)
        # indices = th.tensor([self.input_ids2dict_ids[id.item()] for id in input_ids])
        embeddings = self.embedding_dict[indices]
        # add positional encoding 
        embeddings = embeddings + self.wpe[:embeddings.shape[1], :]
        return embeddings, attention_mask
    
    def to(self, device):
        self.embedding_dict = self.embedding_dict.to(device)
        self.wpe = self.wpe.to(device)
        # self.input_ids2dict_ids = self.input_ids2dict_ids.to(device)
        # self.dict_ids2input_ids = self.dict_ids2input_ids.to(device)
        return self

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
T5_path = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/pretrained_models/t5_ckpts/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(T5_path, )#subfolder="tokenizer")
encoder = T5EncoderModel.from_pretrained(T5_path)

relations = [
    "above",
    "below",
    "to the left of",
    "to the right of",
    "to the upper left of",
    "to the upper right of",
    "to the lower left of",
    "to the lower right of",
]
# visu
# alize prompts 
visualize_prompts = [f"red is {relation} blue" for relation in relations] + [f"blue is {relation} red" for relation in relations]

print(visualize_prompts)

prompt_cache_dir = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/red_blue_8_position_rndembposemb"
tokenizer = T5Tokenizer.from_pretrained(T5_path)
rnd_encoding = th.load(join(text_feat_dir, "word_embedding_dict.pt"))
rndpos_encoder = RandomEmbeddingEncoder_wPosEmb(rnd_encoding["embedding_dict"], 
                                              rnd_encoding["input_ids2dict_ids"], 
                                              rnd_encoding["dict_ids2input_ids"], 
                                              max_seq_len=20, embed_dim=4096,
                                              wpe_scale=1/6).to("cuda")
caption_embeddings = save_prompt_embeddings_randemb(tokenizer, rndpos_encoder, 
    visualize_prompts, prompt_cache_dir, device="cuda", max_length=20, t5_path=T5_path, recompute=True)
for i, embedding in enumerate(caption_embeddings):
    print(f"{i}: {embedding['prompt']} | token num:{embedding['emb_mask'].sum()}")
torch.save(caption_embeddings, join(prompt_cache_dir, "caption_embeddings_list.pth"))

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.23it/s]


['red is above blue', 'red is below blue', 'red is to the left of blue', 'red is to the right of blue', 'red is to the upper left of blue', 'red is to the upper right of blue', 'red is to the lower left of blue', 'red is to the lower right of blue', 'blue is above red', 'blue is below red', 'blue is to the left of red', 'blue is to the right of red', 'blue is to the upper left of red', 'blue is to the upper right of red', 'blue is to the lower left of red', 'blue is to the lower right of red']
Loading text encoder and tokenizer from /n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/pretrained_models/t5_ckpts/t5-v1_1-xxl ...
Preparing Visualization prompt embeddings...
Saving visualizate prompt text embedding at /n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/red_blue_8_position_rndembposemb
Mapping red is above blue...
Mapping red is below blue...
Mapping red is to the left of blue...
Mapping red is to the right of

  rnd_encoding = th.load(join(text_feat_dir, "word_embedding_dict.pt"))


In [8]:
# # check if the new embedding is the same as the old one. Yes.
# new_embedding = th.load(join("/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/red_blue_8_position_rndembposemb", "caption_embeddings_list.pth"))
# old_embedding = th.load(join("/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/objectRel_pilot_rndembposemb", "caption_embeddings_list.pth"))
# for i in range(len(new_embedding)):
#     print(f"{i}: {new_embedding[i]['prompt']} | token num:{new_embedding[i]['emb_mask'].sum()} | {new_embedding[i]['caption_embeds'][0, :10]}")
#     print(f"{i}: {old_embedding[i]['prompt']} | token num:{old_embedding[i]['emb_mask'].sum()} | {old_embedding[i]['caption_embeds'][0, :10]}")
#     print("-"*100)

0:  | token num:1 | tensor([[-0.0609,  0.3837,  0.2152,  ...,  0.2301,  0.0813,  0.3842],
        [ 0.3661,  0.2643,  0.2454,  ...,  0.1502,  0.1104,  0.1653],
        [ 0.3774,  0.1049,  0.2577,  ...,  0.1502,  0.1104,  0.1653],
        ...,
        [ 0.3353,  0.2999,  0.2111,  ...,  0.1502,  0.1105,  0.1653],
        [ 0.3907,  0.1500,  0.2712,  ...,  0.1502,  0.1105,  0.1653],
        [ 0.2945,  0.0224,  0.1803,  ...,  0.1502,  0.1105,  0.1653]],
       device='cuda:0')
0:  | token num:1 | tensor([[-0.0609,  0.3837,  0.2152,  ...,  0.2301,  0.0813,  0.3842],
        [ 0.3661,  0.2643,  0.2454,  ...,  0.1502,  0.1104,  0.1653],
        [ 0.3774,  0.1049,  0.2577,  ...,  0.1502,  0.1104,  0.1653],
        ...,
        [ 0.3353,  0.2999,  0.2111,  ...,  0.1502,  0.1105,  0.1653],
        [ 0.3907,  0.1500,  0.2712,  ...,  0.1502,  0.1105,  0.1653],
        [ 0.2945,  0.0224,  0.1803,  ...,  0.1502,  0.1105,  0.1653]],
       device='cuda:0')
--------------------------------------------

  new_embedding = th.load(join("/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/red_blue_8_position_rndembposemb", "caption_embeddings_list.pth"))
  old_embedding = th.load(join("/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/objectRel_pilot_rndembposemb", "caption_embeddings_list.pth"))


## get embedding using the prompt
Not used. Use Binxu's code instead objectRel_cv2_annot_embed.ipynb.

In [None]:
import os
from os.path import join 
import torch
import sys
sys.path.append("/n/home13/xupan/Projects/DiffusionObjectRelation/DiffusionObjectRelation/PixArt-alpha")
from diffusion import IDDPM
# from diffusion.data.builder import build_dataset, build_dataloader, set_data_root
from diffusion.model.builder import build_model
from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow
sys.path.append("/n/home13/xupan/Projects/DiffusionObjectRelation/DiffusionObjectRelation/utils")
from pixart_utils import state_dict_convert
from image_utils import pil_images_to_grid
from diffusers import AutoencoderKL, Transformer2DModel, PixArtAlphaPipeline, DPMSolverMultistepScheduler

# Add a new hook to get the embedding based on Binxu's code
# subclass a new pipeline from PixArtAlphaPipeline
from typing import Callable, List, Optional, Tuple, Union
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import retrieve_timesteps
from collections import defaultdict
# from diffusers.pipelines.pixart_alpha import EXAMPLE_DOC_STRING, ImagePipelineOutput
class PixArtAlphaPipeline_hookembedding(PixArtAlphaPipeline):
    # def __init__(self, *args, **kwargs):
    #     super().__init__(*args, **kwargs)
    #     self.hook_handles = []
    #     self.embedding = defaultdict(list)
    @classmethod
    def from_pretrained(self, *args, **kwargs):
        pipeline = super().from_pretrained(*args, **kwargs)
        pipeline.hook_handles = []
        pipeline.embedding = defaultdict(list)
        return pipeline

    def clear_embedding(self):
        self.embedding = defaultdict(list)
    
    def hook_forger(self, key: str):
        """Create a hook to capture attention patterns"""
        def hook(module, input, output):
            self.embedding[key].append(input[0].chunk(2)[0].detach().cpu().numpy()) # only use the first half of the embedding, the second half is the negative embedding
        return hook
    
    def setup_embedding_hooks(self, embedding_layer: int = None):
        """Set up hooks for all transformer blocks"""
        # print("Setting up hooks for PixArt attention modules:")
        if embedding_layer is None:
            for block_idx, block in enumerate(self.transformer.transformer_blocks):
                self.hook_handles.append(block.register_forward_hook(self.hook_forger(f"block{block_idx:02d}")))
        else:
            for block_idx, block in enumerate(self.transformer.transformer_blocks):
                if block_idx == embedding_layer:
                    self.hook_handles.append(block.register_forward_hook(self.hook_forger(f"block{block_idx:02d}")))
                    break

    def cleanup_embedding_hooks(self):
        """Remove all hooks"""
        for handle in self.hook_handles:
            handle.remove()
        self.hook_handles = []
    
    # @replace_example_docstring(EXAMPLE_DOC_STRING)
    @torch.no_grad()
    def __call__(
        self,
        embedding_when: float = 0.6, # relative time step to hook the embedding. 0.6 means 60% of the way through the diffusion process.
        embedding_layer: int = None, # if None, hook all layers
        prompt: Union[str, List[str]] = None,
        negative_prompt: str = "",
        num_inference_steps: int = 20,
        timesteps: List[int] = None,
        sigmas: List[float] = None,
        guidance_scale: float = 4.5,
        num_images_per_prompt: Optional[int] = 1,
        height: Optional[int] = None,
        width: Optional[int] = None,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.Tensor] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        prompt_attention_mask: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
        callback_steps: int = 1,
        clean_caption: bool = True,
        use_resolution_binning: bool = True,
        max_sequence_length: int = 120,
        return_sample_pred_traj: bool = False,
        **kwargs,
    ) -> Union[ImagePipelineOutput, Tuple]:
        """
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            num_inference_steps (`int`, *optional*, defaults to 100):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            timesteps (`List[int]`, *optional*):
                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
                passed will be used. Must be in descending order.
            sigmas (`List[float]`, *optional*):
                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
                will be used.
            guidance_scale (`float`, *optional*, defaults to 4.5):
                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
                `guidance_scale` is defined as `w` of equation 2. of [Imagen
                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
                usually at the expense of lower image quality.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            height (`int`, *optional*, defaults to self.unet.config.sample_size):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to self.unet.config.sample_size):
                The width in pixels of the generated image.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
                [`schedulers.DDIMScheduler`], will be ignored for others.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will ge generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
                provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
            negative_prompt_attention_mask (`torch.Tensor`, *optional*):
                Pre-generated attention mask for negative text embeddings.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
            callback (`Callable`, *optional*):
                A function that will be called every `callback_steps` steps during inference. The function will be
                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function will be called. If not specified, the callback will be
                called at every step.
            clean_caption (`bool`, *optional*, defaults to `True`):
                Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
                be installed. If the dependencies are not installed, the embeddings will be created from the raw
                prompt.
            use_resolution_binning (`bool` defaults to `True`):
                If set to `True`, the requested height and width are first mapped to the closest resolutions using
                `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
                the requested resolution. Useful for generating non-square images.
            max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.

        Examples:

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images
        """
        #########################################
        # initialize the embedding hook
        self.clear_embedding()
        self.cleanup_embedding_hooks()
        #########################################
        if "mask_feature" in kwargs:
            deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
            # deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
        # 1. Check inputs. Raise error if not correct
        height = height or self.transformer.config.sample_size * self.vae_scale_factor
        width = width or self.transformer.config.sample_size * self.vae_scale_factor
        # if use_resolution_binning:
        #     if self.transformer.config.sample_size == 128:
        #         aspect_ratio_bin = ASPECT_RATIO_1024_BIN
        #     elif self.transformer.config.sample_size == 64:
        #         aspect_ratio_bin = ASPECT_RATIO_512_BIN
        #     elif self.transformer.config.sample_size == 32:
        #         aspect_ratio_bin = ASPECT_RATIO_256_BIN
        #     else:
        #         raise ValueError("Invalid sample size")
        #     orig_height, orig_width = height, width
        #     height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)

        self.check_inputs(
            prompt,
            height,
            width,
            negative_prompt,
            callback_steps,
            prompt_embeds,
            negative_prompt_embeds,
            prompt_attention_mask,
            negative_prompt_attention_mask,
        )

        # 2. Default height and width to transformer
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device

        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        (
            prompt_embeds,
            prompt_attention_mask,
            negative_prompt_embeds,
            negative_prompt_attention_mask,
        ) = self.encode_prompt(
            prompt,
            do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            num_images_per_prompt=num_images_per_prompt,
            device=device,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            prompt_attention_mask=prompt_attention_mask,
            negative_prompt_attention_mask=negative_prompt_attention_mask,
            clean_caption=clean_caption,
            max_sequence_length=max_sequence_length,
        )
        if do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
            prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
        # print(prompt_embeds.shape)
        # print(prompt_attention_mask.shape)
        # 4. Prepare timesteps
        timesteps, num_inference_steps = retrieve_timesteps(
            self.scheduler, num_inference_steps, device, timesteps, sigmas
        )
        ################################################
        # which timestep to hook the embedding
        hook_timestep = timesteps[int(embedding_when * num_inference_steps)]
        ################################################

        # 5. Prepare latents.
        latent_channels = self.transformer.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            latent_channels,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 6.1 Prepare micro-conditions.
        added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
        if self.transformer.config.sample_size == 128:
            resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
            aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
            resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
            aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)

            if do_classifier_free_guidance:
                resolution = torch.cat([resolution, resolution], dim=0)
                aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)

            added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}

        # 7. Denoising loop
        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

        pred_traj = []
        latents_traj = []
        t_traj = []
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                current_timestep = t
                if not torch.is_tensor(current_timestep):
                    # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
                    # This would be a good case for the `match` statement (Python 3.10+)
                    is_mps = latent_model_input.device.type == "mps"
                    if isinstance(current_timestep, float):
                        dtype = torch.float32 if is_mps else torch.float64
                    else:
                        dtype = torch.int32 if is_mps else torch.int64
                    current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
                elif len(current_timestep.shape) == 0:
                    current_timestep = current_timestep[None].to(latent_model_input.device)
                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
                current_timestep = current_timestep.expand(latent_model_input.shape[0])

                if t == hook_timestep:
                    self.setup_embedding_hooks(embedding_layer)

                # predict noise model_output
                noise_pred = self.transformer(
                    latent_model_input,
                    encoder_hidden_states=prompt_embeds,
                    encoder_attention_mask=prompt_attention_mask,
                    timestep=current_timestep,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # learned sigma
                if self.transformer.config.out_channels // 2 == latent_channels:
                    noise_pred = noise_pred.chunk(2, dim=1)[0]
                else:
                    noise_pred = noise_pred

                latents_traj.append(latents)
                pred_traj.append(noise_pred)
                # compute previous image: x_t -> x_t-1
                if num_inference_steps == 1:
                    # For DMD one step sampling: https://arxiv.org/abs/2311.18828
                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
                else:
                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
                
                # pred_traj.append(self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample)
                
                t_traj.append(t)
                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        step_idx = i // getattr(self.scheduler, "order", 1)
                        callback(step_idx, t, latents)

                # early stop if reached the hook timestep
                if t == hook_timestep:
                    self.cleanup_embedding_hooks()
                    # return self.embedding

        latents_traj.append(latents)
        if not output_type == "latent":
            image = pipeline.vae.decode(latents.to(weight_dtype) / pipeline.vae.config.scaling_factor, return_dict=False)[0]
            image = pipeline.image_processor.postprocess(image, output_type="pil")
        else:
            image = latents

        # if not output_type == "latent":
        #     image = self.image_processor.postprocess(image, output_type=output_type)

        # Offload all models
        # self.maybe_free_model_hooks()

        if not return_dict:
            return (self.embedding, image,)
        if return_sample_pred_traj:
            return ImagePipelineOutput(images=image), pred_traj, latents_traj, t_traj
        return ImagePipelineOutput(images=image)
    

@torch.inference_mode()
def get_embeddings(pipeline, validation_prompts, prompt_cache_dir, embedding_when=0.6, embedding_layer=None, max_length=120, weight_dtype=torch.float16,
                   num_inference_steps=14, guidance_scale=4.5, num_images_per_prompt=25, device="cuda"):
    # logger.info("Running validation... ")
    # device = accelerator.device
    # model = accelerator.unwrap_model(model)
    if validation_prompts is None:
        validation_prompts = [
            "triangle is to the upper left of square", 
            "blue triangle is to the upper left of red square", 
            "triangle is above and to the right of square", 
            "blue circle is above and to the right of blue square", 
            "triangle is to the left of square", 
            "triangle is to the left of triangle", 
            "circle is below red square",
            "red circle is to the left of blue square",
            "blue square is to the right of red circle",
            "red circle is above square",
            "triangle is above red circle",
            "red is above blue",
            "red is to the left of red",
            "blue triangle is above red triangle", 
            "blue circle is above blue square", 
        ]
    pipeline = pipeline.to(device)
    pipeline.set_progress_bar_config(disable=True)
    generator = torch.Generator(device=device).manual_seed(0)

    uncond_data = torch.load(f'{prompt_cache_dir}/uncond_{max_length}token.pth', map_location='cpu')
    uncond_prompt_embeds = uncond_data['caption_embeds'].to(device)
    uncond_prompt_attention_mask = uncond_data['emb_mask'].to(device)

    embeddings = []
    images = []

    for _, prompt in enumerate(validation_prompts):
        if not os.path.exists(f'{prompt_cache_dir}/{prompt}_{max_length}token.pth'):
            continue
        embed = torch.load(f'{prompt_cache_dir}/{prompt}_{max_length}token.pth', map_location='cpu')
        caption_embs, emb_masks = embed['caption_embeds'].to(device), embed['emb_mask'].to(device)
        output = pipeline(
            embedding_when=embedding_when,
            embedding_layer=embedding_layer,
            num_inference_steps=num_inference_steps,
            num_images_per_prompt=num_images_per_prompt,
            # generator=generator,
            guidance_scale=guidance_scale,
            prompt_embeds=caption_embs,
            prompt_attention_mask=emb_masks,
            negative_prompt=None,
            negative_prompt_embeds=uncond_prompt_embeds,
            negative_prompt_attention_mask=uncond_prompt_attention_mask,
            use_resolution_binning=False, # need this for smaller images like ours. 
            return_sample_pred_traj=False,
            return_dict=False,
            output_type="pil",
        )
    
        embeddings.append(output[0])
        images.append(output[1])

    return embeddings, images


In [None]:
embedding_dataset_dir = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/embedding_datasets_for_SAE"