In [None]:
# ==== 使用 LLaVA-Med 进行微调 ====

from transformers import (
    LlavaForConditionalGeneration,
    LlavaProcessor,
    AutoTokenizer,
    CLIPImageProcessor,
    BitsAndBytesConfig,
)
import torch

# 从本地路径加载模型
model_local_path = r"C:\Users\76949\.cache\huggingface\hub\models--microsoft--llava-med-v1.5-mistral-7b\snapshots\91bb16c122001ddc9cf1fd36ce1dae09448943a2"

print(f"从本地路径加载模型: {model_local_path}")

# 加载 processor
# 创建 image_processor 时设置正确的图像尺寸（336x336）
image_processor = CLIPImageProcessor.from_pretrained(
    "openai/clip-vit-large-patch14",
    size={"height": 336, "width": 336},  # LLaVA-Med 模型期望的尺寸
    crop_size={"height": 336, "width": 336}
)
tokenizer = AutoTokenizer.from_pretrained(model_local_path, use_fast=False)

processor = LlavaProcessor(
    tokenizer=tokenizer,
    image_processor=image_processor
)

# 补全 LLaVA-Med 必要属性
processor.patch_size = 14
processor.num_additional_image_tokens = 0
processor.vision_feature_select_strategy = "default"

# 确保 image_processor 的尺寸设置正确
processor.image_processor.size = {"height": 336, "width": 336}
processor.image_processor.crop_size = {"height": 336, "width": 336}

# 设置 tokenizer 的 model_max_length 足够大，避免截断图像 token
# LLaVA-Med 的图像 token 数量约为 576，加上文本需要更大的长度
if hasattr(processor.tokenizer, 'model_max_length'):
    # 如果 model_max_length 太小，设置为一个更大的值
    if processor.tokenizer.model_max_length < 4096:
        processor.tokenizer.model_max_length = 4096
        print(f"✓ 已设置 tokenizer.model_max_length = {processor.tokenizer.model_max_length}")

print("✓ Processor 已加载并配置完成")

# 设置 device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")

# 检测显存大小（16GB显存的智能策略）
if torch.cuda.is_available():
    gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"检测到 GPU 显存: {gpu_memory_gb:.1f} GB")

    # 16GB显存的策略：先尝试不使用量化，如果OOM再使用量化
    if 15.0 <= gpu_memory_gb <= 17.0:  # 16GB左右
        print("⚠️ 检测到16GB显存，将先尝试不使用量化（性能更好）")
        print("   如果出现OOM错误，请重新运行并设置 use_quantization = True")
        use_quantization = False  # 16GB可以尝试不使用量化
    elif gpu_memory_gb < 15.0:
        print("⚠️ 显存 < 15GB，建议使用量化")
        use_quantization = True
    else:
        print("✓ 显存充足，不使用量化")
        use_quantization = False
else:
    use_quantization = False
    print("CPU模式，不使用量化")

# 如果显存正好16GB，可以手动设置策略：
# use_quantization = False  # 尝试不使用量化（性能更好，但可能OOM）
# use_quantization = True   # 使用量化（更安全，但性能略低）

# 尝试加载模型
model_loaded = False

# 策略1：先尝试不使用量化（16GB显存可能可以）
if not use_quantization and torch.cuda.is_available():
    try:
        print("\n尝试不使用量化加载模型（性能更好）...")
        model = LlavaForConditionalGeneration.from_pretrained(
            model_local_path,
            dtype=torch.float16,
            device_map="auto",
        )
        print("✓ 模型已从本地路径加载完成（无量化，性能最佳）")
        model_loaded = True
    except RuntimeError as e:
        if "out of memory" in str(e).lower() or "OOM" in str(e):
            print(f"\n⚠️ 显存不足（OOM），自动切换到4-bit量化...")
            torch.cuda.empty_cache()  # 清除缓存
            use_quantization = True
        else:
            raise e

# 策略2：使用4-bit量化（如果策略1失败或显存不足）
if not model_loaded and use_quantization and torch.cuda.is_available():
    try:
        print("\n使用 4-bit 量化加载模型（节省显存）...")
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )
        model = LlavaForConditionalGeneration.from_pretrained(
            model_local_path,
            dtype=torch.float16,
            quantization_config=quantization_config,
            device_map="auto",
        )
        print("✓ 模型已从本地路径加载完成（4-bit 量化，显存占用约6-8GB）")
        model_loaded = True
    except Exception as e:
        print(f"⚠️ 量化加载也失败: {e}")
        raise e

# CPU模式
if not model_loaded:
    print("\n使用 CPU 模式加载模型...")
    model = LlavaForConditionalGeneration.from_pretrained(
        model_local_path,
        dtype=torch.float32,
        device_map=None,
    )
    model = model.to(device)
    print("✓ 模型已从本地路径加载完成（CPU模式）")

# 检查并设置图像 token ID（在模型加载完成后）
# LLaVA 模型需要知道图像 token ID 来识别图像 token
image_token_id = None
if hasattr(model, 'config'):
    if hasattr(model.config, 'image_token_id'):
        image_token_id = model.config.image_token_id
        print(f"✓ 模型配置中的 image_token_id: {image_token_id}")
    else:
        # 尝试从 tokenizer 获取
        try:
            # 检查 tokenizer 是否有 image_token_id
            if hasattr(processor.tokenizer, 'convert_tokens_to_ids'):
                test_id = processor.tokenizer.convert_tokens_to_ids('<image>')
                if test_id != processor.tokenizer.unk_token_id:
                    image_token_id = test_id
                    print(f"✓ 从 tokenizer 获取 <image> token ID: {image_token_id}")
                else:
                    # 尝试常见的图像 token ID
                    # LLaVA 通常使用词汇表末尾的 token
                    vocab_size = getattr(processor.tokenizer, 'vocab_size', None)
                    if vocab_size:
                        # 使用词汇表末尾的 token（通常是预留的特殊 token）
                        image_token_id = vocab_size - 1
                        print(f"⚠️ <image> token 不在词汇表中，使用 vocab_size - 1 = {image_token_id}")
                    else:
                        image_token_id = 32000
                        print(f"⚠️ 无法获取词汇表大小，使用默认值 32000")
        except Exception as e:
            print(f"⚠️ 无法获取图像 token ID: {e}")
            # 尝试使用词汇表末尾的 token
            vocab_size = getattr(processor.tokenizer, 'vocab_size', None)
            if vocab_size:
                image_token_id = vocab_size - 1
            else:
                image_token_id = 32000

# 验证并修正图像 token ID（确保它在有效范围内）
vocab_size = getattr(processor.tokenizer, 'vocab_size', None)
if vocab_size and image_token_id is not None and image_token_id >= vocab_size:
    image_token_id = vocab_size - 1
    print(f"⚠️ 图像 token ID 超出范围，已修正为 vocab_size - 1 = {image_token_id}")

# 如果找到了 image_token_id，确保模型配置中有它
if image_token_id is not None and hasattr(model, 'config'):
    if not hasattr(model.config, 'image_token_id'):
        model.config.image_token_id = image_token_id
        print(f"✓ 已设置 model.config.image_token_id = {image_token_id}")

print("\n✓ LLaVA-Med 模型与 Processor 已加载。")

  from .autonotebook import tqdm as notebook_tqdm
  if not hasattr(np, "object"):


从本地路径加载模型: C:\Users\76949\.cache\huggingface\hub\models--microsoft--llava-med-v1.5-mistral-7b\snapshots\91bb16c122001ddc9cf1fd36ce1dae09448943a2


You are using a model of type llava_mistral to instantiate a model of type llava. This is not supported for all configurations of models and can yield errors.


✓ 已设置 tokenizer.model_max_length = 4096
✓ Processor 已加载并配置完成
使用设备: cuda
检测到 GPU 显存: 15.9 GB
⚠️ 检测到16GB显存，将先尝试不使用量化（性能更好）
   如果出现OOM错误，请重新运行并设置 use_quantization = True

尝试不使用量化加载模型（性能更好）...


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 75.91it/s]
Some weights of LlavaForConditionalGeneration were not initialized from the model checkpoint at C:\Users\76949\.cache\huggingface\hub\models--microsoft--llava-med-v1.5-mistral-7b\snapshots\91bb16c122001ddc9cf1fd36ce1dae09448943a2 and are newly initialized: ['model.language_model.embed_tokens.weight', 'model.language_model.layers.0.input_layernorm.weight', 'model.language_model.layers.0.mlp.down_proj.weight', 'model.language_model.layers.0.mlp.gate_proj.weight', 'model.language_model.layers.0.mlp.up_proj.weight', 'model.language_model.layers.0.post_attention_layernorm.weight', 'model.language_model.layers.0.self_attn.k_proj.weight', 'model.language_model.layers.0.self_attn.o_proj.weight', 'model.language_model.layers.0.self_attn.q_proj.weight', 'model.language_model.layers.0.self_attn.v_proj.weight', 'model.language_model.layers.1.input_layernorm.weight', 'model.language_model.layers.1.mlp.down_proj.weight', 'model

✓ 模型已从本地路径加载完成（无量化，性能最佳）
✓ 模型配置中的 image_token_id: 32000

✓ LLaVA-Med 模型与 Processor 已加载。


In [2]:
# ==== 步骤1: 导入必要的库和定义工具函数 ====

import os
import json
import re
from typing import Tuple, List
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
import numpy as np
import random
from tqdm.auto import tqdm
from torch.optim import AdamW

# 数据集路径
DATA_JSON_PATH = r"data\archive (1)\VQA_RAD Dataset Public.json"
DATA_IMAGE_DIR = r"data\archive (1)\VQA_RAD Image Folder"

def normalize_answer(ans: str) -> str:
    """简单归一化答案，用于区分close(是/否)和open类型。"""
    if ans is None:
        return ""
    ans = str(ans).lower().strip()
    ans_clean = ans.rstrip('.,!?;:').strip()

    # 先规范 yes/no
    if ans_clean == "yes":
        return "yes"
    if ans_clean == "no":
        return "no"

    if ans_clean.startswith("yes"):
        if len(ans_clean) == 3 or ans_clean[3] in [" ", ",", ".", "!", "?", ";", ":"]:
            return "yes"
    if ans_clean.startswith("no"):
        if len(ans_clean) == 2 or ans_clean[2] in [" ", ",", ".", "!", "?", ";", ":"]:
            return "no"

    # 针对 open 问题的一些简单同义合并
    synonym_map = {
        "right side": "right",
        "left side": "left",
        "rt": "right",
        "lt": "left",
        "xray": "x-ray",
        "x ray": "x-ray",
        "ct scan": "ct",
    }
    if ans_clean in synonym_map:
        return synonym_map[ans_clean]
    
    ans_clean = ans_clean.replace("xray", "x-ray").replace("x ray", "x-ray")
    return ans_clean

print("✓ 工具函数已定义")

✓ 工具函数已定义


In [3]:
# ==== 步骤2: 创建符合 LLaVA 标准的数据集类 ====
# LLaVA 需要原始 PIL 图像，不使用 torchvision 的 transform
# 图像处理由 processor 的 image_processor 完成

class VQARADLLaVADataset(Dataset):
    """符合 LLaVA 标准的数据集类
    
    返回原始 PIL 图像，让 processor 来处理图像预处理
    """
    
    def __init__(self, json_path: str, image_dir: str, indices: List[int] = None):
        """
        Args:
            json_path: JSON 数据文件路径
            image_dir: 图像目录路径
            indices: 可选，指定使用的数据索引列表。如果为 None，使用全部数据
        """
        assert os.path.exists(json_path), f"JSON 文件不存在: {json_path}"
        assert os.path.exists(image_dir), f"图像目录不存在: {image_dir}"
        
        with open(json_path, "r", encoding="utf-8") as f:
            self.data = json.load(f)
        
        self.image_dir = image_dir
        self.indices = indices if indices is not None else list(range(len(self.data)))
    
    def __len__(self) -> int:
        return len(self.indices)
    
    def _get_image_path(self, item) -> str:
        """获取图像路径"""
        image_key = "image" if "image" in item else "image_name"
        return os.path.join(self.image_dir, item[image_key])
    
    def __getitem__(self, idx: int) -> Tuple[Image.Image, str, str]:
        """
        返回: (PIL Image, question, answer)
        
        - 返回原始 PIL 图像，不使用 transform
        - question: 问题文本
        - answer: 答案文本（原始答案，用于训练）
        """
        real_idx = self.indices[idx]
        item = self.data[real_idx]
        
        # 读取原始图像（PIL Image，不做任何 transform）
        image_path = self._get_image_path(item)
        image = Image.open(image_path).convert("RGB")
        
        question = item.get("question", "")
        answer = item.get("answer", "")
        
        return image, question, answer

print("✓ LLaVA 数据集类已定义")

✓ LLaVA 数据集类已定义


In [4]:
# ==== 步骤3: 划分数据集并创建 LLaVA 数据集实例 ====

# 设置随机种子
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)

# 加载原始数据以进行划分
with open(DATA_JSON_PATH, "r", encoding="utf-8") as f:
    raw_data = json.load(f)

# 根据答案类型划分 close 和 open
close_indices = []
open_indices = []

for idx, item in enumerate(raw_data):
    answer_raw = item.get("answer", "")
    ans_norm = normalize_answer(answer_raw)
    if ans_norm in ["yes", "no"]:
        close_indices.append(idx)
    else:
        open_indices.append(idx)

print(f"数据集总样本数: {len(raw_data)}")
print(f"Close (是/否) 样本数: {len(close_indices)}")
print(f"Open  (开放式) 样本数: {len(open_indices)}")

# 对 close 和 open 数据集分别进行 8:2 划分
close_train_idx, close_test_idx = train_test_split(
    close_indices,
    test_size=0.2,
    random_state=42,
    shuffle=True
)

open_train_idx, open_test_idx = train_test_split(
    open_indices,
    test_size=0.2,
    random_state=42,
    shuffle=True
)

print("\nClose (yes/no) 数据集划分:")
print(f"  训练集: {len(close_train_idx)} 个样本 ({len(close_train_idx)/len(close_indices)*100:.1f}%)")
print(f"  测试集: {len(close_test_idx)} 个样本 ({len(close_test_idx)/len(close_indices)*100:.1f}%)")

print("\nOpen 数据集划分:")
print(f"  训练集: {len(open_train_idx)} 个样本 ({len(open_train_idx)/len(open_indices)*100:.1f}%)")
print(f"  测试集: {len(open_test_idx)} 个样本 ({len(open_test_idx)/len(open_indices)*100:.1f}%)")

# 创建 LLaVA 数据集实例
llava_close_train_dataset = VQARADLLaVADataset(DATA_JSON_PATH, DATA_IMAGE_DIR, close_train_idx)
llava_close_test_dataset = VQARADLLaVADataset(DATA_JSON_PATH, DATA_IMAGE_DIR, close_test_idx)

llava_open_train_dataset = VQARADLLaVADataset(DATA_JSON_PATH, DATA_IMAGE_DIR, open_train_idx)
llava_open_test_dataset = VQARADLLaVADataset(DATA_JSON_PATH, DATA_IMAGE_DIR, open_test_idx)

print("\n✓ LLaVA 数据集实例已创建")

数据集总样本数: 2248
Close (是/否) 样本数: 1193
Open  (开放式) 样本数: 1055

Close (yes/no) 数据集划分:
  训练集: 954 个样本 (80.0%)
  测试集: 239 个样本 (20.0%)

Open 数据集划分:
  训练集: 844 个样本 (80.0%)
  测试集: 211 个样本 (20.0%)

✓ LLaVA 数据集实例已创建


In [11]:
# ==== 步骤4: 实现符合 LLaVA 标准的 collate 函数 ====
# 使用 processor 的标准方法来处理数据

def get_valid_image_token_id():
    """获取并验证图像 token ID 是否在有效范围内"""
    image_token_id = None
    
    # 首先尝试从模型配置获取
    if hasattr(model, 'config') and hasattr(model.config, 'image_token_id'):
        image_token_id = model.config.image_token_id
    
    # 获取 tokenizer 的词汇表大小
    vocab_size = getattr(processor.tokenizer, 'vocab_size', None)
    if vocab_size is None:
        try:
            if hasattr(processor.tokenizer, 'get_vocab'):
                vocab = processor.tokenizer.get_vocab()
                vocab_size = len(vocab) if vocab else None
        except:
            pass
    
    # 验证并修正图像 token ID
    if vocab_size:
        if image_token_id is None or image_token_id >= vocab_size:
            # 使用词汇表末尾的 token（通常是预留的特殊 token）
            image_token_id = vocab_size - 1
    else:
        # 如果无法获取词汇表大小，使用默认值
        if image_token_id is None:
            image_token_id = 32000
    
    return image_token_id

def llava_collate_fn(batch):
    """
    LLaVA 标准的 collate 函数
    
    手动处理图像 token 插入，因为 processor 可能无法正确识别 <image> token
    """
    images, questions, answers = zip(*batch)
    
    # 获取图像 token ID（使用辅助函数确保有效性）
    image_token_id = get_valid_image_token_id()
    
    # 调试信息（仅第一次打印）
    if not hasattr(llava_collate_fn, '_debug_printed'):
        vocab_size = getattr(processor.tokenizer, 'vocab_size', 'unknown')
        print(f"调试信息: vocab_size = {vocab_size}, image_token_id = {image_token_id}")
        llava_collate_fn._debug_printed = True
    
    # 计算每个图像需要多少个 token（336/14 * 336/14 = 576）
    patch_size = getattr(processor, 'patch_size', 14)
    image_size = 336
    num_image_tokens = (image_size // patch_size) ** 2  # 576
    
    # 构建对话格式：USER: <image>\n{question}\nASSISTANT: {answer}
    conversations = []
    for q, a in zip(questions, answers):
        conversations.append(f"USER: <image>\n{q}\nASSISTANT: {a}")
    
    # 处理图像
    image_inputs = processor.image_processor(
        list(images),
        return_tensors="pt",
        size={"height": 336, "width": 336},
        do_resize=True,
        do_rescale=True,
        do_normalize=True
    )
    
    # 手动处理文本，插入图像 token
    processed_input_ids = []
    processed_attention_masks = []
    
    for conv in conversations:
        # 找到 <image> 的位置
        image_pos = conv.find('<image>')
        if image_pos != -1:
            # 分割文本
            before_image = conv[:image_pos]
            after_image = conv[image_pos + len('<image>'):]
            
            # 编码文本部分
            before_tokens = processor.tokenizer.encode(before_image, add_special_tokens=False)
            after_tokens = processor.tokenizer.encode(after_image, add_special_tokens=False)
            
            # 插入图像 token（576 个）
            image_tokens = [image_token_id] * num_image_tokens
            
            # 组合：before + image_tokens + after
            full_tokens = before_tokens + image_tokens + after_tokens
        else:
            # 如果没有 <image>，直接编码
            full_tokens = processor.tokenizer.encode(conv, add_special_tokens=False)
        
        processed_input_ids.append(full_tokens)
    
    # Padding 和截断
    # 限制最大长度避免 CUDA 超时（图像 token 576 + 文本，总共不超过 1024）
    max_length = 1024  # 减小最大长度避免超时
    max_len = max(len(ids) for ids in processed_input_ids)
    max_len = min(max_len, max_length)
    
    padded_input_ids = []
    padded_attention_masks = []
    
    for ids in processed_input_ids:
        if len(ids) > max_len:
            ids = ids[:max_len]
        
        attention_mask = [1] * len(ids) + [0] * (max_len - len(ids))
        ids = ids + [processor.tokenizer.pad_token_id] * (max_len - len(ids))
        
        padded_input_ids.append(ids)
        padded_attention_masks.append(attention_mask)
    
    # 转换为 tensor
    inputs = {
        "input_ids": torch.tensor(padded_input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(padded_attention_masks, dtype=torch.long),
        "pixel_values": image_inputs["pixel_values"],
    }
    
    # 构建 labels（只对答案部分计算 loss）
    labels = inputs["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    
    # 只对 ASSISTANT 的答案部分计算 loss
    for i, conv in enumerate(conversations):
        # 找到 ASSISTANT: 的位置
        assistant_pos = conv.find("ASSISTANT:")
        if assistant_pos != -1:
            # 获取 ASSISTANT: 之前的部分（包括 ASSISTANT:）
            prompt_before_answer = conv[:assistant_pos + len("ASSISTANT:")]
            
            # 手动编码 prompt（需要考虑图像 token）
            image_pos = prompt_before_answer.find('<image>')
            if image_pos != -1:
                before_image = prompt_before_answer[:image_pos]
                after_image = prompt_before_answer[image_pos + len('<image>'):]
                before_tokens = processor.tokenizer.encode(before_image, add_special_tokens=False)
                after_tokens = processor.tokenizer.encode(after_image, add_special_tokens=False)
                prompt_tokens = before_tokens + [image_token_id] * num_image_tokens + after_tokens
            else:
                prompt_tokens = processor.tokenizer.encode(prompt_before_answer, add_special_tokens=False)
            
            # 将 prompt 部分的 label 设为 -100
            for j in range(min(len(prompt_tokens), labels.shape[1])):
                labels[i, j] = -100
        else:
            # 如果找不到 ASSISTANT:，全部设为 -100
            labels[i, :] = -100
    
    inputs["labels"] = labels
    return inputs

print("✓ LLaVA collate 函数已定义")

✓ LLaVA collate 函数已定义


In [12]:
# ==== CUDA 状态重置工具（如果遇到 CUDA 超时错误，运行此 cell） ====

import torch
import gc

if torch.cuda.is_available():
    print("正在重置 CUDA 状态...")
    try:
        # 清除所有 CUDA 缓存
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        
        # 强制垃圾回收
        gc.collect()
        torch.cuda.empty_cache()
        
        # 重置内存统计
        try:
            torch.cuda.reset_peak_memory_stats()
        except:
            pass
        
        print("✓ CUDA 状态已重置")
        print(f"当前 GPU 显存使用: {torch.cuda.memory_allocated() / 1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    except Exception as e:
        print(f"⚠️ 重置 CUDA 状态时出错: {e}")
        print("建议：重启 Python 内核（Kernel -> Restart）")
else:
    print("未检测到 CUDA，无需重置")

正在重置 CUDA 状态...
⚠️ 重置 CUDA 状态时出错: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

建议：重启 Python 内核（Kernel -> Restart）


In [13]:
# ==== 步骤5: 创建 DataLoader ====

# LLaVA 模型较大，使用较小的 batch size（避免 CUDA 超时）
# 如果仍然超时，可以进一步减小到 1
llava_batch_size = 1  # 减小到 1 以避免 CUDA 超时

# 创建 DataLoader
llava_close_train_loader = DataLoader(
    llava_close_train_dataset,
    batch_size=llava_batch_size,
    shuffle=True,
    collate_fn=llava_collate_fn,
    num_workers=0,  # Windows 上建议设为 0
)

llava_close_test_loader = DataLoader(
    llava_close_test_dataset,
    batch_size=llava_batch_size,
    shuffle=False,
    collate_fn=llava_collate_fn,
    num_workers=0,
)

llava_open_train_loader = DataLoader(
    llava_open_train_dataset,
    batch_size=llava_batch_size,
    shuffle=True,
    collate_fn=llava_collate_fn,
    num_workers=0,
)

llava_open_test_loader = DataLoader(
    llava_open_test_dataset,
    batch_size=llava_batch_size,
    shuffle=False,
    collate_fn=llava_collate_fn,
    num_workers=0,
)

print(f"LLaVA-Med close 训练集大小: {len(llava_close_train_dataset)}，steps/epoch≈{len(llava_close_train_loader)}")
print(f"LLaVA-Med close 测试集大小: {len(llava_close_test_dataset)}")
print(f"LLaVA-Med open 训练集大小: {len(llava_open_train_dataset)}，steps/epoch≈{len(llava_open_train_loader)}")
print(f"LLaVA-Med open 测试集大小: {len(llava_open_test_dataset)}")
print("\n✓ DataLoader 已创建")

LLaVA-Med close 训练集大小: 954，steps/epoch≈954
LLaVA-Med close 测试集大小: 239
LLaVA-Med open 训练集大小: 844，steps/epoch≈844
LLaVA-Med open 测试集大小: 211

✓ DataLoader 已创建


In [14]:
# ==== 步骤6: 实现 LLaVA-Med Close 数据集训练循环 ====
# 
# ⚠️ 重要提示：
# 如果遇到 CUDA 超时错误，请：
# 1. 重启 Python 内核（Kernel -> Restart）
# 2. 重新运行所有前面的 cell
# 3. 如果问题持续，考虑使用 4-bit 量化或减小 batch size

# 设置 CUDA 超时时间（避免 CUDA 超时错误）
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # 同步模式，便于调试

# 重置 CUDA 状态（如果 GPU 卡住了）
try:
    torch.cuda.set_device(0)
    # 清除所有 CUDA 缓存
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # 重置 CUDA 设备（如果可能）
    try:
        torch.cuda.reset_peak_memory_stats()
    except:
        pass
    print("✓ CUDA 缓存已清除，设备已同步")
except Exception as e:
    print(f"⚠️ CUDA 初始化警告: {e}")
    print("建议：如果持续出现 CUDA 超时，请重启 Python 内核")

# 启用梯度检查点以节省显存
if hasattr(model, 'gradient_checkpointing_enable'):
    model.gradient_checkpointing_enable()
    print("✓ 已启用梯度检查点以节省显存")

# 设置训练参数
num_epochs = 3
lr = 2e-5
optimizer = AdamW(model.parameters(), lr=lr)

# 训练循环
model.train()
best_loss = float('inf')

# 错误计数器（如果连续错误太多，停止训练）
consecutive_errors = 0
max_consecutive_errors = 10

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    successful_batches = 0
    
    for batch in tqdm(llava_close_train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # 在开始处理 batch 之前同步 CUDA
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        try:
            # 智能处理设备：获取模型各部分的设备
            try:
                input_embed_layer = model.get_input_embeddings()
                input_embed_device = next(input_embed_layer.parameters()).device
            except:
                input_embed_device = device
            
            try:
                vision_device = next(model.vision_tower.parameters()).device
            except:
                vision_device = device
            
            # 同步后再移动 tensor
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            # 将输入移动到正确的设备（逐个移动，避免一次性移动太多）
            new_batch = {}
            for k, v in batch.items():
                if isinstance(v, torch.Tensor):
                    if k in ['input_ids', 'attention_mask', 'labels']:
                        new_batch[k] = v.to(input_embed_device, non_blocking=False)
                    elif k == 'pixel_values':
                        new_batch[k] = v.to(vision_device, non_blocking=False)
                    else:
                        new_batch[k] = v.to(device, non_blocking=False)
                else:
                    new_batch[k] = v
            batch = new_batch
            
            # 移动完成后同步
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            # 前向传播（添加错误处理和 CUDA 同步）
            # 添加 CUDA 同步点，确保之前的操作完成
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            outputs = model(**batch)
            loss = outputs.loss
            
            # 再次同步
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            # 同步确保优化器步骤完成
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            epoch_loss += loss.item()
            successful_batches += 1
            consecutive_errors = 0  # 重置错误计数器
            
        except (RuntimeError, Exception) as e:
            error_str = str(e)
            consecutive_errors += 1
            
            if "CUDA" in error_str or "timeout" in error_str.lower() or "AcceleratorError" in error_str:
                print(f"\n⚠️ CUDA 错误 ({consecutive_errors}/{max_consecutive_errors}): {error_str[:200]}...")
                print("尝试清除缓存并跳过此 batch...")
                try:
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()
                except:
                    pass
                
                # 如果连续错误太多，停止训练
                if consecutive_errors >= max_consecutive_errors:
                    print(f"\n❌ 连续 {max_consecutive_errors} 次 CUDA 错误，停止训练")
                    print("建议：请重启 Python 内核以清除 CUDA 状态，然后重新运行")
                    break
                
                # 跳过这个 batch，继续训练
                continue
            else:
                # 其他错误，打印并继续
                print(f"\n⚠️ 训练错误 ({consecutive_errors}/{max_consecutive_errors}): {error_str[:200]}...")
                try:
                    torch.cuda.empty_cache()
                except:
                    pass
                
                if consecutive_errors >= max_consecutive_errors:
                    print(f"\n❌ 连续 {max_consecutive_errors} 次错误，停止训练")
                    break
                continue
    
    # 如果因为错误退出，跳出外层循环
    if consecutive_errors >= max_consecutive_errors:
        break
    
    avg_loss = epoch_loss / len(llava_close_train_loader)
    print(f"Epoch {epoch+1} / {num_epochs}, 平均 loss = {avg_loss:.4f}")
    
    # 保存最佳模型
    if avg_loss < best_loss:
        best_loss = avg_loss
        save_path = "best_llava_close.pth"
        torch.save(model.state_dict(), save_path)
        print(f"✓ 保存最佳模型 (loss={avg_loss:.4f}) 到 {save_path}")

print(f"\n✓ LLaVA-Med close 微调完成")

⚠️ CUDA 初始化警告: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

建议：如果持续出现 CUDA 超时，请重启 Python 内核
✓ 已启用梯度检查点以节省显存


Epoch 1/3:   0%|          | 0/954 [00:00<?, ?it/s]

调试信息: vocab_size = 32000, image_token_id = 31999





AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# ==== 步骤7: 实现 LLaVA-Med Close 数据集评估函数 ====

def is_semantically_similar(pred: str, gt: str) -> bool:
    """宽松的语义相似度判断（用于 close yes/no 问题）"""
    pred = pred.lower().strip()
    gt = gt.lower().strip()
    
    # 完全匹配
    if pred == gt:
        return True
    
    # 去除标点符号后比较
    pred_clean = re.sub(r'[^\w\s]', '', pred)
    gt_clean = re.sub(r'[^\w\s]', '', gt)
    if pred_clean == gt_clean:
        return True
    
    # 提取关键词
    pred_words = set(re.findall(r'\b\w+\b', pred))
    gt_words = set(re.findall(r'\b\w+\b', gt))
    
    # 移除停用词
    stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'this', 'that'}
    pred_words = pred_words - stop_words
    gt_words = gt_words - stop_words
    
    # 计算重叠度
    if len(gt_words) > 0:
        overlap_ratio = len(pred_words & gt_words) / len(gt_words)
        if overlap_ratio >= 0.5:
            return True
    
    return False

def eval_collate_fn(batch):
    """评估用的 collate 函数（不包含答案）"""
    images, questions, answers = zip(*batch)
    
        # 获取图像 token ID（使用辅助函数确保有效性）
    image_token_id = get_valid_image_token_id()
    
    # 计算每个图像需要多少个 token
    patch_size = getattr(processor, 'patch_size', 14)
    image_size = 336
    num_image_tokens = (image_size // patch_size) ** 2  # 576
    
    # 构建 prompt（只包含问题，不包含答案）
    prompts = [f"USER: <image>\n{q}\nASSISTANT:" for q in questions]
    
    # 处理图像
    image_inputs = processor.image_processor(
        list(images),
        return_tensors="pt",
        size={"height": 336, "width": 336},
        do_resize=True,
        do_rescale=True,
        do_normalize=True
    )
    
    # 手动处理文本，插入图像 token
    processed_input_ids = []
    processed_attention_masks = []
    
    for prompt in prompts:
        image_pos = prompt.find('<image>')
        if image_pos != -1:
            before_image = prompt[:image_pos]
            after_image = prompt[image_pos + len('<image>'):]
            before_tokens = processor.tokenizer.encode(before_image, add_special_tokens=False)
            after_tokens = processor.tokenizer.encode(after_image, add_special_tokens=False)
            image_tokens = [image_token_id] * num_image_tokens
            full_tokens = before_tokens + image_tokens + after_tokens
        else:
            full_tokens = processor.tokenizer.encode(prompt, add_special_tokens=False)
        processed_input_ids.append(full_tokens)
    
    # Padding（限制最大长度避免 CUDA 超时）
    max_length = 1024  # 减小最大长度避免超时
    max_len = max(len(ids) for ids in processed_input_ids)
    max_len = min(max_len, max_length)
    
    padded_input_ids = []
    padded_attention_masks = []
    
    for ids in processed_input_ids:
        if len(ids) > max_len:
            ids = ids[:max_len]
        attention_mask = [1] * len(ids) + [0] * (max_len - len(ids))
        ids = ids + [processor.tokenizer.pad_token_id] * (max_len - len(ids))
        padded_input_ids.append(ids)
        padded_attention_masks.append(attention_mask)
    
    inputs = {
        "input_ids": torch.tensor(padded_input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(padded_attention_masks, dtype=torch.long),
        "pixel_values": image_inputs["pixel_values"],
    }
    
    return inputs, answers

def evaluate_accuracy(model, data_loader, device):
    """评估准确率"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            inputs, true_answers = batch
            
            # 智能处理设备
            try:
                input_embed_layer = model.get_input_embeddings()
                input_embed_device = next(input_embed_layer.parameters()).device
            except:
                input_embed_device = device
            
            try:
                vision_device = next(model.vision_tower.parameters()).device
            except:
                vision_device = device
            
            # 移动到正确的设备
            inputs = {
                k: v.to(input_embed_device) if isinstance(v, torch.Tensor) and k in ['input_ids', 'attention_mask']
                else (v.to(vision_device) if isinstance(v, torch.Tensor) and k == 'pixel_values'
                else (v.to(device) if isinstance(v, torch.Tensor) else v))
                for k, v in inputs.items()
            }
            
            # 生成答案
            outputs = model.generate(
                **inputs,
                max_new_tokens=10,  # yes/no 答案很短
                do_sample=False,
            )
            
            # 解码预测结果
            preds = processor.batch_decode(outputs, skip_special_tokens=True)
            
            # 提取答案部分（ASSISTANT: 后面的内容）
            preds_clean = []
            for pred in preds:
                if "ASSISTANT:" in pred:
                    pred_answer = pred.split("ASSISTANT:")[-1].strip()
                else:
                    pred_answer = pred.strip()
                preds_clean.append(pred_answer)
            
            # 计算准确率
            for pred_raw, true_ans_raw in zip(preds_clean, true_answers):
                pred = normalize_answer(pred_raw)
                true_ans = normalize_answer(true_ans_raw)
                
                if is_semantically_similar(pred, true_ans):
                    correct += 1
                total += 1
    
    accuracy = correct / total if total > 0 else 0.0
    return accuracy, correct, total

# 创建评估用的 DataLoader
llava_close_test_loader_eval = DataLoader(
    llava_close_test_dataset,
    batch_size=llava_batch_size,
    shuffle=False,
    collate_fn=eval_collate_fn,
    num_workers=0,
)

# 评估 close 数据集
print("评估 close 测试集...")
accuracy, correct, total = evaluate_accuracy(model, llava_close_test_loader_eval, device)
print(f"\nClose 测试集准确率: {accuracy:.4f} ({correct}/{total})")
print("✓ 评估完成")

In [None]:
# ==== 步骤8: 实现 LLaVA-Med Open 数据集训练循环 ====

# 重新加载一个干净的模型用于 open 数据微调（避免受 close 微调的影响）
print("重新加载干净的 LLaVA-Med 模型用于 open 数据微调...")

from transformers import BitsAndBytesConfig

model_open_loaded = False

# 策略1：先尝试不使用量化
if not use_quantization and torch.cuda.is_available():
    try:
        model_open = LlavaForConditionalGeneration.from_pretrained(
            model_local_path,
            dtype=torch.float16,
            device_map="auto",
        )
        print("✓ 模型已加载（无量化）")
        model_open_loaded = True
    except RuntimeError as e:
        if "out of memory" in str(e).lower() or "OOM" in str(e):
            print(f"⚠️ 显存不足，切换到4-bit量化...")
            use_quantization = True

# 策略2：使用4-bit量化
if not model_open_loaded and use_quantization and torch.cuda.is_available():
    try:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )
        model_open = LlavaForConditionalGeneration.from_pretrained(
            model_local_path,
            dtype=torch.float16,
            quantization_config=quantization_config,
            device_map="auto",
        )
        print("✓ 模型已加载（4-bit 量化）")
        model_open_loaded = True
    except Exception as e:
        print(f"⚠️ 量化加载失败: {e}")
        raise e

# CPU模式
if not model_open_loaded:
    model_open = LlavaForConditionalGeneration.from_pretrained(
        model_local_path,
        dtype=torch.float32,
        device_map=None,
    )
    model_open = model_open.to(device)
    print("✓ 模型已加载（CPU模式）")

# 启用梯度检查点
if hasattr(model_open, 'gradient_checkpointing_enable'):
    model_open.gradient_checkpointing_enable()

# 训练参数
optimizer_open = AdamW(model_open.parameters(), lr=lr)

# 训练循环
model_open.train()
best_loss_open = float('inf')

for epoch in range(num_epochs):
    model_open.train()
    epoch_loss = 0.0
    
    for batch in tqdm(llava_open_train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # 智能处理设备
        try:
            input_embed_layer = model_open.get_input_embeddings()
            input_embed_device = next(input_embed_layer.parameters()).device
        except:
            input_embed_device = device
        
        try:
            vision_device = next(model_open.vision_tower.parameters()).device
        except:
            vision_device = device
        
        # 移动到正确的设备
        batch = {
            k: v.to(input_embed_device) if isinstance(v, torch.Tensor) and k in ['input_ids', 'attention_mask', 'labels']
            else (v.to(vision_device) if isinstance(v, torch.Tensor) and k == 'pixel_values'
            else (v.to(device) if isinstance(v, torch.Tensor) else v))
            for k, v in batch.items()
        }
        
        # 前向传播
        outputs = model_open(**batch)
        loss = outputs.loss
        
        # 反向传播
        optimizer_open.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model_open.parameters(), max_norm=1.0)
        optimizer_open.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(llava_open_train_loader)
    print(f"Epoch {epoch+1} / {num_epochs}, 平均 loss = {avg_loss:.4f}")
    
    # 保存最佳模型
    if avg_loss < best_loss_open:
        best_loss_open = avg_loss
        save_path_open = "best_llava_open.pth"
        torch.save(model_open.state_dict(), save_path_open)
        print(f"✓ 保存最佳模型 (loss={avg_loss:.4f}) 到 {save_path_open}")

print(f"\n✓ LLaVA-Med open 微调完成")

In [None]:
# ==== 步骤9: 实现 LLaVA-Med Open 数据集评估函数 ====

def is_semantically_similar_open(pred: str, gt: str) -> bool:
    """宽松的语义相似度判断（用于 open 问题）"""
    pred = pred.lower().strip()
    gt = gt.lower().strip()
    
    # 完全匹配
    if pred == gt:
        return True
    
    # 去除标点符号后比较
    pred_clean = re.sub(r'[^\w\s]', '', pred)
    gt_clean = re.sub(r'[^\w\s]', '', gt)
    if pred_clean == gt_clean:
        return True
    
    # 一方包含另一方
    if gt in pred or pred in gt:
        return True
    
    # 提取关键词
    pred_words = set(re.findall(r'\b\w+\b', pred))
    gt_words = set(re.findall(r'\b\w+\b', gt))
    
    # 移除停用词
    stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'this', 'that', 'of', 'in', 'on', 'at', 'to', 'for'}
    pred_words = pred_words - stop_words
    gt_words = gt_words - stop_words
    
    # 计算重叠度
    if len(gt_words) > 0:
        overlap_ratio = len(pred_words & gt_words) / len(gt_words)
        if overlap_ratio >= 0.5:
            return True
    
    # 处理同义词
    synonyms = {
        'x-ray': ['xray', 'x ray', 'chest xray', 'chest x-ray'],
        'xray': ['x-ray', 'x ray'],
        'ct': ['ct scan', 'computed tomography'],
        'mri': ['magnetic resonance imaging'],
        'left': ['lt', 'left side'],
        'right': ['rt', 'right side'],
    }
    
    for key, variants in synonyms.items():
        if key in gt and any(v in pred for v in variants):
            return True
        if key in pred and any(v in gt for v in variants):
            return True
    
    return False

def eval_open_collate_fn(batch):
    """Open 数据集评估用的 collate 函数"""
    images, questions, answers = zip(*batch)
    
        # 获取图像 token ID（使用辅助函数确保有效性）
    image_token_id = get_valid_image_token_id()
    
    # 计算每个图像需要多少个 token
    patch_size = getattr(processor, 'patch_size', 14)
    image_size = 336
    num_image_tokens = (image_size // patch_size) ** 2  # 576
    
    prompts = [f"USER: <image>\n{q}\nASSISTANT:" for q in questions]
    
    # 处理图像
    image_inputs = processor.image_processor(
        list(images),
        return_tensors="pt",
        size={"height": 336, "width": 336},
        do_resize=True,
        do_rescale=True,
        do_normalize=True
    )
    
    # 手动处理文本，插入图像 token
    processed_input_ids = []
    processed_attention_masks = []
    
    for prompt in prompts:
        image_pos = prompt.find('<image>')
        if image_pos != -1:
            before_image = prompt[:image_pos]
            after_image = prompt[image_pos + len('<image>'):]
            before_tokens = processor.tokenizer.encode(before_image, add_special_tokens=False)
            after_tokens = processor.tokenizer.encode(after_image, add_special_tokens=False)
            image_tokens = [image_token_id] * num_image_tokens
            full_tokens = before_tokens + image_tokens + after_tokens
        else:
            full_tokens = processor.tokenizer.encode(prompt, add_special_tokens=False)
        processed_input_ids.append(full_tokens)
    
    # Padding（限制最大长度避免 CUDA 超时）
    max_length = 1024  # 减小最大长度避免超时
    max_len = max(len(ids) for ids in processed_input_ids)
    max_len = min(max_len, max_length)
    
    padded_input_ids = []
    padded_attention_masks = []
    
    for ids in processed_input_ids:
        if len(ids) > max_len:
            ids = ids[:max_len]
        attention_mask = [1] * len(ids) + [0] * (max_len - len(ids))
        ids = ids + [processor.tokenizer.pad_token_id] * (max_len - len(ids))
        padded_input_ids.append(ids)
        padded_attention_masks.append(attention_mask)
    
    inputs = {
        "input_ids": torch.tensor(padded_input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(padded_attention_masks, dtype=torch.long),
        "pixel_values": image_inputs["pixel_values"],
    }
    
    return inputs, answers

def evaluate_accuracy_open(model, data_loader, device):
    """评估 open 数据集的准确率"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            inputs, true_answers = batch
            
            # 智能处理设备
            try:
                input_embed_layer = model.get_input_embeddings()
                input_embed_device = next(input_embed_layer.parameters()).device
            except:
                input_embed_device = device
            
            try:
                vision_device = next(model.vision_tower.parameters()).device
            except:
                vision_device = device
            
            # 移动到正确的设备
            inputs = {
                k: v.to(input_embed_device) if isinstance(v, torch.Tensor) and k in ['input_ids', 'attention_mask']
                else (v.to(vision_device) if isinstance(v, torch.Tensor) and k == 'pixel_values'
                else (v.to(device) if isinstance(v, torch.Tensor) else v))
                for k, v in inputs.items()
            }
            
            # 生成答案（open 问题答案可能较长）
            outputs = model.generate(
                **inputs,
                max_new_tokens=30,
                do_sample=False,
            )
            
            # 解码
            preds = processor.batch_decode(outputs, skip_special_tokens=True)
            
            # 提取答案部分
            preds_clean = []
            for pred in preds:
                if "ASSISTANT:" in pred:
                    pred_answer = pred.split("ASSISTANT:")[-1].strip()
                else:
                    pred_answer = pred.strip()
                preds_clean.append(pred_answer)
            
            # 计算准确率
            for pred_raw, true_ans_raw in zip(preds_clean, true_answers):
                pred = normalize_answer(pred_raw)
                true_ans = normalize_answer(true_ans_raw)
                
                if is_semantically_similar_open(pred, true_ans):
                    correct += 1
                total += 1
    
    accuracy = correct / total if total > 0 else 0.0
    return accuracy, correct, total

# 创建 open 测试集的 DataLoader（使用评估用的 collate 函数）
llava_open_test_loader_eval = DataLoader(
    llava_open_test_dataset,
    batch_size=llava_batch_size,
    shuffle=False,
    collate_fn=eval_open_collate_fn,
    num_workers=0,
)

# 评估 open 数据集
print("评估 open 测试集...")
accuracy_open, correct_open, total_open = evaluate_accuracy_open(model_open, llava_open_test_loader_eval, device)
print(f"\nOpen 测试集准确率: {accuracy_open:.4f} ({correct_open}/{total_open})")
print("✓ 评估完成")