In [None]:
from transformers import LlavaForConditionalGeneration, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
import torch
from torch import nn

ORIGINAL_VIT_HIDDEN_SIZE = 1152# 示例值，请替换
CONCATENATION_FACTOR = 9      # 你说的每 9 个拼接在一起

# 计算拼接后新的输入维度
NEW_PROJECTOR_INPUT_SIZE = ORIGINAL_VIT_HIDDEN_SIZE * CONCATENATION_FACTOR

class CustomLlavaProjector(nn.Module):
    """
    自定义的投影层，以匹配你的修改。
    原始的投影层是一个简单的线性层或一个 MLP。这里我们创建一个新的。
    """
    def __init__(self, config: LlavaConfig):
        super().__init__()
        
        # 我们需要从主配置中获取 LLM 的隐藏维度
        llm_hidden_size = config.text_config.hidden_size
        
        # 核心修改：创建一个新的线性层，其输入维度是拼接后的维度
        self.linear_1 = nn.Linear(NEW_PROJECTOR_INPUT_SIZE, llm_hidden_size, bias=True)
        self.act = nn.GELU()
        self.linear_2 = nn.Linear(llm_hidden_size, llm_hidden_size, bias=True)

    def forward(self, image_features):
        # 这里的 image_features 进来之前应该已经被你处理过了（reshape + concat）
        # 如果处理逻辑在模型外部，那这里直接用就行
        # 如果处理逻辑要在模型内部，那 forward 函数也需要修改
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


class LlavaForConditionalGenerationCustom(LlavaForConditionalGeneration):
    """
    我们的自定义 LLaVA 模型类
    """
    def __init__(self, config: LlavaConfig):
        # 首先调用父类的构造函数，创建出除了投影层之外的所有部分
        super().__init__(config)

        # 核心修改：用我们自定义的投影层替换掉父类中默认的投影层
        # 确保这里的属性名 `multi_modal_projector` 和原始模型中的一致
        print("INFO: Replacing the multi_modal_projector with our custom version.")
        self.multi_modal_projector = CustomLlavaProjector(config)

In [10]:
import torch
from transformers import AutoConfig, AutoProcessor
from safetensors.torch import load_file  # 1. 导入正确的函数

model_path = "model"
# 确保路径指向你的 .safetensors 文件
weights_path = f"{model_path}/model.safetensors" 

# 1. 加载处理器（这部分不变）
processor = AutoProcessor.from_pretrained(model_path)

# 2. 加载模型配置
config = AutoConfig.from_pretrained(model_path)

# 3. 使用配置实例化你的自定义模型
print("INFO: Instantiating custom model from config...")
model = LlavaForConditionalGenerationCustom(config)

# 4. 加载权重文件 (state_dict)
print(f"INFO: Loading state_dict from {weights_path}...")
# 2. 使用 safetensors 的函数来加载
state_dict = load_file(weights_path, device="cpu")

# 5. 将权重加载到模型中
print("INFO: Loading state_dict into the custom model...")
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)

INFO: Instantiating custom model from config...
INFO: Replacing the multi_modal_projector with our custom version.
INFO: Loading state_dict from model/model.safetensors...
INFO: Loading state_dict into the custom model...
Missing keys: ['model.vision_tower.vision_model.embeddings.patch_embedding.weight', 'model.vision_tower.vision_model.embeddings.patch_embedding.bias', 'model.vision_tower.vision_model.embeddings.position_embedding.weight', 'model.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight', 'model.vision_tower.vision_model.encoder.layers.0.layer_norm1.bias', 'model.vision_tower.vision_model.encoder.layers.0.self_attn.k_proj.weight', 'model.vision_tower.vision_model.encoder.layers.0.self_attn.k_proj.bias', 'model.vision_tower.vision_model.encoder.layers.0.self_attn.v_proj.weight', 'model.vision_tower.vision_model.encoder.layers.0.self_attn.v_proj.bias', 'model.vision_tower.vision_model.encoder.layers.0.self_attn.q_proj.weight', 'model.vision_tower.vision_model.encode

In [11]:
print(model)

LlavaForConditionalGenerationCustom(
  (model): LlavaModel(
    (vision_tower): SiglipVisionModel(
      (vision_model): SiglipVisionTransformer(
        (embeddings): SiglipVisionEmbeddings(
          (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
          (position_embedding): Embedding(729, 1152)
        )
        (encoder): SiglipEncoder(
          (layers): ModuleList(
            (0-26): 27 x SiglipEncoderLayer(
              (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
              (self_attn): SiglipAttention(
                (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
              )
              (layer_norm2): LayerNorm((1152,), eps=1e-06, elemen