In [1]:
from pdata import PersonalizedMMUDataset, PersonalizedT2IDataset, get_personalized_mmu_dataloader, get_personalized_t2i_dataloader
from lightning.pytorch.utilities import CombinedLoader

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from PIL import Image

from models import Showo, MAGVITv2, get_mask_chedule
from training.prompting_utils import UniversalPrompting, create_attention_mask_predict_next, create_attention_mask_for_mmu
from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter
from transformers import AutoTokenizer
from models.clip_encoder import CLIPVisionTower
from transformers import CLIPImageProcessor
from llava.llava import conversation as conversation_lib

conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"]

import os
from omegaconf import DictConfig, ListConfig, OmegaConf
config = OmegaConf.load('configs/showo_demo.yaml')
# device setup
device = torch.device("cuda:7")

  from .autonotebook import tqdm as notebook_tqdm


[2025-02-04 18:31:46,622] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/arc/miniconda3/envs/showo/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status




In [2]:
# show o tokenizer setup and adding special tokens to universal prompting
# llm model : 'microsoft/phi-1_5'
tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side ="left")
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
                                       special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
                                       ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)

# setting up the magvit-v2, for t2i
vq_model = MAGVITv2.from_pretrained(config.model.vq_model.vq_model_name).to(device)
# vq_model.requires_grad_(False)
# vq_model.eval()

# setting up vision tower: clip-vit only for mmu
# vision_tower_name =config.clip_path
# vision_tower = CLIPVisionTower(vision_tower_name).to(device)
# clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)

# setting up the showo model 
model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
# model.eval()

# setting up the parameters
temperature = 0.8  # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 1  # retain only the top_k most likely tokens, clamp others to have 0 probability
# LLAVA_SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. " \
#                 "The assistant gives helpful, detailed, and polite answers to the user's questions."
# LLAVA_SYSTEM_PROMPT_LEN = 28

Working with z of shape (1, 13, 16, 16) = 3328 dimensions.
Look-up free quantizer with codebook size: 8192


The config attributes {'mask_token_id': 58497} were passed to Showo, but are not expected and will be ignored. Please verify your config.json configuration file.
  if self.w_clip_vit:


attention implementation:  sdpa


In [3]:
# print(model.showo.get_input_embeddings())
model.showo.get_input_embeddings().num_embeddings
model.showo.get_input_embeddings().num_embeddings - len(tokenizer)
model.showo.get_input_embeddings().weight.data.shape
model.showo.lm_head.weight.shape
model.showo.lm_head.bias.shape

torch.Size([58498])

In [None]:
data_root = "/home/arc/full_mcdata"
concept = "dunpai"

In [5]:
nums_new_token_i = 16

#################################
new_tokens = [f"<{concept}>"] + [f"<token_{i}>" for i in range(nums_new_token_i)]
num_new_tokens = len(new_tokens)  # 17

# 已知的原始参数
# 文本 token 数量（ID 0-50304）
original_text_vocab_size = len(tokenizer)  
# Image token 数量（原 ID 50305-58497）
original_image_vocab_size = model.showo.get_input_embeddings().num_embeddings - len(tokenizer)

original_total_vocab = original_text_vocab_size + original_image_vocab_size  # 58498

# 新的参数
new_text_vocab_size = original_text_vocab_size + num_new_tokens  # 50305 + 17 = 50322
new_total_vocab = original_total_vocab + num_new_tokens          # 58498 + 17 = 58515

# ------------------------------
# Step 1: 修改 Tokenizer 的词汇表
# ------------------------------

# 添加新 token 到 50305-50321 的位置
num_new_tokens = tokenizer.add_tokens(new_tokens)
new_token_ids = tokenizer.convert_tokens_to_ids(new_tokens)
print("新 token ID:", new_token_ids)  # 应输出 50305-50321

# ------------------------------
# Step 2: 调整模型的权重
# ------------------------------
with torch.no_grad():
    # 获取嵌入层权重
    embeddings = model.showo.get_input_embeddings().weight.data
    
    # 扩展嵌入层（58498 -> 58515）
    model.showo.resize_token_embeddings(new_total_vocab)
    # new_embeddings = model.showo.get_input_embeddings().weight.data

    # 将原 Image Token 权重后移 17 位
    original_image_weights = embeddings[original_text_vocab_size:original_total_vocab].clone()
    model.showo.get_input_embeddings().weight.data[new_text_vocab_size:new_total_vocab] = original_image_weights
    
    # 初始化新 token 的权重（用原文本最后 17 个 token）
    new_text_weights = embeddings[original_text_vocab_size - num_new_tokens : original_text_vocab_size].clone()
    model.showo.get_input_embeddings().weight.data[original_text_vocab_size : new_text_vocab_size] = new_text_weights
    # print(model.showo.lm_head.weight.data.shape[1])
    # 处理 lm_head（假设与嵌入层共享权重）
    if model.showo.lm_head.weight.data.shape[0] == new_total_vocab:
        # 扩展 lm_head 权重
        lm_head = model.showo.lm_head
        new_lm_head = torch.nn.Linear(
            lm_head.in_features, 
            new_total_vocab, 
            bias=hasattr(lm_head, 'bias')
        )
        new_lm_head.weight.data = lm_head.weight.data.clone()
        new_lm_head.weight.data[new_text_vocab_size:new_total_vocab] = lm_head.weight.data[original_text_vocab_size:original_total_vocab]
        new_lm_head.weight.data[original_text_vocab_size:new_text_vocab_size] = lm_head.weight.data[original_text_vocab_size - num_new_tokens : original_text_vocab_size]
        if hasattr(lm_head, 'bias'):
            new_lm_head.bias.data = lm_head.bias.data.clone()
            new_lm_head.bias.data[new_text_vocab_size:new_total_vocab] = lm_head.bias.data[original_text_vocab_size:original_total_vocab]
            new_lm_head.bias.data[original_text_vocab_size:new_text_vocab_size] = lm_head.bias.data[original_text_vocab_size - num_new_tokens : original_text_vocab_size]
        
        model.showo.lm_head = new_lm_head
    else:
        raise ValueError("lm_head weights do not match the input embeddings!")

index_no_updates = torch.ones((new_total_vocab,), dtype=torch.bool)
index_no_updates[new_token_ids] = False
# ------------------------------
# 验证
# ------------------------------
# 检查新 token 的 ID
print("新增文本 token ID:", [tokenizer.convert_tokens_to_ids(t) for t in new_tokens])  # 应输出 50305-50321

# 检查一个原 Image Token 的新 ID
sample_image_token = tokenizer.convert_ids_to_tokens(original_text_vocab_size)  # 原 ID 50305
print(f"Concept Token '{sample_image_token}' 的新 ID:", tokenizer.convert_tokens_to_ids(sample_image_token))  # 应输出 50322

# 检查嵌入层形状
print("嵌入层大小:", model.showo.get_input_embeddings().weight.shape)  # 应显示 torch.Size([58515, 2048])

# 检查 index_no_updates 中 True 的位置和数量，True 应该是 new token ids
print("index_no_updates 中 False 的位置:", torch.nonzero(~index_no_updates).squeeze())  # 应输出 50305-50321
print("index_no_updates 中 True 的数量:", torch.sum(index_no_updates))  # 应输出 58498

with torch.no_grad():
    orig_embeds = model.showo.get_input_embeddings().weight.data.clone()
    orig_lm_head_weight = model.showo.lm_head.weight.data.clone()
    orig_lm_head_bias = model.showo.lm_head.bias.data.clone()

新 token ID: [50305, 50306, 50307, 50308, 50309, 50310, 50311, 50312, 50313, 50314, 50315, 50316, 50317, 50318, 50319, 50320, 50321]
新增文本 token ID: [50305, 50306, 50307, 50308, 50309, 50310, 50311, 50312, 50313, 50314, 50315, 50316, 50317, 50318, 50319, 50320, 50321]
Concept Token '<dunpai>' 的新 ID: 50305
嵌入层大小: torch.Size([58515, 2048])
index_no_updates 中 False 的位置: tensor([50305, 50306, 50307, 50308, 50309, 50310, 50311, 50312, 50313, 50314,
        50315, 50316, 50317, 50318, 50319, 50320, 50321])
index_no_updates 中 True 的数量: tensor(58498)


In [6]:
uni_prompting.sptids_dict

{'<|soi|>': tensor([50296]),
 '<|eoi|>': tensor([50297]),
 '<|sov|>': tensor([50298]),
 '<|eov|>': tensor([50299]),
 '<|t2i|>': tensor([50300]),
 '<|mmu|>': tensor([50301]),
 '<|t2v|>': tensor([50302]),
 '<|v2v|>': tensor([50303]),
 '<|lvg|>': tensor([50304]),
 '<|sot|>': tensor([50256]),
 '<|eot|>': tensor([50256]),
 '<|pad|>': tensor([50295])}

In [7]:
vq_model.requires_grad_ = False
vq_model.eval()
model.train()
for names, p in model.named_parameters():
    if "embed_tokens" not in names and "lm_head" not in names:
        p.requires_grad = False
    else:
        p.requires_grad = True

trainable_params = [model.showo.get_input_embeddings().weight, model.showo.lm_head.weight, model.showo.lm_head.bias]
optimizer = torch.optim.AdamW(
            trainable_params, # for optimize the embeddings and the head
            lr=1e-2,
            betas=(0.9, 0.999),
            weight_decay=1e-2,
            eps=1e-08,
        )
for names, p in model.named_parameters():
    if p.requires_grad:
        print(f"{names} requires_grad") # embed_token, lm_head会更新

showo.model.embed_tokens.weight requires_grad
showo.lm_head.weight requires_grad
showo.lm_head.bias requires_grad


In [8]:
mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine"))
mask_id = model.mask_token_id
mask_dtype = model.showo.model.embed_tokens.weight.dtype

  mask_id = model.mask_token_id


In [9]:

# t2i_dataset = PersonalizedT2IDataset(data_root, concept)
# t2i_dataloader = DataLoader(t2i_dataset, batch_size=5, shuffle=True, num_workers=10, pin_memory=True)
mmu_dataloader = get_personalized_mmu_dataloader(data_root, concept, tokenizer, batch_size=5, num_workers=0, max_length=128)
t2i_dataloader = get_personalized_t2i_dataloader(data_root, concept, tokenizer, batch_size=5, num_workers=0, max_length=128)


iterables = {
    'mmu_flow': mmu_dataloader,
    't2i_flow': t2i_dataloader
}


combined_dataloader = CombinedLoader(iterables, mode="max_size_cycle")

# Before adding the new tokens, the vocab size is 58498
# vocab size = 58498 = 50295  llm vocabsize
#                    + 10     <|soi|> <|eoi|> <|sov|> <|eov|> <|t2i|> <|mmu|> <|t2v|> <|v2v|> <|lvg|> <|pad|>
#                    + 8192   vq model codebook size
#                    + 1      mask token (token id == 58497)
from typing import Union


uni_prompting.sptids_dict
# {'<|soi|>': tensor([50296]),
#  '<|eoi|>': tensor([50297]),
#  '<|sov|>': tensor([50298]),
#  '<|eov|>': tensor([50299]),
#  '<|t2i|>': tensor([50300]),
#  '<|mmu|>': tensor([50301]),
#  '<|t2v|>': tensor([50302]),
#  '<|v2v|>': tensor([50303]),
#  '<|lvg|>': tensor([50304]),
#  '<|sot|>': tensor([50256]),
#  '<|eot|>': tensor([50256]),
#  '<|pad|>': tensor([50295])}

# uni_prompting.text_tokenizer == tokenizer
def prepare_inputs_and_labels(
        pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor],
        texts: Union[str, str],
        min_masking_rate: float = 0.0,
        is_train: bool = True,
):

    image_tokens = vq_model.get_code(pixel_values_or_image_ids)
    image_tokens = image_tokens + len(uni_prompting.text_tokenizer)

    # create MLM mask and labels
    input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens(
        image_tokens,
        mask_id,
        config,
        mask_schedule=mask_schedule,
        is_train=is_train,
    )
    input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i')

    return input_ids, labels, mask_prob, image_tokens

Formatting llava instruction data


In [10]:
list_combined_dataloader = list(combined_dataloader)
one_batch_mmu = list_combined_dataloader[0][0]['mmu_flow']
one_batch_t2i = list_combined_dataloader[0][0]['t2i_flow']

# one_batch_mmu = next(iter(mmu_dataloader))

In [12]:
model.output_size = new_total_vocab
for epoch in range(0, 100):
    print(f"Epoch {epoch}")
    loss_list = []
    loss_t2i_list = []
    loss_mmu_list = []
    for batch, batch_idx, dataloader_idx in tqdm(list_combined_dataloader):
        batch_size_mmu = batch["mmu_flow"]["images"].shape[0]
        batch_size_t2i = batch["t2i_flow"]["images"].shape[0]
        
        # t2i format
        pixel_values, texts = batch["t2i_flow"]["images"], batch["t2i_flow"]["conditions"]
        pixel_values = pixel_values.to(device)
        input_ids, labels, mask_prob, image_tokens_ori = prepare_inputs_and_labels(pixel_values, texts, is_train=True)
        attention_mask = create_attention_mask_predict_next(input_ids,
                                                                pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
                                                                soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
                                                                eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
                                                                rm_pad_in_image=True,
                                                                return_inverse_mask=True)
        attention_mask = attention_mask.to(mask_dtype)
        
        # mmu format
        pixel_values_mmu, input_ids_mmu, labels_mmu = (batch["mmu_flow"]["images"],
                                                      batch["mmu_flow"]["input_ids"],
                                                      batch["mmu_flow"]["labels"])
        pixel_values_mmu = pixel_values_mmu.to(device, non_blocking=True)
        input_ids_mmu = input_ids_mmu.to(device, non_blocking=True)
        image_tokens_mmu = vq_model.get_code(pixel_values_mmu)
        image_tokens_mmu = image_tokens_mmu + len(uni_prompting.text_tokenizer)
        
        input_ids_mmu = torch.cat([
                    (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(
                        device),
                    (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(
                        device),
                    image_tokens_mmu,
                    (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(
                        device),
                    input_ids_mmu,
                ], dim=1).long()

        labels_mmu = torch.cat([
                    (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(device),
                    (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(device),
                    torch.ones_like(image_tokens_mmu) * uni_prompting.ignore_id,
                    (torch.ones(input_ids_mmu.shape[0], 1) * uni_prompting.ignore_id).to(device),
                    labels_mmu.to(device)
                ], dim=1).long()
        
        
        attention_mask_mmu = create_attention_mask_for_mmu(input_ids_mmu.to(input_ids.device),
                                                               eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']))
        attention_mask_mmu = attention_mask_mmu.to(mask_dtype)
        attention_mask = torch.cat([attention_mask, attention_mask_mmu], dim=0)
        input_ids = torch.cat((input_ids, input_ids_mmu.to(input_ids.device)), dim=0)
        labels = torch.cat((labels, labels_mmu.to(input_ids.device)), dim=0)
        
        optimizer.zero_grad()
        
        logits, loss_t2i, loss_lm, loss_mmu = model(
                    input_ids=input_ids,
                    input_embeddings=None,
                    attention_mask=attention_mask,
                    labels=labels,
                    label_smoothing=0.0,
                    batch_size_t2i=batch_size_t2i,
                    batch_size_lm=0,
                    batch_size_mmu=batch_size_mmu,
                    max_seq_length=128,
                )
        loss = 0.8 * loss_t2i + 0.2 * loss_mmu
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
        loss_t2i_list.append(loss_t2i.item())
        loss_mmu_list.append(loss_mmu.item())
        # tqdm.set_postfix(loss=loss.item(), loss_t2i=loss_t2i.item(), loss_mmu=loss_mmu.item())
        # tqdm.write(f"loss: {loss.item()}, loss_t2i: {loss_t2i.item()}, loss_mmu: {loss_mmu.item()}")
        # 恢复原始权重
        with torch.no_grad():
            model.showo.get_input_embeddings().weight.data[index_no_updates] = orig_embeds[index_no_updates]
            model.showo.lm_head.weight.data[index_no_updates] = orig_lm_head_weight[index_no_updates]
            model.showo.lm_head.bias.data[index_no_updates] = orig_lm_head_bias[index_no_updates]
    print(f"Epoch {epoch} loss: {np.mean(loss_list)}, loss_t2i: {np.mean(loss_t2i_list)}, loss_mmu: {np.mean(loss_mmu_list)}")
    
        

Epoch 0


  0%|          | 0/49 [00:00<?, ?it/s]

100%|██████████| 49/49 [00:28<00:00,  1.69it/s]


Epoch 0 loss: 7.42177457225566, loss_t2i: 8.953611062497508, loss_mmu: 1.2944279726670713
Epoch 1


100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 1 loss: 6.626295547096097, loss_t2i: 7.999885296335026, loss_mmu: 1.1319361937289336
Epoch 2


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 2 loss: 6.107476789124158, loss_t2i: 7.363431599675392, loss_mmu: 1.083656797603685
Epoch 3


100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 3 loss: 5.945915611422792, loss_t2i: 7.1792658397129605, loss_mmu: 1.0125141003910376
Epoch 4


100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 4 loss: 5.800265599270256, loss_t2i: 7.012139320373535, loss_mmu: 0.9527703918972794
Epoch 5


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 5 loss: 5.576078487902271, loss_t2i: 6.74552617754255, loss_mmu: 0.8982872555450517
Epoch 6


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 6 loss: 5.683408416047389, loss_t2i: 6.892803318646489, loss_mmu: 0.845828506107233
Epoch 7


100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 7 loss: 5.770050652173101, loss_t2i: 7.010568054354921, loss_mmu: 0.8079803698525136
Epoch 8


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 8 loss: 5.880341695279491, loss_t2i: 7.158044766406624, loss_mmu: 0.7695289059561126
Epoch 9


100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 9 loss: 5.581008327250578, loss_t2i: 6.791729567002277, loss_mmu: 0.7381228831957798
Epoch 10


100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 10 loss: 5.746380849760406, loss_t2i: 7.005443611923529, loss_mmu: 0.7101292087107288
Epoch 11


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 11 loss: 5.586092754286163, loss_t2i: 6.807965512178382, loss_mmu: 0.6986015451197721
Epoch 12


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 12 loss: 5.578917498491248, loss_t2i: 6.793343349378937, loss_mmu: 0.7212135095377358
Epoch 13


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 13 loss: 5.272494914580364, loss_t2i: 6.424329232196419, loss_mmu: 0.6651570854746566
Epoch 14


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 14 loss: 5.512755447504472, loss_t2i: 6.730252212407637, loss_mmu: 0.6427677994479939
Epoch 15


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 15 loss: 5.778004626838529, loss_t2i: 7.0554337404212175, loss_mmu: 0.6682877464562046
Epoch 16


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 16 loss: 5.431083567288457, loss_t2i: 6.6346889417998645, loss_mmu: 0.6166616179505173
Epoch 17


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 17 loss: 5.20013867592325, loss_t2i: 6.353850715014399, loss_mmu: 0.5852901227011973
Epoch 18


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 18 loss: 5.325779978109866, loss_t2i: 6.509501758886843, loss_mmu: 0.5908922994015168
Epoch 19


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 19 loss: 5.28897145816258, loss_t2i: 6.460553801789576, loss_mmu: 0.6026416741767708
Epoch 20


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 20 loss: 5.161755556962928, loss_t2i: 6.3089103747387325, loss_mmu: 0.5731358259003989
Epoch 21


100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 21 loss: 5.39027291414689, loss_t2i: 6.599544077503438, loss_mmu: 0.5531877048161565
Epoch 22


100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 22 loss: 5.02228256147735, loss_t2i: 6.1404388583436305, loss_mmu: 0.5496569292581811
Epoch 23


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 23 loss: 4.782850888310646, loss_t2i: 5.846313593338947, loss_mmu: 0.5289996795508326
Epoch 24


100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 24 loss: 5.532112136179087, loss_t2i: 6.784745279623538, loss_mmu: 0.5215791626548281
Epoch 25


100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 25 loss: 5.104142845893393, loss_t2i: 6.249997241156442, loss_mmu: 0.5207248514099997
Epoch 26


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 26 loss: 5.267490586455987, loss_t2i: 6.44642481511953, loss_mmu: 0.5517531009961147
Epoch 27


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 27 loss: 5.4702416488102505, loss_t2i: 6.705808041047077, loss_mmu: 0.5279758408361551
Epoch 28


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 28 loss: 5.096083261528793, loss_t2i: 6.2449073207621675, loss_mmu: 0.5007866926643313
Epoch 29


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 29 loss: 4.986432639919982, loss_t2i: 6.107424380827923, loss_mmu: 0.5024653247424534
Epoch 30


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 30 loss: 4.697487524577549, loss_t2i: 5.750976771724467, loss_mmu: 0.4835299893605466
Epoch 31


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 31 loss: 4.985864001877454, loss_t2i: 6.111887153314084, loss_mmu: 0.4817708995269269
Epoch 32


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 32 loss: 4.8310001747948785, loss_t2i: 5.922718926351898, loss_mmu: 0.46412474753297106
Epoch 33


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 33 loss: 4.956420662451763, loss_t2i: 6.0810732598207435, loss_mmu: 0.45780979606265926
Epoch 34


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 34 loss: 4.911873394129228, loss_t2i: 6.0274273473389295, loss_mmu: 0.4496574552387607
Epoch 35


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 35 loss: 4.7986326996161015, loss_t2i: 5.881988223718137, loss_mmu: 0.4652102207194786
Epoch 36


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 36 loss: 5.023519861454866, loss_t2i: 6.168167688408676, loss_mmu: 0.44492798222571
Epoch 37


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 37 loss: 4.904122162838371, loss_t2i: 6.021979755284835, loss_mmu: 0.4326912963724866
Epoch 38


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 38 loss: 5.147906215823427, loss_t2i: 6.327619713179919, loss_mmu: 0.42905166638748987
Epoch 39


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 39 loss: 5.111512008978396, loss_t2i: 6.280119185545007, loss_mmu: 0.4370830788904307
Epoch 40


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 40 loss: 4.962916982417204, loss_t2i: 6.099900707906606, loss_mmu: 0.4149817898109251
Epoch 41


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 41 loss: 4.824662442110022, loss_t2i: 5.927432872811142, loss_mmu: 0.41358012059817506
Epoch 42


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 42 loss: 4.871770450047085, loss_t2i: 5.986421565620267, loss_mmu: 0.4131657547336452
Epoch 43


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 43 loss: 4.761569835701767, loss_t2i: 5.849269798823765, loss_mmu: 0.4107695161840137
Epoch 44


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 44 loss: 4.97635513909009, loss_t2i: 6.117890304448653, loss_mmu: 0.41021399344412646
Epoch 45


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 45 loss: 4.734630672299132, loss_t2i: 5.814119168690273, loss_mmu: 0.41667636073365505
Epoch 46


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 46 loss: 4.77523946762085, loss_t2i: 5.863052421686601, loss_mmu: 0.42398709636561727
Epoch 47


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 47 loss: 4.56616168606038, loss_t2i: 5.6045852923879815, loss_mmu: 0.41246664227575675
Epoch 48


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 48 loss: 5.169183648362452, loss_t2i: 6.361976205086221, loss_mmu: 0.398013097519169
Epoch 49


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 49 loss: 4.536058114499462, loss_t2i: 5.569990488947655, loss_mmu: 0.40032825801445515
Epoch 50


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 50 loss: 4.921446634798634, loss_t2i: 6.052781280206174, loss_mmu: 0.3961076393571435
Epoch 51


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 51 loss: 4.94081998844536, loss_t2i: 6.0795356546129495, loss_mmu: 0.3859569726093691
Epoch 52


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 52 loss: 4.815396153196996, loss_t2i: 5.923952409199306, loss_mmu: 0.38117062665370044
Epoch 53


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 53 loss: 4.8338497901449395, loss_t2i: 5.9474581553011525, loss_mmu: 0.3794160808379553
Epoch 54


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 54 loss: 4.710834858368854, loss_t2i: 5.79451507451583, loss_mmu: 0.37611351952869065
Epoch 55


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 55 loss: 4.76106503058453, loss_t2i: 5.858552572678547, loss_mmu: 0.371114543430051
Epoch 56


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 56 loss: 4.478070220168756, loss_t2i: 5.504282469652137, loss_mmu: 0.3732208450685959
Epoch 57


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 57 loss: 4.810152409028034, loss_t2i: 5.919104814529419, loss_mmu: 0.37434219257259854
Epoch 58


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 58 loss: 4.727627294404166, loss_t2i: 5.816508986512009, loss_mmu: 0.37210016005805563
Epoch 59


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 59 loss: 4.60643356187003, loss_t2i: 5.667305746857001, loss_mmu: 0.36294434957054195
Epoch 60


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 60 loss: 4.438920310565403, loss_t2i: 5.458986160706501, loss_mmu: 0.35865633307519007
Epoch 61


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 61 loss: 4.572150459094924, loss_t2i: 5.625573046353399, loss_mmu: 0.35845979067439937
Epoch 62


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 62 loss: 4.94317795305836, loss_t2i: 6.088279091582006, loss_mmu: 0.3627732271442608
Epoch 63


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 63 loss: 4.294743773888569, loss_t2i: 5.280028452678603, loss_mmu: 0.3536046597422386
Epoch 64


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 64 loss: 4.593328862774129, loss_t2i: 5.653890680293648, loss_mmu: 0.35108124502763455
Epoch 65


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 65 loss: 4.697448501781541, loss_t2i: 5.783927859092246, loss_mmu: 0.3515307250983861
Epoch 66


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 66 loss: 4.76004061893541, loss_t2i: 5.862751488782922, loss_mmu: 0.34919684745219287
Epoch 67


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 67 loss: 4.815482957022531, loss_t2i: 5.932624330326003, loss_mmu: 0.3469172368320275
Epoch 68


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 68 loss: 4.666931921122026, loss_t2i: 5.747735763082699, loss_mmu: 0.3437162933453005
Epoch 69


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 69 loss: 4.513328250573606, loss_t2i: 5.556664787993139, loss_mmu: 0.33998184453467933
Epoch 70


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 70 loss: 4.624139484094114, loss_t2i: 5.695582204935502, loss_mmu: 0.3383681932274176
Epoch 71


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 71 loss: 4.49113543666139, loss_t2i: 5.529593686668241, loss_mmu: 0.33730206212827135
Epoch 72


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 72 loss: 4.265298699846073, loss_t2i: 5.247157057937311, loss_mmu: 0.337864776807172
Epoch 73


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 73 loss: 4.6437845205774115, loss_t2i: 5.721313902309963, loss_mmu: 0.33366657101682257
Epoch 74


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 74 loss: 4.5015217698350245, loss_t2i: 5.543292918983771, loss_mmu: 0.3344368308174367
Epoch 75


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 75 loss: 4.047284688268389, loss_t2i: 4.97664306115131, loss_mmu: 0.32985081614888445
Epoch 76


100%|██████████| 49/49 [00:25<00:00,  1.88it/s]


Epoch 76 loss: 4.541884750736003, loss_t2i: 5.593594857624599, loss_mmu: 0.3350440637797725
Epoch 77


100%|██████████| 49/49 [00:25<00:00,  1.89it/s]


Epoch 77 loss: 4.418182908272256, loss_t2i: 5.440972089767456, loss_mmu: 0.3270257309991486
Epoch 78


100%|██████████| 49/49 [00:25<00:00,  1.89it/s]


Epoch 78 loss: 4.781453541346958, loss_t2i: 5.894612541004103, loss_mmu: 0.32881719284519856
Epoch 79


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 79 loss: 4.415105325835092, loss_t2i: 5.436251839812921, loss_mmu: 0.330518835053152
Epoch 80


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 80 loss: 4.443418590389952, loss_t2i: 5.47255958586323, loss_mmu: 0.326854352022008
Epoch 81


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 81 loss: 4.406244577193747, loss_t2i: 5.424393344898613, loss_mmu: 0.33364905561415514
Epoch 82


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 82 loss: 4.633331796344446, loss_t2i: 5.709937392448892, loss_mmu: 0.32690890817617885
Epoch 83


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 83 loss: 4.518559356125033, loss_t2i: 5.5662966942300605, loss_mmu: 0.3276095322656388
Epoch 84


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 84 loss: 4.278078366299065, loss_t2i: 5.26545022215162, loss_mmu: 0.32859066364412404
Epoch 85


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 85 loss: 4.550484744869933, loss_t2i: 5.606214859047714, loss_mmu: 0.32756379581227596
Epoch 86


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 86 loss: 4.619091158010522, loss_t2i: 5.693014261673908, loss_mmu: 0.3233983812435549
Epoch 87


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 87 loss: 4.3964575626412215, loss_t2i: 5.414811212189344, loss_mmu: 0.323042630236976
Epoch 88


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 88 loss: 4.675721037144563, loss_t2i: 5.763980515149175, loss_mmu: 0.3226828175996031
Epoch 89


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 89 loss: 4.589533197636507, loss_t2i: 5.658124456600267, loss_mmu: 0.31516770554744467
Epoch 90


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 90 loss: 4.10764970098223, loss_t2i: 5.054940846501564, loss_mmu: 0.3184846647235812
Epoch 91


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 91 loss: 4.319661617279053, loss_t2i: 5.321067588669913, loss_mmu: 0.3140374291688204
Epoch 92


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 92 loss: 4.50840767792293, loss_t2i: 5.557981821955467, loss_mmu: 0.3101107891725034
Epoch 93


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 93 loss: 4.640713346247771, loss_t2i: 5.722834752530468, loss_mmu: 0.3122273313679865
Epoch 94


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 94 loss: 4.315993566902316, loss_t2i: 5.316696483261731, loss_mmu: 0.31318148236949833
Epoch 95


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 95 loss: 4.419954796226657, loss_t2i: 5.446438108171735, loss_mmu: 0.3140210105220274
Epoch 96


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 96 loss: 4.566065121670158, loss_t2i: 5.630015694365209, loss_mmu: 0.3102625949042184
Epoch 97


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 97 loss: 4.3650728877709835, loss_t2i: 5.378911261655847, loss_mmu: 0.30971904068577044
Epoch 98


100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 98 loss: 4.578340861262108, loss_t2i: 5.645318262431086, loss_mmu: 0.31043085246822055
Epoch 99


100%|██████████| 49/49 [00:26<00:00,  1.88it/s]

Epoch 99 loss: 4.276808767902608, loss_t2i: 5.268557061954421, loss_mmu: 0.30981513105180797





tensor([50301, 50296, 50424, 50872, 51385, 51136, 50904, 50617, 52184, 51160,
        52201, 55256, 51161, 55191, 54775, 54521, 55265, 54745, 50940, 51326,
        50361, 50425, 54282, 50616, 50693, 54285, 50696, 50644, 50793, 50793,
        55400, 51160, 54680, 55257, 53000, 54266, 53434, 52915, 54007, 53498,
        50425, 51214, 51208, 51716, 50440, 50661, 50648, 50661, 50665, 54745,
        52914, 57539, 53191, 53451, 52930, 53962, 52942, 52918, 52986, 54278,
        51144, 50438, 51188, 50376, 52130, 50646, 50868, 52674, 50831, 56715,
        53765, 52738, 50498, 50371, 53938, 50563, 50946, 50930, 51938, 51187,
        50898, 50851, 51380, 51915, 53434, 52523, 52411, 53939, 51401, 53683,
        50329, 52382, 52915, 51363, 50846, 51350, 52435, 50342, 58049, 55858,
        51610, 50498, 52890, 55492, 55162, 57550, 57645, 55678, 57633, 57612,
        57732, 55793, 57868, 57788, 54390, 55508, 51860, 54089, 51353, 56126,
        57689, 56021, 51416, 54329, 55499, 52473, 57702, 56985, 