# 创建一个Llava模型

## 下载模型

In [None]:
export HF_ENDPOINT=https://hf-mirror.com
hf download openai/clip-vit-large-patch14-336 --local-dir ../../../../models/CLIP/clip-vit-large-patch14-336
hf download Qwen/Qwen1.5-4B-Chat --local-dir ../../../models/Qwen/Qwen1.5-4B-Chat

## Llava初始化

### 将clip模型和llm模型的config拿出来，初始化一个Llava model

In [1]:
from transformers import (
    LlavaForConditionalGeneration,LlavaConfig,LlavaProcessor,
    CLIPVisionModel,CLIPVisionConfig,CLIPImageProcessor,
    Qwen2ForCausalLM,Qwen2Config,Qwen2Tokenizer
    )
import torch
clip_model_path = "./models/clip-vit-large-patch14-336"
qwen_model_path = "./models/Qwen1.5-1.8B-Chat"

In [2]:
# 加载视觉模型
vision_model = CLIPVisionModel.from_pretrained(clip_model_path)
vision_config = CLIPVisionConfig.from_pretrained(clip_model_path)
image_processor = CLIPImageProcessor.from_pretrained(clip_model_path)

# 加载语言模型
qwen_model = Qwen2ForCausalLM.from_pretrained(qwen_model_path)
qwen_config = Qwen2Config.from_pretrained(qwen_model_path)
qwen_tokenizer = Qwen2Tokenizer.from_pretrained(qwen_model_path)

# 添加特殊token
qwen_tokenizer.add_special_tokens({"additional_special_tokens":qwen_tokenizer.additional_special_tokens+["<image>"]})

llava_config = LlavaConfig(
    vision_config=vision_config,
    text_config=qwen_config,
    ignore_index=-100,
    image_token_index=151646
)
# 使用从配置初始化 Llava 模型结构
llava_model = LlavaForConditionalGeneration(llava_config)
llava_processor = LlavaProcessor(
    image_processor=image_processor,
    tokenizer=qwen_tokenizer,
    num_additional_image_tokens=1,
    patch_size=vision_config.patch_size,
    vision_feature_select_strategy="default" # 需要指定，否则tokens的维度与image_feature的维度对不上
)

### 复制权重

In [3]:
llava_model.model.vision_tower = vision_model
llava_model.model.language_model = qwen_model.model

### 查看权重是否赋值成功

In [4]:
print(llava_model.model.language_model.embed_tokens.weight.data[0,:2])
print(qwen_model.model.embed_tokens.weight.data[0,:2])

tensor([-0.0271,  0.0280])
tensor([-0.0271,  0.0280])


### 复制pad_token_id

In [5]:
llava_model.config.pad_token_id = qwen_tokenizer.pad_token_id
llava_model.config.pad_token_id

151643

### 保存模型和相关的配置文件

In [6]:
llava_model.save_pretrained("models/llava_clip-L-14-336_Qwen1.5-1.8B")
llava_processor.save_pretrained("models/llava_clip-L-14-336_Qwen1.5-1.8B")

['models/llava_clip-L-14-336_Qwen1.5-1.8B/processor_config.json']

## 测试

In [7]:
from transformers import AutoProcessor,LlavaForConditionalGeneration,LlavaProcessor
import torch


llava_mode_path = "models/llava_clip-L-14-336_Qwen1.5-1.8B"

llava_processor = LlavaProcessor.from_pretrained(llava_mode_path, use_fast=True)
llava_model = LlavaForConditionalGeneration.from_pretrained(
    llava_mode_path, device_map="cpu"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:

from PIL import Image

prompt_text = "<image>\nWhat are these?"


messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt_text},
]
prompt = llava_processor.tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)


image_path = "test.jpg"
image = Image.open(image_path)


inputs = llava_processor(text=prompt, images=image, return_tensors="pt")

# Generate
generate_ids = llava_model.generate(**inputs, max_new_tokens=15)
llava_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
for tk in inputs.keys():
    inputs[tk] = inputs[tk].to(llava_model.device)
generate_ids = llava_model.generate(**inputs, max_new_tokens=20)
gen_text = llava_processor.batch_decode(
    generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)[0]

print(gen_text)

Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><image><