# 步骤 2.3：基线模型训练与验证

本 Notebook 基于 `plan1.md` 的 “#### 2.3 基线模型训练与验证” 规划，并复用 `step_2_2-特征融合与匹配机制` 中的实现思路：
- 使用文本与图像特征提取器（BERT + ResNet50）
- 使用特征融合（投影到共享空间）与相似度计算（余弦相似度）
- 构建对比学习（InfoNCE）训练循环，优化投影层
- 在验证集上评估 Recall@1/5/10 并报告 MeanRecall


## 1. 环境准备
- 统一设置本地缓存目录 `/mnt/d/HuggingFaceModels`，仅从本地加载
- 导入依赖与数据加载组件


In [1]:
# !pip install -q transformers torch timm torchvision tqdm

In [2]:
import os
import sys
import math
from typing import List, Dict
import torch
import timm
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel
from tqdm import tqdm

# 仅使用本地缓存
cache_dir = "/mnt/d/HuggingFaceModels"
os.environ['TORCH_HOME'] = cache_dir
os.environ['HF_HOME'] = cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"

# 导入数据加载器
module_path = os.path.abspath(os.path.join('.'))
if module_path not in sys.path:
    sys.path.append(module_path)
from data_loader import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


## 2. 特征提取器与融合/相似度模块
与 `step_2_2` 保持一致：
- 文本：`bert-base-chinese` 的 [CLS] 输出（768维）
- 图像：`resnet50` 的全局特征（2048维）
- 融合：线性投影到共享空间（默认512维）


In [3]:
class TextFeatureExtractor:
    def __init__(self, model_name='bert-base-chinese', device='cpu', cache_dir=None):
        self.device = device
        self.tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True)
        self.model = BertModel.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True).to(device)
        self.model.eval()
        
    def extract_text_features(self, texts: List[str]) -> torch.Tensor:
        if not texts:
            return torch.empty((0, 768), dtype=torch.float32, device=self.device)
        if isinstance(texts, str):
            texts = [texts]
        inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128).to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.last_hidden_state[:, 0, :]
    
    def extract_features(self, texts):
        return self.extract_text_features(texts)

class ImageFeatureExtractor:
    def __init__(self, model_name='resnet50', device='cpu', cache_dir=None):
        self.device = device
        # timm 将使用 TORCH_HOME 缓存目录；需确保权重已存在以避免下载
        self.model = timm.create_model(model_name, pretrained=True, num_classes=0).to(device)
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def extract_image_features(self, images: List[Image.Image]) -> torch.Tensor:
        if not images:
            return torch.empty((0, 2048), dtype=torch.float32, device=self.device)
        tensors = torch.stack([self.transform(img.convert('RGB')) for img in images]).to(self.device)
        with torch.no_grad():
            feats = self.model(tensors)
        return feats
    
    def extract_features(self, images):
        return self.extract_image_features(images)

class FeatureFusion:
    def __init__(self, fusion_method='projection', projection_dim=512, device=None):
        self.fusion_method = fusion_method
        self.projection_dim = projection_dim
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if fusion_method == 'projection':
            self.text_projector = torch.nn.Linear(768, projection_dim).to(self.device)
            self.image_projector = torch.nn.Linear(2048, projection_dim).to(self.device)
    def fuse_text_features(self, text_features: torch.Tensor) -> torch.Tensor:
        return self.text_projector(text_features) if self.fusion_method == 'projection' else text_features
    def fuse_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
        return self.image_projector(image_features) if self.fusion_method == 'projection' else image_features

class SimilarityCalculator:
    def __init__(self, similarity_type='cosine'):
        self.similarity_type = similarity_type
    def normalize_features(self, features: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.normalize(features, p=2, dim=1)
    def calculate_similarity(self, text_features: torch.Tensor, image_features: torch.Tensor) -> torch.Tensor:
        if self.similarity_type == 'cosine':
            t_n = self.normalize_features(text_features)
            i_n = self.normalize_features(image_features)
            return torch.mm(t_n, i_n.t())
        return torch.mm(text_features, image_features.t())

class CrossModalRetrievalModel:
    def __init__(self, text_extractor, image_extractor, fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=None):
        self.text_extractor = text_extractor
        self.image_extractor = image_extractor
        self.fusion = FeatureFusion(fusion_method, projection_dim, device)
        self.sim = SimilarityCalculator(similarity_type)
        self.normalize_features = normalize_features
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.normalize(x, p=2, dim=1) if self.normalize_features else x
    def extract_and_fuse_text_features(self, texts: List[str]) -> torch.Tensor:
        t = self.text_extractor.extract_features(texts)
        return self._norm(self.fusion.fuse_text_features(t))
    def extract_and_fuse_image_features(self, images: List[Image.Image]) -> torch.Tensor:
        i = self.image_extractor.extract_features(images)
        return self._norm(self.fusion.fuse_image_features(i))
    def build_image_index(self, images_dict: Dict[str, Image.Image], batch_size=32) -> Dict[str, torch.Tensor]:
        feats = {}
        keys = list(images_dict.keys())
        for s in range(0, len(keys), batch_size):
            batch_ids = keys[s:s+batch_size]
            batch_imgs = [images_dict[k] for k in batch_ids if images_dict[k] is not None]
            # 保持顺序映射；若有None，跳过
            valid_ids = [k for k in batch_ids if images_dict[k] is not None]
            if not batch_imgs:
                continue
            bf = self.extract_and_fuse_image_features(batch_imgs)
            for j, img_id in enumerate(valid_ids):
                feats[img_id] = bf[j].detach().cpu()
        return feats


## 3. 数据准备：构建训练配对与验证索引
- 按 query 的 `item_ids` 选择对应图片
- 跳过缺失或未能解码的图片


In [None]:
loader = DataLoader()
train_df = loader.load_queries(split='train')
valid_df = loader.load_queries(split='valid')

# 为训练与验证加载一定数量的图片
train_imgs = loader.create_img_id_to_image_dict(split='train', max_samples=500)
valid_imgs = loader.create_img_id_to_image_dict(split='valid', max_samples=500)
print(f'Train queries: {len(train_df)}, Train images: {len(train_imgs)}')
print(f'Valid queries: {len(valid_df)}, Valid images: {len(valid_imgs)}')

# 构建 (text, image) 训练配对
train_pairs = []
if 'item_ids' in train_df.columns:
    for _, row in train_df.iterrows():
        q = row.get('query_text', None)
        ids = row.get('item_ids', [])
        if not q or not ids:
            continue
        # 寻找第一个可用图片
        chosen_img = None
        chosen_id = None
        for iid in ids:
            sid = str(iid)
            if sid in train_imgs and train_imgs[sid] is not None:
                chosen_img = train_imgs[sid]
                chosen_id = sid
                break
        if chosen_img is not None:
            train_pairs.append((q, chosen_img, chosen_id))
print(f'Usable train pairs: {len(train_pairs)}')

# 验证：过滤出带 item_ids 的query
valid_queries = []
if 'item_ids' in valid_df.columns:
    for _, row in valid_df.iterrows():
        q = row.get('query_text', None)
        ids = [str(i) for i in row.get('item_ids', [])] if isinstance(row.get('item_ids', []), list) else []
        if q and ids:
            valid_queries.append((q, ids))
print(f'Usable valid queries: {len(valid_queries)}')


## 4. 训练：对比学习优化投影层（InfoNCE）
- 仅优化 `FeatureFusion` 的投影参数
- logits = sim(text, image) / temperature；label = 对角匹配


In [None]:
# 初始化模型
text_extractor = TextFeatureExtractor(device=device, cache_dir=cache_dir)
image_extractor = ImageFeatureExtractor(device=device, cache_dir=cache_dir)
model = CrossModalRetrievalModel(text_extractor, image_extractor, fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=device)

# 仅优化投影层参数
optim = torch.optim.Adam(
    list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()), 
    lr=1e-3, 
    weight_decay=1e-4
)
temperature = 0.07
epochs = 2
batch_size = 16

def info_nce_loss(text_feats: torch.Tensor, image_feats: torch.Tensor, temp: float) -> torch.Tensor:
    logits = torch.mm(text_feats, image_feats.t()) / temp
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_t = torch.nn.functional.cross_entropy(logits, labels)
    loss_i = torch.nn.functional.cross_entropy(logits.t(), labels)
    return (loss_t + loss_i) / 2

# 构建训练小批次
def batch_iter(pairs, bs):
    for s in range(0, len(pairs), bs):
        yield pairs[s:s+bs]

for ep in range(epochs):
    model.fusion.text_projector.train()
    model.fusion.image_projector.train()
    epoch_loss = 0.0
    steps = 0
    for batch in tqdm(batch_iter(train_pairs, batch_size), desc=f'Epoch {ep+1}/{epochs}'):
        texts = [b[0] for b in batch]
        imgs = [b[1] for b in batch]
        t_feats = model.extract_and_fuse_text_features(texts)
        i_feats = model.extract_and_fuse_image_features(imgs)
        if t_feats.size(0) == 0 or i_feats.size(0) == 0:
            continue
        optim.zero_grad()
        loss = info_nce_loss(t_feats, i_feats, temperature)
        loss.backward()
        optim.step()
        epoch_loss += loss.item()
        steps += 1
    print(f'Epoch {ep+1}: avg loss={epoch_loss/max(1,steps):.4f}')

# 冻结投影进行评估
model.fusion.text_projector.eval()
model.fusion.image_projector.eval()


## 5. 验证评估：Recall@1/5/10 与 MeanRecall
- 基于验证集构建图像索引
- 对每条查询计算相似度并统计召回


In [None]:
# 构建验证图像索引
image_index = model.build_image_index(valid_imgs, batch_size=32)
all_image_ids = list(image_index.keys())
all_image_feats = torch.stack([image_index[i] for i in all_image_ids]).to(device) if all_image_ids else torch.empty((0, 512), device=device)

def compute_recall_at_k(k_values, queries):
    recalls = {k: 0 for k in k_values}
    total = 0
    for q_text, gt_ids in tqdm(queries, desc='Evaluate'):
        if all_image_feats.size(0) == 0:
            continue
        q_feat = model.extract_and_fuse_text_features([q_text])
        sims = model.sim.calculate_similarity(q_feat, all_image_feats)
        top_scores, top_idx = torch.topk(sims[0], k=max(k_values))
        top_ids = [all_image_ids[i] for i in top_idx.tolist()]
        total += 1
        for k in k_values:
            if any(g in set(top_ids[:k]) for g in gt_ids):
                recalls[k] += 1
    return {k: (recalls[k] / total if total > 0 else 0.0) for k in k_values}, total

rec, total_q = compute_recall_at_k([1,5,10], valid_queries)
mean_recall = (rec.get(1,0)+rec.get(5,0)+rec.get(10,0))/3 if total_q>0 else 0.0
print(f'Recall@1={rec.get(1,0):.4f}, Recall@5={rec.get(5,0):.4f}, Recall@10={rec.get(10,0):.4f}, MeanRecall={mean_recall:.4f} (N={total_q})')


## 6. 保存投影层权重
- 便于后续复现与继续训练


In [None]:
save_dir = 'weights'
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'step_2_3_projection.pth')
torch.save({
    'text_projector': model.fusion.text_projector.state_dict(),
    'image_projector': model.fusion.image_projector.state_dict(),
    'projection_dim': model.fusion.projection_dim
}, save_path)
print(f'Saved projection weights to: {save_path}')


## 7. 注意事项
- 若本地 `timm` 的 `resnet50` 权重不存在，需先手动缓存到 `TORCH_HOME` 目录；否则可能尝试联网下载。
- 若验证集的 `item_ids` 与图片索引不匹配或样本较少，评估指标可能为0或不稳定。
- 训练循环演示为轻量版本（epochs=2）；实际训练可适当增大 epochs 与样本规模。
