In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from PIL import Image

from models import Showo, MAGVITv2, get_mask_chedule
from training.prompting_utils import UniversalPrompting, create_attention_mask_predict_next, create_attention_mask_for_mmu
from training.utils import get_config, flatten_omega_conf, mask_or_random_replace_tokens, AverageMeter
from transformers import AutoTokenizer
from models.clip_encoder import CLIPVisionTower
from transformers import CLIPImageProcessor
from llava.llava import conversation as conversation_lib

conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"]

import os
from omegaconf import DictConfig, ListConfig, OmegaConf
config = OmegaConf.load('configs/showo_demo.yaml')
# device setup
device = torch.device("cuda:7")

In [3]:
tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side ="left")
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
                                       special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
                                       ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)
vq_model = MAGVITv2
vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)
vq_model.requires_grad_(False)
vq_model.eval()
model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)
model.eval()

Working with z of shape (1, 13, 16, 16) = 3328 dimensions.
Look-up free quantizer with codebook size: 8192


The config attributes {'mask_token_id': 58497} were passed to Showo, but are not expected and will be ignored. Please verify your config.json configuration file.
  if self.w_clip_vit:


attention implementation:  sdpa


Showo(
  (showo): PhiForCausalLM(
    (model): PhiModel(
      (embed_tokens): Embedding(58498, 2048)
      (embed_dropout): Dropout(p=0.0, inplace=False)
      (layers): ModuleList(
        (0-23): 24 x PhiDecoderLayer(
          (self_attn): PhiSdpaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (q_layernorm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (k_layernorm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (rotary_emb): PhiRotaryEmbedding()
          )
          (mlp): PhiMLP(
            (activation_fn): NewGELUActivation()
            (fc1): Linear(in_features=2048, out_features=8192, bias=True)
            (fc2): Linear(in_features=8192, out_features=2048, b

In [4]:
temperature = 1  # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 1  # retain only the top_k most likely tokens, clamp others to have 0 probability

In [5]:
data_root = "/home/hpyky/full_mcdata"
concept = "dunpai"
ckpt_name = "t2i_wo_sys"
epoch2load = 90

ckpt_path = os.path.join("saves", concept, ckpt_name)
ckpt_embed_path = os.path.join(ckpt_path, f"epoch_{epoch2load}_embed.pt")
ckpt_lm_head_weight_path = os.path.join(ckpt_path, f"epoch_{epoch2load}_lm_head_weight.pt")
ckpt_lm_head_bias_path = os.path.join(ckpt_path, f"epoch_{epoch2load}_lm_head_bias.pt")

nums_new_token_i = 16

#################################
new_tokens = [f"<{concept}>"] + [f"<token_{i}>" for i in range(nums_new_token_i)]
num_new_tokens = len(new_tokens)  # 17

# 已知的原始参数
# 文本 token 数量（ID 0-50304）
original_text_vocab_size = len(tokenizer)  
# Image token 数量（原 ID 50305-58497）
original_image_vocab_size = model.showo.get_input_embeddings().num_embeddings - len(tokenizer)

original_total_vocab = original_text_vocab_size + original_image_vocab_size  # 58498

# 新的参数
new_text_vocab_size = original_text_vocab_size + num_new_tokens  # 50305 + 17 = 50322
new_total_vocab = original_total_vocab + num_new_tokens          # 58498 + 17 = 58515

# ------------------------------
# Step 1: 修改 Tokenizer 的词汇表
# ------------------------------

# 添加新 token 到 50305-50321 的位置
num_new_tokens = tokenizer.add_tokens(new_tokens)
new_token_ids = tokenizer.convert_tokens_to_ids(new_tokens)
print("新 token ID:", new_token_ids)  # 应输出 50305-50321

# ------------------------------
# Step 2: 调整模型的权重
# ------------------------------
with torch.no_grad():
    # 获取嵌入层权重
    embeddings = model.showo.get_input_embeddings().weight.data
    
    # 扩展嵌入层（58498 -> 58515）
    model.showo.resize_token_embeddings(new_total_vocab)
    # new_embeddings = model.showo.get_input_embeddings().weight.data

    # 将原 Image Token 权重后移 17 位
    original_image_weights = embeddings[original_text_vocab_size:original_total_vocab].clone()
    model.showo.get_input_embeddings().weight.data[new_text_vocab_size:new_total_vocab] = original_image_weights
    
    # 初始化新 token 的权重（用原文本最后 17 个 token）
    if os.path.exists(ckpt_embed_path):
        ckpt_embed_weight = torch.load(ckpt_embed_path)
        with torch.no_grad():
            model.showo.get_input_embeddings().weight.data[original_text_vocab_size:new_text_vocab_size] = ckpt_embed_weight.to(model.showo.get_input_embeddings().weight.device)
    else:
        raise ValueError("Embedding weights do not exist!")
        
    # new_text_weights = embeddings[original_text_vocab_size - num_new_tokens : original_text_vocab_size].clone()
    # model.showo.get_input_embeddings().weight.data[original_text_vocab_size : new_text_vocab_size] = new_text_weights
    # print(model.showo.lm_head.weight.data.shape[1])
    # 处理 lm_head（假设与嵌入层共享权重）
    if model.showo.lm_head.weight.data.shape[0] == new_total_vocab:
        # 扩展 lm_head 权重
        lm_head = model.showo.lm_head
        new_lm_head = torch.nn.Linear(
            lm_head.in_features, 
            new_total_vocab, 
            bias=hasattr(lm_head, 'bias')
        )
        new_lm_head.weight.data = lm_head.weight.data.clone()
        new_lm_head.weight.data[new_text_vocab_size:new_total_vocab] = lm_head.weight.data[original_text_vocab_size:original_total_vocab]
        
        if os.path.exists(ckpt_lm_head_weight_path):
            ckpt_lm_head_weight = torch.load(ckpt_lm_head_weight_path)
            with torch.no_grad():
                new_lm_head.weight.data[original_text_vocab_size:new_text_vocab_size] = ckpt_lm_head_weight.to(new_lm_head.weight.device)
        else:
            raise ValueError("lm_head weights do not exist!")
        # new_lm_head.weight.data[original_text_vocab_size:new_text_vocab_size] = lm_head.weight.data[original_text_vocab_size - num_new_tokens : original_text_vocab_size]
        if hasattr(lm_head, 'bias'):
            new_lm_head.bias.data = lm_head.bias.data.clone()
            new_lm_head.bias.data[new_text_vocab_size:new_total_vocab] = lm_head.bias.data[original_text_vocab_size:original_total_vocab]
            
            if os.path.exists(ckpt_lm_head_bias_path):
                ckpt_lm_head_bias = torch.load(ckpt_lm_head_bias_path)
                with torch.no_grad():
                    new_lm_head.bias.data[original_text_vocab_size:new_text_vocab_size] = ckpt_lm_head_bias.to(new_lm_head.weight.device)
            else:
                raise ValueError("lm_head bias do not exist!")
            # new_lm_head.bias.data[original_text_vocab_size:new_text_vocab_size] = lm_head.bias.data[original_text_vocab_size - num_new_tokens : original_text_vocab_size]
        
        model.showo.lm_head = new_lm_head
    else:
        raise ValueError("lm_head weights do not match the input embeddings!")

index_no_updates = torch.ones((new_total_vocab,), dtype=torch.bool)
index_no_updates[new_token_ids] = False
# ------------------------------
# 验证
# ------------------------------
# 检查新 token 的 ID
print("新增文本 token ID:", [tokenizer.convert_tokens_to_ids(t) for t in new_tokens])  # 应输出 50305-50321

# 检查一个原 Image Token 的新 ID
sample_image_token = tokenizer.convert_ids_to_tokens(original_text_vocab_size)  # 原 ID 50305
print(f"Concept Token '{sample_image_token}' 的新 ID:", tokenizer.convert_tokens_to_ids(sample_image_token))  # 应输出 50322

# 检查嵌入层形状
print("嵌入层大小:", model.showo.get_input_embeddings().weight.shape)  # 应显示 torch.Size([58515, 2048])

# 检查 index_no_updates 中 True 的位置和数量，True 应该是 new token ids
print("index_no_updates 中 False 的位置:", torch.nonzero(~index_no_updates).squeeze())  # 应输出 50305-50321
print("index_no_updates 中 True 的数量:", torch.sum(index_no_updates))  # 应输出 58498

config.model.showo.llm_vocab_size = len(tokenizer) - 10


新 token ID: [50305, 50306, 50307, 50308, 50309, 50310, 50311, 50312, 50313, 50314, 50315, 50316, 50317, 50318, 50319, 50320, 50321]
新增文本 token ID: [50305, 50306, 50307, 50308, 50309, 50310, 50311, 50312, 50313, 50314, 50315, 50316, 50317, 50318, 50319, 50320, 50321]
Concept Token '<dunpai>' 的新 ID: 50305
嵌入层大小: torch.Size([58515, 2048])
index_no_updates 中 False 的位置: tensor([50305, 50306, 50307, 50308, 50309, 50310, 50311, 50312, 50313, 50314,
        50315, 50316, 50317, 50318, 50319, 50320, 50321])
index_no_updates 中 True 的数量: tensor(58498)


In [6]:
system_personalized_prompt = f"<{concept}> is "
for i in range(nums_new_token_i):
    system_personalized_prompt += f"<token_{i}>"
    if i == nums_new_token_i - 1:
        system_personalized_prompt += "."

In [39]:
from pdata import image_transform
from PIL import Image
image_path = "/home/hpyky/full_mcdata/two_concept/concept/test/dunpai/0.png"
question = system_personalized_prompt + "\n" +'What is the color of the <dunpai> in the image?'
# question = system_personalized_prompt + "\n" + "Can you see the <dunpai> in the image?"
## processing
image_ori = Image.open(image_path).convert("RGB")
# tranforming the image to the required resolution:256x256
image = image_transform(image_ori, resolution = config.dataset.params.resolution).to(device)
image = image.unsqueeze(0)
print(f"image shape: {image.shape}") # torch.Size([1, 3, 256, 256])

image_tokens_mmu = vq_model.get_code(image)
image_tokens = image_tokens_mmu + len(uni_prompting.text_tokenizer)

batch_size = 1
input_ids = uni_prompting.text_tokenizer(['USER: \n' + question + ' ASSISTANT:'])[
    'input_ids']
input_ids = torch.tensor(input_ids).to(device)

input_ids = torch.cat([
    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device),
    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
    image_tokens,
    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device),
    input_ids
], dim=1).long()

attention_mask = create_attention_mask_for_mmu(input_ids.to(device),
                                                eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']))

cont_toks_list = model.mmu_generate(input_ids, attention_mask=attention_mask,
                            max_new_tokens=100, top_k=top_k,
                            eot_token=uni_prompting.sptids_dict['<|eot|>'])

cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]

text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)
print(text)

image shape: torch.Size([1, 3, 256, 256])
[' The color of the <dunpai> in the image is red, white, and blue.']
