# 数据加载

## 环境包导入

In [12]:
# 首先导入基础包
import os
import json
import torch
import torch.nn as nn

from torch.utils.data import Dataset


# 1. 导入所需的包
import os
import json
from PIL import Image
import torchvision.transforms as transforms
from transformers import (
    AutoProcessor,
    LlavaForConditionalGeneration,
    BitsAndBytesConfig
)
from modelscope import snapshot_download
from custom_llava import CustomLlavaModel
from micl_model import MICLModel


In [2]:
class MICLDataset(Dataset):
    """MICL数据集类"""
    def __init__(self, json_path, image_dir, transform=None):
        """
        初始化数据集
        Args:
            json_path: JSON文件路径
            image_dir: 图片目录路径
            transform: 图像转换函数
        """
        self.data = self.load_data(json_path, image_dir)
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])

    @staticmethod
    def load_data(json_path, image_dir):
        """加载数据集"""
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        dataset = []
        for item in data:
            q=[]
            image_path = os.path.join(image_dir, item['query']['image_name'])
            if os.path.exists(image_path):
                q.append({
                    'image': image_path,
                    'topic': item['query']['keywords'],
                    'comment': item['query']['comment']
                })
            t=[]
            image_path = os.path.join(image_dir, item['target']['image_name'])
            if os.path.exists(image_path):
                t.append({
                    'image': image_path,
                    'topic': item['target']['keywords'],
                    'comment': item['target']['comment']
                })
            dataset.append({
                'query': q,
                'target': t
            })
        return dataset

    def __len__(self):
        """返回数据集大小"""
        return len(self.data)

    def __getitem__(self, idx):
        """获取单个数据样本"""
        item = self.data[idx]

        # 处理查询图像
        query_image = self.transform(item['query'][0]['image'])
        query_topic = item['query'][0]['topic']
        query_comment = item['query'][0]['comment']

        # 处理目标图像
        target_image = self.transform(item['target'][0]['image'])
        target_topic = item['target'][0]['topic']
        target_comment = item['target'][0]['comment']

        return {
            'query_image': query_image,
            'query_topic': query_topic,
            'query_comment': query_comment,
            'target_image': target_image,
            'target_topic': target_topic,
            'target_comment': target_comment
        }


In [7]:
# 加载数据
json_path = 'data/caption_key3_sim_bey25.json'
image_dir = '/root/course/llava/data/img200/'
dataset = MICLDataset.load_data(json_path, image_dir)


dataset[:1]

[{'query': [{'image': '/root/course/llava/data/img200/1000344755.jpg',
    'topic': 'man blue,standing stair,cleaning windows',
    'comment': 'Someone in a blue shirt and hat is standing on stair and leaning against a window . A man in a blue shirt is standing on a ladder cleaning a window . A man on a ladder cleans the window of a tall building . man in blue shirt and jeans on ladder cleaning windows a man on a ladder cleans a window'}],
  'target': [{'image': '/root/course/llava/data/img200/1044798682.jpg',
    'topic': 'man stands,scaffolding,cleaning windows',
    'comment': 'A man wearing a hat and a white shirt is cleaning windows . A man in a white shirt stands high up on scaffolding A man stands on boards on top of a huge ladder . Man works on top of scaffolding . A guy works on a building .'}]}]

# 数据处理

In [5]:
from PIL import Image
class DataProcessor(nn.Module):
    """
    数据处理器类，用于处理MICL数据集的图像和文本
    结合了AutoProcessor的功能来处理多模态输入
    """
    def __init__(self, dataset, processor, device='cuda'):
        super().__init__()
        self.dataset = dataset
        self.processor = processor
        self.device = device  # 添加设备参数
                # 设置处理器的patch_size
        self.processor.patch_size = 14  # LLaVA默认使用14x14的patch size
        self.processor.num_additional_image_tokens = 1
        self.processor.vision_feature_select_strategy = "default"
        self.device = device

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        """获取单个数据样本"""
        item = self.dataset[idx]

        # 打开图像文件
        query_image = Image.open(item['query'][0]['image']).convert('RGB')
        target_image = Image.open(item['target'][0]['image']).convert('RGB')

        # 构建查询提示
        prompt_q = f"Note content: {{'topic': '{item['query'][0]['topic']}', 'content': '{item['query'][0]['comment']}', 'image': <image>}}, Compress this note into one word："

        # 构建目标提示
        prompt_t = f"Note content: {{'topic': '{item['target'][0]['topic']}', 'content': '{item['target'][0]['comment']}', 'image': <image>}}, Compress this note into one word："

        # 处理查询输入
        input_q = self.processor(
            images=query_image,
            text=prompt_q,
            return_tensors="pt",
            padding=True
        )

        # 处理目标输入
        input_t = self.processor(
            images=target_image,
            text=prompt_t,
            return_tensors="pt",
            padding=True
        )

        # 移除批次维度
        for k in input_q.keys():
            if torch.is_tensor(input_q[k]):
                input_q[k] = input_q[k].squeeze(0)
        for k in input_t.keys():
            if torch.is_tensor(input_t[k]):
                input_t[k] = input_t[k].squeeze(0)

        return {
            'query_inputs': input_q,
            'target_inputs': input_t
        }

    def collate_fn(self, batch):
        """
        将多个样本组合成一个批次
        Args:
            batch: 样本列表
        Returns:
            批处理后的数据
        """
        # 初始化批次数据结构
        batch_data = {
            'query_inputs': {
                'input_ids': [],
                'attention_mask': [],
                'pixel_values': []
            },
            'target_inputs': {
                'input_ids': [],
                'attention_mask': [],
                'pixel_values': []
            }
        }

        # 首先收集所有序列长度
        max_length = {
            'query_inputs': {'input_ids': 0, 'attention_mask': 0},
            'target_inputs': {'input_ids': 0, 'attention_mask': 0}
        }

        # 找出最大长度
        for sample in batch:
            for input_type in ['query_inputs', 'target_inputs']:
                max_length[input_type]['input_ids'] = max(
                    max_length[input_type]['input_ids'],
                    len(sample[input_type]['input_ids'])
                )
                max_length[input_type]['attention_mask'] = max(
                    max_length[input_type]['attention_mask'],
                    len(sample[input_type]['attention_mask'])
                )

        # 收集并padding数据
        for sample in batch:
            for input_type in ['query_inputs', 'target_inputs']:
                # 处理input_ids
                input_ids = sample[input_type]['input_ids']
                padding_length = max_length[input_type]['input_ids'] - len(input_ids)
                padded_input_ids = torch.cat([
                    input_ids,
                    torch.zeros(padding_length, dtype=input_ids.dtype)
                ])
                batch_data[input_type]['input_ids'].append(padded_input_ids)

                # 处理attention_mask
                attention_mask = sample[input_type]['attention_mask']
                padding_length = max_length[input_type]['attention_mask'] - len(attention_mask)
                padded_attention_mask = torch.cat([
                    attention_mask,
                    torch.zeros(padding_length, dtype=attention_mask.dtype)
                ])
                batch_data[input_type]['attention_mask'].append(padded_attention_mask)

                # 处理pixel_values (图像特征通常已经是固定大小)
                batch_data[input_type]['pixel_values'].append(sample[input_type]['pixel_values'])

        # 堆叠张量并移动到指定设备
        for input_type in ['query_inputs', 'target_inputs']:
            for key in batch_data[input_type].keys():
                batch_data[input_type][key] = torch.stack(batch_data[input_type][key]).to(self.device)

        return batch_data

# 模型本地导入

In [14]:

model_dir = snapshot_download('swift/llava-1.5-7b-hf')
# 2. 配置显存优化参数
torch.cuda.empty_cache()  # 清空显存缓存
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

# 3. 配置量化参数（更激进的量化）
nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,                     # 启用 4-bit 量化
    bnb_4bit_quant_type="nf4",            # 使用 NF4 量化类型
    bnb_4bit_use_double_quant=True,       # 启用双量化
    bnb_4bit_compute_dtype=torch.float16,  # 使用 float16 而不是 bfloat16
    llm_int8_enable_fp32_cpu_offload=True  # 启用 CPU 卸载
)

# 4. 加载模型（启用梯度检查点）
model = CustomLlavaModel.from_pretrained(
    model_dir,
    quantization_config=nf4_config,
    device_map="auto",
    torch_dtype=torch.float16,            # 使用 float16
)


# 6. 加载处理器并配置
llava_processor = AutoProcessor.from_pretrained(
    model_dir,
    use_fast=True,
)



Downloading Model to directory: /root/.cache/modelscope/hub/swift/llava-1.5-7b-hf


2025-02-14 18:11:54,269 - modelscope - INFO - Target directory already exists, skipping creation.
Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.45s/it]


## some test  for sample data


In [None]:
len(llava_processor)

In [None]:
sample.keys()

In [None]:
sample['query_inputs'].keys()

In [51]:
ss=sample['query_inputs']['input_ids']

In [None]:
ss[-1]

In [None]:
ss[600:610]

# 开始训练-测试

## micl

In [None]:
from torch.utils.data import _utils
# 设置训练参数
batch_size = 1
num_epochs = 10
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# 创建数据处理器实例
data_processor = DataProcessor(
    dataset=dataset,
    processor=llava_processor,
    device=model.device  # 使用模型的设备
)
# 使用DataLoader进行批处理
train_loader = torch.utils.data.DataLoader(
    data_processor,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=data_processor.collate_fn,
)

# 训练循环
for epoch in range(num_epochs):
    total_loss = 0
    for batch in train_loader:
        # 获取输入
        query_inputs = batch['query_inputs']
        target_inputs = batch['target_inputs']

        # 提取嵌入
        q_i, q_t = model.extract_embedding(
            input_ids=query_inputs['input_ids'],          # 使用字典访问方式
            pixel_values=query_inputs['pixel_values'],    # 使用字典访问方式
            attention_mask=query_inputs['attention_mask'] # 使用字典访问方式
        )

        t_i, t_t = model.extract_embedding(
            input_ids=target_inputs['input_ids'],         # 使用字典访问方式
            pixel_values=target_inputs['pixel_values'],   # 使用字典访问方式
            attention_mask=target_inputs['attention_mask']# 使用字典访问方式
        )

        # 计算MICL损失
        loss = model.micl_loss(q_i, q_t, t_i, t_t)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # 打印每个epoch的平均损失
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")



## ez 取得表征txt/img向量

In [8]:
from torch.utils.data import _utils
# 设置训练参数
batch_size = 1
num_epochs = 10
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# 创建数据处理器实例
data_processor = DataProcessor(
    dataset=dataset,
    processor=llava_processor,
    device=model.device  # 使用模型的设备
)
# 使用DataLoader进行批处理
train_loader = torch.utils.data.DataLoader(
    data_processor,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=data_processor.collate_fn,
)

# 训练循环
for epoch in range(num_epochs):
    total_loss = 0
    for batch in train_loader:
        # 获取输入
        query_inputs = batch['query_inputs']
        target_inputs = batch['target_inputs']

        # 提取嵌入
        q_i, q_t = model.extract_embedding(
            input_ids=query_inputs['input_ids'],          # 使用字典访问方式
            pixel_values=query_inputs['pixel_values'],    # 使用字典访问方式
            attention_mask=query_inputs['attention_mask'] # 使用字典访问方式
        )

        t_i, t_t = model.extract_embedding(
            input_ids=target_inputs['input_ids'],         # 使用字典访问方式
            pixel_values=target_inputs['pixel_values'],   # 使用字典访问方式
            attention_mask=target_inputs['attention_mask']# 使用字典访问方式
        )

        print("q_i shape:", q_i.shape)
        # ... 其余训练代码 ...

q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape: torch.Size([1, 4096])
q_i shape:

In [None]:
h.keys()