In [1]:
import io
import os
import json
import requests
from PIL import Image
from typing import List, Dict, Union, Any

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForCausalLM,AutoProcessor, AutoModel, \
    Trainer, TrainingArguments, DataCollatorWithPadding
from transformers.modeling_outputs import CausalLMOutputWithPast

In [2]:
text_tokenizers = AutoTokenizer.from_pretrained("./base_models/llm_model_qwen2.5_1.5b")

In [31]:
class MLLMConfig(PretrainedConfig):
    model_type = "mllm"
    def __init__(
        self,
        llm_model_path = '/mnt/bn/brench-lq1/mllm_self_training/mllm_building/base_models/llm_model_qwen2.5_1.5b',
        vision_model_path = '/mnt/bn/brench-lq1/mllm_self_training/mllm_building/base_models/vision_model_siglip_16_224',
        image_pad_num = 81,
        freeze_vision_model = False,
        **kwargs
    ):
        self.llm_model_path = llm_model_path
        self.vision_model_path = vision_model_path
        self.image_pad_num = image_pad_num
        self.freeze_vision_model = freeze_vision_model
        super().__init__(**kwargs)

In [30]:
class MLLM(PreTrainedModel):
    config_class = MLLMConfig
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.vision_model = AutoModel.from_pretrained(self.config.vision_model_path)
        self.processor = AutoProcessor.from_pretrained(self.config.vision_model_path)
        self.llm_model = AutoModelForCausalLM.from_pretrained(self.config.llm_model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_path)
        self.linear1 = nn.Linear(self.vision_model.config.vision_config.hidden_size*9, self.llm_model.config.hidden_size)
        self.linear2 = nn.Linear(self.llm_model.config.hidden_size, self.llm_model.config.hidden_size)
        if self.config.freeze_vision_model:
            for param in self.vision_model.parameters():
                param.requires_grad = False
        for param in self.llm_model.parameters():
            param.requires_grad = False
    
    def forward(
        self,
        input_ids,
        pixel_values,
        labels,
        attention_mask=None  
    ):
        """
        模型的前向传播函数。

        参数:
        - input_ids (torch.Tensor): 输入的 token 索引张量。
        - attention_mask (torch.Tensor): 注意力掩码张量。
        - pixel_values (torch.Tensor): 图像像素值张量。
        - labels (torch.Tensor): 标签张量。

        返回:
        - CausalLMOutputWithPast: 包含 logits 和损失的输出对象。
        """
        # 通过视觉模型获取图像的嵌入表示
        vision_embedding = self.vision_model.vision_model(pixel_values=pixel_values).last_hidden_state
        # 通过语言模型获取文本的嵌入表示
        text_embedding = self.llm_model.get_input_embeddings()(input_ids)

        batch_size, image_tokens, embedding_dim_size = vision_embedding.shape
        vision_embedding = vision_embedding.view(batch_size,-1,embedding_dim_size*9)
        # 对视觉嵌入进行线性变换和激活函数处理，得到图像特征
        image_features = self.linear2(F.silu(self.linear1(vision_embedding)))
        text_embedding = text_embedding.to(image_features.dtype)

        # 将图像特征与文本嵌入合并
        inputs_embedding = self.merge_input_ids_with_image_features(image_features, text_embedding, input_ids)
        # 将合并后的嵌入输入到语言模型中，获取输出
        outputs = self.llm_model(inputs_embeds=inputs_embedding, attention_mask=attention_mask)
        # 从输出中提取 logits
        logits = outputs[0]
        # 初始化损失为 None
        loss = None
        # 如果有标签，则计算损失
        if labels is not None:
            # 创建交叉熵损失函数
            loss_func = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
            # 计算 logits 和标签之间的交叉熵损失
            loss = loss_func(
                logits.view(-1, logits.size(-1)), labels.view(-1).to(logits.device)
            )
        # 返回包含 logits 和损失的输出对象
        return CausalLMOutputWithPast(logits=logits,loss=loss)
    
    
    def merge_input_ids_with_image_features(self, image_features, inputs_embedding, input_ids):
        num_images, num_image_pathes, embedding_dim_size = image_features.shape
        batch_indices, image_indices = torch.where(input_ids == self.tokenizer('<|image_pad|>')['input_ids'][0])
        inputs_embedding[batch_indices, image_indices]  = image_features.view(-1, embedding_dim_size)
        return inputs_embedding

        

In [5]:
system_prompt = {
        "role": "system",
        "content": "你叫Flash,你是为一位专门为Brench服务的多模态AI助手"
}
class PretrainedDataset(Dataset):
    def __init__(self, images_path, annotations_path, config):
        self.config = config
        self.images_path = images_path
        self.annotations_path = annotations_path
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_path)
        self.processor = AutoProcessor.from_pretrained(self.config.vision_model_path)

        with open(self.annotations_path, 'r', encoding='utf-8') as f:
            self.processor_data = json.load(f)
    
    def __len__(self):
        return len(self.processor_data)
    
    def __getitem__(self, index):
        data_sample = self.processor_data[index]
        try:
            image_file_name = data_sample['image']
            conversations = data_sample['conversations']
            image = Image.open(os.path.join(self.images_path, image_file_name)).convert('RGB')
            pixel_values = self.processor(text=None, images=image)['pixel_values']
            user_prompt = {
                "role": "user",
                "content": conversations[0]['value']
            }
            query_text = [system_prompt, user_prompt]
            query_input = self.tokenizer.apply_chat_template(
                query_text,
                tokenize=False,
                add_generation_prompt=True
            ).replace('<image>','<|image_pad|>'*self.config.image_pad_num)
            response_text = conversations[1]['value'] + self.tokenizer.eos_token
            query_input_ids = self.tokenizer(query_input)['input_ids']
            response_input_ids = self.tokenizer(response_text)['input_ids']
            input_ids = query_input_ids + response_input_ids
            labels = [self.tokenizer.pad_token_id] * len(query_input_ids) + response_input_ids
            input_ids = input_ids[:-1]
            labels = labels[1:]
        except:
            default_image = Image.new('RGB',(224,224),color='white')
            pixel_values = self.processor(text=None, images=default_image)['pixel_values']
            user_prompt = {
                "role": "user",
                "content":"这张图片描述的内容是什么\n<image>"
            }
            query_text = [system_prompt, user_prompt]
            query_input = self.tokenizer.apply_chat_template(
                query_text,
                tokenize=False,
                add_generation_prompt=True
            ).replace('<image>','<|image_pad|>'*self.config.image_pad_num)
            response_text = "图片内容为空，无法生成相关的回复\n" + self.tokenizer.eos_token
            query_input_ids = self.tokenizer(query_input)['input_ids']
            response_input_ids = self.tokenizer(response_text)['input_ids']
            input_ids = query_input_ids + response_input_ids
            labels = [self.tokenizer.pad_token_id] * len(query_input_ids) + response_input_ids
            input_ids = input_ids[:-1]
            labels = labels[1:]
        return {
            'input_ids': input_ids,
            'labels': labels,
            'pixel_values': pixel_values
        }
            
class  DatasetCollator: 
    def __init__(self, config):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_path)
    
    def __call__(self, features: List[Dict[str,Any]])->Dict[str,torch.Tensor]:
        max_length = max(len(feature['input_ids']) for feature in features)
        input_ids = []
        labels = []
        pixel_values = []
        for feature in features:
            input_ids.append(feature['input_ids'] + [self.tokenizer.pad_token_id] * (max_length - len(feature['input_ids'])))
            labels.append(feature['labels'] + [self.tokenizer.pad_token_id] * (max_length - len(feature['labels'])))
            pixel_values.append(feature['pixel_values'])
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'pixel_values': torch.cat(pixel_values, dim=0)
        }


In [6]:
config = MLLMConfig()

In [7]:
images_path = '/mnt/bn/brench-lq1/mllm_self_training/mllm_building/pretrained_data/LLaVA-CC3M-Pretrain-595K/images'
annotations_path = '/mnt/bn/brench-lq1/mllm_self_training/mllm_building/pretrained_data/LLaVA-CC3M-Pretrain-595K/chat-translated.json'

In [8]:
pretrained_dataset =  PretrainedDataset(images_path,annotations_path,config)

In [9]:
pretrained_dataset[0].keys()

dict_keys(['input_ids', 'labels', 'pixel_values'])

In [34]:
llm_model_path = '/mnt/bn/brench-lq1/mllm_self_training/mllm_building/base_models/llm_model_qwen2.5_1.5b'
vision_model_path = '/mnt/bn/brench-lq1/mllm_self_training/mllm_building/base_models/vision_model_siglip_14_384'
image_pad_num = 81
config = MLLMConfig(
    llm_model_path = llm_model_path,
    vision_model_path = vision_model_path,
    image_pad_num = image_pad_num
)
config

MLLMConfig {
  "freeze_vision_model": false,
  "image_pad_num": 81,
  "llm_model_path": "/mnt/bn/brench-lq1/mllm_self_training/mllm_building/base_models/llm_model_qwen2.5_1.5b",
  "model_type": "mllm",
  "transformers_version": "4.46.3",
  "vision_model_path": "/mnt/bn/brench-lq1/mllm_self_training/mllm_building/base_models/vision_model_siglip_14_384"
}

In [35]:
model = MLLM(config).cuda()
print(model)
print(f'模型参数量为：{sum(p.numel() for p in model.parameters() if p.requires_grad)}')

MLLM(
  (vision_model): SiglipModel(
    (text_model): SiglipTextTransformer(
      (embeddings): SiglipTextEmbeddings(
        (token_embedding): Embedding(32000, 1152)
        (position_embedding): Embedding(64, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (self_attn): SiglipSdpaAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear(in_features=1152, out_features=4304, bias=True)
              (fc2): Linear(in_features=4304, out_fe

In [36]:
images_path = '/mnt/bn/brench-lq1/mllm_self_training/mllm_building/pretrained_data/LLaVA-CC3M-Pretrain-595K/images'
annotations_path = '/mnt/bn/brench-lq1/mllm_self_training/mllm_building/pretrained_data/LLaVA-CC3M-Pretrain-595K/chat-translated.json'
output_dir = '/mnt/bn/brench-lq1/mllm_self_training/mllm_building/pretrained_model_save' 


args = TrainingArguments(
    output_dir=output_dir,
    do_train=True,
    per_device_train_batch_size=8,
    learning_rate=1e-4,
    num_train_epochs=5,
    save_steps=200,
    save_total_limit=2,
    fp16=True,
    gradient_accumulation_steps=8,
    logging_steps=1,
    report_to='tensorboard',
    dataloader_pin_memory=True,
    dataloader_num_workers=16
)


In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=PretrainedDataset(images_path, annotations_path, config),
    data_collator=DatasetCollator(config) 
)

trainer.train(resume_from_checkpoint=False)
trainer.save_model(output_dir)
trainer.save_state()

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
Detected kernel version 5.4.143, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork 

Step,Training Loss
1,6.7006
2,6.5428
3,6.6177
4,6.6588
5,6.6107


NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
