# 步骤 2.3-3：改用 Mean-Pooling 文本特征 + AMP/FAISS (屯)

在 `step_2_3-2_坤-解冻轻量微调-v2_amp_faiss.ipynb` 的基础上，
将文本特征从 [CLS] 向量改为基于 `attention_mask` 的 mean-pooling，
以提高句向量稳定性。保留 AMP 加速与 FAISS 检索加速，并保持训练/评估流程一致。

xmk: pooling的方法还有很多，见： https://blog.csdn.net/fengdu78/article/details/128059894 


In [1]:
import os
import sys
import math
from typing import List, Dict, Tuple
import torch
import timm
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import importlib

# 可选：Faiss 加速检索（安全导入，避免 NumPy 版本问题导致报错中断）
HAS_FAISS = False
faiss = None
try:
    spec = importlib.util.find_spec('faiss')
    if spec is not None:
        faiss = importlib.import_module('faiss')
        HAS_FAISS = True
except BaseException as e:
    print(f'Faiss unavailable: {e.__class__.__name__}')
    HAS_FAISS = False

# 设置环境变量（与基线一致，按需修改为本地镜像/缓存）
cache_dir = "/mnt/d/HuggingFaceModels/"
os.environ['TORCH_HOME'] = cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['all_proxy'] = 'socks5://127.0.0.1:7890'
os.environ["WANDB_DISABLED"] = "true"
os.environ['CURL_CA_BUNDLE'] = ""
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

# 导入数据加载器（使用 plan1_1/data_loader.py）
sys.path.append(os.path.abspath(os.path.join('.', 'Multimodal_Retrieval', 'plan1_1')))
from data_loader import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


## 模型与特征模块
文本特征使用基于 `attention_mask` 的 mean-pooling（排除 padding），相比仅用 [CLS] 更稳健。训练时允许梯度。


In [2]:
class TextFeatureExtractor:
    def __init__(self, model_name='bert-base-chinese', device='cpu', cache_dir=None):
        self.device = device
        self.tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True)
        self.model = BertModel.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True).to(device)
        # 默认 eval，训练时将对子模块单独切换 train
        self.model.eval()
        
    def encode_with_grad(self, texts: List[str]) -> torch.Tensor:
        if not texts:
            return torch.empty((0, 768), dtype=torch.float32, device=self.device)
        if isinstance(texts, str):
            texts = [texts]
        inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128).to(self.device)
        outputs = self.model(**inputs)
        token_embeddings = outputs.last_hidden_state  # [B, L, 768]
        attention_mask = inputs.get('attention_mask', None)
        if attention_mask is None:
            # 兜底：无 mask 时退化为 CLS
            return token_embeddings[:, 0, :]
        mask = attention_mask.unsqueeze(-1).type_as(token_embeddings)  # [B, L, 1]
        summed = (token_embeddings * mask).sum(dim=1)  # [B, 768]
        lengths = mask.sum(dim=1).clamp(min=1)        # [B, 1]
        mean_pooled = summed / lengths
        return mean_pooled

class ImageFeatureExtractor:
    def __init__(self, model_name='resnet50', device='cpu', cache_dir=None):
        self.device = device
        self.model = timm.create_model(
            model_name, pretrained=True, num_classes=0,
            cache_dir=cache_dir
        ).to(device)
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def encode_with_grad(self, images: List[Image.Image]) -> torch.Tensor:
        if not images:
            return torch.empty((0, 2048), dtype=torch.float32, device=self.device)
        tensors = torch.stack([self.transform(img.convert('RGB')) for img in images]).to(self.device)
        feats = self.model(tensors)
        return feats

class FeatureFusion:
    def __init__(self, fusion_method='projection', projection_dim=512, device=None):
        self.fusion_method = fusion_method
        self.projection_dim = projection_dim
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if fusion_method == 'projection':
            self.text_projector = torch.nn.Linear(768, projection_dim).to(self.device)
            self.image_projector = torch.nn.Linear(2048, projection_dim).to(self.device)
    def fuse_text_features(self, text_features: torch.Tensor) -> torch.Tensor:
        return self.text_projector(text_features) if self.fusion_method == 'projection' else text_features
    def fuse_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
        return self.image_projector(image_features) if self.fusion_method == 'projection' else image_features

class SimilarityCalculator:
    def __init__(self, similarity_type='cosine'):
        self.similarity_type = similarity_type
    def normalize_features(self, features: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.normalize(features, p=2, dim=1)
    def calculate_similarity(self, text_features: torch.Tensor, image_features: torch.Tensor) -> torch.Tensor:
        if self.similarity_type == 'cosine':
            t_n = self.normalize_features(text_features)
            i_n = self.normalize_features(image_features)
            return torch.mm(t_n, i_n.t())
        return torch.mm(text_features, image_features.t())

class CrossModalRetrievalModel:
    def __init__(self, text_extractor, image_extractor, fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=None):
        self.text_extractor = text_extractor
        self.image_extractor = image_extractor
        self.fusion = FeatureFusion(fusion_method, projection_dim, device)
        self.sim = SimilarityCalculator(similarity_type)
        self.normalize_features = normalize_features
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.normalize(x, p=2, dim=1) if self.normalize_features else x
    def extract_and_fuse_text_features(self, texts: List[str]) -> torch.Tensor:
        with torch.no_grad():
            t = self.text_extractor.encode_with_grad(texts)
        return self._norm(self.fusion.fuse_text_features(t))
    def extract_and_fuse_image_features(self, images: List[Image.Image]) -> torch.Tensor:
        with torch.no_grad():
            i = self.image_extractor.encode_with_grad(images)
        return self._norm(self.fusion.fuse_image_features(i))
    def build_image_index(self, images_dict: Dict[str, Image.Image], batch_size: int = 32) -> Dict[str, torch.Tensor]:
        feats = {}
        keys = list(images_dict.keys())
        for s in range(0, len(keys), batch_size):
            batch_ids = keys[s:s+batch_size]
            batch_imgs = [images_dict[k] for k in batch_ids if images_dict[k] is not None]
            valid_ids = [k for k in batch_ids if images_dict[k] is not None]
            if not batch_imgs:
                continue
            bf = self.extract_and_fuse_image_features(batch_imgs)
            for j, img_id in enumerate(valid_ids):
                feats[img_id] = bf[j].detach().cpu()
        return feats

def info_nce_loss(text_feats: torch.Tensor, image_feats: torch.Tensor, temp: float) -> torch.Tensor:
    logits = torch.mm(text_feats, image_feats.t()) / temp
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_t = torch.nn.functional.cross_entropy(logits, labels)
    loss_i = torch.nn.functional.cross_entropy(logits.t(), labels)
    return (loss_t + loss_i) * 0.5


## 顶层解冻与优化器分组
仅解冻顶层，降低学习率，控制训练稳定性。


In [3]:
def unfreeze_text_top_layers(text_extractor: TextFeatureExtractor, last_n_layers: int = 2):
    for p in text_extractor.model.parameters():
        p.requires_grad = False
    enc = text_extractor.model.encoder
    total_layers = len(enc.layer)
    for i in range(total_layers - last_n_layers, total_layers):
        for p in enc.layer[i].parameters():
            p.requires_grad = True
        enc.layer[i].train()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        for p in text_extractor.model.pooler.parameters():
            p.requires_grad = True
        text_extractor.model.pooler.train()
    text_extractor.model.eval()

def unfreeze_image_top_block(image_extractor: ImageFeatureExtractor, unfreeze_layer4: bool = True):
    for p in image_extractor.model.parameters():
        p.requires_grad = False
    if unfreeze_layer4 and hasattr(image_extractor.model, 'layer4'):
        for p in image_extractor.model.layer4.parameters():
            p.requires_grad = True
        image_extractor.model.layer4.train()
    image_extractor.model.eval()

def build_optimizer(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                   lr_proj: float = 1e-3, lr_text_top: float = 5e-5, lr_img_top: float = 1e-4, weight_decay: float = 1e-4):
    params = []
    params.append({
        'params': list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
        'lr': lr_proj,
        'weight_decay': weight_decay
    })
    text_top_params = []
    enc = text_extractor.model.encoder
    for mod in enc.layer[-2:]:
        text_top_params += list(mod.parameters())
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_top_params += list(text_extractor.model.pooler.parameters())
    params.append({
        'params': [p for p in text_top_params if p.requires_grad],
        'lr': lr_text_top,
        'weight_decay': 0.0
    })
    img_top_params = []
    if hasattr(image_extractor.model, 'layer4'):
        img_top_params += list(image_extractor.model.layer4.parameters())
    params.append({
        'params': [p for p in img_top_params if p.requires_grad],
        'lr': lr_img_top,
        'weight_decay': 0.0
    })
    optimizer = torch.optim.Adam(params)
    return optimizer


## 数据加载与训练参数
保持与基线一致的查询数据加载；图片按批次流式以控制显存。默认使用较小批次以便顺利运行。


In [4]:
loader = DataLoader()
train_df = loader.load_queries(split='train')
valid_df = loader.load_queries(split='valid')

# # 训练与流式参数（默认较小，确保可运行）
# train_image_batch_size = 500  # 一个大batch的图片数量
# max_train_batches = 1         # 加载多少个大batch
# epochs_per_batch = 1          # 每个大batch的训练轮数
# train_step_batch_size = 32    # 训练时的小 batch size

# 训练与流式参数（按需调整）：实际用
train_image_batch_size = 15000 ## 一个大batch有这么多图片样本。
max_train_batches = 10 ## 总共加载多少个大batch。
epochs_per_batch = 5 ## 每个大batch训练几个epoch。
train_step_batch_size = 32 ## 每个大batch里面训练的时候的小batch_size是多少。

use_amp = True
temperature = 0.07


2025-11-08 18:51:44,129 - INFO - 初始化数据加载器，数据目录: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval
2025-11-08 18:51:44,133 - INFO - 加载train查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_queries.jsonl
加载train查询数据: 248786it [00:01, 213593.85it/s]
2025-11-08 18:51:45,384 - INFO - 成功加载train查询数据，共248786条
2025-11-08 18:51:45,396 - INFO - 加载valid查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_queries.jsonl
加载valid查询数据: 5008it [00:00, 200130.29it/s]
2025-11-08 18:51:45,426 - INFO - 成功加载valid查询数据，共5008条


## 初始化模型并执行顶层解冻
解冻文本最后2层与池化；解冻图像 `layer4`。


In [5]:
image_extractor = ImageFeatureExtractor(device=device, cache_dir=cache_dir)
text_extractor = TextFeatureExtractor(device=device, cache_dir=cache_dir)
model = CrossModalRetrievalModel(
    text_extractor, image_extractor, 
    fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=device
)

unfreeze_text_top_layers(text_extractor, last_n_layers=2)
unfreeze_image_top_block(image_extractor, unfreeze_layer4=True)

optim = build_optimizer(model, text_extractor, image_extractor,
                        lr_proj=1e-3, lr_text_top=5e-5, lr_img_top=1e-4, weight_decay=1e-4)
scaler = GradScaler(enabled=(device.type == 'cuda' and use_amp))
print('Optim groups:', len(optim.param_groups))


2025-11-08 18:51:45,678 - INFO - Loading pretrained weights from Hugging Face hub (timm/resnet50.a1_in1k)
2025-11-08 18:51:46,466 - INFO - [timm/resnet50.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


Optim groups: 3


  scaler = GradScaler(enabled=(device.type == 'cuda' and use_amp))


## 训练循环：按批次流式构建配对并微调顶层
仅使用配对中的第一张可用图片；文本与图像编码器顶层参与反向传播。


In [6]:
def build_batch_pairs(train_df, img_dict: Dict[str, Image.Image]) -> List[Tuple[str, Image.Image, str]]:
    pairs = []
    if 'item_ids' in train_df.columns:
        for _, row in train_df.iterrows():
            q = row.get('query_text', None)
            ids = row.get('item_ids', [])
            if not q or not isinstance(ids, list) or not ids:
                continue
            chosen_img = None
            chosen_id = None
            for iid in ids:
                sid = str(iid)
                if sid in img_dict and img_dict[sid] is not None:
                    chosen_img = img_dict[sid]
                    chosen_id = sid
                    break
            if chosen_img is not None:
                pairs.append((q, chosen_img, chosen_id))
    return pairs

def train_one_batch(pairs: List[Tuple[str, Image.Image, str]], epochs: int, step_bs: int):
    model.fusion.text_projector.train()
    model.fusion.image_projector.train()
    text_extractor.model.encoder.layer[-1].train()
    text_extractor.model.encoder.layer[-2].train()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_extractor.model.pooler.train()
    if hasattr(image_extractor.model, 'layer4'):
        image_extractor.model.layer4.train()

    for e in range(epochs):
        running_loss = 0.0
        steps = 0
        for s in range(0, len(pairs), step_bs):
            batch = pairs[s:s+step_bs]
            if not batch:
                continue
            texts = [t for (t, _, _) in batch]
            imgs = [im for (_, im, _) in batch]

            optim.zero_grad()
            if use_amp and device.type == 'cuda':
                with autocast(enabled=True):
                    t_feats = text_extractor.encode_with_grad(texts)
                    i_feats = image_extractor.encode_with_grad(imgs)
                    t_proj = model._norm(model.fusion.fuse_text_features(t_feats))
                    i_proj = model._norm(model.fusion.fuse_image_features(i_feats))
                    loss = info_nce_loss(t_proj, i_proj, temp=temperature)
                scaler.scale(loss).backward()
                scaler.unscale_(optim)
                torch.nn.utils.clip_grad_norm_(
                    list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
                    max_norm=5.0
                )
                scaler.step(optim)
                scaler.update()
            else:
                t_feats = text_extractor.encode_with_grad(texts)
                i_feats = image_extractor.encode_with_grad(imgs)
                t_proj = model._norm(model.fusion.fuse_text_features(t_feats))
                i_proj = model._norm(model.fusion.fuse_image_features(i_feats))
                loss = info_nce_loss(t_proj, i_proj, temp=temperature)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
                    max_norm=5.0
                )
                optim.step()
            running_loss += loss.item()
            steps += 1
        print(f"Epoch {e+1}/{epochs}: avg loss={running_loss/max(steps,1):.4f}")

# 流式加载图片与训练
batch_idx = 0
for image_batch in loader.load_images_batch(split='train', batch_size=train_image_batch_size, max_batches=max_train_batches):
    batch_idx += 1
    img_map = {item['img_id']: item['image'] for item in image_batch}
    pairs = build_batch_pairs(train_df, img_map)
    print(f"Batch {batch_idx}: images={len(img_map)}, usable_pairs={len(pairs)}")
    if not pairs:
        del img_map
        if device.type == 'cuda':
            torch.cuda.empty_cache()
        continue
    train_one_batch(pairs, epochs=epochs_per_batch, step_bs=train_step_batch_size)
    del img_map
    if device.type == 'cuda':
        torch.cuda.empty_cache()


2025-11-08 18:51:52,626 - INFO - 批量加载train图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_imgs.tsv
实际加载train图片数据:  11%|█████▊                                             | 14784/129380 [00:06<00:47, 2393.49it/s]

Batch 1: images=15000, usable_pairs=29127


  with autocast(enabled=True):
实际加载train图片数据:  11%|█████▊                                             | 14784/129380 [00:19<00:47, 2393.49it/s]

Epoch 1/5: avg loss=1.4942
Epoch 2/5: avg loss=0.6086
Epoch 3/5: avg loss=0.2039
Epoch 4/5: avg loss=0.0802


实际加载train图片数据:  12%|█████▉                                             | 15000/129380 [03:15<7:42:02,  4.13it/s]

Epoch 5/5: avg loss=0.0463


实际加载train图片数据:  23%|███████████▋                                       | 29797/129380 [03:20<00:33, 2979.10it/s]

Batch 2: images=15000, usable_pairs=28909


实际加载train图片数据:  23%|███████████▋                                       | 29797/129380 [03:39<00:33, 2979.10it/s]

Epoch 1/5: avg loss=1.2430
Epoch 2/5: avg loss=0.3849
Epoch 3/5: avg loss=0.1131
Epoch 4/5: avg loss=0.0518


实际加载train图片数据:  23%|███████████▊                                       | 30000/129380 [06:15<5:16:27,  5.23it/s]

Epoch 5/5: avg loss=0.0532


实际加载train图片数据:  35%|█████████████████▋                                 | 44728/129380 [06:20<00:28, 2962.52it/s]

Batch 3: images=15000, usable_pairs=28595


实际加载train图片数据:  35%|█████████████████▋                                 | 44728/129380 [06:39<00:28, 2962.52it/s]

Epoch 1/5: avg loss=1.1801
Epoch 2/5: avg loss=0.3575
Epoch 3/5: avg loss=0.1046
Epoch 4/5: avg loss=0.0507


实际加载train图片数据:  35%|█████████████████▋                                 | 45000/129380 [09:14<4:09:34,  5.63it/s]

Epoch 5/5: avg loss=0.0519


实际加载train图片数据:  46%|███████████████████████▌                           | 59713/129380 [09:18<00:21, 3178.72it/s]

Batch 4: images=15000, usable_pairs=29155


实际加载train图片数据:  46%|███████████████████████▌                           | 59713/129380 [09:29<00:21, 3178.72it/s]

Epoch 1/5: avg loss=1.1550
Epoch 2/5: avg loss=0.3363
Epoch 3/5: avg loss=0.0963
Epoch 4/5: avg loss=0.0470


实际加载train图片数据:  46%|███████████████████████▋                           | 60000/129380 [12:15<3:15:42,  5.91it/s]

Epoch 5/5: avg loss=0.0512


实际加载train图片数据:  58%|█████████████████████████████▍                     | 74705/129380 [12:20<00:16, 3247.99it/s]

Batch 5: images=15000, usable_pairs=29347


实际加载train图片数据:  58%|█████████████████████████████▍                     | 74705/129380 [12:39<00:16, 3247.99it/s]

Epoch 1/5: avg loss=1.1236
Epoch 2/5: avg loss=0.3077
Epoch 3/5: avg loss=0.0886
Epoch 4/5: avg loss=0.0439
Epoch 5/5: avg loss=0.0400


实际加载train图片数据:  69%|███████████████████████████████████▍               | 89790/129380 [15:25<00:11, 3312.79it/s]

Batch 6: images=15000, usable_pairs=29191


实际加载train图片数据:  69%|███████████████████████████████████▍               | 89790/129380 [15:39<00:11, 3312.79it/s]

Epoch 1/5: avg loss=1.0916
Epoch 2/5: avg loss=0.2929
Epoch 3/5: avg loss=0.0875
Epoch 4/5: avg loss=0.0422


实际加载train图片数据:  70%|███████████████████████████████████▍               | 90000/129380 [18:24<1:59:46,  5.48it/s]

Epoch 5/5: avg loss=0.0404


实际加载train图片数据:  81%|████████████████████████████████████████▍         | 104776/129380 [18:29<00:07, 3284.49it/s]

Batch 7: images=15000, usable_pairs=28504


实际加载train图片数据:  81%|████████████████████████████████████████▍         | 104776/129380 [18:49<00:07, 3284.49it/s]

Epoch 1/5: avg loss=1.0757
Epoch 2/5: avg loss=0.2729
Epoch 3/5: avg loss=0.0775
Epoch 4/5: avg loss=0.0374


实际加载train图片数据:  81%|████████████████████████████████████████▌         | 105000/129380 [21:23<1:09:55,  5.81it/s]

Epoch 5/5: avg loss=0.0332


实际加载train图片数据:  93%|██████████████████████████████████████████████▎   | 119830/129380 [21:28<00:03, 3148.71it/s]

Batch 8: images=15000, usable_pairs=28894


实际加载train图片数据:  93%|██████████████████████████████████████████████▎   | 119830/129380 [21:39<00:03, 3148.71it/s]

Epoch 1/5: avg loss=1.0543
Epoch 2/5: avg loss=0.2620
Epoch 3/5: avg loss=0.0742
Epoch 4/5: avg loss=0.0368
Epoch 5/5: avg loss=0.0318


实际加载train图片数据: 100%|████████████████████████████████████████████████████| 129380/129380 [24:26<00:00, 88.22it/s]


Batch 9: images=9380, usable_pairs=18432
Epoch 1/5: avg loss=1.0653
Epoch 2/5: avg loss=0.1987
Epoch 3/5: avg loss=0.0567
Epoch 4/5: avg loss=0.0323
Epoch 5/5: avg loss=0.0251


## 保存：投影层 + 已解冻顶层 + 优化器
保存 BERT 的最后2层 + pooler、ResNet50 的 layer4、投影层、优化器。


In [7]:
save_dir = '/mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights'
save_path = os.path.join(save_dir, 'step_2_3_3_mean_pool_checkpoint.pth')

def save_unfreeze_checkpoint(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                             optimizer: torch.optim.Optimizer, save_path: str, last_n_layers: int):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    ckpt = {
        'projection_dim': model.fusion.projection_dim,
        'last_n_layers': last_n_layers,
        'fusion': {
            'text_projector': model.fusion.text_projector.state_dict(),
            'image_projector': model.fusion.image_projector.state_dict(),
        },
        'text_unfrozen': {},
        'image_unfrozen': {},
        'optimizer': optimizer.state_dict(),
    }
    enc = text_extractor.model.encoder
    total_layers = len(enc.layer)
    start_idx = max(0, total_layers - last_n_layers)
    for i in range(start_idx, total_layers):
        ckpt['text_unfrozen'][f'encoder_layer_{i}'] = enc.layer[i].state_dict()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        ckpt['text_unfrozen']['pooler'] = text_extractor.model.pooler.state_dict()
    if hasattr(image_extractor.model, 'layer4'):
        ckpt['image_unfrozen']['layer4'] = image_extractor.model.layer4.state_dict()
    torch.save(ckpt, save_path)
    print(f"Checkpoint saved to: {save_path}")

# 保存一次
save_unfreeze_checkpoint(model, text_extractor, image_extractor, optim, save_path, 2)


Checkpoint saved to: /mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights/step_2_3_3_mean_pool_checkpoint.pth


## 加载：恢复解冻顶层与投影层权重，继续训练
加载后会自动再次执行顶层解冻，并恢复优化器状态（如提供）。


In [8]:
def load_unfreeze_checkpoint(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                             optimizer: torch.optim.Optimizer, load_path: str):
    ckpt = torch.load(load_path, map_location='cpu')
    model.fusion.text_projector.load_state_dict(ckpt['fusion']['text_projector'])
    model.fusion.image_projector.load_state_dict(ckpt['fusion']['image_projector'])
    ln = ckpt.get('last_n_layers', 2)
    unfreeze_text_top_layers(text_extractor, last_n_layers=ln)
    unfreeze_image_top_block(image_extractor, unfreeze_layer4=True)
    enc = text_extractor.model.encoder
    for k, v in ckpt['text_unfrozen'].items():
        if k.startswith('encoder_layer_'):
            idx = int(k.split('_')[-1])
            if 0 <= idx < len(enc.layer):
                enc.layer[idx].load_state_dict(v)
    if 'pooler' in ckpt['text_unfrozen'] and hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_extractor.model.pooler.load_state_dict(ckpt['text_unfrozen']['pooler'])
    if 'layer4' in ckpt['image_unfrozen'] and hasattr(image_extractor.model, 'layer4'):
        image_extractor.model.layer4.load_state_dict(ckpt['image_unfrozen']['layer4'])
    if optimizer is not None and 'optimizer' in ckpt:
        optimizer.load_state_dict(ckpt['optimizer'])
    print(f"Checkpoint loaded from: {load_path}")

# 测试加载
load_unfreeze_checkpoint(model, text_extractor, image_extractor, optim, save_path)


Checkpoint loaded from: /mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights/step_2_3_3_mean_pool_checkpoint.pth


## 验证评估：Recall@1/5/10 与 MeanRecall
- 基于验证集构建图像索引
- 对每条查询计算相似度并统计召回（优先使用 FAISS；不可用则 Torch 回退）


In [9]:
valid_imgs = loader.create_img_id_to_image_dict(
    split='valid', 
    max_samples=50000,
    batch_size=3000,
    max_batches=10
)

valid_queries = []
if 'item_ids' in valid_df.columns:
    for _, row in valid_df.iterrows():
        q = row.get('query_text', None)
        ids = [str(i) for i in row.get('item_ids', [])] if isinstance(row.get('item_ids', []), list) else []
        if q and ids:
            valid_queries.append((q, ids))
print(f'Usable valid queries: {len(valid_queries)}')

image_index = model.build_image_index(valid_imgs, batch_size=32)
all_image_ids = list(image_index.keys())
all_image_feats = torch.stack([image_index[i] for i in all_image_ids]) if all_image_ids else torch.empty((0, 512))
faiss_index = None
if HAS_FAISS and all_image_feats.size(0) > 0:
    d = all_image_feats.size(1)
    faiss_index = faiss.IndexFlatIP(d)
    feats_np = all_image_feats.detach().cpu().numpy().astype('float32')
    faiss_index.add(feats_np)

all_image_feats = all_image_feats.to(device)

def compute_recall_at_k(k_values, queries):
    recalls = {k: 0 for k in k_values}
    total = 0
    for q_text, gt_ids in tqdm(queries, desc='Evaluate'):
        if all_image_feats.size(0) == 0:
            continue
        q_feat = model.extract_and_fuse_text_features([q_text])
        if faiss_index is not None:
            q_np = q_feat.detach().cpu().numpy().astype('float32')
            _, I = faiss_index.search(q_np, max(k_values))
            top_idx = I[0].tolist()
            top_ids = [all_image_ids[i] for i in top_idx]
        else:
            sims = model.sim.calculate_similarity(q_feat, all_image_feats)
            _, top_idx = torch.topk(sims[0], k=max(k_values))
            top_ids = [all_image_ids[i] for i in top_idx.tolist()]
        total += 1
        for k in k_values:
            if any(g in set(top_ids[:k]) for g in gt_ids):
                recalls[k] += 1
    return {k: (recalls[k] / total if total > 0 else 0.0) for k in k_values}, total

rec, total_q = compute_recall_at_k([1,5,10], valid_queries)
mean_recall = (rec.get(1,0)+rec.get(5,0)+rec.get(10,0))/3 if total_q>0 else 0.0
print(f'Recall@1={rec.get(1,0):.4f}, Recall@5={rec.get(5,0):.4f}, Recall@10={rec.get(10,0):.4f}, MeanRecall={mean_recall:.4f} (N={total_q})')


2025-11-08 19:18:27,294 - INFO - 批量加载valid图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_imgs.tsv
实际加载valid图片数据: 100%|████████████████████████████████████████████████████| 29806/29806 [00:12<00:00, 2365.20it/s]
2025-11-08 19:18:42,754 - INFO - 成功创建valid图片映射字典，共29806张图片


Usable valid queries: 5008


Evaluate: 100%|████████████████████████████████████████████████████████████████████| 5008/5008 [00:39<00:00, 125.58it/s]

Recall@1=0.0571, Recall@5=0.1905, Recall@10=0.2927, MeanRecall=0.1801 (N=5008)





In [10]:
## 训练5轮：
## Recall@1=0.0571, Recall@5=0.1905, Recall@10=0.2927, MeanRecall=0.1801 (N=5008)

In [11]:
## 上一版本测试结果：
## Recall@1=0.0056, Recall@5=0.0202, Recall@10=0.0375, MeanRecall=0.0211 (N=5008)
## 上一版实际跑结果：
## Recall@1=0.0487, Recall@5=0.1753, Recall@10=0.2708, MeanRecall=0.1649 (N=5008)