In [None]:
import os
import json
from typing import Tuple

from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from sklearn.model_selection import train_test_split
import numpy as np
import random
from tqdm.auto import tqdm
from torch.optim import AdamW
from collections import Counter

DATA_JSON_PATH = "data/archive (1)/VQA_RAD Dataset Public.json"
DATA_IMAGE_DIR = "data/archive (1)/VQA_RAD Image Folder"


def normalize_answer(ans: str) -> str:
    """简单归一化答案，用于区分close(是/否)和open类型。

    - 先 lower + 去空格
    - 去掉句末标点
    - 规范 "yes" / "no" 写法
    - 对部分 open 答案做同义词合并（如 right side -> right）
    """
    if ans is None:
        return ""
    ans = str(ans).lower().strip()
    ans_clean = ans.rstrip('.,!?;:').strip()

    # 先规范 yes/no
    if ans_clean == "yes":
        return "yes"
    if ans_clean == "no":
        return "no"

    if ans_clean.startswith("yes"):
        if len(ans_clean) == 3 or ans_clean[3] in [" ", ",", ".", "!", "?", ";", ":"]:
            return "yes"
    if ans_clean.startswith("no"):
        if len(ans_clean) == 2 or ans_clean[2] in [" ", ",", ".", "!", "?", ";", ":"]:
            return "no"

    # 针对 open 问题的一些简单同义合并
    synonym_map = {
        "right side": "right",
        "left side": "left",
        "rt": "right",
        "lt": "left",
        "xray": "x-ray",
        "x ray": "x-ray",
        "ct scan": "ct",
    }
    if ans_clean in synonym_map:
        return synonym_map[ans_clean]
    
    ans_clean = ans_clean.replace("xray", "x-ray").replace("x ray", "x-ray")
    return ans_clean


class VQARADBaselineDataset(Dataset):
    """基础 VQA-RAD 数据集类，只负责读图像/文本/答案。

    之后我们会在此之上构建不同模型（CNN baseline, Transformer 等）。
    """

    def __init__(self, json_path: str = DATA_JSON_PATH, image_dir: str = DATA_IMAGE_DIR,
                 transform=None):
        assert os.path.exists(json_path), f"JSON 文件不存在: {json_path}"
        assert os.path.exists(image_dir), f"图像目录不存在: {image_dir}"

        with open(json_path, "r", encoding="utf-8") as f:
            self.data = json.load(f)

        self.image_dir = image_dir
        # 图像增强 + 标准化：有助于减轻过拟合
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self) -> int:
        return len(self.data)

    def _get_image_path(self, item) -> str:
        # 兼容 "image" / "image_name" 两种字段
        image_key = "image" if "image" in item else "image_name"
        return os.path.join(self.image_dir, item[image_key])

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, str, str]:
        """返回: (image, question, answer, q_type)

        - q_type: "close" 表示是/否题; "open" 表示开放式问题
        """
        item = self.data[idx]
        img_path = self._get_image_path(item)

        image = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)

        question = item["question"]
        answer_raw = item.get("answer", "")
        ans_norm = normalize_answer(answer_raw)
        q_type = "close" if ans_norm in ["yes", "no"] else "open"

        return image, question, str(answer_raw), q_type

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import random

# 设置随机种子以确保结果可复现
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)

# 1. 加载完整数据集
full_dataset = VQARADBaselineDataset(
    json_path=DATA_JSON_PATH,
    image_dir=DATA_IMAGE_DIR,
)

print(f"数据集总样本数: {len(full_dataset)}")

# 2. 根据答案把样本划分为 close (yes/no) 和 open
close_indices = []
open_indices = []

for idx in range(len(full_dataset)):
    _, _, ans_raw, q_type = full_dataset[idx]
    if q_type == "close":
        close_indices.append(idx)
    else:
        open_indices.append(idx)

print(f"Close (是/否) 样本数: {len(close_indices)}")
print(f"Open  (开放式) 样本数: {len(open_indices)}")

# 3. 对close和open数据集分别进行8:2的划分（训练集:测试集）
close_train_idx, close_test_idx = train_test_split(
    close_indices, 
    test_size=0.2,  # 20% 作为测试集
    random_state=42,
    shuffle=True
)

open_train_idx, open_test_idx = train_test_split(
    open_indices,
    test_size=0.2,  # 20% 作为测试集
    random_state=42,
    shuffle=True
)

print("\nClose (yes/no) 数据集划分:")
print(f"  训练集: {len(close_train_idx)} 个样本 ({len(close_train_idx)/len(close_indices)*100:.1f}%)")
print(f"  测试集: {len(close_test_idx)} 个样本 ({len(close_test_idx)/len(close_indices)*100:.1f}%)")

print("\nOpen 数据集划分:")
print(f"  训练集: {len(open_train_idx)} 个样本 ({len(open_train_idx)/len(open_indices)*100:.1f}%)")
print(f"  测试集: {len(open_test_idx)} 个样本 ({len(open_test_idx)/len(open_indices)*100:.1f}%)")

# 4. 创建数据集子集
close_train_dataset = Subset(full_dataset, close_train_idx)
close_test_dataset = Subset(full_dataset, close_test_idx)

open_train_dataset = Subset(full_dataset, open_train_idx)
open_test_dataset = Subset(full_dataset, open_test_idx)

# 5. 基于所有 question 构建简单词表，用于文本 Transformer
def tokenize(text: str):
    text = str(text).lower().strip()
    # 简单按空格切分即可；医学术语也能基本覆盖
    return text.replace("?", " ").replace(",", " ").split()


counter = Counter()
for item in full_dataset.data:
    q = item.get("question", "")
    tokens = tokenize(q)
    counter.update(tokens)

# 特殊符号
word2idx = {"<pad>": 0, "<unk>": 1}
for w, c in counter.items():
    # 过滤特别少见的词可以稍微减小词表，这里阈值设为 1 就是全收
    if w not in word2idx and c >= 1:
        word2idx[w] = len(word2idx)

idx2word = {i: w for w, i in word2idx.items()}
MAX_Q_LEN = 20  # 问题一般比较短，20 足够覆盖大部分

print(f"\n词表大小: {len(word2idx)}, MAX_Q_LEN = {MAX_Q_LEN}")

数据集总样本数: 2248
Close (是/否) 样本数: 1193
Open  (开放式) 样本数: 1055

Close (yes/no) 数据集划分:
  训练集: 954 个样本 (80.0%)
  测试集: 239 个样本 (20.0%)

Open 数据集划分:
  训练集: 844 个样本 (80.0%)
  测试集: 211 个样本 (20.0%)

词表大小: 1227, MAX_Q_LEN = 20


In [None]:
import torch.nn.functional as F
from torchvision import models

# ==== CNN+Transformer 模型定义 ====

class CNNTransformerVQA(nn.Module):
    """CNN + Transformer 架构的 VQA 模型
    
    使用 ResNet 提取图像特征，Transformer 编码器处理问题文本，
    然后融合两者进行答案预测。
    """
    
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, num_heads=8, 
                 num_layers=2, num_classes=2, dropout=0.3):
        """
        Args:
            vocab_size: 词表大小
            embed_dim: 词嵌入维度
            hidden_dim: Transformer 隐藏层维度
            num_heads: 多头注意力头数
            num_layers: Transformer 层数
            num_classes: 输出类别数（close任务为2，open任务为答案词表大小）
            dropout: Dropout 比率
        """
        super(CNNTransformerVQA, self).__init__()
        
        # 1. CNN 图像编码器（使用预训练的 ResNet18）
        # 兼容新旧版本的 torchvision
        try:
            resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        except:
            resnet = models.resnet18(pretrained=True)
        # 移除最后的全连接层，保留特征提取部分
        self.cnn = nn.Sequential(*list(resnet.children())[:-1])  # 输出: [B, 512, 1, 1]
        self.image_proj = nn.Linear(512, hidden_dim)  # 投影到 hidden_dim
        
        # 2. 文本编码器（Transformer Encoder）
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.word_embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # 文本投影层（从 embed_dim 到 hidden_dim）
        self.text_proj = nn.Linear(embed_dim, hidden_dim)
        
        # 3. 多模态融合层
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # 4. 分类/回归头
        self.classifier = nn.Linear(hidden_dim, num_classes)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, images, questions):
        """
        Args:
            images: [B, 3, 224, 224] 图像张量
            questions: [B, seq_len] 问题文本的 token ids
        
        Returns:
            logits: [B, num_classes] 分类 logits
        """
        batch_size = images.size(0)
        
        # 1. 图像特征提取
        img_features = self.cnn(images)  # [B, 512, 1, 1]
        img_features = img_features.view(batch_size, -1)  # [B, 512]
        img_features = self.image_proj(img_features)  # [B, hidden_dim]
        img_features = self.dropout(img_features)
        
        # 2. 文本特征提取
        # 词嵌入
        q_embed = self.word_embed(questions)  # [B, seq_len, embed_dim]
        q_embed = self.text_proj(q_embed)  # [B, seq_len, hidden_dim]
        
        # Transformer 编码
        # 创建 padding mask（0 表示 padding）
        padding_mask = (questions == 0)  # [B, seq_len]
        q_features = self.transformer(q_embed, src_key_padding_mask=padding_mask)  # [B, seq_len, hidden_dim]
        
        # 取最后一个非 padding 位置的输出，或使用平均池化
        # 方法1: 取最后一个非 padding token
        # 方法2: 使用平均池化（忽略 padding）
        q_mask = (~padding_mask).float().unsqueeze(-1)  # [B, seq_len, 1]
        q_features = (q_features * q_mask).sum(dim=1) / (q_mask.sum(dim=1) + 1e-8)  # [B, hidden_dim]
        q_features = self.dropout(q_features)
        
        # 3. 多模态融合
        combined = torch.cat([img_features, q_features], dim=1)  # [B, hidden_dim * 2]
        fused = self.fusion(combined)  # [B, hidden_dim]
        
        # 4. 分类
        logits = self.classifier(fused)  # [B, num_classes]
        
        return logits


# 辅助函数：将问题文本转换为 token ids
def encode_question(question, word2idx, max_len=MAX_Q_LEN):
    """将问题文本编码为 token ids"""
    tokens = tokenize(question)
    token_ids = [word2idx.get(token, word2idx["<unk>"]) for token in tokens]
    # Padding 或截断
    if len(token_ids) < max_len:
        token_ids = token_ids + [word2idx["<pad>"]] * (max_len - len(token_ids))
    else:
        token_ids = token_ids[:max_len]
    return torch.tensor(token_ids, dtype=torch.long)


# 自定义 collate 函数
def collate_fn(batch):
    """将 batch 中的样本打包"""
    images, questions, answers, q_types = zip(*batch)
    
    # 图像已经是 tensor，直接 stack
    images = torch.stack(images)
    
    # 问题文本转换为 token ids
    question_ids = []
    for q in questions:
        q_ids = encode_question(q, word2idx, MAX_Q_LEN)
        question_ids.append(q_ids)
    question_ids = torch.stack(question_ids)
    
    return images, question_ids, list(answers), list(q_types)


print("✓ CNN+Transformer 模型已定义")

✓ CNN+Transformer 模型已定义


In [None]:
# ==== 训练 Close (Yes/No) 任务 ====

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 创建 DataLoader
batch_size = 32
close_train_loader = DataLoader(
    close_train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    collate_fn=collate_fn
)
close_test_loader = DataLoader(
    close_test_dataset, 
    batch_size=batch_size, 
    shuffle=False,
    collate_fn=collate_fn
)

print(f"Close 训练集: {len(close_train_dataset)} 个样本, {len(close_train_loader)} batches")
print(f"Close 测试集: {len(close_test_dataset)} 个样本, {len(close_test_loader)} batches")

# 创建模型（close 任务：2 分类）
model_close = CNNTransformerVQA(
    vocab_size=len(word2idx),
    embed_dim=256,
    hidden_dim=512,
    num_heads=8,
    num_layers=2,
    num_classes=2,  # yes/no 二分类
    dropout=0.3
).to(device)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model_close.parameters(), lr=1e-4, weight_decay=1e-5)

# 训练参数
num_epochs = 10
print(f"\n开始训练 Close (Yes/No) 任务，共 {num_epochs} 个 epoch...")

best_test_acc = 0.0

for epoch in range(num_epochs):
    # 训练阶段
    model_close.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for images, question_ids, answers, q_types in tqdm(close_train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images = images.to(device)
        question_ids = question_ids.to(device)
        
        # 准备标签（yes=1, no=0）
        labels = []
        for ans in answers:
            ans_norm = normalize_answer(ans)
            if ans_norm == "yes":
                labels.append(1)
            else:
                labels.append(0)
        labels = torch.tensor(labels, dtype=torch.long).to(device)
        
        # 前向传播
        optimizer.zero_grad()
        logits = model_close(images, question_ids)
        loss = criterion(logits, labels)
        
        # 反向传播
        loss.backward()
        optimizer.step()
        
        # 统计
        train_loss += loss.item()
        _, predicted = torch.max(logits.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
    
    train_acc = 100.0 * train_correct / train_total
    avg_train_loss = train_loss / len(close_train_loader)
    
    # 测试阶段
    model_close.eval()
    test_correct = 0
    test_total = 0
    
    with torch.no_grad():
        for images, question_ids, answers, q_types in close_test_loader:
            images = images.to(device)
            question_ids = question_ids.to(device)
            
            # 准备标签
            labels = []
            for ans in answers:
                ans_norm = normalize_answer(ans)
                if ans_norm == "yes":
                    labels.append(1)
                else:
                    labels.append(0)
            labels = torch.tensor(labels, dtype=torch.long).to(device)
            
            # 前向传播
            logits = model_close(images, question_ids)
            _, predicted = torch.max(logits.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
    
    test_acc = 100.0 * test_correct / test_total
    
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  训练 Loss: {avg_train_loss:.4f}, 训练准确率: {train_acc:.2f}%")
    print(f"  测试准确率: {test_acc:.2f}%")
    
    # 保存最佳模型
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        torch.save(model_close.state_dict(), "cnn_transformer_best_model.pth")
        print(f"  ✓ 保存最佳模型 (测试准确率: {test_acc:.2f}%)")
    print()

print(f"训练完成！最佳测试准确率: {best_test_acc:.2f}%")
print(f"模型已保存到: cnn_transformer_best_model.pth")

使用设备: cuda
Close 训练集: 954 个样本, 30 batches
Close 测试集: 239 个样本, 8 batches

开始训练 Close (Yes/No) 任务，共 10 个 epoch...


Epoch 1/10: 100%|██████████| 30/30 [00:04<00:00,  6.22it/s]
  output = torch._nested_tensor_from_mask(


Epoch 1/10:
  训练 Loss: 0.6799, 训练准确率: 55.66%
  测试准确率: 60.25%
  ✓ 保存最佳模型 (测试准确率: 60.25%)



Epoch 2/10: 100%|██████████| 30/30 [00:04<00:00,  6.22it/s]


Epoch 2/10:
  训练 Loss: 0.6252, 训练准确率: 63.63%
  测试准确率: 65.69%
  ✓ 保存最佳模型 (测试准确率: 65.69%)



Epoch 3/10: 100%|██████████| 30/30 [00:04<00:00,  6.23it/s]


Epoch 3/10:
  训练 Loss: 0.5478, 训练准确率: 71.38%
  测试准确率: 63.18%



Epoch 4/10: 100%|██████████| 30/30 [00:04<00:00,  6.26it/s]


Epoch 4/10:
  训练 Loss: 0.4789, 训练准确率: 78.62%
  测试准确率: 68.20%
  ✓ 保存最佳模型 (测试准确率: 68.20%)



Epoch 5/10: 100%|██████████| 30/30 [00:04<00:00,  6.23it/s]


Epoch 5/10:
  训练 Loss: 0.3867, 训练准确率: 83.65%
  测试准确率: 66.11%



Epoch 6/10: 100%|██████████| 30/30 [00:04<00:00,  6.27it/s]


Epoch 6/10:
  训练 Loss: 0.3619, 训练准确率: 86.27%
  测试准确率: 65.69%



Epoch 7/10: 100%|██████████| 30/30 [00:04<00:00,  6.19it/s]


Epoch 7/10:
  训练 Loss: 0.2931, 训练准确率: 87.53%
  测试准确率: 70.29%
  ✓ 保存最佳模型 (测试准确率: 70.29%)



Epoch 8/10: 100%|██████████| 30/30 [00:04<00:00,  6.19it/s]


Epoch 8/10:
  训练 Loss: 0.2537, 训练准确率: 89.73%
  测试准确率: 64.44%



Epoch 9/10: 100%|██████████| 30/30 [00:04<00:00,  6.20it/s]


Epoch 9/10:
  训练 Loss: 0.2346, 训练准确率: 90.88%
  测试准确率: 67.36%



Epoch 10/10: 100%|██████████| 30/30 [00:04<00:00,  6.20it/s]


Epoch 10/10:
  训练 Loss: 0.2132, 训练准确率: 92.03%
  测试准确率: 69.46%

训练完成！最佳测试准确率: 70.29%
模型已保存到: cnn_transformer_best_model.pth


In [None]:
# ==== 评估 Close 任务 ====

# 加载最佳模型
model_close.load_state_dict(torch.load("cnn_transformer_best_model.pth", map_location=device))
model_close.eval()

print("评估 Close (Yes/No) 任务...")
print("=" * 80)

test_correct = 0
test_total = 0
all_results = []

with torch.no_grad():
    for images, question_ids, answers, q_types in tqdm(close_test_loader, desc="评估中"):
        images = images.to(device)
        question_ids = question_ids.to(device)
        
        # 准备标签
        labels = []
        for ans in answers:
            ans_norm = normalize_answer(ans)
            if ans_norm == "yes":
                labels.append(1)
            else:
                labels.append(0)
        labels = torch.tensor(labels, dtype=torch.long).to(device)
        
        # 前向传播
        logits = model_close(images, question_ids)
        probs = F.softmax(logits, dim=1)
        _, predicted = torch.max(logits.data, 1)
        
        # 统计准确率
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()
        
        # 保存结果用于详细分析
        for i in range(len(answers)):
            pred_label = predicted[i].item()
            true_label = labels[i].item()
            pred_ans = "yes" if pred_label == 1 else "no"
            true_ans_norm = normalize_answer(answers[i])
            
            all_results.append({
                "question": "",  # 如果需要可以保存问题
                "true_answer": answers[i],
                "true_answer_norm": true_ans_norm,
                "pred_answer": pred_ans,
                "is_correct": pred_label == true_label,
                "confidence": probs[i][pred_label].item()
            })

test_acc = 100.0 * test_correct / test_total
print(f"\n测试准确率: {test_acc:.2f}% ({test_correct}/{test_total})")

# 打印前10个和后10个样本
print("\n" + "=" * 80)
print("前 10 个样本的预测结果:")
print("=" * 80)
for i, result in enumerate(all_results[:10], 1):
    status = "✓ 正确" if result["is_correct"] else "✗ 错误"
    print(f"\n[样本 {i}] {status}")
    print(f"  真实答案: {result['true_answer']} (归一化: {result['true_answer_norm']})")
    print(f"  预测答案: {result['pred_answer']} (置信度: {result['confidence']:.3f})")

print("\n" + "=" * 80)
print("后 10 个样本的预测结果:")
print("=" * 80)
for i, result in enumerate(all_results[-10:], len(all_results)-9):
    status = "✓ 正确" if result["is_correct"] else "✗ 错误"
    print(f"\n[样本 {i}] {status}")
    print(f"  真实答案: {result['true_answer']} (归一化: {result['true_answer_norm']})")
    print(f"  预测答案: {result['pred_answer']} (置信度: {result['confidence']:.3f})")
print("=" * 80)

评估 Close (Yes/No) 任务...


评估中: 100%|██████████| 8/8 [00:01<00:00,  7.91it/s]


测试准确率: 69.87% (167/239)

前 10 个样本的预测结果:

[样本 1] ✓ 正确
  真实答案: No (归一化: no)
  预测答案: no (置信度: 0.891)

[样本 2] ✓ 正确
  真实答案: no (归一化: no)
  预测答案: no (置信度: 0.948)

[样本 3] ✓ 正确
  真实答案: yes (归一化: yes)
  预测答案: yes (置信度: 0.878)

[样本 4] ✗ 错误
  真实答案: yes (归一化: yes)
  预测答案: no (置信度: 0.964)

[样本 5] ✓ 正确
  真实答案: yes (归一化: yes)
  预测答案: yes (置信度: 0.882)

[样本 6] ✓ 正确
  真实答案: No (归一化: no)
  预测答案: no (置信度: 0.997)

[样本 7] ✓ 正确
  真实答案: no (归一化: no)
  预测答案: no (置信度: 0.916)

[样本 8] ✗ 错误
  真实答案: yes (归一化: yes)
  预测答案: no (置信度: 0.907)

[样本 9] ✓ 正确
  真实答案: No (归一化: no)
  预测答案: no (置信度: 0.954)

[样本 10] ✓ 正确
  真实答案: Yes (归一化: yes)
  预测答案: yes (置信度: 0.998)

后 10 个样本的预测结果:

[样本 230] ✗ 错误
  真实答案: Yes (归一化: yes)
  预测答案: no (置信度: 0.883)

[样本 231] ✓ 正确
  真实答案: Yes (归一化: yes)
  预测答案: yes (置信度: 0.955)

[样本 232] ✓ 正确
  真实答案: No (归一化: no)
  预测答案: no (置信度: 0.997)

[样本 233] ✗ 错误
  真实答案: No (归一化: no)
  预测答案: yes (置信度: 0.831)

[样本 234] ✓ 正确
  真实答案: Yes (归一化: yes)
  预测答案: yes (置信度: 0.996)

[样本 235] ✓ 正确
  真实答案: No (归一化: no)
  预测




In [None]:
# ==== 训练 Open 任务（检索式方法） ====

# 构建答案词表（从训练集中收集所有答案）
answer_vocab = set()
for idx in open_train_idx:
    _, _, ans_raw, _ = full_dataset[idx]
    ans_norm = normalize_answer(ans_raw)
    if ans_norm and ans_norm not in ["yes", "no"]:  # 排除 yes/no
        answer_vocab.add(ans_norm)

answer_vocab = sorted(list(answer_vocab))
answer_vocab_size = len(answer_vocab)
answer2idx = {ans: idx for idx, ans in enumerate(answer_vocab)}
idx2answer = {idx: ans for ans, idx in answer2idx.items()}

print(f"Open 任务答案词表大小: {answer_vocab_size}")
print(f"前10个答案示例: {answer_vocab[:10]}")

# 创建 DataLoader
open_train_loader = DataLoader(
    open_train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    collate_fn=collate_fn
)
open_test_loader = DataLoader(
    open_test_dataset, 
    batch_size=batch_size, 
    shuffle=False,
    collate_fn=collate_fn
)

print(f"\nOpen 训练集: {len(open_train_dataset)} 个样本, {len(open_train_loader)} batches")
print(f"Open 测试集: {len(open_test_dataset)} 个样本, {len(open_test_loader)} batches")

# 创建模型（open 任务：多分类，类别数为答案词表大小）
model_open = CNNTransformerVQA(
    vocab_size=len(word2idx),
    embed_dim=256,
    hidden_dim=512,
    num_heads=8,
    num_layers=2,
    num_classes=answer_vocab_size,  # 答案词表大小
    dropout=0.3
).to(device)

# 损失函数和优化器
criterion_open = nn.CrossEntropyLoss()
optimizer_open = AdamW(model_open.parameters(), lr=1e-4, weight_decay=1e-5)

# 训练参数
num_epochs_open = 10
print(f"\n开始训练 Open 任务，共 {num_epochs_open} 个 epoch...")

best_test_acc_open = 0.0

for epoch in range(num_epochs_open):
    # 训练阶段
    model_open.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for images, question_ids, answers, q_types in tqdm(open_train_loader, desc=f"Epoch {epoch+1}/{num_epochs_open}"):
        images = images.to(device)
        question_ids = question_ids.to(device)
        
        # 准备标签（答案在答案词表中的索引）
        labels = []
        valid_indices = []
        for i, ans in enumerate(answers):
            ans_norm = normalize_answer(ans)
            if ans_norm in answer2idx:
                labels.append(answer2idx[ans_norm])
                valid_indices.append(i)
        
        if len(labels) == 0:  # 如果这个 batch 中没有有效答案，跳过
            continue
        
        # 只处理有效的样本
        images = images[valid_indices]
        question_ids = question_ids[valid_indices]
        labels = torch.tensor(labels, dtype=torch.long).to(device)
        
        # 前向传播
        optimizer_open.zero_grad()
        logits = model_open(images, question_ids)
        loss = criterion_open(logits, labels)
        
        # 反向传播
        loss.backward()
        optimizer_open.step()
        
        # 统计
        train_loss += loss.item()
        _, predicted = torch.max(logits.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
    
    if train_total > 0:
        train_acc = 100.0 * train_correct / train_total
        avg_train_loss = train_loss / max(1, len(open_train_loader))
    else:
        train_acc = 0.0
        avg_train_loss = 0.0
    
    # 测试阶段
    model_open.eval()
    test_correct = 0
    test_total = 0
    
    with torch.no_grad():
        for images, question_ids, answers, q_types in open_test_loader:
            images = images.to(device)
            question_ids = question_ids.to(device)
            
            # 准备标签
            labels = []
            valid_indices = []
            for i, ans in enumerate(answers):
                ans_norm = normalize_answer(ans)
                if ans_norm in answer2idx:
                    labels.append(answer2idx[ans_norm])
                    valid_indices.append(i)
            
            if len(labels) == 0:
                continue
            
            # 只处理有效的样本
            images = images[valid_indices]
            question_ids = question_ids[valid_indices]
            labels = torch.tensor(labels, dtype=torch.long).to(device)
            
            # 前向传播
            logits = model_open(images, question_ids)
            _, predicted = torch.max(logits.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
    
    if test_total > 0:
        test_acc = 100.0 * test_correct / test_total
    else:
        test_acc = 0.0
    
    print(f"Epoch {epoch+1}/{num_epochs_open}:")
    print(f"  训练 Loss: {avg_train_loss:.4f}, 训练准确率: {train_acc:.2f}%")
    print(f"  测试准确率: {test_acc:.2f}%")
    
    # 保存最佳模型
    if test_acc > best_test_acc_open:
        best_test_acc_open = test_acc
        torch.save({
            'model_state_dict': model_open.state_dict(),
            'answer_vocab': answer_vocab,
            'answer2idx': answer2idx,
            'idx2answer': idx2answer
        }, "cnn_transformer_open_best_model.pth")
        print(f"  ✓ 保存最佳模型 (测试准确率: {test_acc:.2f}%)")
    print()

print(f"训练完成！最佳测试准确率: {best_test_acc_open:.2f}%")
print(f"模型已保存到: cnn_transformer_open_best_model.pth")

Open 任务答案词表大小: 446
前10个答案示例: ['10-20 minutes', '12', '2', '2.5cm x 1.7cm x 1.6cm', '3.4 cm', '4', '4th ventricle', '5%', '5.6cm focal, predominantly hypodense', '5cm']

Open 训练集: 844 个样本, 27 batches
Open 测试集: 211 个样本, 7 batches

开始训练 Open 任务，共 10 个 epoch...


Epoch 1/10: 100%|██████████| 27/27 [00:04<00:00,  6.59it/s]


Epoch 1/10:
  训练 Loss: 6.0842, 训练准确率: 2.13%
  测试准确率: 6.21%
  ✓ 保存最佳模型 (测试准确率: 6.21%)



Epoch 2/10: 100%|██████████| 27/27 [00:04<00:00,  6.68it/s]


Epoch 2/10:
  训练 Loss: 5.9251, 训练准确率: 4.62%
  测试准确率: 6.21%



Epoch 3/10: 100%|██████████| 27/27 [00:04<00:00,  6.60it/s]


Epoch 3/10:
  训练 Loss: 5.7168, 训练准确率: 4.74%
  测试准确率: 11.03%
  ✓ 保存最佳模型 (测试准确率: 11.03%)



Epoch 4/10: 100%|██████████| 27/27 [00:04<00:00,  6.71it/s]


Epoch 4/10:
  训练 Loss: 5.4894, 训练准确率: 6.28%
  测试准确率: 12.41%
  ✓ 保存最佳模型 (测试准确率: 12.41%)



Epoch 5/10: 100%|██████████| 27/27 [00:04<00:00,  6.70it/s]


Epoch 5/10:
  训练 Loss: 5.2716, 训练准确率: 7.94%
  测试准确率: 15.17%
  ✓ 保存最佳模型 (测试准确率: 15.17%)



Epoch 6/10: 100%|██████████| 27/27 [00:04<00:00,  6.67it/s]


Epoch 6/10:
  训练 Loss: 5.0917, 训练准确率: 10.43%
  测试准确率: 15.86%
  ✓ 保存最佳模型 (测试准确率: 15.86%)



Epoch 7/10: 100%|██████████| 27/27 [00:04<00:00,  6.69it/s]


Epoch 7/10:
  训练 Loss: 4.8286, 训练准确率: 11.37%
  测试准确率: 15.86%



Epoch 8/10: 100%|██████████| 27/27 [00:03<00:00,  6.84it/s]


Epoch 8/10:
  训练 Loss: 4.6464, 训练准确率: 13.74%
  测试准确率: 15.17%



Epoch 9/10: 100%|██████████| 27/27 [00:03<00:00,  6.85it/s]


Epoch 9/10:
  训练 Loss: 4.4197, 训练准确率: 15.88%
  测试准确率: 17.24%
  ✓ 保存最佳模型 (测试准确率: 17.24%)



Epoch 10/10: 100%|██████████| 27/27 [00:03<00:00,  6.76it/s]


Epoch 10/10:
  训练 Loss: 4.2214, 训练准确率: 18.84%
  测试准确率: 16.55%

训练完成！最佳测试准确率: 17.24%
模型已保存到: cnn_transformer_open_best_model.pth


In [None]:
# ==== 评估 Open 任务 ====

# 加载最佳模型和答案词表
checkpoint = torch.load("cnn_transformer_open_best_model.pth", map_location=device)
model_open.load_state_dict(checkpoint['model_state_dict'])
answer_vocab = checkpoint['answer_vocab']
answer2idx = checkpoint['answer2idx']
idx2answer = checkpoint['idx2answer']

model_open.eval()

print("评估 Open 任务...")
print("=" * 80)

test_correct = 0
test_total = 0
all_results = []

with torch.no_grad():
    for images, question_ids, answers, q_types in tqdm(open_test_loader, desc="评估中"):
        images = images.to(device)
        question_ids = question_ids.to(device)
        
        # 前向传播
        logits = model_open(images, question_ids)
        probs = F.softmax(logits, dim=1)
        _, predicted = torch.max(logits.data, 1)
        
        # 统计准确率
        for i in range(len(answers)):
            ans_norm = normalize_answer(answers[i])
            pred_idx = predicted[i].item()
            pred_ans = idx2answer.get(pred_idx, "<unknown>")
            
            # 精确匹配
            is_correct = (pred_ans == ans_norm)
            
            if ans_norm in answer2idx:  # 只统计在答案词表中的样本
                test_total += 1
                if is_correct:
                    test_correct += 1
            
            all_results.append({
                "true_answer": answers[i],
                "true_answer_norm": ans_norm,
                "pred_answer": pred_ans,
                "is_correct": is_correct,
                "confidence": probs[i][pred_idx].item()
            })

if test_total > 0:
    test_acc = 100.0 * test_correct / test_total
    print(f"\n测试准确率: {test_acc:.2f}% ({test_correct}/{test_total})")
else:
    print("\n没有有效的测试样本（答案不在答案词表中）")

# 打印前10个和后10个样本
print("\n" + "=" * 80)
print("前 10 个样本的预测结果:")
print("=" * 80)
for i, result in enumerate(all_results[:10], 1):
    status = "✓ 正确" if result["is_correct"] else "✗ 错误"
    print(f"\n[样本 {i}] {status}")
    print(f"  真实答案: {result['true_answer']} (归一化: {result['true_answer_norm']})")
    print(f"  预测答案: {result['pred_answer']} (置信度: {result['confidence']:.3f})")

print("\n" + "=" * 80)
print("后 10 个样本的预测结果:")
print("=" * 80)
for i, result in enumerate(all_results[-10:], len(all_results)-9):
    status = "✓ 正确" if result["is_correct"] else "✗ 错误"
    print(f"\n[样本 {i}] {status}")
    print(f"  真实答案: {result['true_answer']} (归一化: {result['true_answer_norm']})")
    print(f"  预测答案: {result['pred_answer']} (置信度: {result['confidence']:.3f})")
print("=" * 80)

评估 Open 任务...


评估中: 100%|██████████| 7/7 [00:00<00:00,  8.30it/s]


测试准确率: 17.24% (25/145)

前 10 个样本的预测结果:

[样本 1] ✗ 错误
  真实答案: cardiomegaly (归一化: cardiomegaly)
  预测答案: female (置信度: 0.017)

[样本 2] ✗ 错误
  真实答案: right sided pleural effusion (归一化: right sided pleural effusion)
  预测答案: female (置信度: 0.038)

[样本 3] ✓ 正确
  真实答案: axial (归一化: axial)
  预测答案: axial (置信度: 0.959)

[样本 4] ✗ 错误
  真实答案: Cerebellum (归一化: cerebellum)
  预测答案: brain (置信度: 0.093)

[样本 5] ✗ 错误
  真实答案: x-ray (归一化: x-ray)
  预测答案: chest x-ray (置信度: 0.190)

[样本 6] ✗ 错误
  真实答案: with contrast (归一化: with contrast)
  预测答案: mri (置信度: 0.046)

[样本 7] ✗ 错误
  真实答案: The small intestines (归一化: the small intestines)
  预测答案: diffuse (置信度: 0.005)

[样本 8] ✗ 错误
  真实答案: Not sure (归一化: not sure)
  预测答案: fat (置信度: 0.023)

[样本 9] ✗ 错误
  真实答案: Left Parietal lobe (归一化: left parietal lobe)
  预测答案: right (置信度: 0.057)

[样本 10] ✗ 错误
  真实答案: Left rectus abdominus (归一化: left rectus abdominus)
  预测答案: right temporal lobe (置信度: 0.021)

后 10 个样本的预测结果:

[样本 202] ✓ 正确
  真实答案: Right side (归一化: right)
  预测答案: right (置信度: 0.651)


