In [None]:
import os
import json
from typing import Tuple

from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms


DATA_JSON_PATH = "data/archive (1)/VQA_RAD Dataset Public.json"
DATA_IMAGE_DIR = "data/archive (1)/VQA_RAD Image Folder"


def normalize_answer(ans: str) -> str:
    """简单归一化答案，用于区分close(是/否)和open类型。

    - 先 lower + 去空格
    - 去掉句末标点
    - 规范 "yes" / "no" 写法
    - 对部分 open 答案做同义词合并（如 right side -> right）
    """
    if ans is None:
        return ""
    ans = str(ans).lower().strip()
    ans_clean = ans.rstrip('.,!?;:').strip()

    # 先规范 yes/no
    if ans_clean == "yes":
        return "yes"
    if ans_clean == "no":
        return "no"

    if ans_clean.startswith("yes"):
        if len(ans_clean) == 3 or ans_clean[3] in [" ", ",", ".", "!", "?", ";", ":"]:
            return "yes"
    if ans_clean.startswith("no"):
        if len(ans_clean) == 2 or ans_clean[2] in [" ", ",", ".", "!", "?", ";", ":"]:
            return "no"

    # 针对 open 问题的一些简单同义合并
    synonym_map = {
        "right side": "right",
        "left side": "left",
        "rt": "right",
        "lt": "left",
        "xray": "x-ray",
        "x ray": "x-ray",
        "ct scan": "ct",
    }
    if ans_clean in synonym_map:
        return synonym_map[ans_clean]
    
    # 处理包含 "xray" 或 "x-ray" 的答案（如 "chest xray" -> "x-ray"）
    # 将 "xray" 替换为 "x-ray" 以便统一匹配
    ans_clean = ans_clean.replace("xray", "x-ray").replace("x ray", "x-ray")
    # 如果答案以 "x-ray" 结尾或包含 "x-ray"，且问题是问图像类型，可以归一化为 "x-ray"
    # 但这里我们只做简单的替换，保留原始信息
    
    return ans_clean


class VQARADBaselineDataset(Dataset):
    """基础 VQA-RAD 数据集类，只负责读图像/文本/答案。

    之后我们会在此之上构建不同模型（CNN baseline, Transformer 等）。
    """

    def __init__(self, json_path: str = DATA_JSON_PATH, image_dir: str = DATA_IMAGE_DIR,
                 transform=None):
        assert os.path.exists(json_path), f"JSON 文件不存在: {json_path}"
        assert os.path.exists(image_dir), f"图像目录不存在: {image_dir}"

        with open(json_path, "r", encoding="utf-8") as f:
            self.data = json.load(f)

        self.image_dir = image_dir
        # 图像增强 + 标准化：有助于减轻过拟合
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self) -> int:
        return len(self.data)

    def _get_image_path(self, item) -> str:
        # 兼容 "image" / "image_name" 两种字段
        image_key = "image" if "image" in item else "image_name"
        return os.path.join(self.image_dir, item[image_key])

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, str, str]:
        """返回: (image, question, answer, q_type)

        - q_type: "close" 表示是/否题; "open" 表示开放式问题
        """
        item = self.data[idx]
        img_path = self._get_image_path(item)

        image = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)

        question = item["question"]
        answer_raw = item.get("answer", "")
        ans_norm = normalize_answer(answer_raw)
        q_type = "close" if ans_norm in ["yes", "no"] else "open"

        return image, question, str(answer_raw), q_type

In [2]:
from torch.utils.data import Subset, DataLoader

# 1. 加载完整数据集
full_dataset = VQARADBaselineDataset(
    json_path=DATA_JSON_PATH,
    image_dir=DATA_IMAGE_DIR,
)

print(f"数据集总样本数: {len(full_dataset)}")

# 2. 根据答案把样本划分为 close (yes/no) 和 open
close_indices = []
open_indices = []

for idx in range(len(full_dataset)):
    _, _, ans_raw, q_type = full_dataset[idx]
    if q_type == "close":
        close_indices.append(idx)
    else:
        open_indices.append(idx)

print(f"Close (是/否) 样本数: {len(close_indices)}")
print(f"Open  (开放式) 样本数: {len(open_indices)}")

# 3. 构造两个子数据集，后面可以分别训练/评估
close_dataset = Subset(full_dataset, close_indices)
open_dataset = Subset(full_dataset, open_indices)

# 4. 先准备好 DataLoader，后面直接接模型即可
batch_size = 32

close_loader = DataLoader(close_dataset, batch_size=batch_size, shuffle=True)
open_loader = DataLoader(open_dataset, batch_size=batch_size, shuffle=True)

print("\nDataLoader 已创建:")
print(f"  close_loader: batch_size={batch_size}, steps/epoch≈{len(close_loader)}")
print(f"  open_loader : batch_size={batch_size}, steps/epoch≈{len(open_loader)}")

# 5. 基于所有 question 构建简单词表，用于文本 Transformer
from collections import Counter


def tokenize(text: str):
    text = str(text).lower().strip()
    # 简单按空格切分即可；医学术语也能基本覆盖
    return text.replace("?", " ").replace(",", " ").split()


counter = Counter()
for item in full_dataset.data:
    q = item.get("question", "")
    tokens = tokenize(q)
    counter.update(tokens)

# 特殊符号
word2idx = {"<pad>": 0, "<unk>": 1}
for w, c in counter.items():
    # 过滤特别少见的词可以稍微减小词表，这里阈值设为 1 就是全收
    if w not in word2idx and c >= 1:
        word2idx[w] = len(word2idx)

idx2word = {i: w for w, i in word2idx.items()}
MAX_Q_LEN = 20  # 问题一般比较短，20 足够覆盖大部分

print(f"词表大小: {len(word2idx)}, MAX_Q_LEN = {MAX_Q_LEN}")

数据集总样本数: 2248
Close (是/否) 样本数: 1193
Open  (开放式) 样本数: 1055

DataLoader 已创建:
  close_loader: batch_size=32, steps/epoch≈38
  open_loader : batch_size=32, steps/epoch≈33
词表大小: 1227, MAX_Q_LEN = 20


In [None]:
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import random
# 设置随机种子以确保结果可复现
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)

# 1. 加载完整数据集
full_dataset = VQARADBaselineDataset(
    json_path=DATA_JSON_PATH,
    image_dir=DATA_IMAGE_DIR,
)

print(f"数据集总样本数: {len(full_dataset)}")

# 2. 根据答案把样本划分为 close (yes/no) 和 open
close_indices = []
open_indices = []

for idx in range(len(full_dataset)):
    _, _, ans_raw, q_type = full_dataset[idx]
    if q_type == "close":
        close_indices.append(idx)
    else:
        open_indices.append(idx)

print(f"Close (是/否) 样本数: {len(close_indices)}")
print(f"Open  (开放式) 样本数: {len(open_indices)}")

# 3. 对close和open数据集分别进行8:2的划分（训练集:测试集）
# 3.1 划分close数据集
close_train_idx, close_test_idx = train_test_split(
    close_indices, 
    test_size=0.2,  # 20% 作为测试集
    random_state=42,
    shuffle=True
)

# 3.2 划分open数据集
open_train_idx, open_test_idx = train_test_split(
    open_indices,
    test_size=0.2,  # 20% 作为测试集
    random_state=42,
    shuffle=True
)

print("\nClose (yes/no) 数据集划分:")
print(f"  训练集: {len(close_train_idx)} 个样本 ({len(close_train_idx)/len(close_indices)*100:.1f}%)")
print(f"  测试集: {len(close_test_idx)} 个样本 ({len(close_test_idx)/len(close_indices)*100:.1f}%)")

print("\nOpen 数据集划分:")
print(f"  训练集: {len(open_train_idx)} 个样本 ({len(open_train_idx)/len(open_indices)*100:.1f}%)")
print(f"  测试集: {len(open_test_idx)} 个样本 ({len(open_test_idx)/len(open_indices)*100:.1f}%)")

# 4. 创建数据集子集
close_train_dataset = Subset(full_dataset, close_train_idx)
close_test_dataset = Subset(full_dataset, close_test_idx)

open_train_dataset = Subset(full_dataset, open_train_idx)
open_test_dataset = Subset(full_dataset, open_test_idx)

# 5. 创建DataLoader
batch_size = 32

close_train_loader = DataLoader(close_train_dataset, batch_size=batch_size, shuffle=True)
close_test_loader = DataLoader(close_test_dataset, batch_size=batch_size, shuffle=False)

open_train_loader = DataLoader(open_train_dataset, batch_size=batch_size, shuffle=True)
open_test_loader = DataLoader(open_test_dataset, batch_size=batch_size, shuffle=False)

print("\nDataLoader 已创建:")
print(f"  close_train_loader: {len(close_train_loader)} batches, {len(close_train_dataset)} samples")
print(f"  close_test_loader: {len(close_test_loader)} batches, {len(close_test_dataset)} samples")
print(f"  open_train_loader: {len(open_train_loader)} batches, {len(open_train_dataset)} samples")
print(f"  open_test_loader: {len(open_test_loader)} batches, {len(open_test_dataset)} samples")

# 6. 基于所有 question 构建简单词表，用于文本 Transformer
from collections import Counter


def tokenize(text: str):
    text = str(text).lower().strip()
    # 简单按空格切分即可；医学术语也能基本覆盖
    return text.replace("?", " ").replace(",", " ").split()


counter = Counter()
for item in full_dataset.data:
    q = item.get("question", "")
    tokens = tokenize(q)
    counter.update(tokens)

# 特殊符号
word2idx = {"<pad>": 0, "<unk>": 1}
for w, c in counter.items():
    # 过滤特别少见的词可以稍微减小词表，这里阈值设为 1 就是全收
    if w not in word2idx and c >= 1:
        word2idx[w] = len(word2idx)

idx2word = {i: w for w, i in word2idx.items()}
MAX_Q_LEN = 20  # 问题一般比较短，20 足够覆盖大部分

print(f"\n词表大小: {len(word2idx)}, MAX_Q_LEN = {MAX_Q_LEN}")

数据集总样本数: 2248
Close (是/否) 样本数: 1193
Open  (开放式) 样本数: 1055

Close (yes/no) 数据集划分:
  训练集: 954 个样本 (80.0%)
  测试集: 239 个样本 (20.0%)

Open 数据集划分:
  训练集: 844 个样本 (80.0%)
  测试集: 211 个样本 (20.0%)

DataLoader 已创建:
  close_train_loader: 30 batches, 954 samples
  close_test_loader: 8 batches, 239 samples
  open_train_loader: 27 batches, 844 samples
  open_test_loader: 7 batches, 211 samples

词表大小: 1227, MAX_Q_LEN = 20


In [None]:
# ==== 使用 BLIP 在 close (yes/no) 子集上进行微调 ====

%pip install -q transformers accelerate

import torch
from transformers import BlipProcessor, BlipForQuestionAnswering

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("使用设备:", device)

# 加载预训练 BLIP VQA 模型与 Processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)

print("BLIP 模型与 Processor 已加载。")


Note: you may need to restart the kernel to use updated packages.


  from .autonotebook import tqdm as notebook_tqdm
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


使用设备: cuda
BLIP 模型与 Processor 已加载。


In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from typing import List, Tuple

# 构建一个专门给 BLIP 用的 close(yes/no) 数据集
class VQARADBLIPCloseDataset(Dataset):
    """基于 full_dataset 和 close_indices，返回 (PIL image, question, answer_norm)。"""

    def __init__(self, base_dataset: VQARADBaselineDataset, close_indices: List[int]):
        self.base_dataset = base_dataset
        self.close_indices = close_indices
        self.image_dir = base_dataset.image_dir
        self.raw_items = base_dataset.data

    def __len__(self) -> int:
        return len(self.close_indices)

    def __getitem__(self, idx: int) -> Tuple[Image.Image, str, str]:
        real_idx = self.close_indices[idx]
        item = self.raw_items[real_idx]

        # 读取原始图像（不做 torchvision 的增强，让 BLIP 自己处理）
        image_path = self.base_dataset._get_image_path(item)
        image = Image.open(image_path).convert("RGB")

        question = item.get("question", "")
        ans_raw = item.get("answer", "")
        ans_norm = normalize_answer(ans_raw)

        # 理论上 close_indices 里已经是 yes/no，这里再保险判断一下
        if ans_norm not in ["yes", "no"]:
            # 如果极个别出现其他值，则强制转成 "no"（避免 tokenizer 出现太多无关 token）
            ans_norm = "no"

        return image, question, ans_norm


def blip_collate_fn(batch):
    """将 (image, question, answer) 列表打包成 BLIP 可直接输入的 batch。"""
    images, questions, answers = zip(*batch)

    # 使用 Processor 进行图像和文本的联合编码
    inputs = processor(
        images=list(images),
        text=list(questions),
        padding=True,
        truncation=True,
        return_tensors="pt",
    )

    # 将答案也转成文本 token，作为 labels
    label_tokens = processor.tokenizer(
        list(answers),
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    inputs["labels"] = label_tokens.input_ids

    return inputs


# 基于划分后的 close_train_idx 和 close_test_idx 构建 BLIP 训练用 DataLoader
blip_close_train_dataset = VQARADBLIPCloseDataset(full_dataset, close_train_idx)
blip_close_test_dataset = VQARADBLIPCloseDataset(full_dataset, close_test_idx)

blip_batch_size = 8  # BLIP 模型较大，batch 不宜太大
blip_close_train_loader = DataLoader(
    blip_close_train_dataset,
    batch_size=blip_batch_size,
    shuffle=True,
    collate_fn=blip_collate_fn,
)
blip_close_test_loader = DataLoader(
    blip_close_test_dataset,
    batch_size=blip_batch_size,
    shuffle=False,
    collate_fn=blip_collate_fn,
)

print(f"BLIP close 训练集大小: {len(blip_close_train_dataset)}，steps/epoch≈{len(blip_close_train_loader)}")
print(f"BLIP close 测试集大小: {len(blip_close_test_dataset)}")


# ==== 简单的 BLIP 微调训练循环（针对 yes/no） ====
from torch.optim import AdamW
from tqdm.auto import tqdm

num_epochs = 4
lr = 5e-5
optimizer = AdamW(model.parameters(), lr=lr)

model.train()

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for batch in tqdm(blip_close_train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # 将 batch 移动到 device 上
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / max(1, len(blip_close_train_loader))
    print(f"Epoch {epoch+1} / {num_epochs}, 平均 loss = {avg_loss:.4f}")

# 训练完成后保存权重
save_path = "best_blip_close_yesno.pth"
torch.save(model.state_dict(), save_path)
print(f"BLIP close(yes/no) 微调完成，模型已保存到: {save_path}")


BLIP close 训练集大小: 954，steps/epoch≈120
BLIP close 测试集大小: 239


Epoch 1/4: 100%|██████████| 120/120 [00:40<00:00,  2.97it/s]


Epoch 1 / 4, 平均 loss = 0.3572


Epoch 2/4: 100%|██████████| 120/120 [00:40<00:00,  2.96it/s]


Epoch 2 / 4, 平均 loss = 0.3115


Epoch 3/4: 100%|██████████| 120/120 [00:40<00:00,  2.96it/s]


Epoch 3 / 4, 平均 loss = 0.2456


Epoch 4/4: 100%|██████████| 120/120 [00:40<00:00,  2.99it/s]


Epoch 4 / 4, 平均 loss = 0.1892
BLIP close(yes/no) 微调完成，模型已保存到: best_blip_close_yesno.pth


In [None]:
# ==== 在 close(yes/no) 测试集上评估 BLIP 准确率 ====

from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import numpy as np
import re

# 自定义 collate_fn，保持 PIL.Image / 文本原样，避免默认拼成 tensor
def eval_collate_fn(batch):
    images, questions, answers = zip(*batch)  # tuple of length batch_size
    return list(images), list(questions), list(answers)

# 在测试集上评估
eval_test_loader = DataLoader(
    blip_close_test_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=eval_collate_fn,
)

def is_semantically_similar(pred: str, gt: str) -> bool:
    """
    宽松的语义相似度判断函数
    只要预测答案和真实答案在语义上相似，就返回True
    """
    pred = pred.lower().strip()
    gt = gt.lower().strip()
    
    # 1. 完全匹配
    if pred == gt:
        return True
    
    # 2. 去除标点符号后比较
    pred_clean = re.sub(r'[^\w\s]', '', pred)
    gt_clean = re.sub(r'[^\w\s]', '', gt)
    if pred_clean == gt_clean:
        return True
    
    # 3. 一方包含另一方（子串匹配）
    if gt in pred or pred in gt:
        return True
    
    # 4. 提取关键词进行比较（单词级别）
    pred_words = set(re.findall(r'\b\w+\b', pred))
    gt_words = set(re.findall(r'\b\w+\b', gt))
    
    # 移除停用词
    stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'this', 'that', 'these', 'those'}
    pred_words = pred_words - stop_words
    gt_words = gt_words - stop_words
    
    # 计算单词重叠度
    if len(gt_words) > 0:
        overlap_ratio = len(pred_words & gt_words) / len(gt_words)
        # 如果重叠度超过50%，认为语义相似
        if overlap_ratio >= 0.5:
            return True
    
    # 5. 处理常见的同义词和变体
    synonyms = {
        'x-ray': ['xray', 'x ray', 'chest xray', 'chest x-ray', 'xray image'],
        'xray': ['x-ray', 'x ray', 'chest xray', 'chest x-ray'],
        'ct': ['ct scan', 'computed tomography'],
        'mri': ['magnetic resonance imaging'],
        'intestines': ['small intestines', 'bowel', 'small bowel'],
        'left': ['lt', 'left side'],
        'right': ['rt', 'right side'],
    }
    
    # 检查同义词
    for key, variants in synonyms.items():
        if key in gt and any(v in pred for v in variants):
            return True
        if key in pred and any(v in gt for v in variants):
            return True
    
    # 6. 处理数字和单位（如 "5cm" vs "5 cm"）
    pred_nums = set(re.findall(r'\d+', pred))
    gt_nums = set(re.findall(r'\d+', gt))
    if pred_nums == gt_nums and len(pred_nums) > 0:
        # 如果数字相同，且其他单词有重叠，认为相似
        pred_non_num_words = set(re.findall(r'\b[a-z]+\b', pred))
        gt_non_num_words = set(re.findall(r'\b[a-z]+\b', gt))
        if len(pred_non_num_words & gt_non_num_words) > 0:
            return True
    
    return False

def evaluate_accuracy(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    # 存储所有结果用于打印
    all_results = []
    
    with torch.no_grad():
        for images, questions, answers in tqdm(data_loader, desc="Evaluating"):
            # 准备模型输入
            inputs = processor(
                images=images,
                text=questions,
                padding=True,
                truncation=True,
                return_tensors="pt",
            ).to(device)
            
            # 生成预测
            outputs = model.generate(
                input_ids=inputs["input_ids"],
                pixel_values=inputs["pixel_values"],
                attention_mask=inputs["attention_mask"],
                max_length=10,
            )
            
            # 解码预测结果
            preds = processor.batch_decode(outputs, skip_special_tokens=True)
            
            # 计算准确率并保存结果（使用宽松的语义匹配）
            for pred_raw, true_ans_raw, question in zip(preds, answers, questions):
                pred = normalize_answer(pred_raw)
                true_ans = normalize_answer(true_ans_raw)
                
                # 使用宽松的语义匹配
                is_correct = is_semantically_similar(pred, true_ans)
                
                if is_correct:
                    correct += 1
                total += 1
                
                # 保存结果
                all_results.append({
                    "question": question,
                    "pred_raw": pred_raw,
                    "pred_norm": pred,
                    "true_ans_raw": true_ans_raw,
                    "true_ans_norm": true_ans,
                    "is_correct": is_correct
                })
    
    accuracy = correct / max(1, total) * 100
    
    # 打印前10个和后10个样本
    print("\n" + "="*80)
    print("前 10 个样本的预测结果对比:")
    print("="*80)
    for i, result in enumerate(all_results[:10], 1):
        status = "✓ 正确" if result["is_correct"] else "✗ 错误"
        print(f"\n[样本 {i}] {status}")
        print(f"  问题: {result['question']}")
        print(f"  真实答案: {result['true_ans_raw']} (归一化: {result['true_ans_norm']})")
        print(f"  预测答案: {result['pred_raw']} (归一化: {result['pred_norm']})")
    
    print("\n" + "="*80)
    print("后 10 个样本的预测结果对比:")
    print("="*80)
    for i, result in enumerate(all_results[-10:], len(all_results)-9):
        status = "✓ Right" if result["is_correct"] else "✗ Fault"
        print(f"\n[Sample {i}] {status}")
        print(f"  Question: {result['question']}")
        print(f"  Real answer: {result['true_ans_raw']} (Normalization: {result['true_ans_norm']})")
        print(f"  Predicted answer: {result['pred_raw']} (Normalization: {result['pred_norm']})")
    print("="*80)
    
    return accuracy

# 加载训练好的模型
model.load_state_dict(torch.load("best_blip_close_yesno.pth", map_location=device))
model.to(device)

print("\nEvaluate on the test...")
test_accuracy = evaluate_accuracy(model, eval_test_loader, device)
print(f"Test accuracy: {test_accuracy:.2f}%")


Evaluate on the test...


Evaluating: 100%|██████████| 30/30 [00:04<00:00,  6.20it/s]


前 10 个样本的预测结果对比:

[样本 1] ✗ 错误
  问题: Are the patients' ribs symmetric on both sides?
  真实答案: no (归一化: no)
  预测答案: yes (归一化: yes)

[样本 2] ✗ 错误
  问题: Are there cilia present at the level of alveoli?
  真实答案: no (归一化: no)
  预测答案: yes (归一化: yes)

[样本 3] ✗ 错误
  问题: Is this coronal plane?
  真实答案: yes (归一化: yes)
  预测答案: no (归一化: no)

[样本 4] ✗ 错误
  问题: Is the patient lying down?
  真实答案: yes (归一化: yes)
  预测答案: no (归一化: no)

[样本 5] ✓ 正确
  问题: Do you see a cavitary lesion in this chest xray?
  真实答案: yes (归一化: yes)
  预测答案: yes (归一化: yes)

[样本 6] ✓ 正确
  问题: Is there free air under the diaphragm?
  真实答案: no (归一化: no)
  预测答案: no (归一化: no)

[样本 7] ✓ 正确
  问题: is there tracheal deviation?
  真实答案: no (归一化: no)
  预测答案: no (归一化: no)

[样本 8] ✓ 正确
  问题: Is this in the lumbar vertebral level?
  真实答案: yes (归一化: yes)
  预测答案: yes (归一化: yes)

[样本 9] ✓ 正确
  问题: Does this patient have a pneumothorax?
  真实答案: no (归一化: no)
  预测答案: no (归一化: no)

[样本 10] ✓ 正确
  问题: Was this patient given IV contrast?
  真实答案: yes (归一化: y




In [None]:
# ==== 构建 BLIP 用的 open 数据集和 DataLoader（80% 训练 / 20% 测试） ====

from torch.utils.data import Dataset, DataLoader
from PIL import Image
from typing import List, Tuple

# 构建一个专门给 BLIP 用的 open 数据集
class VQARADBLIPOpenDataset(Dataset):
    """基于 full_dataset 和 open_idx 列表，返回 (PIL image, question, answer_norm)。"""

    def __init__(self, base_dataset: VQARADBaselineDataset, indices: List[int]):
        self.base_dataset = base_dataset
        self.indices = indices
        self.image_dir = base_dataset.image_dir
        self.raw_items = base_dataset.data

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, idx: int) -> Tuple[Image.Image, str, str]:
        real_idx = self.indices[idx]
        item = self.raw_items[real_idx]

        # 读取原始图像（不做 torchvision 的增强，让 BLIP 自己处理）
        image_path = self.base_dataset._get_image_path(item)
        image = Image.open(image_path).convert("RGB")

        question = item.get("question", "")
        ans_raw = item.get("answer", "")
        ans_norm = normalize_answer(ans_raw)

        return image, question, ans_norm


def blip_open_collate_fn(batch):
    """将 (image, question, answer) 列表打包成 BLIP 可直接输入的 batch。"""
    images, questions, answers = zip(*batch)

    # 使用 Processor 进行图像和文本的联合编码
    inputs = processor(
        images=list(images),
        text=list(questions),
        padding=True,
        truncation=True,
        return_tensors="pt",
    )

    # 将答案也转成文本 token，作为 labels
    label_tokens = processor.tokenizer(
        list(answers),
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    inputs["labels"] = label_tokens.input_ids

    return inputs


# 基于 Cell 2 中的 open_train_idx / open_test_idx 构建 BLIP 训练 / 测试集
blip_open_train_dataset = VQARADBLIPOpenDataset(full_dataset, open_train_idx)
blip_open_test_dataset = VQARADBLIPOpenDataset(full_dataset, open_test_idx)

blip_open_batch_size = 4  # BLIP 模型较大，batch 不宜太大
blip_open_train_loader = DataLoader(
    blip_open_train_dataset,
    batch_size=blip_open_batch_size,
    shuffle=True,
    collate_fn=blip_open_collate_fn,
)
blip_open_test_loader = DataLoader(
    blip_open_test_dataset,
    batch_size=blip_open_batch_size,
    shuffle=False,
    collate_fn=blip_open_collate_fn,
)

print(f"BLIP open 训练集大小: {len(blip_open_train_dataset)}，steps/epoch≈{len(blip_open_train_loader)}")
print(f"BLIP open 测试集大小: {len(blip_open_test_dataset)}")


BLIP open 训练集大小: 844，steps/epoch≈211
BLIP open 测试集大小: 211


In [None]:
# ==== BLIP 微调训练循环（针对 open 数据） ====

# 重新加载一个干净的 BLIP 模型（避免受 close 微调的影响）
# 如果你想在 close 微调的基础上继续微调 open，可以注释掉下面一行，直接使用已有的 model
model_open = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)

from torch.optim import AdamW
from tqdm.auto import tqdm

num_epochs = 5
lr = 5e-5
optimizer_open = AdamW(model_open.parameters(), lr=lr)

for epoch in range(num_epochs):
    model_open.train()
    epoch_loss = 0.0
    for batch in tqdm(blip_open_train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # 将 batch 移动到 device 上
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model_open(**batch)
        loss = outputs.loss

        optimizer_open.zero_grad()
        loss.backward()
        optimizer_open.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / max(1, len(blip_open_train_loader))
    print(f"Epoch {epoch+1} / {num_epochs}, 平均 loss = {avg_loss:.4f}")

# 训练完成后保存权重
save_path_open = "best_blip_open.pth"
torch.save(model_open.state_dict(), save_path_open)
print(f"BLIP open 微调完成，模型已保存到: {save_path_open}")


Epoch 1/5:   0%|          | 0/211 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
Epoch 1/5: 100%|██████████| 211/211 [00:44<00:00,  4.70it/s]


Epoch 1 / 5, 平均 loss = 4.2191


Epoch 2/5: 100%|██████████| 211/211 [00:45<00:00,  4.66it/s]


Epoch 2 / 5, 平均 loss = 1.5112


Epoch 3/5: 100%|██████████| 211/211 [00:44<00:00,  4.73it/s]


Epoch 3 / 5, 平均 loss = 0.9712


Epoch 4/5: 100%|██████████| 211/211 [00:44<00:00,  4.73it/s]


Epoch 4 / 5, 平均 loss = 0.7500


Epoch 5/5: 100%|██████████| 211/211 [00:45<00:00,  4.67it/s]


Epoch 5 / 5, 平均 loss = 0.5926
BLIP open 微调完成，模型已保存到: best_blip_open.pth


In [None]:
# ==== 在 open 子集上评估 BLIP 准确率（使用宽松的语义匹配） ====

from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import re

def eval_open_collate_fn(batch):
    images, questions, answers = zip(*batch)
    return list(images), list(questions), list(answers)

eval_open_test_loader = DataLoader(
    blip_open_test_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=eval_open_collate_fn,
)

def is_semantically_similar_open(pred: str, gt: str) -> bool:
    """
    宽松的语义相似度判断函数（用于open数据）
    只要预测答案和真实答案在语义上相似，就返回True
    """
    pred = pred.lower().strip()
    gt = gt.lower().strip()
    
    # 1. 完全匹配
    if pred == gt:
        return True
    
    # 2. 去除标点符号后比较
    pred_clean = re.sub(r'[^\w\s]', '', pred)
    gt_clean = re.sub(r'[^\w\s]', '', gt)
    if pred_clean == gt_clean:
        return True
    
    # 3. 一方包含另一方（子串匹配）
    if gt in pred or pred in gt:
        return True
    
    # 4. 提取关键词进行比较（单词级别）
    pred_words = set(re.findall(r'\b\w+\b', pred))
    gt_words = set(re.findall(r'\b\w+\b', gt))
    
    # 移除停用词
    stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'this', 'that', 'these', 'those', 'of', 'in', 'on', 'at', 'to', 'for'}
    pred_words = pred_words - stop_words
    gt_words = gt_words - stop_words
    
    # 计算单词重叠度
    if len(gt_words) > 0:
        overlap_ratio = len(pred_words & gt_words) / len(gt_words)
        # 如果重叠度超过50%，认为语义相似
        if overlap_ratio >= 0.5:
            return True
    
    # 5. 处理常见的同义词和变体
    synonyms = {
        'x-ray': ['xray', 'x ray', 'chest xray', 'chest x-ray', 'xray image', 'radiograph'],
        'xray': ['x-ray', 'x ray', 'chest xray', 'chest x-ray'],
        'ct': ['ct scan', 'computed tomography'],
        'mri': ['magnetic resonance imaging'],
        'intestines': ['small intestines', 'bowel', 'small bowel', 'intestine'],
        'left': ['lt', 'left side'],
        'right': ['rt', 'right side'],
        'lobe': ['lobes'],
        'parietal': ['parietal lobe'],
        'frontal': ['frontal lobe'],
        'temporal': ['temporal lobe'],
        'occipital': ['occipital lobe'],
    }
    
    # 检查同义词
    for key, variants in synonyms.items():
        if key in gt and any(v in pred for v in variants):
            return True
        if key in pred and any(v in gt for v in variants):
            return True
    
    # 6. 处理数字和单位（如 "5cm" vs "5 cm"）
    pred_nums = set(re.findall(r'\d+', pred))
    gt_nums = set(re.findall(r'\d+', gt))
    if pred_nums == gt_nums and len(pred_nums) > 0:
        # 如果数字相同，且其他单词有重叠，认为相似
        pred_non_num_words = set(re.findall(r'\b[a-z]+\b', pred))
        gt_non_num_words = set(re.findall(r'\b[a-z]+\b', gt))
        if len(pred_non_num_words & gt_non_num_words) > 0:
            return True
    
    # 7. 处理位置描述（如 "left parietal lobe" vs "parietal lobe"）
    # 如果一方是另一方的子集且包含主要关键词，认为相似
    if len(gt_words) > 0 and len(pred_words) > 0:
        # 检查主要关键词是否都在
        important_words = gt_words - {'left', 'right', 'upper', 'lower', 'anterior', 'posterior', 'side'}
        if len(important_words) > 0 and important_words.issubset(pred_words):
            return True
    
    return False

model_open.eval()

@torch.no_grad()
def eval_open_accuracy(eval_loader, dataset_name):
    exact_match = 0  # 完全匹配
    partial_match = 0  # 部分匹配（预测答案包含真实答案的关键词）
    total = 0
    # 存储所有结果用于打印
    all_results = []

    for images, questions, answers in tqdm(eval_loader, desc=f"Evaluating {dataset_name}"):
        inputs = processor(
            images=list(images),
            text=list(questions),
            return_tensors="pt",
            padding=True,
            truncation=True,
        ).to(device)

        # 生成答案文本（open 问题答案可能较长，增加 max_new_tokens）
        generated_ids = model_open.generate(**inputs, max_new_tokens=20)
        preds = processor.batch_decode(generated_ids, skip_special_tokens=True)

        for pred_raw, gt_raw, question in zip(preds, answers, questions):
            pred_norm = normalize_answer(pred_raw).lower()
            gt_norm = normalize_answer(gt_raw).lower()

            # 使用宽松的语义匹配
            is_exact = (pred_norm == gt_norm)
            is_semantic_match = is_semantically_similar_open(pred_norm, gt_norm)
            
            if is_exact:
                exact_match += 1
                partial_match += 1
            elif is_semantic_match:
                partial_match += 1

            total += 1
            
            # 保存结果
            all_results.append({
                "question": question,
                "pred_raw": pred_raw,
                "pred_norm": pred_norm,
                "true_ans_raw": gt_raw,
                "true_ans_norm": gt_norm,
                "is_exact": is_exact,
                "is_partial": is_semantic_match
            })

    exact_acc = exact_match / max(1, total)
    partial_acc = partial_match / max(1, total)

    print(f"\n{dataset_name} 评估结果:")
    print(f"  精确匹配准确率: {exact_acc*100:.2f}%  (exact={exact_match}, total={total})")
    print(f"  语义匹配准确率（宽松）: {partial_acc*100:.2f}%  (semantic={partial_match}, total={total})")
    print(f"  Overall accuracy：{exact_acc*100:.2f}% + {partial_acc*100:.2f}%")
    # 打印前10个和后10个样本
    print("\n" + "="*80)
    print("前 10 个样本的预测结果对比:")
    print("="*80)
    for i, result in enumerate(all_results[:10], 1):
        exact_status = "✓ 精确匹配" if result["is_exact"] else ("✓ 部分匹配" if result["is_partial"] else "✗ 不匹配")
        print(f"\n[样本 {i}] {exact_status}")
        print(f"  问题: {result['question']}")
        print(f"  真实答案: {result['true_ans_raw']} (归一化: {result['true_ans_norm']})")
        print(f"  预测答案: {result['pred_raw']} (归一化: {result['pred_norm']})")
    
    print("\n" + "="*80)
    print("后 10 个样本的预测结果对比:")
    print("="*80)
    for i, result in enumerate(all_results[-10:], len(all_results)-9):
        exact_status = "✓ 精确匹配" if result["is_exact"] else ("✓ 部分匹配" if result["is_partial"] else "✗ 不匹配")
        print(f"\n[样本 {i}] {exact_status}")
        print(f"  问题: {result['question']}")
        print(f"  真实答案: {result['true_ans_raw']} (归一化: {result['true_ans_norm']})")
        print(f"  预测答案: {result['pred_raw']} (归一化: {result['pred_norm']})")
    print("="*80)


model_open.eval()
eval_open_accuracy(eval_open_test_loader, "测试集 (open)")


Evaluating 测试集 (open): 100%|██████████| 27/27 [00:05<00:00,  4.79it/s]


测试集 (open) 评估结果:
  精确匹配准确率: 13.27%  (exact=28, total=211)
  语义匹配准确率（宽松）: 28.91%  (semantic=61, total=211)
  Overall accuracy：13.27% + 28.91%

前 10 个样本的预测结果对比:

[样本 1] ✓ 部分匹配
  问题: what pathology is demonstrated?
  真实答案: cardiomegaly (归一化: cardiomegaly)
  预测答案: cardiomegaly with edema (归一化: cardiomegaly with edema)

[样本 2] ✗ 不匹配
  问题: What is the pathology?
  真实答案: right sided pleural effusion (归一化: right sided pleural effusion)
  预测答案: cardiomegaly with edema (归一化: cardiomegaly with edema)

[样本 3] ✓ 精确匹配
  问题: This image is taken in what plane?
  真实答案: axial (归一化: axial)
  预测答案: axial (归一化: axial)

[样本 4] ✗ 不匹配
  问题: Where do we see multiple infarcts in the above image?
  真实答案: cerebellum (归一化: cerebellum)
  预测答案: left thalamus (归一化: left thalamus)

[样本 5] ✗ 不匹配
  问题: What kind of image is this?
  真实答案: x-ray (归一化: x-ray)
  预测答案: ct (归一化: ct)

[样本 6] ✓ 精确匹配
  问题: Was this MRI taken with or without contrast?
  真实答案: with contrast (归一化: with contrast)
  预测答案: with contrast (归一化: with co


