# 步骤 2.3-4：两层MLP投影头 + 可学习温度 + AMP/FAISS (蒙)

依据 `2.3的改进方案.md` 的第三优先：
- 将投影头升级为两层 MLP（Linear → GELU → Dropout → Linear）
- 引入可学习温度（logit_scale），替代固定常数 temperature
- 保留 AMP 训练加速与 FAISS 检索回退逻辑，保证可运行且稳定

在 `step_2_3-3_屯-mean_pooling.ipynb` 基础上实现，文本特征使用 attention_mask 的 mean-pooling。

In [1]:
import os
import sys
import math
import json
import numpy as np
from typing import List, Dict, Tuple
import torch
import timm
from torchvision import transforms
from PIL import Image
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

# AMP
from torch.cuda.amp import autocast, GradScaler

# 安全导入 FAISS（不可用则回退，并打印清晰日志）
HAS_FAISS = False
faiss = None
try:
    import importlib.util
    spec = importlib.util.find_spec('faiss')
    if spec is not None:
        faiss = __import__('faiss')
        HAS_FAISS = True
        print('FAISS available: using accelerated index.')
    else:
        print('FAISS not found: falling back to torch similarity.')
except Exception as e:
    print(f'FAISS import failed: {e.__class__.__name__}. Fallback to torch.')
    HAS_FAISS = False

# 环境与缓存
# os.environ["WANDB_DISABLED"] = "true"
# os.environ['CURL_CA_BUNDLE'] = ""
# os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cache_dir = "/mnt/d/HuggingFaceModels/"

os.environ['TORCH_HOME'] = cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['all_proxy'] = 'socks5://127.0.0.1:7890'
os.environ["WANDB_DISABLED"] = "true"
os.environ['CURL_CA_BUNDLE'] = ""
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

# 导入数据加载器
sys.path.append(os.path.abspath(os.path.join('.', 'Multimodal_Retrieval', 'plan1_1')))
from data_loader import DataLoader

print(f'Using device: {device}')

  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


FAISS available: using accelerated index.
Using device: cuda


In [2]:
## !pip install numpy==1.25.2
## 只有用1.25.2的numpy才能用得了faiss。
## 如果出了什么问题，我们就换回2.3.4版本的numpy。

## 模型模块
- 文本特征：attention_mask mean-pooling
- 投影头：两层 MLP（含 GELU + Dropout）
- 可学习温度：logit_scale 参数，训练中联合优化

In [3]:
class TextFeatureExtractor:
    def __init__(self, model_name='bert-base-chinese', device='cpu', cache_dir=None):
        self.device = device
        # 优先本地加载，失败则远程镜像加载
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True)
            self.model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True).to(device)
        except Exception:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False)
            self.model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False).to(device)
        self.model.eval()

    def encode_with_grad(self, texts: List[str]) -> torch.Tensor:
        if not texts:
            return torch.empty((0, 768), dtype=torch.float32, device=self.device)
        inputs = self.tokenizer(
            texts, padding=True, truncation=True, max_length=32, return_tensors='pt'
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        outputs = self.model(**inputs)
        token_embeddings = outputs.last_hidden_state  # [B, L, 768]
        attention_mask = inputs['attention_mask']     # [B, L]
        # mean-pooling with attention mask
        mask = attention_mask.unsqueeze(-1).type_as(token_embeddings)  # [B, L, 1]
        summed = (token_embeddings * mask).sum(dim=1)  # [B, 768]
        lengths = mask.sum(dim=1).clamp(min=1)        # [B, 1]
        mean_pooled = summed / lengths
        return mean_pooled

# from safetensors.torch import load_file
# class ImageFeatureExtractor:
#     '''
#     改进版，使得timm不要每次都去连huggingface。
#     '''
#     def __init__(self, model_name='resnet50', device='cpu', weights_path=None, cache_dir=None):
#         self.device = device
#         self.model = timm.create_model(model_name, pretrained=False, num_classes=0, cache_dir=cache_dir)

#         if weights_path is not None:
#             if weights_path.endswith('.safetensors'):
#                 state_dict = load_file(weights_path)
#             else:
#                 state_dict = torch.load(weights_path, map_location='cpu')
#             self.model.load_state_dict(state_dict, strict=False)

#         self.model = self.model.to(device)
#         self.model.eval()

#         self.transform = transforms.Compose([
#             transforms.Resize((224, 224)),
#             transforms.ToTensor(),
#             transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                  std=[0.229, 0.224, 0.225])
#         ])

#     def encode_with_grad(self, images: List[Image.Image]) -> torch.Tensor:
#         if not images:
#             in_dim = getattr(self.model, 'num_features', 2048)
#             return torch.empty((0, in_dim), dtype=torch.float32, device=self.device)
#         tensors = torch.stack([self.transform(img.convert('RGB')) for img in images]).to(self.device)
#         feats = self.model(tensors)
#         return feats
# image_extractor = ImageFeatureExtractor(
#     device=device, 
#     weights_path="/mnt/d/HuggingFaceModels/models--timm--resnet50.a1_in1k/snapshots/767268603ca0cb0bfe326fa87277f19c419566ef/model.safetensors"
# )

class ImageFeatureExtractor:
    def __init__(self, model_name='resnet50', device='cpu', cache_dir=None):
        self.device = device
        self.model = timm.create_model(
            model_name, pretrained=True, num_classes=0,
            cache_dir=cache_dir
        ).to(device)
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def encode_with_grad(self, images: List[Image.Image]) -> torch.Tensor:
        if not images:
            return torch.empty((0, 2048), dtype=torch.float32, device=self.device)
        tensors = torch.stack([self.transform(img.convert('RGB')) for img in images]).to(self.device)
        feats = self.model(tensors)
        return feats

class FeatureFusion:
    def __init__(self, fusion_method='projection', projection_dim=512, device=None, hidden_dim=1024, dropout=0.1, text_in_dim=768, image_in_dim=2048):
        self.fusion_method = fusion_method
        self.projection_dim = projection_dim
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.hidden_dim = hidden_dim
        self.dropout_p = dropout
        if fusion_method == 'projection':
            self.text_projector = torch.nn.Sequential(
                torch.nn.Linear(text_in_dim, hidden_dim),
                torch.nn.GELU(),
                torch.nn.Dropout(p=dropout),
                torch.nn.Linear(hidden_dim, projection_dim)
            ).to(self.device)
            self.image_projector = torch.nn.Sequential(
                torch.nn.Linear(image_in_dim, hidden_dim),
                torch.nn.GELU(),
                torch.nn.Dropout(p=dropout),
                torch.nn.Linear(hidden_dim, projection_dim)
            ).to(self.device)

    def fuse_text_features(self, text_features: torch.Tensor) -> torch.Tensor:
        return self.text_projector(text_features) if self.fusion_method == 'projection' else text_features

    def fuse_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
        return self.image_projector(image_features) if self.fusion_method == 'projection' else image_features

class SimilarityCalculator:
    def __init__(self, similarity_type='cosine'):
        self.similarity_type = similarity_type
    def normalize_features(self, features: torch.Tensor) -> torch.Tensor:
        # 数值稳健化：去除 NaN/Inf 并在归一化中使用 eps
        f = torch.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
        return torch.nn.functional.normalize(f, p=2, dim=1, eps=1e-6)
    def calculate_similarity(self, text_features: torch.Tensor, image_features: torch.Tensor) -> torch.Tensor:
        if self.similarity_type == 'cosine':
            t_n = self.normalize_features(text_features)
            i_n = self.normalize_features(image_features)
            return torch.mm(t_n, i_n.t())
        return torch.mm(text_features, image_features.t())

class CrossModalRetrievalModel:
    def __init__(self, text_extractor, image_extractor, fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=None):
        self.text_extractor = text_extractor
        self.image_extractor = image_extractor
        # 动态获取输入维度，避免环境差异导致维度不符
        text_in_dim = getattr(text_extractor.model.config, 'hidden_size', 768)
        image_in_dim = getattr(image_extractor.model, 'num_features', 2048)
        self.fusion = FeatureFusion(fusion_method, projection_dim, device, text_in_dim=text_in_dim, image_in_dim=image_in_dim)
        self.sim = SimilarityCalculator(similarity_type)
        self.normalize_features = normalize_features
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # 可学习温度：logit_scale，取值经 exp 后再取倒数作为 temperature
        init_temp = 0.07
        self.logit_scale = torch.nn.Parameter(torch.tensor(math.log(1.0 / init_temp), dtype=torch.float32))

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        if not self.normalize_features:
            return x
        x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        return torch.nn.functional.normalize(x, p=2, dim=1, eps=1e-6)

    def current_temperature(self) -> torch.Tensor:
        # 将 logit_scale 转换为正的温度，并进行合理范围裁剪
        temp = 1.0 / torch.exp(self.logit_scale)
        return torch.clamp(temp, min=1e-3, max=10.0)

    def extract_and_fuse_text_features(self, texts: List[str]) -> torch.Tensor:
        with torch.no_grad():
            t = self.text_extractor.encode_with_grad(texts)
        return self._norm(self.fusion.fuse_text_features(t))

    def extract_and_fuse_image_features(self, images: List[Image.Image]) -> torch.Tensor:
        with torch.no_grad():
            i = self.image_extractor.encode_with_grad(images)
        return self._norm(self.fusion.fuse_image_features(i))

    def build_image_index(self, images_dict: Dict[str, Image.Image], batch_size: int = 32) -> Dict[str, torch.Tensor]:
        feats = {}
        keys = list(images_dict.keys())
        for s in range(0, len(keys), batch_size):
            batch_ids = keys[s:s+batch_size]
            batch_imgs = [images_dict[k] for k in batch_ids if images_dict[k] is not None]
            valid_ids = [k for k in batch_ids if images_dict[k] is not None]
            if not batch_imgs:
                continue
            bf = self.extract_and_fuse_image_features(batch_imgs)
            for j, img_id in enumerate(valid_ids):
                feats[img_id] = bf[j].detach().cpu()
        return feats

def info_nce_loss(text_feats: torch.Tensor, image_feats: torch.Tensor, temperature: torch.Tensor) -> torch.Tensor:
    # 数值防护：去除 NaN/Inf 并在半精度下转 float32 防溢出
    t = torch.nan_to_num(text_feats, nan=0.0, posinf=0.0, neginf=0.0)
    i = torch.nan_to_num(image_feats, nan=0.0, posinf=0.0, neginf=0.0)
    logits = torch.mm(t, i.t()).float() / temperature.float()
    # 限幅，避免极端值导致 softmax 溢出或 NaN
    logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)
    logits = torch.clamp(logits, -100.0, 100.0)
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_t = torch.nn.functional.cross_entropy(logits, labels)
    loss_i = torch.nn.functional.cross_entropy(logits.t(), labels)
    return (loss_t + loss_i) * 0.5

class MemoryBank:
    def __init__(self, dim: int, size: int = 65536, device: str = 'cpu'):
        self.size = size
        self.ptr = 0
        self.device = device
        self.bank = torch.zeros((size, dim), dtype=torch.float32, device=device)
    def add(self, feats: torch.Tensor):
        if feats is None or feats.numel() == 0:
            return
        b = feats.shape[0]
        feats = torch.nn.functional.normalize(torch.nan_to_num(feats, nan=0.0, posinf=0.0, neginf=0.0), p=2, dim=1, eps=1e-6)
        if b >= self.size:
            self.bank = feats[-self.size:].detach().to(self.device)
            self.ptr = 0
            return
        end = (self.ptr + b) % self.size
        if self.ptr + b <= self.size:
            self.bank[self.ptr:self.ptr+b] = feats.detach()
        else:
            part1 = self.size - self.ptr
            self.bank[self.ptr:] = feats[:part1].detach()
            self.bank[:end] = feats[part1:].detach()
        self.ptr = end
    def get(self, k: int) -> torch.Tensor:
        if self.bank.numel() == 0:
            return torch.empty((0, 0), device=self.device)
        valid = min(self.size, k)
        return self.bank[:valid]

def info_nce_with_memory(text_feats: torch.Tensor, image_feats: torch.Tensor, temperature: torch.Tensor, img_memory: MemoryBank = None, txt_memory: MemoryBank = None, neg_k: int = 2048) -> torch.Tensor:
    t = torch.nn.functional.normalize(torch.nan_to_num(text_feats, nan=0.0, posinf=0.0, neginf=0.0), p=2, dim=1, eps=1e-6)
    i = torch.nn.functional.normalize(torch.nan_to_num(image_feats, nan=0.0, posinf=0.0, neginf=0.0), p=2, dim=1, eps=1e-6)
    logits = torch.mm(t, i.t()).float() / temperature.float()
    logits = torch.clamp(torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4), -100.0, 100.0)
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_main = 0.5 * (torch.nn.functional.cross_entropy(logits, labels) + torch.nn.functional.cross_entropy(logits.t(), labels))
    loss_mem = 0.0
    if img_memory is not None:
        i_mem = img_memory.get(neg_k)
        if i_mem.numel() > 0:
            i_mem = torch.nn.functional.normalize(i_mem, p=2, dim=1, eps=1e-6)
            logits_t_mem = torch.mm(t, i_mem.t()).float() / temperature.float()
            logits_t_mem = torch.clamp(torch.nan_to_num(logits_t_mem, nan=0.0, posinf=1e4, neginf=-1e4), -100.0, 100.0)
            labels_t = torch.arange(logits_t_mem.size(0), device=logits_t_mem.device)
            loss_mem += torch.nn.functional.cross_entropy(logits_t_mem, labels_t)
    if txt_memory is not None:
        t_mem = txt_memory.get(neg_k)
        if t_mem.numel() > 0:
            t_mem = torch.nn.functional.normalize(t_mem, p=2, dim=1, eps=1e-6)
            logits_i_mem = torch.mm(i, t_mem.t()).float() / temperature.float()
            logits_i_mem = torch.clamp(torch.nan_to_num(logits_i_mem, nan=0.0, posinf=1e4, neginf=-1e4), -100.0, 100.0)
            labels_i = torch.arange(logits_i_mem.size(0), device=logits_i_mem.device)
            loss_mem += torch.nn.functional.cross_entropy(logits_i_mem, labels_i)
    return loss_main + 0.5 * loss_mem


## 顶层解冻与优化器分组
- 解冻 BERT 顶层（最后2层 + pooler）与 ResNet layer4
- 优化器包含：两层MLP投影头参数、已解冻的顶层参数、logit_scale 参数

In [4]:
def unfreeze_text_top_layers(text_extractor: TextFeatureExtractor, last_n_layers: int = 2):
    for p in text_extractor.model.parameters():
        p.requires_grad = False
    enc = text_extractor.model.encoder
    total_layers = len(enc.layer)
    for i in range(total_layers - last_n_layers, total_layers):
        for p in enc.layer[i].parameters():
            p.requires_grad = True
        enc.layer[i].train()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        for p in text_extractor.model.pooler.parameters():
            p.requires_grad = True
        text_extractor.model.pooler.train()
    text_extractor.model.eval()

def unfreeze_image_top_block(image_extractor: ImageFeatureExtractor, unfreeze_layer4: bool = True):
    for p in image_extractor.model.parameters():
        p.requires_grad = False
    if unfreeze_layer4 and hasattr(image_extractor.model, 'layer4'):
        for p in image_extractor.model.layer4.parameters():
            p.requires_grad = True
        image_extractor.model.layer4.train()
    image_extractor.model.eval()

def build_optimizer(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                   lr_proj: float = 1e-3, lr_text_top: float = 5e-5, lr_img_top: float = 1e-4, lr_logit_scale: float = 1e-3, weight_decay: float = 1e-4):
    params = []
    # 两层MLP投影头
    params.append({
        'params': list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
        'lr': lr_proj,
        'weight_decay': weight_decay
    })
    # 文本顶层
    text_top_params = []
    enc = text_extractor.model.encoder
    for mod in enc.layer[-2:]:
        text_top_params += list(mod.parameters())
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_top_params += list(text_extractor.model.pooler.parameters())
    params.append({
        'params': [p for p in text_top_params if p.requires_grad],
        'lr': lr_text_top,
        'weight_decay': 0.0
    })
    # 图像顶层
    img_top_params = []
    if hasattr(image_extractor.model, 'layer4'):
        img_top_params += list(image_extractor.model.layer4.parameters())
    params.append({
        'params': [p for p in img_top_params if p.requires_grad],
        'lr': lr_img_top,
        'weight_decay': 0.0
    })
    # 可学习温度参数
    params.append({
        'params': [model.logit_scale],
        'lr': lr_logit_scale,
        'weight_decay': 0.0
    })
    optimizer = torch.optim.Adam(params)
    return optimizer


## 数据加载与训练参数

In [5]:
loader = DataLoader()
train_df = loader.load_queries(split='train')
valid_df = loader.load_queries(split='valid')

# 训练与流式参数（默认较小，确保顺利运行；可按需增大）
train_image_batch_size = 500
max_train_batches = 1
epochs_per_batch = 1
train_step_batch_size = 32
valid_imgs_max_samples = 100

# # 训练与流式参数（按需调整）：实际用
# train_image_batch_size = 15000 ## 一个大batch有这么多图片样本。
# max_train_batches = 10 ## 总共加载多少个大batch。
# epochs_per_batch = 10 ## 每个大batch训练几个epoch。
# train_step_batch_size = 32 ## 每个大batch里面训练的时候的小batch_size是多少。
# valid_imgs_max_samples = 30000

use_amp = True


2025-11-08 14:34:51,053 - INFO - 初始化数据加载器，数据目录: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval
2025-11-08 14:34:51,054 - INFO - 加载train查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_queries.jsonl
加载train查询数据: 248786it [00:01, 244943.29it/s]
2025-11-08 14:34:52,142 - INFO - 成功加载train查询数据，共248786条
2025-11-08 14:34:52,155 - INFO - 加载valid查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_queries.jsonl
加载valid查询数据: 5008it [00:00, 261637.89it/s]
2025-11-08 14:34:52,178 - INFO - 成功加载valid查询数据，共5008条


## 初始化模型与优化器，并进行顶层解冻

In [6]:
image_extractor = ImageFeatureExtractor(device=device, cache_dir=cache_dir)

2025-11-08 14:34:52,345 - INFO - Loading pretrained weights from Hugging Face hub (timm/resnet50.a1_in1k)
2025-11-08 14:34:53,172 - INFO - [timm/resnet50.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


In [7]:
text_extractor = TextFeatureExtractor(device=device, cache_dir=cache_dir)

In [8]:
model = CrossModalRetrievalModel(
    text_extractor, image_extractor, 
    fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=device
)

unfreeze_text_top_layers(text_extractor, last_n_layers=2)
unfreeze_image_top_block(image_extractor, unfreeze_layer4=True)

optim = build_optimizer(model, text_extractor, image_extractor,
                        lr_proj=1e-3, lr_text_top=5e-5, lr_img_top=1e-4, lr_logit_scale=1e-3, weight_decay=1e-4)
scaler = GradScaler(enabled=(device.type == 'cuda' and use_amp))
print('Optim groups:', len(optim.param_groups))


Optim groups: 4


  scaler = GradScaler(enabled=(device.type == 'cuda' and use_amp))


## 训练循环（流式）
- 构建 (query, image) 配对
- AMP 加速与梯度裁剪
- 使用可学习温度 model.current_temperature()

In [9]:
# 记忆队列与梯度累积参数（仅当前笔记本修改，不影响其他文件）
accum_steps = 1
use_memory_loss = False
negatives_k = 4096
memory_size = 65536
proj_dim = model.fusion.projection_dim
img_memory = MemoryBank(dim=proj_dim, size=memory_size, device=device) if use_memory_loss else None
txt_memory = MemoryBank(dim=proj_dim, size=memory_size, device=device) if use_memory_loss else None
print(f'Accumulation steps: {accum_steps}, use_memory_loss: {use_memory_loss}, negatives_k: {negatives_k}, memory_size: {memory_size}, proj_dim: {proj_dim}')

Accumulation steps: 1, use_memory_loss: False, negatives_k: 4096, memory_size: 65536, proj_dim: 512


In [10]:
def build_batch_pairs(train_df, img_dict: Dict[str, Image.Image]) -> List[Tuple[str, Image.Image, str]]:
    pairs = []
    if 'item_ids' in train_df.columns:
        for _, row in train_df.iterrows():
            q = row.get('query_text', None)
            ids = row.get('item_ids', [])
            if not q or not isinstance(ids, list) or not ids:
                continue
            chosen_img = None
            chosen_id = None
            for iid in ids:
                sid = str(iid)
                if sid in img_dict and img_dict[sid] is not None:
                    chosen_img = img_dict[sid]
                    chosen_id = sid
                    break
            if chosen_img is not None:
                pairs.append((q, chosen_img, chosen_id))
    return pairs

def train_one_batch(pairs: List[Tuple[str, Image.Image, str]], epochs: int, step_bs: int):
    model.fusion.text_projector.train()
    model.fusion.image_projector.train()
    text_extractor.model.encoder.layer[-1].train()
    text_extractor.model.encoder.layer[-2].train()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_extractor.model.pooler.train()
    if hasattr(image_extractor.model, 'layer4'):
        image_extractor.model.layer4.train()

    # 本地容错参数，防止未执行初始化单元时出错
    accum_steps_local = globals().get('accum_steps', 1)
    negatives_k_local = globals().get('negatives_k', 0)
    img_memory_local = globals().get('img_memory', None)
    txt_memory_local = globals().get('txt_memory', None)

    for e in range(epochs):
        running_loss = 0.0
        steps = 0
        for s in range(0, len(pairs), step_bs):
            batch = pairs[s:s+step_bs]
            if not batch:
                continue
            texts = [t for (t, _, _) in batch]
            imgs = [im for (_, im, _) in batch]

            # 梯度累积：仅在每 accum_steps_local 次才清零
            if steps % accum_steps_local == 0:
                optim.zero_grad()
            temp = model.current_temperature()
            if use_amp and device.type == 'cuda':
                with autocast(enabled=True):
                    t_feats = text_extractor.encode_with_grad(texts)
                    i_feats = image_extractor.encode_with_grad(imgs)
                    t_proj = model._norm(model.fusion.fuse_text_features(t_feats))
                    i_proj = model._norm(model.fusion.fuse_image_features(i_feats))
                    # 先计算未缩放的原始损失，再用于统计
                    if globals().get('use_memory_loss', False) and 'info_nce_with_memory' in globals():
                        loss_raw = info_nce_with_memory(t_proj, i_proj, temperature=temp, img_memory=img_memory_local, txt_memory=txt_memory_local, neg_k=negatives_k_local)
                    else:
                        loss_raw = info_nce_loss(t_proj, i_proj, temperature=temp)
                    # 累积梯度时按步数缩放反传损失，避免有效学习率膨胀
                    loss = loss_raw / max(1, accum_steps_local)
                scaler.scale(loss).backward()
                # 梯度累积：仅在每 accum_steps_local 次执行优化器步进
                if (steps + 1) % accum_steps_local == 0:
                    scaler.unscale_(optim)
                    torch.nn.utils.clip_grad_norm_(
                        list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
                        max_norm=5.0
                    )
                    scaler.step(optim)
                    scaler.update()
            else:
                t_feats = text_extractor.encode_with_grad(texts)
                i_feats = image_extractor.encode_with_grad(imgs)
                t_proj = model._norm(model.fusion.fuse_text_features(t_feats))
                i_proj = model._norm(model.fusion.fuse_image_features(i_feats))
                # 先计算未缩放的原始损失，再用于统计
                if globals().get('use_memory_loss', False) and 'info_nce_with_memory' in globals():
                    loss_raw = info_nce_with_memory(t_proj, i_proj, temperature=temp, img_memory=img_memory_local, txt_memory=txt_memory_local, neg_k=negatives_k_local)
                else:
                    loss_raw = info_nce_loss(t_proj, i_proj, temperature=temp)
                # 累积梯度时按步数缩放反传损失，避免有效学习率膨胀
                loss = loss_raw / max(1, accum_steps_local)
                loss.backward()
                # 梯度累积：仅在每 accum_steps_local 次执行优化器步进
                if (steps + 1) % accum_steps_local == 0:
                    torch.nn.utils.clip_grad_norm_(
                        list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
                        max_norm=5.0
                    )
                    optim.step()
            # 更新记忆队列（使用当前 batch 的投影特征），仅在启用记忆损失时
            if globals().get('use_memory_loss', False):
                try:
                    if img_memory_local is not None:
                        img_memory_local.add(i_proj.detach())
                    if txt_memory_local is not None:
                        txt_memory_local.add(t_proj.detach())
                except Exception as _:
                    pass
            # 统计使用未缩放的原始损失，便于观察下降趋势
            running_loss += loss_raw.item()
            steps += 1
        print(f"Epoch {e+1}/{epochs}: avg loss={running_loss/max(steps,1):.4f}")

# 流式加载与训练
batch_idx = 0
for image_batch in loader.load_images_batch(split='train', batch_size=train_image_batch_size, max_batches=max_train_batches):
    batch_idx += 1
    img_map = {item['img_id']: item['image'] for item in image_batch}
    pairs = build_batch_pairs(train_df, img_map)
    print(f"Batch {batch_idx}: images={len(img_map)}, usable_pairs={len(pairs)}")
    if not pairs:
        del img_map
        if device.type == 'cuda':
            torch.cuda.empty_cache()
        continue
    train_one_batch(pairs, epochs=epochs_per_batch, step_bs=train_step_batch_size)
    del img_map
    if device.type == 'cuda':
        torch.cuda.empty_cache()


2025-11-08 14:34:56,186 - INFO - 批量加载train图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_imgs.tsv
实际加载train图片数据:   0%|                                                     | 301/129380 [00:00<00:42, 3005.57it/s]

Batch 1: images=500, usable_pairs=997


  with autocast(enabled=True):
实际加载train图片数据:   0%|▏                                                      | 499/129380 [00:06<29:19, 73.27it/s]

Epoch 1/1: avg loss=2.3265





## 保存/加载检查点
保存两层MLP投影头、解冻顶层与优化器，并记录 logit_scale（可学习温度）。

In [11]:
save_dir = '/mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights'
save_path = os.path.join(save_dir, 'step_2_3_4_mlp_temp_checkpoint.pth')

def save_unfreeze_checkpoint(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                             optimizer: torch.optim.Optimizer, save_path: str, last_n_layers: int):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    ckpt = {
        'projection_dim': model.fusion.projection_dim,
        'last_n_layers': last_n_layers,
        'fusion': {
            'text_projector': model.fusion.text_projector.state_dict(),
            'image_projector': model.fusion.image_projector.state_dict(),
        },
        'logit_scale': model.logit_scale.detach().cpu(),
        'text_unfrozen': {},
        'image_unfrozen': {},
        'optimizer': optimizer.state_dict(),
    }
    enc = text_extractor.model.encoder
    total_layers = len(enc.layer)
    start_idx = max(0, total_layers - last_n_layers)
    for i in range(start_idx, total_layers):
        ckpt['text_unfrozen'][f'encoder_layer_{i}'] = enc.layer[i].state_dict()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        ckpt['text_unfrozen']['pooler'] = text_extractor.model.pooler.state_dict()
    if hasattr(image_extractor.model, 'layer4'):
        ckpt['image_unfrozen']['layer4'] = image_extractor.model.layer4.state_dict()
    torch.save(ckpt, save_path)
    print(f"Checkpoint saved to: {save_path}")

def load_unfreeze_checkpoint(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                             optimizer: torch.optim.Optimizer, load_path: str):
    ckpt = torch.load(load_path, map_location='cpu')
    model.fusion.text_projector.load_state_dict(ckpt['fusion']['text_projector'])
    model.fusion.image_projector.load_state_dict(ckpt['fusion']['image_projector'])
    if 'logit_scale' in ckpt:
        model.logit_scale.data = ckpt['logit_scale'].to(model.logit_scale.device)
    ln = ckpt.get('last_n_layers', 2)
    unfreeze_text_top_layers(text_extractor, last_n_layers=ln)
    unfreeze_image_top_block(image_extractor, unfreeze_layer4=True)
    enc = text_extractor.model.encoder
    for k, v in ckpt['text_unfrozen'].items():
        if k.startswith('encoder_layer_'):
            idx = int(k.split('_')[-1])
            if 0 <= idx < len(enc.layer):
                enc.layer[idx].load_state_dict(v)
    if 'pooler' in ckpt['text_unfrozen'] and hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_extractor.model.pooler.load_state_dict(ckpt['text_unfrozen']['pooler'])
    if 'layer4' in ckpt['image_unfrozen'] and hasattr(image_extractor.model, 'layer4'):
        image_extractor.model.layer4.load_state_dict(ckpt['image_unfrozen']['layer4'])
    if optimizer is not None and 'optimizer' in ckpt:
        optimizer.load_state_dict(ckpt['optimizer'])
    print(f"Checkpoint loaded from: {load_path}")

# 保存一次并测试加载
save_unfreeze_checkpoint(model, text_extractor, image_extractor, optim, save_path, 2)
# load_unfreeze_checkpoint(model, text_extractor, image_extractor, optim, save_path)


Checkpoint loaded from: /mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights/step_2_3_4_mlp_temp_checkpoint.pth


## 验证评估：Recall@1/5/10 与 MeanRecall
优先使用 FAISS；不可用则回退到 Torch 相似度计算。

In [12]:
valid_imgs = loader.create_img_id_to_image_dict(
    split='valid', 
    max_samples=valid_imgs_max_samples, 
    batch_size=3000, max_batches=10
)

valid_queries = []
if 'item_ids' in valid_df.columns:
    for _, row in valid_df.iterrows():
        q = row.get('query_text', None)
        ids = [str(i) for i in row.get('item_ids', [])] if isinstance(row.get('item_ids', []), list) else []
        if q and ids:
            valid_queries.append((q, ids))
print(f'Usable valid queries: {len(valid_queries)}')

image_index = model.build_image_index(valid_imgs, batch_size=32)
all_image_ids = list(image_index.keys())
all_image_feats = torch.stack([image_index[i] for i in all_image_ids]) if all_image_ids else torch.empty((0, 512))
faiss_index = None
if HAS_FAISS and all_image_feats.size(0) > 0:
    d = all_image_feats.size(1)
    faiss_index = faiss.IndexFlatIP(d)
    feats_np = all_image_feats.detach().cpu().numpy().astype('float32')
    faiss_index.add(feats_np)

all_image_feats = all_image_feats.to(device)

def compute_recall_at_k(k_values, queries):
    recalls = {k: 0 for k in k_values}
    total = 0
    for q_text, gt_ids in tqdm(queries, desc='Evaluate'):
        if all_image_feats.size(0) == 0:
            continue
        q_feat = model.extract_and_fuse_text_features([q_text])
        if faiss_index is not None:
            q_np = q_feat.detach().cpu().numpy().astype('float32')
            _, I = faiss_index.search(q_np, max(k_values))
            top_idx = I[0].tolist()
            top_ids = [all_image_ids[i] for i in top_idx]
        else:
            sims = model.sim.calculate_similarity(q_feat, all_image_feats)
            _, top_idx = torch.topk(sims[0], k=max(k_values))
            top_ids = [all_image_ids[i] for i in top_idx.tolist()]
        total += 1
        for k in k_values:
            if any(g in set(top_ids[:k]) for g in gt_ids):
                recalls[k] += 1
    return {k: (recalls[k] / total if total > 0 else 0.0) for k in k_values}, total

rec, total_q = compute_recall_at_k([1,5,10], valid_queries)
mean_recall = (rec.get(1,0)+rec.get(5,0)+rec.get(10,0))/3 if total_q>0 else 0.0
print(f'Recall@1={rec.get(1,0):.4f}, Recall@5={rec.get(5,0):.4f}, Recall@10={rec.get(10,0):.4f}, MeanRecall={mean_recall:.4f} (N={total_q})')


2025-11-08 14:36:29,315 - INFO - 批量加载valid图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_imgs.tsv
实际加载valid图片数据:  99%|████████████████████████████████████████████████████████▍| 99/100 [00:00<00:00, 3316.42it/s]
2025-11-08 14:36:33,367 - INFO - 成功创建valid图片映射字典，共100张图片


Usable valid queries: 5008


Evaluate: 100%|████████████████████████████████████████████████████████████████████| 5008/5008 [00:21<00:00, 237.27it/s]

Recall@1=0.0004, Recall@5=0.0018, Recall@10=0.0038, MeanRecall=0.0020 (N=5008)



