## 下载数据集

In [None]:
export HF_ENDPOINT=https://hf-mirror.com
hf download liuhaotian/LLaVA-CC3M-Pretrain-595K --repo-type=dataset --local-dir ../../../../datasets/LLaVA-CC3M-Pretrain-595K
hf download CaptionEmporium/TextOCR-GPT4o --repo-type=dataset --local-dir ../../../../datasets/TextOCR-GPT4o

## 查看数据集

In [1]:
from dataclasses import dataclass
from PIL import Image
from torch.utils.data import Dataset
from pathlib import Path
import pandas as pd
import torch
from typing import List
from transformers import AutoProcessor,LlavaProcessor


In [2]:
data_dir = "datasets/LLaVA-CC3M-Pretrain-595K"
data_dir

chat_file = Path(data_dir,"chat.json")
chat_data = pd.read_json(path_or_buf=chat_file)
# chat_data.shape
# chat_data.head(20)
# chat_data.iloc[10]["conversations"][1]["value"]

## 构建数据集类

### 1、构建llava读取数据集的类

In [3]:
class LlavaDataset(Dataset):

    def __init__(self, data_dir) -> None:
        super().__init__()
        self.chat_data, self.image_dir = self.build_dataset(data_dir=data_dir)

    def build_dataset(self, data_dir) -> tuple[list, Path]:
        data_dir = Path(data_dir)
        chat_file = data_dir.joinpath("chat.json")
        images_dir = data_dir.joinpath("images")

        chat_data = pd.read_json(path_or_buf=chat_file).to_dict("records")
        return chat_data, images_dir

    def __len__(self):
        return len(self.chat_data)

    def __getitem__(self, index) -> tuple[str, str, Path]:
        cur_data = self.chat_data[index]
        humen_input = cur_data["conversations"][0]["value"]
        gpt_output = cur_data["conversations"][1]["value"]
        image_path = self.image_dir.joinpath(cur_data.get("image"))
        return (humen_input, gpt_output, image_path)

In [4]:
data_dir = "datasets/LLaVA-CC3M-Pretrain-595K"
test_llavadataset = LlavaDataset(data_dir)
test_llavadataset[123]

('Relay a brief, clear account of the picture shown.\n<image>',
 'e and person vintage initials logo symbol .',
 PosixPath('datasets/LLaVA-CC3M-Pretrain-595K/images/GCC_train_001015674.jpg'))

### 2、单个输入处理

In [5]:
llava_model_path = "models/llava_clip-L-14-336_Qwen1.5-1.8B"
llava_process = LlavaProcessor.from_pretrained(llava_model_path,use_fast=True)

In [None]:
llava_process.tokenizer(
    "a_text",
    return_tensors="pt",
    padding="longest",
    truncation=True
)['input_ids']

In [7]:
test_data = test_llavadataset[123]
test_data

('Relay a brief, clear account of the picture shown.\n<image>',
 'e and person vintage initials logo symbol .',
 PosixPath('datasets/LLaVA-CC3M-Pretrain-595K/images/GCC_train_001015674.jpg'))

In [8]:
@dataclass
class QaImagaOutput:
    q_input_ids: torch.Tensor
    pixel_values: torch.Tensor
    a_input_ids: torch.Tensor


def build_qaimage(
    processor: LlavaProcessor, q_text: str, a_text: str, image_path: Path
):

    # 千问的对话模板
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": q_text},
    ]
    prompt = processor.tokenizer.apply_chat_template(
        messages,
        tokenize=False,  # 是否直接返回token ID（默认False，返回字符串）
        add_generation_prompt=True,  # 是否在末尾添加生成提示（如"Assistant:"）
    )
    image_file = image_path
    raw_image = Image.open(image_file)

    inputs = processor(text=prompt, images=raw_image, return_tensors="pt")

    a_input_ids = processor.tokenizer(
        a_text, return_tensors="pt", padding="longest", truncation=True
    )["input_ids"]

    return QaImagaOutput(
        q_input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        a_input_ids=a_input_ids,
    )

In [9]:
c = build_qaimage(llava_process, test_data[0], test_data[1], test_data[2])

In [10]:
c.q_input_ids.shape

torch.Size([1, 607])

In [11]:
torch.tensor(llava_process.tokenizer.eos_token_ids).reshape(1,-1)

tensor([[151645]])

### 3、多batch整理

In [12]:
class TrainLlavaModelCollator:
    def __init__(self, processor: LlavaProcessor, IGNORE_INDEX: int = -100):
        self.processor = processor
        self.ignore_index = IGNORE_INDEX
        self.eos_token_ids = self.processor.tokenizer.eos_token_ids
        self.pad_token_ids = self.processor.tokenizer.pad_token_ids

    def convert_one_piece(self, q_input_ids: torch.Tensor, a_input_ids: torch.Tensor):
        input_ids = torch.concat(
            [
                q_input_ids,
                a_input_ids,
                torch.tensor(self.eos_token_ids).reshape(1, -1),
            ],
            axis=1,
        )
        # labels 用来控制是否计算loss
        labels = torch.concat(
            [
                torch.full_like(q_input_ids, fill_value=self.ignore_index),
                a_input_ids,
                torch.tensor(self.eos_token_ids).reshape(1, -1),
            ],
            axis=1,
        )
        return input_ids, labels

    def __call__(self, features: List):
        input_ids_list = []
        labels_list = []
        pixel_values_list = []
        max_input_len_list = []
        for feature in features:
            qaimage_output: QaImagaOutput = build_qaimage(
                self.processor, feature[0], feature[1], feature[2]
            )
            temp_input_ids, temp_labels = self.convert_one_piece(
                qaimage_output.q_input_ids, qaimage_output.a_input_ids
            )
            max_input_len_list.append(temp_input_ids.shape[1])

            input_ids_list.append(temp_input_ids)
            labels_list.append(temp_labels)
            pixel_values_list.append(qaimage_output.pixel_values)
        max_input_len = max(max_input_len_list)
        # 对齐token的长度
        input_ids = []
        for index, value in enumerate(input_ids_list):
            new_value = torch.concat(
                [
                    torch.full(
                        size=(1, max_input_len - max_input_len_list[index]),
                        fill_value=self.pad_token_ids,
                    ),
                    value,
                ],
                axis=1,
            )
            input_ids.append(new_value)
        labels = []
        for index, value in enumerate(labels_list):
            labels.append(
                torch.concat(
                    [
                        torch.full(
                            size=(1, max_input_len - max_input_len_list[index]),
                            fill_value=self.ignore_index,
                        ),
                        value,
                    ],
                    axis=1,
                )
            )
        input_ids = torch.concat(input_ids, axis=0)
        labels = torch.concat(labels, axis=0)
        pixel_values = torch.concat(pixel_values_list, axis=0)

        attention_mask = torch.ones_like(input_ids)
        attention_mask[input_ids == self.pad_token_ids] = 0 # 将填充的置为0

        return {
            "input_ids": input_ids,
            "labels": labels,
            "pixel_values": pixel_values,
            "attention_mask": attention_mask,
        }

In [None]:
llava_process.tokenizer.pad_token_ids,llava_process.tokenizer.eos_token_ids

In [14]:
tlmc = TrainLlavaModelCollator(llava_process,-100)
d = tlmc([test_llavadataset[23],test_llavadataset[100]])
d["attention_mask"]
# d = tlmc.convert_one_piece(c.q_input_ids,c.a_input_ids)
# d[0].shape

tensor([[0, 0, 0,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])

## 测试处理的数据，模型是否能够处理

### 1、模型初始化

In [13]:
from transformers import AutoProcessor,LlavaProcessor,LlavaForConditionalGeneration



llava_model_path = "models/llava_clip-L-14-336_Qwen1.5-1.8B"
llava_model = LlavaForConditionalGeneration.from_pretrained(llava_model_path,dtype=torch.bfloat16,device_map="cuda:0")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### 2、测试多batch数据

In [15]:
for tk in d.keys():
    d[tk] = d[tk].to(llava_model.device)
model_output = llava_model(**d)

In [17]:
model_output.loss

tensor(22.0397, device='cuda:0', grad_fn=<NllLossBackward0>)