# T5 tokenization + T5 embedding

In [None]:
import numpy as np
import torch as th
import torch.nn as nn
import torch
from os.path import join
from tqdm.notebook import tqdm, trange
from transformers import T5Tokenizer, T5EncoderModel

In [2]:
dataset_root = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/objectRelSingle_pilot1"
caption_dir = join(dataset_root, "captions")
image_dir = join(dataset_root, "images")
img_feat_dir = join(dataset_root, "img_vae_features_128resolution")
text_feat_dir = join(dataset_root, "caption_feature_wmask")

In [3]:
T5_path = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/pretrained_models/t5_ckpts/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(T5_path, )#subfolder="tokenizer")
encoder = T5EncoderModel.from_pretrained(T5_path)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
model_max_length = 20
# use T5 tokenizer 
corpus = []
input_ids_col = []
attention_mask_col = []
for i in trange(10000):
    text = open(join(caption_dir, f"{i}.txt")).read()
    text_tokens_and_mask = tokenizer(
        text,
        max_length=model_max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors='pt'
    )
    input_ids_col.append(text_tokens_and_mask['input_ids'])
    attention_mask_col.append(text_tokens_and_mask['attention_mask'])
    # break
    # corpus.append(text)

input_ids_tsr = th.cat(input_ids_col, dim=0)

  0%|          | 0/10000 [00:00<?, ?it/s]

In [8]:
# import os
# @torch.no_grad()
# def save_prompt_embeddings(tokenizer, text_encoder, validation_prompts, prompt_cache_dir="output/tmp/prompt_cache", 
#                            device="cuda", max_length=20, t5_path=None, recompute=False):
#     """Save T5 text embeddings for a list of prompts to cache directory.
    
#     Args:
#         validation_prompts (list): List of text prompts to encode
#         prompt_cache_dir (str): Directory to save embeddings
#         device (str): Device to run encoding on
#         max_length (int): Max sequence length for tokenization
#         t5_path (str): Path to T5 model. If None, uses default path
#     """
#     if t5_path is None:
#         t5_path = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/pretrained_models/t5_ckpts/t5-v1_1-xxl"
    
#     result_col = []
#     os.makedirs(prompt_cache_dir, exist_ok=True)

#     # Load models
#     print(f"Loading text encoder and tokenizer from {t5_path} ...")
#     # tokenizer = T5Tokenizer.from_pretrained(t5_path)
#     # text_encoder = T5EncoderModel.from_pretrained(t5_path).to(device)
#     # text_encoder = text_encoder.to(device)

#     # Save unconditioned embedding
#     uncond = tokenizer("", max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
#     uncond_prompt_embeds = text_encoder(uncond.input_ids, attention_mask=uncond.attention_mask)[0]
#     torch.save({'caption_embeds': uncond_prompt_embeds, 'emb_mask': uncond.attention_mask, 'prompt': ''}, 
#                join(prompt_cache_dir,f'uncond_{max_length}token.pth'))
#     result_col.append({'prompt': '', 'caption_embeds': uncond_prompt_embeds, 'emb_mask': uncond.attention_mask})

#     print("Preparing Visualization prompt embeddings...")
#     print(f"Saving visualizate prompt text embedding at {prompt_cache_dir}")
    
#     for prompt in validation_prompts:
#         if os.path.exists(join(prompt_cache_dir,f'{prompt}_{max_length}token.pth')) and not recompute:
#             result_col.append(torch.load(join(prompt_cache_dir,f'{prompt}_{max_length}token.pth')))
#             continue
#         print(f"Mapping {prompt}...")
#         caption_token = tokenizer(prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
#         caption_emb = text_encoder(caption_token.input_ids, attention_mask=caption_token.attention_mask)[0]
#         torch.save({'caption_embeds': caption_emb, 'emb_mask': caption_token.attention_mask, 'prompt': prompt}, 
#                     join(prompt_cache_dir,f'{prompt}_{max_length}token.pth'))
#         result_col.append({'prompt': prompt, 'caption_embeds': caption_emb, 'emb_mask': caption_token.attention_mask})
#     print("Done!")
#     # garbage collection
#     del tokenizer, text_encoder
#     torch.cuda.empty_cache()
#     return result_col


In [5]:
class T5EmbeddingEncoder(nn.Module):
    def __init__(self, model_name="t5-base", device="cuda", ):
        super().__init__()
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.encoder = T5EncoderModel.from_pretrained(model_name, ).to(torch.bfloat16)
        self.encoder.eval()
        self.device = device
        self.encoder.to(device)

    def __call__(self, input_ids, attention_mask=None):
        return self.encode(input_ids, attention_mask)

    def encode(self, input_ids, attention_mask=None):
        if isinstance(input_ids, list) and isinstance(input_ids[0], str):
            # assume input_ids is raw text prompts
            tokens = self.tokenizer(input_ids, return_tensors="pt", padding=True, truncation=True)
            input_ids = tokens["input_ids"].to(self.device)
            attention_mask = tokens["attention_mask"].to(self.device)
        else:
            input_ids = input_ids.to(self.device)
            if attention_mask is None:
                attention_mask = (input_ids != self.tokenizer.pad_token_id).long().to(self.device)
            else:
                attention_mask = attention_mask.to(self.device)

        with torch.no_grad():
            outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state  # [batch_size, seq_len, hidden_dim]
        return embeddings, attention_mask

    def to(self, device):
        self.device = device
        self.encoder.to(device)
        return self

text_encoder = T5EmbeddingEncoder().to("cuda")
text_emb =  text_encoder(input_ids_tsr[0:1])[0]

In [6]:
text_emb

tensor([[[-0.0060,  0.0064,  0.0045,  ...,  0.0049, -0.0029,  0.0041],
         [-0.5352, -0.0132,  0.0356,  ...,  0.1289,  0.3086, -0.5703],
         [-0.0815, -0.0069, -0.1289,  ..., -0.1123,  0.2930, -0.2500],
         ...,
         [-0.2441, -0.2676,  0.0574,  ...,  0.0977, -0.0337, -0.4316],
         [-0.2559, -0.2969,  0.0396,  ...,  0.0957, -0.0234, -0.4414],
         [-0.2754, -0.3301,  0.0544,  ...,  0.0811, -0.0155, -0.4355]]],
       device='cuda:0', dtype=torch.bfloat16)

In [10]:
import sys
sys.path.append("/n/home12/hjkim/Github/DiffusionObjectRelation/PixArt-alpha")
from diffusion.model.t5 import T5Embedder


pretrain_path = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/pretrained_models/"
t5 = T5Embedder(device="cuda", local_cache=True, cache_dir=f'{pretrain_path}/t5_ckpts', model_max_length=model_max_length)



/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/pretrained_models//t5_ckpts/t5-v1_1-xxl




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [61]:
caption = ["triangle is above red circle"]
caption_emb, emb_mask = t5.get_text_embeddings(caption)

In [13]:
validation_prompts = [
    "triangle",
    "blue triangle",
    "red square",
    "square",
    "circle",
    "blue circle",
    "triangle is to the upper left of square", 
    "triangle is to the left of square", 
    "triangle is to the left of triangle", 
    "circle is below red square",
    "red circle is to the left of blue square",
    "blue square is to the right of red circle",
    "red circle is above square",
    "triangle is above red circle",
    ]
max_length = 20
result_col = []
recompute = True 
device = "cuda"
prompt_cache_dir = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/prompt_cache_t5emb"
os.makedirs(prompt_cache_dir, exist_ok=True)

pretrain_path = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/pretrained_models/"
t5 = T5Embedder(device="cuda", local_cache=True, cache_dir=f'{pretrain_path}/t5_ckpts', model_max_length=max_length)

# Save unconditioned embedding
uncond_prompt_embeds, uncond_attention_mask = t5.get_text_embeddings([""], )
torch.save({'caption_embeds': uncond_prompt_embeds, 'emb_mask': uncond_attention_mask, 'prompt': ''}, 
            join(prompt_cache_dir,f'uncond_{max_length}token.pth'))
result_col.append({'prompt': '', 'caption_embeds': uncond_prompt_embeds, 'emb_mask': uncond_attention_mask})

print("Preparing Visualization prompt embeddings...")
print(f"Saving visualizate prompt text embedding at {prompt_cache_dir}")

for prompt in validation_prompts:
    if os.path.exists(join(prompt_cache_dir,f'{prompt}_{max_length}token.pth')) and not recompute:
        result_col.append(torch.load(join(prompt_cache_dir,f'{prompt}_{max_length}token.pth')))
        continue
    print(f"Mapping {prompt}...")
    caption_emb, caption_token_attention_mask =  t5.get_text_embeddings([prompt], )
    torch.save({'caption_embeds': caption_emb, 'emb_mask': caption_token_attention_mask, 'prompt': prompt}, 
                join(prompt_cache_dir,f'{prompt}_{max_length}token.pth'))
    result_col.append({'prompt': prompt, 'caption_embeds': caption_emb, 'emb_mask': caption_token_attention_mask})
print("Done!")

/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/pretrained_models//t5_ckpts/t5-v1_1-xxl


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Preparing Visualization prompt embeddings...
Saving visualizate prompt text embedding at /n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/prompt_cache_t5emb
Mapping triangle...
Mapping blue triangle...
Mapping red square...
Mapping square...
Mapping circle...
Mapping blue cirle...
Mapping triangle is to the upper left of square...
Mapping triangle is to the left of square...
Mapping triangle is to the left of triangle...
Mapping circle is below red square...
Mapping red circle is to the left of blue square...
Mapping blue square is to the right of red circle...
Mapping red circle is above square...
Mapping triangle is above red circle...
Done!


In [62]:
emb_mask.shape

torch.Size([1, 20])

In [63]:
caption_emb.shape

torch.Size([1, 20, 4096])

In [64]:
embedding["caption_embeds"].shape

torch.Size([1, 20, 4096])

In [65]:
caption_emb

tensor([[[ 0.0356, -0.0014,  0.0184,  ...,  0.2559, -0.0903,  0.0052],
         [ 0.0723,  0.0737,  0.0354,  ...,  0.0579, -0.1177, -0.0776],
         [ 0.0845, -0.0737, -0.1328,  ...,  0.2051, -0.2041,  0.1089],
         ...,
         [-0.1084,  0.2100, -0.0830,  ..., -0.0542,  0.0327,  0.0664],
         [-0.1348,  0.1826, -0.1064,  ..., -0.0049,  0.0618, -0.0110],
         [-0.0752,  0.2207, -0.1543,  ..., -0.0413,  0.0056, -0.0293]]],
       device='cuda:0', dtype=torch.bfloat16)

In [74]:
th.allclose(caption_emb, embedding[ 'caption_embeds'], atol=1E-1, rtol=1E-1)

True

In [68]:
emb_mask

tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
       device='cuda:0')

In [67]:
embedding

{'prompt': 'triangle is above red circle',
 'caption_embeds': tensor([[[ 0.0386,  0.0007,  0.0167,  ...,  0.2578, -0.0903,  0.0042],
          [ 0.0723,  0.0742,  0.0349,  ...,  0.0586, -0.1196, -0.0781],
          [ 0.0859, -0.0747, -0.1338,  ...,  0.2041, -0.2061,  0.1104],
          ...,
          [-0.1069,  0.2070, -0.0820,  ..., -0.0559,  0.0347,  0.0674],
          [-0.1299,  0.1836, -0.1099,  ..., -0.0120,  0.0574, -0.0200],
          [-0.0781,  0.2246, -0.1523,  ..., -0.0378,  0.0069, -0.0378]]],
        device='cuda:0', dtype=torch.bfloat16),
 'emb_mask': tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        device='cuda:0')}

In [36]:
embedding["emb_mask"].shape

torch.Size([1, 20])

In [32]:
prompt_cache_dir1 = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/PixArt/output/prompt_cache_t5emb"
caption_embeddings = torch.load(join(prompt_cache_dir1, "caption_embeddings_list.pth"))
# for i, embedding in enumerate(caption_embeddings):
#     print(f"{i}: {embedding['prompt']} | token num:{embedding['emb_mask'].sum()}")

  caption_embeddings = torch.load(join(prompt_cache_dir1, "caption_embeddings_list.pth"))
