In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import json
import torch
from torch.utils.data import Dataset
print("Torch version:", torch.__version__)
from transformers import AutoModelForCausalLM, AutoProcessor, TrainingArguments, Trainer
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from PIL import Image
from accelerate import Accelerator
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
def print_gpu():
    props = torch.cuda.get_device_properties(0)
    print("当前卡型号", props.name)  # 会显示你的第 4 张卡型号
    for cuda_id in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(cuda_id)
        total = props.total_memory / 1024**3
        allocated = torch.cuda.memory_allocated(cuda_id) / 1024**3
        reserved = torch.cuda.memory_reserved(cuda_id) / 1024**3
        free = total - reserved
        print(f"GPU {cuda_id}:")
        print(f"  总显存: {total:.2f} GB")
        print(f"  已分配显存: {allocated:.2f} GB")
        print(f"  已缓存显存: {reserved:.2f} GB")
        print(f"  空闲显存: {free:.2f} GB")

# torch.cuda.set_device(4)
# 下面的代码需要修改，以适配不同的设备
print_gpu()
model_path = r"/home/tsingtao/vl_finetune_train/ERNIE-4.5-VL-28B-A3B-Paddle"
data_file = r"/home/tsingtao/vl_finetune_train/vl_finetune/sft_vl_train_shuffle.jsonl"
image_dir = r"/home/tsingtao/vl_finetune_train/vl_finetune"
output_dir = "output"
num_epochs = 10
batch_size = 2
gradient_accumulation_steps = 8
learning_rate = 3e-4

# accelerator = Accelerator()
# device = accelerator.device

# device="cpu"
# 加载模型和处理器
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    load_in_8bit=True,
    llm_int8_enable_fp32_cpu_offload=True,
    device_map="auto",
)

print(model.device)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)

# LoRA配置
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
lora_config = LoraConfig(
    r=2,
    lora_alpha=8,
    lora_dropout=0.2,
    target_modules=["q_proj", "v_proj", "k_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

print_gpu()

In [None]:
%matplotlib notebook
import os
import json
import torch
import gc
from torch.utils.data import Dataset
from PIL import Image
from dataclasses import dataclass
from typing import List, Dict, Union
from transformers import Trainer, TrainingArguments
import matplotlib.pyplot as plt
import psutil
import GPUtil
import numpy as np
import threading
import time
from IPython.display import clear_output
from transformers import TrainerCallback
import traceback  

# 显示显存占用，方便通过图查看显存占用
class MemoryMonitor:
    def __init__(self, delay=1, plot_interval=5):
        self.delay = delay  # 更新时间
        self.plot_interval = plot_interval  
        self.keep_measuring = True
        self.gpu_memory = []
        self.timestamps = []
        self.start_time = time.time()
        self.fig, self.ax = plt.subplots(figsize=(10, 5))
        
    def measure_memory(self):
        while self.keep_measuring:
            try:
                # 这里使用torch查看缓存的显存来了解当前的显存占用
                cuda_id = 0
                allocated = torch.cuda.memory_allocated(cuda_id) / 1024**3  # GB
                self.gpu_memory.append(allocated)
                self.timestamps.append(time.time() - self.start_time)
            
                # 更新
                if len(self.gpu_memory) % self.plot_interval == 0:
                    self.update_plot()
            except Exception as e:
                print(f"Memory monitoring error: {e}")
                
            time.sleep(self.delay)
    
    def update_plot(self):
        from IPython.display import display
        
        # 不用中文避免乱码
        with plt.ioff():
            self.ax.clear()
            self.ax.plot(self.timestamps, self.gpu_memory, label='GPU Memory (GB)')
            self.ax.set_xlabel('Time (s)')
            self.ax.set_ylabel('Memory Usage (GB)')
            self.ax.set_title('GPU Memory')
            self.ax.legend()
            self.ax.grid(True)
            self.fig.tight_layout()
            self.fig.canvas.draw_idle()
        
        # 使用matplotlib在notebook中展示
        display(self.fig)
    
    # 新开线程实现绘制显存占用图
    def start(self):
        self.thread = threading.Thread(target=self.measure_memory)
        self.thread.daemon = True
        self.thread.start()
    
    def stop(self):
        self.keep_measuring = False
        if hasattr(self, 'thread'):
            self.thread.join(timeout=1)
        self.update_plot()

class TeaDiseaseDataset(Dataset):
    def __init__(self, json_file, image_dir, processor, max_length=1024):  # Reduced max_length
        self.processor = processor
        self.max_length = max_length
        
        # 加载json文件
        with open(json_file, 'r', encoding='utf-8') as f:
            self.data = [json.loads(line) for line in f]
        
        # 图片文件夹
        self.base_dir = image_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        try:
            # 获取图片
            image_path = os.path.join(self.base_dir, item["image_info"][0]["image_url"])
            image = Image.open(image_path).convert("RGB")
            
            # 信息
            question = item["text_info"][0]["text"]
            answer = item["text_info"][1]["text"]
            
            instruction = f"请根据图片回答以下问题：{question}"
            full_text = instruction + answer

            encoding = self.processor(
                text=full_text,
                images=image,
                return_tensors="pt",
                padding="max_length",
                max_length=self.max_length,
                truncation=True,
            )
            
            input_ids = encoding.input_ids[0]
            
            if hasattr(encoding, 'attention_mask') and encoding.attention_mask is not None:
                attention_mask = encoding.attention_mask[0]
            else:
                pad_token_id = self.processor.tokenizer.pad_token_id
                attention_mask = (input_ids != pad_token_id).long()

            instruction_encoding = self.processor(
                text=instruction,
                return_tensors="pt",
                truncation=True,
                max_length=self.max_length, 
            )
            instruction_len = instruction_encoding.input_ids.shape[1]
            
            # 省点内存
            del instruction_encoding
            
            # 造label
            labels = input_ids.clone()
            labels[:instruction_len] = -100
            
            # mask
            pad_token_id = self.processor.tokenizer.pad_token_id
            labels[labels == pad_token_id] = -100
            
            # 获取pixel
            pixel_values = None
            if hasattr(encoding, "pixel_values"):
                pixel_values = encoding.pixel_values[0]
            
            # 获取位置id
            seq_length = input_ids.shape[0]
            position_ids = torch.arange(seq_length, dtype=torch.long)
            
            if hasattr(encoding, "token_type_ids"):
                token_type_ids = encoding.token_type_ids[0]
            else:
                token_type_ids = torch.zeros(seq_length, dtype=torch.long)
            
            del encoding
            
            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "pixel_values": pixel_values,
                "labels": labels,
                "token_type_ids": token_type_ids,
            }
        
        except Exception as e:
            print(f"错误：")
            traceback_print_exc()
            return self.__getitem__(max(0, idx-1))

@dataclass
class ErnieDataCollator:
    def __call__(self, features: List[Dict[str, Union[torch.Tensor, None]]]) -> Dict[str, torch.Tensor]:
        batch = {}
        for key in ["input_ids", "attention_mask", "position_ids", "labels"]:
            if key not in features[0]:
                continue
                
            batch[key] = torch.stack([f[key] for f in features if f[key] is not None])

        if "pixel_values" in features[0] and features[0]["pixel_values"] is not None:
            batch["pixel_values"] = torch.stack([f["pixel_values"] for f in features if f["pixel_values"] is not None])
        
        if "token_type_ids" in features[0] and features[0]["token_type_ids"] is not None:
            batch["token_type_ids"] = torch.stack([f["token_type_ids"] for f in features if f["token_type_ids"] is not None])
            
        return batch

class ErnieTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        计算损失，ERNIE4.5-VL没给损失，需要我们自己构造
        """
        # Extract inputs that the model needs for forward pass
        input_dict = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "labels": inputs["labels"],
        }
        
 
        if "pixel_values" in inputs:
            input_dict["pixel_values"] = inputs["pixel_values"]
        

        if "token_type_ids" in inputs:
            token_type_ids = inputs["token_type_ids"]
            input_ids_length = inputs["input_ids"].shape[1]
            
            if token_type_ids.shape[1] == input_ids_length:
                last_token = token_type_ids[:, -1].unsqueeze(1)
                token_type_ids = torch.cat([token_type_ids, last_token], dim=1)
            input_dict["token_type_ids"] = token_type_ids

        if "position_ids" in inputs:
            input_dict["position_ids"] = inputs["position_ids"]
            
        outputs = model(**input_dict)
        
        if "loss" not in outputs:
            logits = outputs.logits
            labels = inputs["labels"]
            loss_fct = torch.nn.CrossEntropyLoss()
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
            outputs = {"loss": loss, "logits": logits}
        torch.cuda.empty_cache()
        
        return (outputs["loss"], outputs) if return_outputs else outputs["loss"]

    def training_step(self, *args, **kwargs):
        output = super().training_step(*args, **kwargs)
        
        # 缓存清空，每10步一次，避免oom
        if self.state.global_step % 10 == 0:  # Every 10 steps
            torch.cuda.empty_cache()
            gc.collect()
            
        return output

class PrinterCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        _ = logs.pop("total_flos", None)
        if state.is_local_process_zero:
            print(f"Step {state.global_step}: {logs}")

def train_model(data_file, image_dir, processor, model, output_dir, 
               num_epochs=3, learning_rate=2e-5, batch_size=1, 
               gradient_accumulation_steps=4, max_length=1024):
    """
    训练主函数
    """
    memory_monitor = MemoryMonitor(delay=2)
    memory_monitor.start()

    try:
        dataset = TeaDiseaseDataset(data_file, image_dir, processor, max_length=max_length)
        print(f"训练集大小: {len(dataset)}")
        collect_fn = ErnieDataCollator()

        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            learning_rate=learning_rate,
            weight_decay=0.01,
            warmup_ratio=0.03,
            lr_scheduler_type="cosine",
            save_strategy="epoch",
            logging_steps=10,
            fp16=True,
            gradient_checkpointing=False,  # 这个一定不能用，ERNIE多模态模型不支持
            report_to="tensorboard",
            dataloader_pin_memory=False,
            optim="adamw_torch",
            save_total_limit=1,
            eval_steps=None,
        )
        printer_callback = PrinterCallback()

        trainer = ErnieTrainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            data_collator=collect_fn,
            callbacks=[printer_callback]
        )
        
        print("开练...")
        trainer.train()

        model.save_pretrained(os.path.join(output_dir, "final_model"))
        processor.save_pretrained(os.path.join(output_dir, "final_processor"))
        print(f"保存于{os.path.join(output_dir, 'final_model')}")

        
    finally:
        memory_monitor.stop()

train_model(
    data_file=data_file,
    image_dir=image_dir, 
    processor=processor, 
    model=model, 
    output_dir=output_dir,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    batch_size=1,
    gradient_accumulation_steps=8,  # 设置了8次就会梯度累计然后清空
    max_length=1024 
)