# 步骤 2.3-4：两层MLP投影头 + 可学习温度 + AMP/FAISS (蒙)

依据 `2.3的改进方案.md` 的第三优先：
- 将投影头升级为两层 MLP（Linear → GELU → Dropout → Linear）
- 引入可学习温度（logit_scale），替代固定常数 temperature
- 保留 AMP 训练加速与 FAISS 检索回退逻辑，保证可运行且稳定

在 `step_2_3-3_屯-mean_pooling.ipynb` 基础上实现，文本特征使用 attention_mask 的 mean-pooling。

发现loss到0.04就不好下降了。做一些改动：
* 优化器改用adamw。
* projector看上去优化了一些。

In [1]:
import os
import sys
import math
from typing import List, Dict, Tuple
import torch
import timm
from torchvision import transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

# AMP
from torch.cuda.amp import autocast, GradScaler

# 安全导入 FAISS（不可用则回退）
HAS_FAISS = False
try:
    import faiss
    HAS_FAISS = True
except Exception:
    HAS_FAISS = False

# 环境与缓存
# os.environ["WANDB_DISABLED"] = "true"
# os.environ['CURL_CA_BUNDLE'] = ""
# os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cache_dir = "/mnt/d/HuggingFaceModels/"

os.environ['TORCH_HOME'] = cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['all_proxy'] = 'socks5://127.0.0.1:7890'
os.environ["WANDB_DISABLED"] = "true"
os.environ['CURL_CA_BUNDLE'] = ""
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

# 导入数据加载器
sys.path.append(os.path.abspath(os.path.join('.', 'Multimodal_Retrieval', 'plan1_1')))
from data_loader import DataLoader

print(f'Using device: {device}')


  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


## 模型模块
- 文本特征：attention_mask mean-pooling
- 投影头：两层 MLP（含 GELU + Dropout）
- 可学习温度：logit_scale 参数，训练中联合优化

In [2]:
import torch.nn.functional as F
from torch import nn
# 1. 优化版两层MLP投影头（核心组件）
class OptimizedMLPProjector(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.1, use_bn=True):
        super().__init__()
        self.layer1 = nn.Linear(in_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim) if use_bn else nn.Identity()
        self.dropout = nn.Dropout(dropout)
        self.layer2 = nn.Linear(hidden_dim, out_dim)
        # Kaiming初始化
        self._init_weights()
    
    def _init_weights(self):
        nn.init.kaiming_normal_(self.layer1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.zeros_(self.layer1.bias)
        nn.init.kaiming_normal_(self.layer2.weight, mode='fan_in', nonlinearity='linear')
        nn.init.zeros_(self.layer2.bias)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.bn1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.layer2(x)
        return x

In [3]:
class TextFeatureExtractor:
    # 类作用：封装 BERT 文本编码，支持带梯度的特征提取
    # 参数:
    # - model_name: 预训练文本模型名称（默认 'bert-base-chinese'）
    # - device: 设备（'cpu' 或 'cuda'），用于放置模型与张量
    # - cache_dir: 本地模型缓存目录，优先本地加载，失败再远程
    def __init__(self, model_name='bert-base-chinese', device='cpu', cache_dir=None):
        self.device = device
        # 优先本地加载，失败则远程镜像加载
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True)
            self.model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True).to(device)
        except Exception:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False)
            self.model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False).to(device)
        self.model.eval()

    def encode_with_grad(self, texts: List[str]) -> torch.Tensor:
        # 函数作用：对一批文本进行编码并做 attention 掩码的 mean-pooling
        # 参数：texts 文本字符串列表
        # 返回：形状 [B, 768] 的句向量（带注意力掩码的均值池化）
        if not texts:
            return torch.empty((0, 768), dtype=torch.float32, device=self.device)
        inputs = self.tokenizer(
            texts, padding=True, truncation=True, max_length=32, return_tensors='pt'
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        outputs = self.model(**inputs)
        token_embeddings = outputs.last_hidden_state  # [B, L, 768]
        attention_mask = inputs['attention_mask']     # [B, L]
        # mean-pooling with attention mask
        mask = attention_mask.unsqueeze(-1).type_as(token_embeddings)  # [B, L, 1]
        summed = (token_embeddings * mask).sum(dim=1)  # [B, 768]
        lengths = mask.sum(dim=1).clamp(min=1)        # [B, 1]
        mean_pooled = summed / lengths
        return mean_pooled

# from safetensors.torch import load_file
# class ImageFeatureExtractor:
#     '''
#     改进版，使得timm不要每次都去连huggingface。
#     '''
#     def __init__(self, model_name='resnet50', device='cpu', weights_path=None, cache_dir=None):
#         self.device = device
#         self.model = timm.create_model(model_name, pretrained=False, num_classes=0, cache_dir=cache_dir)

#         if weights_path is not None:
#             if weights_path.endswith('.safetensors'):
#                 state_dict = load_file(weights_path)
#             else:
#                 state_dict = torch.load(weights_path, map_location='cpu')
#             self.model.load_state_dict(state_dict, strict=False)

#         self.model = self.model.to(device)
#         self.model.eval()

#         self.transform = transforms.Compose([
#             transforms.Resize((224, 224)),
#             transforms.ToTensor(),
#             transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                  std=[0.229, 0.224, 0.225])
#         ])

#     def encode_with_grad(self, images: List[Image.Image]) -> torch.Tensor:
#         if not images:
#             in_dim = getattr(self.model, 'num_features', 2048)
#             return torch.empty((0, in_dim), dtype=torch.float32, device=self.device)
#         tensors = torch.stack([self.transform(img.convert('RGB')) for img in images]).to(self.device)
#         feats = self.model(tensors)
#         return feats
# image_extractor = ImageFeatureExtractor(
#     device=device, 
#     weights_path="/mnt/d/HuggingFaceModels/models--timm--resnet50.a1_in1k/snapshots/767268603ca0cb0bfe326fa87277f19c419566ef/model.safetensors"
# )

class ImageFeatureExtractor:
    # 类作用：使用 timm 的图像模型提取图像特征，num_classes=0 输出特征向量
    # 参数:
    # - model_name: timm 模型名称（默认 'resnet50'）
    # - device: 设备（'cpu' 或 'cuda'）
    # - cache_dir: timm 模型缓存目录
    def __init__(self, model_name='resnet50', device='cpu', cache_dir=None):
        self.device = device
        self.model = timm.create_model(
            model_name, pretrained=True, num_classes=0,
            cache_dir=cache_dir
        ).to(device)
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def encode_with_grad(self, images: List[Image.Image]) -> torch.Tensor:
        # 函数作用：对图像列表进行预处理并前向提取特征
        # 参数：images PIL.Image 列表
        # 返回：形状 [B, 2048] 的图像特征（以 ResNet50 为例）
        if not images:
            return torch.empty((0, 2048), dtype=torch.float32, device=self.device)
        tensors = torch.stack([self.transform(img.convert('RGB')) for img in images]).to(self.device)
        feats = self.model(tensors)
        return feats

class FeatureFusion:
    # 类作用：将原始文本/图像特征投影到共同的子空间（projection_dim）
    # 参数:
    # - fusion_method: 融合方式，当前支持 'projection'
    # - projection_dim: 目标投影维度
    # - device: 设备
    # - hidden_dim: 两层 MLP 的中间隐藏层维度
    # - dropout: Dropout 概率
    # - text_in_dim/image_in_dim: 输入维度（默认文本768/图像2048）
    def __init__(self, fusion_method='projection', projection_dim=512, device=None, hidden_dim=1024, dropout=0.1, text_in_dim=768, image_in_dim=2048):
        self.fusion_method = fusion_method
        self.projection_dim = projection_dim
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.hidden_dim = hidden_dim
        self.dropout_p = dropout
        if fusion_method == 'projection':
            # self.text_projector = torch.nn.Sequential(
            #     torch.nn.Linear(text_in_dim, hidden_dim),
            #     torch.nn.GELU(),
            #     torch.nn.Dropout(p=dropout),
            #     torch.nn.Linear(hidden_dim, projection_dim)
            # ).to(self.device)
            # self.image_projector = torch.nn.Sequential(
            #     torch.nn.Linear(image_in_dim, hidden_dim),
            #     torch.nn.GELU(),
            #     torch.nn.Dropout(p=dropout),
            #     torch.nn.Linear(hidden_dim, projection_dim)
            # ).to(self.device)
            self.text_projector = OptimizedMLPProjector(text_in_dim, hidden_dim, projection_dim, dropout=dropout).to(self.device)
            self.image_projector = OptimizedMLPProjector(image_in_dim, hidden_dim, projection_dim, dropout=dropout).to(self.device)
            

    def fuse_text_features(self, text_features: torch.Tensor) -> torch.Tensor:
        # 将文本特征通过两层 MLP 投影到 projection_dim，并返回
        return self.text_projector(text_features) if self.fusion_method == 'projection' else text_features

    def fuse_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
        # 将图像特征通过两层 MLP 投影到 projection_dim，并返回
        return self.image_projector(image_features) if self.fusion_method == 'projection' else image_features

class SimilarityCalculator:
    # 类作用：提供特征归一化与相似度计算（默认余弦相似度）
    def __init__(self, similarity_type='cosine'):
        self.similarity_type = similarity_type
    def normalize_features(self, features: torch.Tensor) -> torch.Tensor:
        # 功能：数值稳健化（去 NaN/Inf）后做 L2 归一化
        # 数值稳健化：去除 NaN/Inf 并在归一化中使用 eps
        f = torch.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
        return torch.nn.functional.normalize(f, p=2, dim=1, eps=1e-6)
    def calculate_similarity(self, text_features: torch.Tensor, image_features: torch.Tensor) -> torch.Tensor:
        # 功能：计算文本与图像特征之间的相似度矩阵
        # 当 similarity_type='cosine' 时，先归一化再矩阵乘法；否则直接点积
        if self.similarity_type == 'cosine':
            t_n = self.normalize_features(text_features)
            i_n = self.normalize_features(image_features)
            return torch.mm(t_n, i_n.t())
        return torch.mm(text_features, image_features.t())

class CrossModalRetrievalModel:
    # 类作用：跨模态检索模型，封装提取、投影、相似度与可学习温度
    # 参数:
    # - text_extractor/image_extractor: 文本/图像特征提取器实例
    # - fusion_method/projection_dim/similarity_type: 融合与相似度配置
    # - normalize_features: 是否在融合后做归一化
    # - device: 模型设备
    def __init__(self, text_extractor, image_extractor, fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=None):
        self.text_extractor = text_extractor
        self.image_extractor = image_extractor
        # 动态获取输入维度，避免环境差异导致维度不符
        text_in_dim = getattr(text_extractor.model.config, 'hidden_size', 768)
        image_in_dim = getattr(image_extractor.model, 'num_features', 2048)
        self.fusion = FeatureFusion(fusion_method, projection_dim, device, text_in_dim=text_in_dim, image_in_dim=image_in_dim)
        self.sim = SimilarityCalculator(similarity_type)
        self.normalize_features = normalize_features
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # 可学习温度：logit_scale，取值经 exp 后再取倒数作为 temperature
        init_temp = 0.07
        self.logit_scale = torch.nn.Parameter(torch.tensor(math.log(1.0 / init_temp), dtype=torch.float32))

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        # 将特征数值稳健化并做 L2 归一化（可关闭）
        if not self.normalize_features:
            return x
        x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        return torch.nn.functional.normalize(x, p=2, dim=1, eps=1e-6)

    def current_temperature(self) -> torch.Tensor:
        # 将可学习参数 logit_scale 转换为正温度（1/exp），并裁剪范围
        # 将 logit_scale 转换为正的温度，并进行合理范围裁剪
        temp = 1.0 / torch.exp(self.logit_scale)
        return torch.clamp(temp, min=1e-3, max=10.0)

    def extract_and_fuse_text_features(self, texts: List[str]) -> torch.Tensor:
        # 用文本提取器编码文本，随后通过 MLP 投影并归一化
        # 评估阶段需禁用 BN/Dropout 的训练行为
        self.fusion.text_projector.eval()
        with torch.no_grad():
            t = self.text_extractor.encode_with_grad(texts)
        return self._norm(self.fusion.fuse_text_features(t))

    def extract_and_fuse_image_features(self, images: List[Image.Image]) -> torch.Tensor:
        # 用图像提取器编码图像，随后通过 MLP 投影并归一化
        # 评估阶段需禁用 BN/Dropout 的训练行为
        self.fusion.image_projector.eval()
        with torch.no_grad():
            i = self.image_extractor.encode_with_grad(images)
        return self._norm(self.fusion.fuse_image_features(i))

    def build_image_index(self, images_dict: Dict[str, Image.Image], batch_size: int = 32) -> Dict[str, torch.Tensor]:
        # 功能：批量构建图像检索索引（id -> 特征），按 batch 处理
        # 参数：images_dict 图像字典（id -> PIL.Image）；batch_size 批大小
        # 返回：每个图像 id 对应融合后的特征张量（在 CPU 上）
        feats = {}
        keys = list(images_dict.keys())
        for s in range(0, len(keys), batch_size):
            batch_ids = keys[s:s+batch_size]
            batch_imgs = [images_dict[k] for k in batch_ids if images_dict[k] is not None]
            valid_ids = [k for k in batch_ids if images_dict[k] is not None]
            if not batch_imgs:
                continue
            bf = self.extract_and_fuse_image_features(batch_imgs)
            for j, img_id in enumerate(valid_ids):
                feats[img_id] = bf[j].detach().cpu()
        return feats

def info_nce_loss(text_feats: torch.Tensor, image_feats: torch.Tensor, temperature: torch.Tensor) -> torch.Tensor:
    # 损失作用：对齐文本-图像，通过双向对比学习的 InfoNCE 损失
    # 参数:
    # - text_feats: 文本投影后的特征 [B, D]
    # - image_feats: 图像投影后的特征 [B, D]
    # - temperature: 温度（正值），用于缩放 logits 稳定训练
    # 原理：计算 t·i^T 的相似度，分别以行/列为正样本做交叉熵，最后取平均
    # 数值防护：去除 NaN/Inf 并在半精度下转 float32 防溢出
    t = torch.nan_to_num(text_feats, nan=0.0, posinf=0.0, neginf=0.0)
    i = torch.nan_to_num(image_feats, nan=0.0, posinf=0.0, neginf=0.0)
    logits = torch.mm(t, i.t()).float() / temperature.float()
    # 限幅，避免极端值导致 softmax 溢出或 NaN
    logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)
    logits = torch.clamp(logits, -100.0, 100.0)
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_t = torch.nn.functional.cross_entropy(logits, labels)
    loss_i = torch.nn.functional.cross_entropy(logits.t(), labels)
    return (loss_t + loss_i) * 0.5


## 顶层解冻与优化器分组
- 解冻 BERT 顶层（最后2层 + pooler）与 ResNet layer4
- 优化器包含：两层MLP投影头参数、已解冻的顶层参数、logit_scale 参数

In [4]:
def unfreeze_text_top_layers(text_extractor: TextFeatureExtractor, last_n_layers: int = 2):
    """
    冻结 BERT 主体，仅解冻最后 N 层与 pooler，以便轻量微调上层表示。
    参数：
    - text_extractor: 文本特征提取器实例
    - last_n_layers: 需要解冻的最后层数（默认2）
    原理：设置 requires_grad=True 并将对应模块置为 train()；其余部分仍保持 eval()。
    """
    """
    功能：冻结 BERT 主体参数，仅解冻最后 N 层与 pooler 以进行轻量微调。
    参数：
    - text_extractor: 文本特征提取器实例
    - last_n_layers: 解冻的最后层数（默认2）
    过程：设置 requires_grad 并将解冻层设为 train()，其余仍 eval()。
    """
    for p in text_extractor.model.parameters():
        p.requires_grad = False
    enc = text_extractor.model.encoder
    total_layers = len(enc.layer)
    for i in range(total_layers - last_n_layers, total_layers):
        for p in enc.layer[i].parameters():
            p.requires_grad = True
        enc.layer[i].train()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        for p in text_extractor.model.pooler.parameters():
            p.requires_grad = True
        text_extractor.model.pooler.train()
    text_extractor.model.eval()

def unfreeze_image_top_block(image_extractor: ImageFeatureExtractor, unfreeze_layer4: bool = True):
    """
    冻结图像模型主体，仅解冻顶层 block（如 ResNet 的 layer4）以微调。
    参数：
    - image_extractor: 图像特征提取器实例
    - unfreeze_layer4: 是否解冻 layer4（默认 True）
    原理：设置 requires_grad=True 并将对应模块置为 train()；其余部分仍保持 eval()。
    """
    """
    功能：冻结图像模型参数，仅解冻最顶层 block（如 ResNet layer4）。
    参数：
    - image_extractor: 图像特征提取器实例
    - unfreeze_layer4: 是否解冻 layer4（默认 True）
    过程：设置 requires_grad 并将解冻层设为 train()，其余仍 eval()。
    """
    for p in image_extractor.model.parameters():
        p.requires_grad = False
    if unfreeze_layer4 and hasattr(image_extractor.model, 'layer4'):
        for p in image_extractor.model.layer4.parameters():
            p.requires_grad = True
        image_extractor.model.layer4.train()
    image_extractor.model.eval()

def build_optimizer(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                   lr_proj: float = 1e-3, lr_text_top: float = 5e-5, lr_img_top: float = 1e-4, lr_logit_scale: float = 1e-3, weight_decay: float = 1e-4):
    """
    构建分组 Adam 优化器，为不同模块设置差异化学习率与权重衰减。
    参数：
    - model: 跨模态检索模型（含两层 MLP 与可学习温度参数 logit_scale）
    - text_extractor/image_extractor: 提供已解冻顶层参数的提取器
    - lr_proj/lr_text_top/lr_img_top/lr_logit_scale: 各参数组学习率
    - weight_decay: 投影头参数的权重衰减（其他组设为 0）
    返回：torch.optim.Adam 实例
    原理：参数组包括：两层 MLP、文本顶层（最后两层+pooler）、图像 top block、logit_scale。
    """
    """
    功能：按模块分组构建 Adam 优化器，设置不同学习率与权重衰减。
    参数：
    - model: 跨模态检索模型（包含两层MLP与 logit_scale）
    - text_extractor/image_extractor: 提供顶层可训练参数的提取器
    - lr_proj/lr_text_top/lr_img_top/lr_logit_scale: 各参数组的学习率
    - weight_decay: 投影头参数的权重衰减
    返回：torch.optim.Adam 实例
    原理：将投影头、文本顶层、图像顶层、logit_scale 分别作为参数组，以便差异化优化。
    """
    params = []
    # 两层MLP投影头
    params.append({
        'params': list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
        'lr': lr_proj,
        'weight_decay': weight_decay
    })
    # 文本顶层
    text_top_params = []
    enc = text_extractor.model.encoder
    for mod in enc.layer[-2:]:
        text_top_params += list(mod.parameters())
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_top_params += list(text_extractor.model.pooler.parameters())
    params.append({
        'params': [p for p in text_top_params if p.requires_grad],
        'lr': lr_text_top,
        'weight_decay': 0.0
    })
    # 图像顶层
    img_top_params = []
    if hasattr(image_extractor.model, 'layer4'):
        img_top_params += list(image_extractor.model.layer4.parameters())
    params.append({
        'params': [p for p in img_top_params if p.requires_grad],
        'lr': lr_img_top,
        'weight_decay': 0.0
    })
    # 可学习温度参数
    params.append({
        'params': [model.logit_scale],
        'lr': lr_logit_scale,
        'weight_decay': 0.0
    })
    optimizer = torch.optim.AdamW(params)
    return optimizer


## 数据加载与训练参数

In [5]:
loader = DataLoader()
train_df = loader.load_queries(split='train')
valid_df = loader.load_queries(split='valid')

# 训练与流式参数（默认较小，确保顺利运行；可按需增大）
train_image_batch_size = 500
max_train_batches = 1
epochs_per_batch = 1
train_step_batch_size = 32
valid_imgs_max_samples = 100

# # 训练与流式参数（按需调整）：实际用
# train_image_batch_size = 15000 ## 一个大batch有这么多图片样本。
# max_train_batches = 10 ## 总共加载多少个大batch。
# epochs_per_batch = 10 ## 每个大batch训练几个epoch。
# train_step_batch_size = 32 ## 每个大batch里面训练的时候的小batch_size是多少。
# valid_imgs_max_samples = 30000

use_amp = True


2025-11-08 17:10:07,009 - INFO - 初始化数据加载器，数据目录: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval
2025-11-08 17:10:07,011 - INFO - 加载train查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_queries.jsonl
加载train查询数据: 248786it [00:01, 222502.29it/s]
2025-11-08 17:10:08,198 - INFO - 成功加载train查询数据，共248786条
2025-11-08 17:10:08,209 - INFO - 加载valid查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_queries.jsonl
加载valid查询数据: 5008it [00:00, 302466.30it/s]
2025-11-08 17:10:08,231 - INFO - 成功加载valid查询数据，共5008条


## 初始化模型与优化器，并进行顶层解冻

In [6]:
image_extractor = ImageFeatureExtractor(device=device, cache_dir=cache_dir)

2025-11-08 17:10:08,368 - INFO - Loading pretrained weights from Hugging Face hub (timm/resnet50.a1_in1k)
2025-11-08 17:10:08,949 - INFO - [timm/resnet50.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


In [7]:
text_extractor = TextFeatureExtractor(device=device, cache_dir=cache_dir)

In [8]:
model = CrossModalRetrievalModel(
    text_extractor, image_extractor, 
    fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=device
)

unfreeze_text_top_layers(text_extractor, last_n_layers=2)
unfreeze_image_top_block(image_extractor, unfreeze_layer4=True)

optim = build_optimizer(model, text_extractor, image_extractor,
                        lr_proj=1e-3, lr_text_top=5e-5, lr_img_top=1e-4, lr_logit_scale=1e-3, weight_decay=1e-4)
scaler = GradScaler(enabled=(device.type == 'cuda' and use_amp))
print('Optim groups:', len(optim.param_groups))


Optim groups: 4


  scaler = GradScaler(enabled=(device.type == 'cuda' and use_amp))


## 训练循环（流式）
- 构建 (query, image) 配对
- AMP 加速与梯度裁剪
- 使用可学习温度 model.current_temperature()

In [9]:
def build_batch_pairs(train_df, img_dict: Dict[str, Image.Image]) -> List[Tuple[str, Image.Image, str]]:
    """
    功能：从训练 DataFrame 按 query_text 和 item_ids 选择第一张可用图片，形成 (文本, 图像, 图像id) 三元组列表。
    参数：
    - train_df: 包含列 'query_text' 与 'item_ids' 的 DataFrame，其中 'item_ids' 为候选图片 id 列表
    - img_dict: 图片字典（id -> PIL.Image），不可用或缺失的 id 映射为 None
    返回：List[Tuple[str, Image.Image, str]]
    原理：遍历每行，从 item_ids 中挑第一张在 img_dict 中存在的图片，作为该 query 的正样本。
    """
    pairs = []
    if 'item_ids' in train_df.columns:
        for _, row in train_df.iterrows():
            q = row.get('query_text', None)
            ids = row.get('item_ids', [])
            if not q or not isinstance(ids, list) or not ids:
                continue
            chosen_img = None
            chosen_id = None
            for iid in ids:
                sid = str(iid)
                if sid in img_dict and img_dict[sid] is not None:
                    chosen_img = img_dict[sid]
                    chosen_id = sid
                    break
            if chosen_img is not None:
                pairs.append((q, chosen_img, chosen_id))
    return pairs

def train_one_batch(pairs: List[Tuple[str, Image.Image, str]], epochs: int, step_bs: int):
    """
    功能：对一个 (文本, 图像) 对列表进行多轮小批训练。
    参数：
    - pairs: 训练三元组列表 (text, image, image_id)
    - epochs: 在该批数据上迭代的轮数
    - step_bs: 每个优化步的小批大小
    过程：
    - 训练模式开启指定的投影头与顶层模块
    - 每个小批：编码文本与图像 -> 通过两层 MLP 投影 -> 计算 InfoNCE 损失
    - AMP 分支使用 GradScaler 缩放与步进；非 AMP 分支直接 backward + step
    - 对投影头参数进行梯度裁剪（max_norm=5.0）以提升稳定性
    """
    model.fusion.text_projector.train()
    model.fusion.image_projector.train()
    text_extractor.model.encoder.layer[-1].train()
    text_extractor.model.encoder.layer[-2].train()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_extractor.model.pooler.train()
    if hasattr(image_extractor.model, 'layer4'):
        image_extractor.model.layer4.train()

    for e in range(epochs):
        running_loss = 0.0
        steps = 0
        for s in range(0, len(pairs), step_bs):
            batch = pairs[s:s+step_bs]
            if not batch:
                continue
            texts = [t for (t, _, _) in batch]
            imgs = [im for (_, im, _) in batch]

            optim.zero_grad()
            temp = model.current_temperature()
            if use_amp and device.type == 'cuda':
                with autocast(enabled=True):
                    t_feats = text_extractor.encode_with_grad(texts)
                    i_feats = image_extractor.encode_with_grad(imgs)
                    t_proj = model._norm(model.fusion.fuse_text_features(t_feats))
                    i_proj = model._norm(model.fusion.fuse_image_features(i_feats))
                    loss = info_nce_loss(t_proj, i_proj, temperature=temp)
                scaler.scale(loss).backward()
                scaler.unscale_(optim)
                torch.nn.utils.clip_grad_norm_(
                    list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
                    max_norm=5.0
                )
                scaler.step(optim)
                scaler.update()
            else:
                t_feats = text_extractor.encode_with_grad(texts)
                i_feats = image_extractor.encode_with_grad(imgs)
                t_proj = model._norm(model.fusion.fuse_text_features(t_feats))
                i_proj = model._norm(model.fusion.fuse_image_features(i_feats))
                loss = info_nce_loss(t_proj, i_proj, temperature=temp)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
                    max_norm=5.0
                )
                optim.step()
            running_loss += loss.item()
            steps += 1
        print(f"Epoch {e+1}/{epochs}: avg loss={running_loss/max(steps,1):.4f}")

# 流式加载与训练
batch_idx = 0
for image_batch in loader.load_images_batch(split='train', batch_size=train_image_batch_size, max_batches=max_train_batches):
    batch_idx += 1
    img_map = {item['img_id']: item['image'] for item in image_batch}
    pairs = build_batch_pairs(train_df, img_map)
    print(f"Batch {batch_idx}: images={len(img_map)}, usable_pairs={len(pairs)}")
    if not pairs:
        del img_map
        if device.type == 'cuda':
            torch.cuda.empty_cache()
        continue
    train_one_batch(pairs, epochs=epochs_per_batch, step_bs=train_step_batch_size)
    del img_map
    if device.type == 'cuda':
        torch.cuda.empty_cache()


2025-11-08 17:10:11,546 - INFO - 批量加载train图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_imgs.tsv
实际加载train图片数据:   0%|                                                     | 260/129380 [00:00<00:50, 2579.77it/s]

Batch 1: images=500, usable_pairs=997


  with autocast(enabled=True):
实际加载train图片数据:   0%|▏                                                      | 499/129380 [00:06<29:13, 73.51it/s]

Epoch 1/1: avg loss=2.0210





## 保存/加载检查点
保存两层MLP投影头、解冻顶层与优化器，并记录 logit_scale（可学习温度）。

In [10]:
save_dir = '/mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights'
save_path = os.path.join(save_dir, 'step_2_3_4_mlp_temp_checkpoint.pth')

def save_unfreeze_checkpoint(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                             optimizer: torch.optim.Optimizer, save_path: str, last_n_layers: int):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    ckpt = {
        'projection_dim': model.fusion.projection_dim,
        'last_n_layers': last_n_layers,
        'fusion': {
            'text_projector': model.fusion.text_projector.state_dict(),
            'image_projector': model.fusion.image_projector.state_dict(),
        },
        'logit_scale': model.logit_scale.detach().cpu(),
        'text_unfrozen': {},
        'image_unfrozen': {},
        'optimizer': optimizer.state_dict(),
    }
    enc = text_extractor.model.encoder
    total_layers = len(enc.layer)
    start_idx = max(0, total_layers - last_n_layers)
    for i in range(start_idx, total_layers):
        ckpt['text_unfrozen'][f'encoder_layer_{i}'] = enc.layer[i].state_dict()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        ckpt['text_unfrozen']['pooler'] = text_extractor.model.pooler.state_dict()
    if hasattr(image_extractor.model, 'layer4'):
        ckpt['image_unfrozen']['layer4'] = image_extractor.model.layer4.state_dict()
    torch.save(ckpt, save_path)
    print(f"Checkpoint saved to: {save_path}")

def load_unfreeze_checkpoint(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                             optimizer: torch.optim.Optimizer, load_path: str):
    ckpt = torch.load(load_path, map_location='cpu')
    model.fusion.text_projector.load_state_dict(ckpt['fusion']['text_projector'])
    model.fusion.image_projector.load_state_dict(ckpt['fusion']['image_projector'])
    if 'logit_scale' in ckpt:
        model.logit_scale.data = ckpt['logit_scale'].to(model.logit_scale.device)
    ln = ckpt.get('last_n_layers', 2)
    unfreeze_text_top_layers(text_extractor, last_n_layers=ln)
    unfreeze_image_top_block(image_extractor, unfreeze_layer4=True)
    enc = text_extractor.model.encoder
    for k, v in ckpt['text_unfrozen'].items():
        if k.startswith('encoder_layer_'):
            idx = int(k.split('_')[-1])
            if 0 <= idx < len(enc.layer):
                enc.layer[idx].load_state_dict(v)
    if 'pooler' in ckpt['text_unfrozen'] and hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_extractor.model.pooler.load_state_dict(ckpt['text_unfrozen']['pooler'])
    if 'layer4' in ckpt['image_unfrozen'] and hasattr(image_extractor.model, 'layer4'):
        image_extractor.model.layer4.load_state_dict(ckpt['image_unfrozen']['layer4'])
    if optimizer is not None and 'optimizer' in ckpt:
        optimizer.load_state_dict(ckpt['optimizer'])
    print(f"Checkpoint loaded from: {load_path}")

# 保存一次并测试加载
save_unfreeze_checkpoint(model, text_extractor, image_extractor, optim, save_path, 2)
# load_unfreeze_checkpoint(model, text_extractor, image_extractor, optim, save_path)


Checkpoint saved to: /mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights/step_2_3_4_mlp_temp_checkpoint.pth


## 验证评估：Recall@1/5/10 与 MeanRecall
优先使用 FAISS；不可用则回退到 Torch 相似度计算。

In [11]:
valid_imgs = loader.create_img_id_to_image_dict(
    split='valid', 
    max_samples=valid_imgs_max_samples, 
    batch_size=3000, max_batches=10
)

valid_queries = []
if 'item_ids' in valid_df.columns:
    for _, row in valid_df.iterrows():
        q = row.get('query_text', None)
        ids = [str(i) for i in row.get('item_ids', [])] if isinstance(row.get('item_ids', []), list) else []
        if q and ids:
            valid_queries.append((q, ids))
print(f'Usable valid queries: {len(valid_queries)}')

image_index = model.build_image_index(valid_imgs, batch_size=32)
all_image_ids = list(image_index.keys())
all_image_feats = torch.stack([image_index[i] for i in all_image_ids]) if all_image_ids else torch.empty((0, 512))
faiss_index = None
if HAS_FAISS and all_image_feats.size(0) > 0:
    d = all_image_feats.size(1)
    faiss_index = faiss.IndexFlatIP(d)
    feats_np = all_image_feats.detach().cpu().numpy().astype('float32')
    faiss_index.add(feats_np)

all_image_feats = all_image_feats.to(device)

def compute_recall_at_k(k_values, queries):
    recalls = {k: 0 for k in k_values}
    total = 0
    for q_text, gt_ids in tqdm(queries, desc='Evaluate'):
        if all_image_feats.size(0) == 0:
            continue
        q_feat = model.extract_and_fuse_text_features([q_text])
        if faiss_index is not None:
            q_np = q_feat.detach().cpu().numpy().astype('float32')
            _, I = faiss_index.search(q_np, max(k_values))
            top_idx = I[0].tolist()
            top_ids = [all_image_ids[i] for i in top_idx]
        else:
            sims = model.sim.calculate_similarity(q_feat, all_image_feats)
            _, top_idx = torch.topk(sims[0], k=max(k_values))
            top_ids = [all_image_ids[i] for i in top_idx.tolist()]
        total += 1
        for k in k_values:
            if any(g in set(top_ids[:k]) for g in gt_ids):
                recalls[k] += 1
    return {k: (recalls[k] / total if total > 0 else 0.0) for k in k_values}, total

rec, total_q = compute_recall_at_k([1,5,10], valid_queries)
mean_recall = (rec.get(1,0)+rec.get(5,0)+rec.get(10,0))/3 if total_q>0 else 0.0
print(f'Recall@1={rec.get(1,0):.4f}, Recall@5={rec.get(5,0):.4f}, Recall@10={rec.get(10,0):.4f}, MeanRecall={mean_recall:.4f} (N={total_q})')


2025-11-08 17:10:38,681 - INFO - 批量加载valid图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_imgs.tsv
实际加载valid图片数据:  99%|████████████████████████████████████████████████████████▍| 99/100 [00:00<00:00, 2207.32it/s]
2025-11-08 17:10:43,482 - INFO - 成功创建valid图片映射字典，共100张图片


Usable valid queries: 5008


Evaluate: 100%|████████████████████████████████████████████████████████████████████| 5008/5008 [00:36<00:00, 137.87it/s]

Recall@1=0.0020, Recall@5=0.0074, Recall@10=0.0104, MeanRecall=0.0066 (N=5008)



