In [1]:
from chimera.model.expert_encoder.configuration_sci_encoder import SciEncoderConfig, init_config_from_meta_config
from chimera.model.expert_encoder.modeling_sci_encoder import SciEncoder
from transformers import Pix2StructVisionModel, Pix2StructVisionConfig, Pix2StructConfig, Pix2StructForConditionalGeneration, Pix2StructImageProcessor
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor
from chimera.model.kosmos2_5.modeling_kosmos2_5 import Kosmos2_5Config, Kosmos2_5ForConditionalGeneration
from chimera.model.kosmos2_5 import Kosmos2_5VisionModel, Kosmos2_5VisionConfig, Kosmos2_5ImageProcessor
# from chimera.train.internvl_chat_finetune import *
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
internvl_path = '/mnt/workspace/pengtianshuo_hf_ckp/OpenGVLab/InternVL2-8B'
p2s_path = '/mnt/workspace/pengtianshuo_hf_ckp/solva_modules/chart_p2s'
clip_path = '/mnt/workspace/pengtianshuo_hf_dataset/MAVIS/CLIP-Math/Arxiv-ViT-L-14-336'
kosmos_path = '/mnt/workspace/pengtianshuo_hf_ckp/solva_modules/table_kosmos'


In [3]:
kosmos_processor = Kosmos2_5ImageProcessor.from_pretrained(kosmos_path)
clip_processor = CLIPImageProcessor.from_pretrained(clip_path)
p2s_processor = Pix2StructImageProcessor.from_pretrained(p2s_path)

In [4]:
clip_processor.max_patches 

AttributeError: 'CLIPImageProcessor' object has no attribute 'max_patches'

In [None]:
kosmos_processor, clip_processor, p2s_processor

In [None]:
image = Image.new('RGB', (224, 224), (255, 255, 255))

images = [image, image, image]

In [5]:
kosmos_processor.image_processor_type

'Kosmos2_5ImageProcessor'

In [None]:
kosmos_output = kosmos_processor(images, return_tensors="pt")
clip_output = clip_processor(images, return_tensors="pt")
p2s_output = p2s_processor(images, return_tensors="pt")

In [None]:
kosmos_output.flattened_patches.shape, kosmos_output.attention_mask.shape,  

In [None]:
clip_output.pixel_values.shape

In [None]:
p2s_output.flattened_patches.shape, p2s_output.attention_mask.shape,  

In [None]:
vlm_config = InternVLChatConfig.from_pretrained(internvl_path)
vlm_config.vision_config.drop_path_rate = 0.0
if vlm_config.llm_config.model_type == 'internlm2':
    vlm_config.llm_config.attn_implementation = 'flash_attention_2'  # for InternLM
else:
    vlm_config.llm_config._attn_implementation = 'flash_attention_2'  # for LLaMA

vlm_config.template = "internlm2-chat"
vlm_config.select_layer = -1
vlm_config.dynamic_image_size = True
vlm_config.use_thumbnail = True
vlm_config.ps_version = 'v2'
vlm_config.min_dynamic_patch = 1
vlm_config.max_dynamic_patch = 6

In [None]:
sci_encoder_meta_config = json.loads(open('/mnt/workspace/Solva/internvl_chat/shell/sci_encoder_config/solva_table_math_chart.json').read())
for i in sci_encoder_meta_config:
    # sci_encoder_meta_config中每个元素的键"model_name_or_path"和"model_type"都只在训练中用到，加载模型结束后就会被删除
    i["config"] = init_config_from_meta_config(i).to_dict()
    name_or_path = i.pop("model_name_or_path")
    i["config"]['_name_or_path'] = name_or_path
    
sci_config = SciEncoderConfig(sci_encoder_meta_config, vlm_config.llm_config.hidden_size)
# 分别加载来自不同路径的encoder
sci_config.separate_load = True
vlm_config.sci_encoder_config = sci_config

In [None]:
model = InternVLChatModel.from_pretrained(
    internvl_path, torch_dtype=torch.bfloat16, config=vlm_config)
tokenizer = AutoTokenizer.from_pretrained(
    internvl_path, add_eos_token=False, trust_remote_code=True, use_fast=False)

In [None]:
sci_token_list = []
token_list = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN,
                  QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN,
                  REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN]
for i in range(3):
    sci_token_list.append(f"<DOMAIN_{i}_CONTEXT>")
token_list.extend(sci_token_list)

num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
model.img_context_token_id = img_context_token_id


sci_context_token_id = tokenizer.convert_tokens_to_ids(sci_token_list)
model.set_domain_context_token_ids(sci_context_token_id)