In [1]:
import os
import sys
import math
from typing import List, Dict, Tuple
import torch
import timm
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import importlib

# 可选：Faiss 加速检索（安全导入，避免 NumPy 版本问题导致报错中断）
HAS_FAISS = False
faiss = None
try:
    spec = importlib.util.find_spec('faiss')
    if spec is not None:
        faiss = importlib.import_module('faiss')
        HAS_FAISS = True
except BaseException as e:
    print(f'Faiss unavailable: {e.__class__.__name__}')
    HAS_FAISS = False

# 设置环境变量（与基线一致，按需修改为本地镜像/缓存）
cache_dir = "/mnt/d/HuggingFaceModels/"
os.environ['TORCH_HOME'] = cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['all_proxy'] = 'socks5://127.0.0.1:7890'
os.environ["WANDB_DISABLED"] = "true"
os.environ['CURL_CA_BUNDLE'] = ""
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

# 导入数据加载器（使用 plan1_1/data_loader.py）
sys.path.append(os.path.abspath(os.path.join('.', 'Multimodal_Retrieval', 'plan1_1')))
from data_loader import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
save_dir = '/mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights'
save_path = os.path.join(save_dir, 'step_3_1_1.pth')

In [3]:
import time
while True:
    if os.path.exists("step_2_3-3_屯-mean_pooling-v2_cp1.finishflag"):
        break
    time.sleep(5)

## 模型与特征模块
文本特征使用基于 `attention_mask` 的 mean-pooling（排除 padding），相比仅用 [CLS] 更稳健。训练时允许梯度。


In [4]:
import torch.nn.functional as F
from torch import nn
# 1. 优化版两层MLP投影头（核心组件）
class OptimizedMLPProjector(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.1, use_bn=True):
        super().__init__()
        self.layer1 = nn.Linear(in_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim) if use_bn else nn.Identity()
        self.dropout = nn.Dropout(dropout)
        self.layer2 = nn.Linear(hidden_dim, out_dim)
        # Kaiming初始化
        self._init_weights()
    
    def _init_weights(self):
        nn.init.kaiming_normal_(self.layer1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.zeros_(self.layer1.bias)
        nn.init.kaiming_normal_(self.layer2.weight, mode='fan_in', nonlinearity='linear')
        nn.init.zeros_(self.layer2.bias)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.bn1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.layer2(x)
        return x

In [5]:
class TextFeatureExtractor:
    def __init__(self, model_name='bert-base-chinese', device='cpu', cache_dir=None):
        self.device = device
        self.tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True)
        self.model = BertModel.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True).to(device)
        # 默认 eval，训练时将对子模块单独切换 train
        self.model.eval()
        
    def encode_with_grad(self, texts: List[str]) -> torch.Tensor:
        if not texts:
            return torch.empty((0, 768), dtype=torch.float32, device=self.device)
        if isinstance(texts, str):
            texts = [texts]
        inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128).to(self.device)
        outputs = self.model(**inputs)
        token_embeddings = outputs.last_hidden_state  # [B, L, 768]
        attention_mask = inputs.get('attention_mask', None)
        if attention_mask is None:
            # 兜底：无 mask 时退化为 CLS
            return token_embeddings[:, 0, :]
        mask = attention_mask.unsqueeze(-1).type_as(token_embeddings)  # [B, L, 1]
        summed = (token_embeddings * mask).sum(dim=1)  # [B, 768]
        lengths = mask.sum(dim=1).clamp(min=1)        # [B, 1]
        mean_pooled = summed / lengths
        return mean_pooled

# class ImageFeatureExtractor:
#     def __init__(self, model_name='resnet50', device='cpu', cache_dir=None):
#         self.device = device
#         self.model = timm.create_model(
#             model_name, pretrained=True, num_classes=0,
#             cache_dir=cache_dir
#         ).to(device)
#         self.model.eval()
#         self.transform = transforms.Compose([
#             transforms.Resize((224, 224)),
#             transforms.ToTensor(),
#             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#         ])
        
#     def encode_with_grad(self, images: List[Image.Image]) -> torch.Tensor:
#         if not images:
#             return torch.empty((0, 2048), dtype=torch.float32, device=self.device)
#         tensors = torch.stack([self.transform(img.convert('RGB')) for img in images]).to(self.device)
#         feats = self.model(tensors)
#         return feats
# image_extractor = ImageFeatureExtractor(device=device, cache_dir=cache_dir)

from safetensors.torch import load_file
class ImageFeatureExtractor:
    '''
    改进版，使得timm不要每次都去连huggingface。
    '''
    def __init__(self, model_name='resnet50', device='cpu', weights_path=None, cache_dir=None):
        self.device = device
        self.model = timm.create_model(model_name, pretrained=False, num_classes=0, cache_dir=cache_dir)

        if weights_path is not None:
            if weights_path.endswith('.safetensors'):
                state_dict = load_file(weights_path)
            else:
                state_dict = torch.load(weights_path, map_location='cpu')
            self.model.load_state_dict(state_dict, strict=False)

        self.model = self.model.to(device)
        self.model.eval()

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def encode_with_grad(self, images: List[Image.Image]) -> torch.Tensor:
        if not images:
            in_dim = getattr(self.model, 'num_features', 2048)
            return torch.empty((0, in_dim), dtype=torch.float32, device=self.device)
        tensors = torch.stack([self.transform(img.convert('RGB')) for img in images]).to(self.device)
        feats = self.model(tensors)
        return feats
        
class FeatureFusion:
    # 类作用：将原始文本/图像特征投影到共同的子空间（projection_dim）
    # 参数:
    # - fusion_method: 融合方式，当前支持 'projection'
    # - projection_dim: 目标投影维度
    # - device: 设备
    # - hidden_dim: 两层 MLP 的中间隐藏层维度
    # - dropout: Dropout 概率
    # - text_in_dim/image_in_dim: 输入维度（默认文本768/图像2048）
    def __init__(self, fusion_method='projection', projection_dim=512, device=None, hidden_dim=1024, dropout=0.1, text_in_dim=768, image_in_dim=2048):
        self.fusion_method = fusion_method
        self.projection_dim = projection_dim
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.hidden_dim = hidden_dim
        self.dropout_p = dropout
        if fusion_method == 'projection':
            # self.text_projector = torch.nn.Sequential(
            #     torch.nn.Linear(text_in_dim, hidden_dim),
            #     torch.nn.GELU(),
            #     torch.nn.Dropout(p=dropout),
            #     torch.nn.Linear(hidden_dim, projection_dim)
            # ).to(self.device)
            # self.image_projector = torch.nn.Sequential(
            #     torch.nn.Linear(image_in_dim, hidden_dim),
            #     torch.nn.GELU(),
            #     torch.nn.Dropout(p=dropout),
            #     torch.nn.Linear(hidden_dim, projection_dim)
            # ).to(self.device)
            
            # self.text_projector = OptimizedMLPProjector(text_in_dim, hidden_dim, projection_dim, dropout=dropout).to(self.device)
            # self.image_projector = OptimizedMLPProjector(image_in_dim, hidden_dim, projection_dim, dropout=dropout).to(self.device)

            self.text_projector = torch.nn.Linear(text_in_dim, projection_dim).to(self.device)
            self.image_projector = torch.nn.Linear(image_in_dim, projection_dim).to(self.device)
            
    def fuse_text_features(self, text_features: torch.Tensor) -> torch.Tensor:
        return self.text_projector(text_features) if self.fusion_method == 'projection' else text_features
    def fuse_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
        return self.image_projector(image_features) if self.fusion_method == 'projection' else image_features

class SimilarityCalculator:
    def __init__(self, similarity_type='cosine'):
        self.similarity_type = similarity_type
    def normalize_features(self, features: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.normalize(features, p=2, dim=1)
    def calculate_similarity(self, text_features: torch.Tensor, image_features: torch.Tensor) -> torch.Tensor:
        if self.similarity_type == 'cosine':
            t_n = self.normalize_features(text_features)
            i_n = self.normalize_features(image_features)
            return torch.mm(t_n, i_n.t())
        return torch.mm(text_features, image_features.t())

class CrossModalRetrievalModel:
    def __init__(self, text_extractor, image_extractor, fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=None):
        self.text_extractor = text_extractor
        self.image_extractor = image_extractor
        self.fusion = FeatureFusion(fusion_method, projection_dim, device)
        self.sim = SimilarityCalculator(similarity_type)
        self.normalize_features = normalize_features
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.normalize(x, p=2, dim=1) if self.normalize_features else x
    def extract_and_fuse_text_features(self, texts: List[str]) -> torch.Tensor:
        # 评估阶段禁用 BN/Dropout 的训练行为
        self.fusion.text_projector.eval()
        with torch.no_grad():
            t = self.text_extractor.encode_with_grad(texts)
        return self._norm(self.fusion.fuse_text_features(t))
    def extract_and_fuse_image_features(self, images: List[Image.Image]) -> torch.Tensor:
        # 评估阶段禁用 BN/Dropout 的训练行为
        self.fusion.image_projector.eval()
        with torch.no_grad():
            i = self.image_extractor.encode_with_grad(images)
        return self._norm(self.fusion.fuse_image_features(i))
    def build_image_index(self, images_dict: Dict[str, Image.Image], batch_size: int = 32) -> Dict[str, torch.Tensor]:
        feats = {}
        keys = list(images_dict.keys())
        for s in range(0, len(keys), batch_size):
            batch_ids = keys[s:s+batch_size]
            batch_imgs = [images_dict[k] for k in batch_ids if images_dict[k] is not None]
            valid_ids = [k for k in batch_ids if images_dict[k] is not None]
            if not batch_imgs:
                continue
            bf = self.extract_and_fuse_image_features(batch_imgs)
            for j, img_id in enumerate(valid_ids):
                feats[img_id] = bf[j].detach().cpu()
        return feats

def info_nce_loss(text_feats: torch.Tensor, image_feats: torch.Tensor, temp: float) -> torch.Tensor:
    logits = torch.mm(text_feats, image_feats.t()) / temp
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_t = torch.nn.functional.cross_entropy(logits, labels)
    loss_i = torch.nn.functional.cross_entropy(logits.t(), labels)
    return (loss_t + loss_i) * 0.5


## 顶层解冻与优化器分组
仅解冻顶层，降低学习率，控制训练稳定性。


In [6]:

def unfreeze_text_top_layers(text_extractor: TextFeatureExtractor, last_n_layers: int = 2):
    for p in text_extractor.model.parameters():
        p.requires_grad = False
    enc = text_extractor.model.encoder
    total_layers = len(enc.layer)
    for i in range(total_layers - last_n_layers, total_layers):
        for p in enc.layer[i].parameters():
            p.requires_grad = True
        enc.layer[i].train()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        for p in text_extractor.model.pooler.parameters():
            p.requires_grad = True
        text_extractor.model.pooler.train()
    text_extractor.model.eval()

def unfreeze_image_top_block(image_extractor: ImageFeatureExtractor, unfreeze_layer4: bool = True):
    for p in image_extractor.model.parameters():
        p.requires_grad = False
    if unfreeze_layer4 and hasattr(image_extractor.model, 'layer4'):
        for p in image_extractor.model.layer4.parameters():
            p.requires_grad = True
        image_extractor.model.layer4.train()
    image_extractor.model.eval()

def build_optimizer(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                   lr_proj: float = 1e-3, lr_text_top: float = 5e-5, lr_img_top: float = 1e-4, weight_decay: float = 1e-4):
    params = []
    params.append({
        'params': list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
        'lr': lr_proj,
        'weight_decay': weight_decay
    })
    text_top_params = []
    enc = text_extractor.model.encoder
    for mod in enc.layer[-2:]:
        text_top_params += list(mod.parameters())
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_top_params += list(text_extractor.model.pooler.parameters())
    params.append({
        'params': [p for p in text_top_params if p.requires_grad],
        'lr': lr_text_top,
        'weight_decay': 0.0
    })
    img_top_params = []
    if hasattr(image_extractor.model, 'layer4'):
        img_top_params += list(image_extractor.model.layer4.parameters())
    params.append({
        'params': [p for p in img_top_params if p.requires_grad],
        'lr': lr_img_top,
        'weight_decay': 0.0
    })
    optimizer = torch.optim.Adam(params)
    return optimizer

# LLRD（Layer-wise LR Decay）优化器构建：为BERT顶层设置逐层衰减学习率
def build_llrd_optimizer(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                         lr_proj: float = 1e-3, lr_text_max: float = 5e-5, lr_img_top: float = 1e-4, decay: float = 0.9,
                         last_n_layers: int = 2, weight_decay: float = 1e-4):
    params = []
    # 投影层（文本/图像）
    params.append({
        'params': list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
        'lr': lr_proj,
        'weight_decay': weight_decay
    })
    # 文本顶层：逐层衰减（最顶层lr=lr_text_max，其次乘以decay）
    enc = text_extractor.model.encoder
    total_layers = len(enc.layer)
    start_idx = max(0, total_layers - last_n_layers)
    # 从顶层到次顶层设置lr
    order = 0
    for i in range(total_layers - 1, start_idx - 1, -1):
        group_lr = lr_text_max * (decay ** order)
        params.append({
            'params': [p for p in enc.layer[i].parameters() if p.requires_grad],
            'lr': group_lr,
            'weight_decay': 0.0
        })
        order += 1
    # pooler（若存在），使用与最顶层一致的lr
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        params.append({
            'params': [p for p in text_extractor.model.pooler.parameters() if p.requires_grad],
            'lr': lr_text_max,
            'weight_decay': 0.0
        })
    # 图像顶层（layer4）
    if hasattr(image_extractor.model, 'layer4'):
        params.append({
            'params': [p for p in image_extractor.model.layer4.parameters() if p.requires_grad],
            'lr': lr_img_top,
            'weight_decay': 0.0
        })
    optimizer = torch.optim.Adam(params)
    return optimizer

# Warmup + Cosine 学习率调度器
def build_warmup_cosine_scheduler(optimizer: torch.optim.Optimizer, warmup_ratio: float, min_lr_ratio: float, total_steps: int):
    warmup_steps = max(1, int(total_steps * max(0.0, min(warmup_ratio, 0.5))))
    min_ratio = max(0.0, min(min_lr_ratio, 1.0))
    def lr_lambda(current_step: int):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        # Cosine decay from 1.0 -> min_ratio
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
        return min_ratio + (1.0 - min_ratio) * cosine
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


## 数据加载与训练参数
保持与基线一致的查询数据加载；图片按批次流式以控制显存。默认使用较小批次以便顺利运行。


In [7]:

loader = DataLoader()
train_df = loader.load_queries(split='train')
valid_df = loader.load_queries(split='valid')

# # 训练与流式参数（默认较小，确保顺利运行；可按需增大）
# train_image_batch_size = 500
# max_train_batches = 1
# epochs_per_batch = 1
# train_step_batch_size = 32
# valid_imgs_max_samples = 100

# 训练与流式参数（按需调整）：实际用
train_image_batch_size = 15000 ## 一个大batch有这么多图片样本。
max_train_batches = 10 ## 总共加载多少个大batch。
epochs_per_batch = 5 ## 每个大batch训练几个epoch。
train_step_batch_size = 32 ## 每个大batch里面训练的时候的小batch_size是多少。
valid_imgs_max_samples = 30000

use_amp = True
temperature = 0.07

# 微调与调度参数
last_n_layers = 2  # 顶层解冻层数
warmup_ratio = 0.1  # 预热比例
min_lr_ratio = 0.1  # 余弦最低学习率相对比例
use_grad_checkpoint = False  # 可选：启用BERT梯度检查点以降低显存


2025-11-09 16:47:23,214 - INFO - 初始化数据加载器，数据目录: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval
2025-11-09 16:47:23,216 - INFO - 加载train查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_queries.jsonl
加载train查询数据: 248786it [00:00, 285464.12it/s]
2025-11-09 16:47:24,159 - INFO - 成功加载train查询数据，共248786条
2025-11-09 16:47:24,170 - INFO - 加载valid查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_queries.jsonl
加载valid查询数据: 5008it [00:00, 329729.29it/s]
2025-11-09 16:47:24,190 - INFO - 成功加载valid查询数据，共5008条


## 初始化模型并执行顶层解冻
解冻文本最后2层与池化；解冻图像 `layer4`。


In [8]:
image_extractor = ImageFeatureExtractor(
    device=device, 
    weights_path="/mnt/d/HuggingFaceModels/models--timm--resnet50.a1_in1k/snapshots/767268603ca0cb0bfe326fa87277f19c419566ef/model.safetensors"
)

text_extractor = TextFeatureExtractor(device=device, cache_dir=cache_dir)
model = CrossModalRetrievalModel(
    text_extractor, image_extractor, 
    fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=device
)

# 可选：启用BERT梯度检查点以降低显存
if use_grad_checkpoint and hasattr(text_extractor.model, 'gradient_checkpointing_enable'):
    text_extractor.model.gradient_checkpointing_enable()

unfreeze_text_top_layers(text_extractor, last_n_layers=last_n_layers)
unfreeze_image_top_block(image_extractor, unfreeze_layer4=True)

# 使用LLRD优化器：为文本顶层设置逐层衰减的学习率
optim = build_llrd_optimizer(model, text_extractor, image_extractor,
                             lr_proj=1e-3, lr_text_max=5e-5, lr_img_top=1e-4, decay=0.9,
                             last_n_layers=last_n_layers, weight_decay=1e-4)
scaler = GradScaler(enabled=(device.type == 'cuda' and use_amp))
print('Optim groups:', len(optim.param_groups))


Optim groups: 5


  scaler = GradScaler(enabled=(device.type == 'cuda' and use_amp))


## 训练循环：按批次流式构建配对并微调顶层
仅使用配对中的第一张可用图片；文本与图像编码器顶层参与反向传播。


In [9]:

def build_batch_pairs(train_df, img_dict: Dict[str, Image.Image]) -> List[Tuple[str, Image.Image, str]]:
    pairs = []
    if 'item_ids' in train_df.columns:
        for _, row in train_df.iterrows():
            q = row.get('query_text', None)
            ids = row.get('item_ids', [])
            if not q or not isinstance(ids, list) or not ids:
                continue
            chosen_img = None
            chosen_id = None
            for iid in ids:
                sid = str(iid)
                if sid in img_dict and img_dict[sid] is not None:
                    chosen_img = img_dict[sid]
                    chosen_id = sid
                    break
            if chosen_img is not None:
                pairs.append((q, chosen_img, chosen_id))
    return pairs

def train_one_batch(pairs: List[Tuple[str, Image.Image, str]], epochs: int, step_bs: int):
    model.fusion.text_projector.train()
    model.fusion.image_projector.train()
    text_extractor.model.encoder.layer[-1].train()
    text_extractor.model.encoder.layer[-2].train()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_extractor.model.pooler.train()
    if hasattr(image_extractor.model, 'layer4'):
        image_extractor.model.layer4.train()

    # 为当前大batch构建 Warmup+Cosine 学习率调度器（按总steps）
    steps_per_epoch = math.ceil(len(pairs) / max(1, step_bs))
    total_steps = epochs * max(1, steps_per_epoch)
    scheduler = build_warmup_cosine_scheduler(optim, warmup_ratio=warmup_ratio, min_lr_ratio=min_lr_ratio, total_steps=total_steps)

    for e in range(epochs):
        running_loss = 0.0
        steps = 0
        for s in range(0, len(pairs), step_bs):
            batch = pairs[s:s+step_bs]
            if not batch:
                continue
            texts = [t for (t, _, _) in batch]
            imgs = [im for (_, im, _) in batch]

            optim.zero_grad()
            if use_amp and device.type == 'cuda':
                with autocast(enabled=True):
                    t_feats = text_extractor.encode_with_grad(texts)
                    i_feats = image_extractor.encode_with_grad(imgs)
                    t_proj = model._norm(model.fusion.fuse_text_features(t_feats))
                    i_proj = model._norm(model.fusion.fuse_image_features(i_feats))
                    loss = info_nce_loss(t_proj, i_proj, temp=temperature)
                scaler.scale(loss).backward()
                scaler.unscale_(optim)
                torch.nn.utils.clip_grad_norm_(
                    list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
                    max_norm=5.0
                )
                scaler.step(optim)
                scaler.update()
                scheduler.step()
            else:
                t_feats = text_extractor.encode_with_grad(texts)
                i_feats = image_extractor.encode_with_grad(imgs)
                t_proj = model._norm(model.fusion.fuse_text_features(t_feats))
                i_proj = model._norm(model.fusion.fuse_image_features(i_feats))
                loss = info_nce_loss(t_proj, i_proj, temp=temperature)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
                    max_norm=5.0
                )
                optim.step()
                scheduler.step()
            running_loss += loss.item()
            steps += 1
            if (steps % 100) == 0:
                print('Current LRs:', [pg['lr'] for pg in optim.param_groups])
        print(f"Epoch {e+1}/{epochs}: avg loss={running_loss/max(steps,1):.4f}")

# 流式加载图片与训练
batch_idx = 0
for image_batch in loader.load_images_batch(split='train', batch_size=train_image_batch_size, max_batches=max_train_batches):
    batch_idx += 1
    img_map = {item['img_id']: item['image'] for item in image_batch}
    pairs = build_batch_pairs(train_df, img_map)
    print(f"Batch {batch_idx}: images={len(img_map)}, usable_pairs={len(pairs)}")
    if not pairs:
        del img_map
        if device.type == 'cuda':
            torch.cuda.empty_cache()
        continue
    train_one_batch(pairs, epochs=epochs_per_batch, step_bs=train_step_batch_size)
    del img_map
    if device.type == 'cuda':
        torch.cuda.empty_cache()


2025-11-09 16:47:25,827 - INFO - 批量加载train图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_imgs.tsv
实际加载train图片数据:  11%|█████▊                                             | 14809/129380 [00:06<00:42, 2723.60it/s]

Batch 1: images=15000, usable_pairs=29127


  with autocast(enabled=True):


Current LRs: [0.00021978021978021978, 1.0989010989010989e-05, 9.89010989010989e-06, 1.0989010989010989e-05, 2.1978021978021977e-05]


实际加载train图片数据:  11%|█████▊                                             | 14809/129380 [00:17<00:42, 2723.60it/s]

Current LRs: [0.00043956043956043956, 2.1978021978021977e-05, 1.978021978021978e-05, 2.1978021978021977e-05, 4.3956043956043955e-05]
Current LRs: [0.0006593406593406593, 3.296703296703297e-05, 2.9670329670329673e-05, 3.296703296703297e-05, 6.593406593406594e-05]
Current LRs: [0.0008791208791208791, 4.3956043956043955e-05, 3.956043956043956e-05, 4.3956043956043955e-05, 8.791208791208791e-05]
Current LRs: [0.0009997325167765264, 4.9986625838826325e-05, 4.498796325494369e-05, 4.9986625838826325e-05, 9.997325167765265e-05]
Current LRs: [0.0009972253784704896, 4.986126892352448e-05, 4.4875142031172034e-05, 4.986126892352448e-05, 9.972253784704895e-05]
Current LRs: [0.000992093743812234, 4.96046871906117e-05, 4.464421847155053e-05, 4.96046871906117e-05, 9.92093743812234e-05]
Current LRs: [0.0009843677272744393, 4.921838636372196e-05, 4.429654772734977e-05, 4.921838636372196e-05, 9.843677272744393e-05]
Current LRs: [0.0009740926681942644, 4.870463340971322e-05, 4.38341700687419e-05, 4.8704633

实际加载train图片数据:  12%|█████▉                                             | 15000/129380 [03:08<6:41:39,  4.75it/s]

Epoch 5/5: avg loss=0.0571


实际加载train图片数据:  23%|███████████▊                                       | 29830/129380 [03:13<00:36, 2734.47it/s]

Batch 2: images=15000, usable_pairs=28909
Current LRs: [0.00022123893805309737, 1.1061946902654869e-05, 9.955752212389382e-06, 1.1061946902654869e-05, 2.2123893805309738e-05]
Current LRs: [0.00044247787610619474, 2.2123893805309738e-05, 1.9911504424778764e-05, 2.2123893805309738e-05, 4.4247787610619477e-05]


实际加载train图片数据:  23%|███████████▊                                       | 29830/129380 [03:27<00:36, 2734.47it/s]

Current LRs: [0.0006637168141592921, 3.3185840707964604e-05, 2.9867256637168145e-05, 3.3185840707964604e-05, 6.637168141592921e-05]
Current LRs: [0.0008849557522123895, 4.4247787610619477e-05, 3.982300884955753e-05, 4.4247787610619477e-05, 8.849557522123895e-05]
Current LRs: [0.000999690861483414, 4.9984543074170704e-05, 4.498608876675363e-05, 4.9984543074170704e-05, 9.996908614834141e-05]
Current LRs: [0.0009970638991512201, 4.9853194957561005e-05, 4.486787546180491e-05, 4.9853194957561005e-05, 9.970638991512201e-05]
Current LRs: [0.0009917719712445245, 4.958859856222623e-05, 4.462973870600361e-05, 4.958859856222623e-05, 9.917719712445247e-05]
Current LRs: [0.00098384662315476, 4.9192331157738004e-05, 4.4273098041964204e-05, 4.9192331157738004e-05, 9.838466231547601e-05]
Current LRs: [0.0009733350981951647, 4.866675490975823e-05, 4.380007941878241e-05, 4.866675490975823e-05, 9.733350981951646e-05]
Epoch 1/5: avg loss=1.2562
Current LRs: [0.0009597272139897383, 4.798636069948692e-05, 4

实际加载train图片数据:  35%|█████████████████▋                                 | 44964/129380 [06:08<00:23, 3587.52it/s]

Batch 3: images=15000, usable_pairs=28595
Current LRs: [0.00022371364653243848, 1.1185682326621925e-05, 1.0067114093959732e-05, 1.1185682326621925e-05, 2.237136465324385e-05]
Current LRs: [0.00044742729306487697, 2.237136465324385e-05, 2.0134228187919465e-05, 2.237136465324385e-05, 4.47427293064877e-05]
Current LRs: [0.0006711409395973155, 3.3557046979865775e-05, 3.02013422818792e-05, 3.3557046979865775e-05, 6.711409395973155e-05]


实际加载train图片数据:  35%|█████████████████▋                                 | 44964/129380 [06:27<00:23, 3587.52it/s]

Current LRs: [0.0008948545861297539, 4.47427293064877e-05, 4.026845637583893e-05, 4.47427293064877e-05, 8.94854586129754e-05]
Current LRs: [0.0009996146352894524, 4.998073176447262e-05, 4.498265858802535e-05, 4.998073176447262e-05, 9.996146352894524e-05]
Current LRs: [0.0009967918965714344, 4.983959482857172e-05, 4.485563534571455e-05, 4.983959482857172e-05, 9.967918965714344e-05]
Current LRs: [0.0009912459282684686, 4.9562296413423426e-05, 4.460606677208108e-05, 4.9562296413423426e-05, 9.912459282684685e-05]
Current LRs: [0.0009830105334666796, 4.9150526673333975e-05, 4.423547400600058e-05, 4.9150526673333975e-05, 9.830105334666795e-05]
Epoch 1/5: avg loss=1.1742
Current LRs: [0.0009595663234991816, 4.797831617495908e-05, 4.3180484557463175e-05, 4.797831617495908e-05, 9.595663234991816e-05]
Current LRs: [0.0009437746306750542, 4.718873153375271e-05, 4.246985838037744e-05, 4.718873153375271e-05, 9.437746306750542e-05]
Current LRs: [0.0009255828524145411, 4.627914262072706e-05, 4.165122

实际加载train图片数据:  35%|█████████████████▋                                 | 45000/129380 [08:56<4:27:02,  5.27it/s]

Epoch 5/5: avg loss=0.0398


实际加载train图片数据:  46%|███████████████████████▌                           | 59714/129380 [09:00<00:19, 3592.19it/s]

Batch 4: images=15000, usable_pairs=29155
Current LRs: [0.0002192982456140351, 1.0964912280701754e-05, 9.868421052631579e-06, 1.0964912280701754e-05, 2.1929824561403507e-05]
Current LRs: [0.0004385964912280702, 2.1929824561403507e-05, 1.9736842105263158e-05, 2.1929824561403507e-05, 4.3859649122807014e-05]
Current LRs: [0.0006578947368421054, 3.289473684210527e-05, 2.9605263157894742e-05, 3.289473684210527e-05, 6.578947368421054e-05]


实际加载train图片数据:  46%|███████████████████████▌                           | 59714/129380 [09:19<00:19, 3592.19it/s]

Current LRs: [0.0008771929824561404, 4.3859649122807014e-05, 3.9473684210526316e-05, 4.3859649122807014e-05, 8.771929824561403e-05]
Current LRs: [0.000999744769921361, 4.998723849606805e-05, 4.498851464646125e-05, 4.998723849606805e-05, 9.99744769921361e-05]
Current LRs: [0.0009972688047930772, 4.986344023965386e-05, 4.4877096215688477e-05, 4.986344023965386e-05, 9.972688047930772e-05]
Current LRs: [0.0009921731999528022, 4.9608659997640114e-05, 4.46477939978761e-05, 4.9608659997640114e-05, 9.921731999528023e-05]
Current LRs: [0.0009844878002022593, 4.9224390010112965e-05, 4.430195100910167e-05, 4.9224390010112965e-05, 9.844878002022593e-05]
Current LRs: [0.0009742576186927644, 4.871288093463822e-05, 4.38415928411744e-05, 4.871288093463822e-05, 9.742576186927644e-05]
Epoch 1/5: avg loss=1.1431
Current LRs: [0.0009598531721142239, 4.799265860571119e-05, 4.319339274514008e-05, 4.799265860571119e-05, 9.598531721142238e-05]
Current LRs: [0.0009444441673792107, 4.722220836896054e-05, 4.2499

实际加载train图片数据:  46%|███████████████████████▋                           | 60000/129380 [11:50<2:54:03,  6.64it/s]

Epoch 5/5: avg loss=0.0368


实际加载train图片数据:  58%|█████████████████████████████▌                     | 74937/129380 [11:54<00:15, 3577.86it/s]

Batch 5: images=15000, usable_pairs=29347
Current LRs: [0.0002178649237472767, 1.0893246187363835e-05, 9.803921568627453e-06, 1.0893246187363835e-05, 2.178649237472767e-05]
Current LRs: [0.0004357298474945534, 2.178649237472767e-05, 1.9607843137254906e-05, 2.178649237472767e-05, 4.357298474945534e-05]


实际加载train图片数据:  58%|█████████████████████████████▌                     | 74937/129380 [12:10<00:15, 3577.86it/s]

Current LRs: [0.0006535947712418301, 3.2679738562091506e-05, 2.9411764705882354e-05, 3.2679738562091506e-05, 6.535947712418301e-05]
Current LRs: [0.0008714596949891068, 4.357298474945534e-05, 3.921568627450981e-05, 4.357298474945534e-05, 8.714596949891068e-05]
Current LRs: [0.0009997812719901233, 4.998906359950617e-05, 4.499015723955555e-05, 4.998906359950617e-05, 9.997812719901233e-05]
Current LRs: [0.0009974153965040112, 4.9870769825200563e-05, 4.488369284268051e-05, 4.9870769825200563e-05, 9.974153965040113e-05]
Current LRs: [0.0009924631492012948, 4.962315746006474e-05, 4.4660841714058265e-05, 4.962315746006474e-05, 9.924631492012948e-05]
Current LRs: [0.0009849531575117686, 4.924765787558843e-05, 4.4322892088029587e-05, 4.924765787558843e-05, 9.849531575117686e-05]
Current LRs: [0.0009749288344046789, 4.874644172023395e-05, 4.387179754821055e-05, 4.874644172023395e-05, 9.74928834404679e-05]
Epoch 1/5: avg loss=1.1183
Current LRs: [0.0009599460792587228, 4.799730396293614e-05, 4.31

实际加载train图片数据:  58%|█████████████████████████████▌                     | 75000/129380 [14:47<2:51:33,  5.28it/s]

Epoch 5/5: avg loss=0.0351


实际加载train图片数据:  70%|███████████████████████████████████▍               | 89939/129380 [14:51<00:11, 3459.69it/s]

Batch 6: images=15000, usable_pairs=29191
Current LRs: [0.0002192982456140351, 1.0964912280701754e-05, 9.868421052631579e-06, 1.0964912280701754e-05, 2.1929824561403507e-05]


实际加载train图片数据:  70%|███████████████████████████████████▍               | 89939/129380 [15:01<00:11, 3459.69it/s]

Current LRs: [0.0004385964912280702, 2.1929824561403507e-05, 1.9736842105263158e-05, 2.1929824561403507e-05, 4.3859649122807014e-05]
Current LRs: [0.0006578947368421054, 3.289473684210527e-05, 2.9605263157894742e-05, 3.289473684210527e-05, 6.578947368421054e-05]
Current LRs: [0.0008771929824561404, 4.3859649122807014e-05, 3.9473684210526316e-05, 4.3859649122807014e-05, 8.771929824561403e-05]
Current LRs: [0.0009997453906337855, 4.998726953168927e-05, 4.4988542578520345e-05, 4.998726953168927e-05, 9.997453906337855e-05]
Current LRs: [0.0009972754408986632, 4.986377204493316e-05, 4.487739484043985e-05, 4.986377204493316e-05, 9.972754408986632e-05]
Current LRs: [0.00099219218106604, 4.9609609053302006e-05, 4.46486481479718e-05, 4.9609609053302006e-05, 9.921921810660401e-05]
Current LRs: [0.0009845253112558735, 4.922626556279368e-05, 4.430363900651431e-05, 4.922626556279368e-05, 9.845253112558736e-05]
Current LRs: [0.0009743196269263065, 4.871598134631532e-05, 4.3844383211683796e-05, 4.871

实际加载train图片数据:  70%|███████████████████████████████████▍               | 90000/129380 [17:42<2:06:29,  5.19it/s]

Epoch 5/5: avg loss=0.0342


实际加载train图片数据:  81%|████████████████████████████████████████▌         | 104916/129380 [17:47<00:06, 3502.48it/s]

Batch 7: images=15000, usable_pairs=28504
Current LRs: [0.00022471910112359551, 1.1235955056179776e-05, 1.0112359550561798e-05, 1.1235955056179776e-05, 2.2471910112359552e-05]
Current LRs: [0.00044943820224719103, 2.2471910112359552e-05, 2.0224719101123596e-05, 2.2471910112359552e-05, 4.4943820224719104e-05]


实际加载train图片数据:  81%|████████████████████████████████████████▌         | 104916/129380 [18:02<00:06, 3502.48it/s]

Current LRs: [0.0006741573033707865, 3.370786516853933e-05, 3.0337078651685396e-05, 3.370786516853933e-05, 6.741573033707866e-05]
Current LRs: [0.0008988764044943821, 4.4943820224719104e-05, 4.044943820224719e-05, 4.4943820224719104e-05, 8.988764044943821e-05]
Current LRs: [0.000999582312286147, 4.997911561430736e-05, 4.498120405287662e-05, 4.997911561430736e-05, 9.995823122861471e-05]
Current LRs: [0.0009966862238090929, 4.983431119045465e-05, 4.485088007140918e-05, 4.983431119045465e-05, 9.96686223809093e-05]
Current LRs: [0.0009910498775867961, 4.955249387933981e-05, 4.459724449140583e-05, 4.955249387933981e-05, 9.910498775867962e-05]
Current LRs: [0.0009827078505561919, 4.9135392527809596e-05, 4.422185327502864e-05, 4.9135392527809596e-05, 9.827078505561919e-05]
Epoch 1/5: avg loss=1.0651
Current LRs: [0.0009594541395226342, 4.797270697613171e-05, 4.317543627851854e-05, 4.797270697613171e-05, 9.594541395226342e-05]
Current LRs: [0.0009435883633158757, 4.7179418165793784e-05, 4.2461

实际加载train图片数据:  81%|████████████████████████████████████████▌         | 105000/129380 [20:34<1:14:52,  5.43it/s]

Epoch 5/5: avg loss=0.0321


实际加载train图片数据:  93%|██████████████████████████████████████████████▎   | 119925/129380 [20:38<00:02, 3641.24it/s]

Batch 8: images=15000, usable_pairs=28894
Current LRs: [0.00022172949002217298, 1.1086474501108649e-05, 9.977827050997784e-06, 1.1086474501108649e-05, 2.2172949002217298e-05]
Current LRs: [0.00044345898004434595, 2.2172949002217298e-05, 1.9955654101995567e-05, 2.2172949002217298e-05, 4.4345898004434597e-05]


实际加载train图片数据:  93%|██████████████████████████████████████████████▎   | 119925/129380 [20:52<00:02, 3641.24it/s]

Current LRs: [0.0006651884700665188, 3.325942350332594e-05, 2.993348115299335e-05, 3.325942350332594e-05, 6.651884700665188e-05]
Current LRs: [0.0008869179600886919, 4.4345898004434597e-05, 3.9911308203991135e-05, 4.4345898004434597e-05, 8.869179600886919e-05]
Current LRs: [0.0009996772136967632, 4.9983860684838164e-05, 4.4985474616354346e-05, 4.9983860684838164e-05, 9.996772136967633e-05]
Current LRs: [0.0009970182775954116, 4.985091387977058e-05, 4.486582249179352e-05, 4.985091387977058e-05, 9.970182775954115e-05]
Current LRs: [0.0009916894024438153, 4.958447012219076e-05, 4.462602310997169e-05, 4.958447012219076e-05, 9.916894024438152e-05]
Current LRs: [0.000983722416407788, 4.9186120820389405e-05, 4.4267508738350466e-05, 4.9186120820389405e-05, 9.837224164077881e-05]
Current LRs: [0.0009731649044926787, 4.8658245224633935e-05, 4.3792420702170546e-05, 4.8658245224633935e-05, 9.731649044926787e-05]
Epoch 1/5: avg loss=1.0352
Current LRs: [0.0009596491049974418, 4.798245524987209e-05,

实际加载train图片数据: 100%|████████████████████████████████████████████████████| 129380/129380 [23:29<00:00, 91.76it/s]


Batch 9: images=9380, usable_pairs=18432
Current LRs: [0.00034722222222222224, 1.736111111111111e-05, 1.5625e-05, 1.736111111111111e-05, 3.472222222222222e-05]
Current LRs: [0.0006944444444444445, 3.472222222222222e-05, 3.125e-05, 3.472222222222222e-05, 6.944444444444444e-05]
Current LRs: [0.0009999524043671985, 4.999762021835992e-05, 4.499785819652393e-05, 4.999762021835992e-05, 9.999524043671984e-05]
Current LRs: [0.0009958601815141803, 4.979300907570902e-05, 4.4813708168138116e-05, 4.979300907570902e-05, 9.958601815141804e-05]
Current LRs: [0.0009852261668953826, 4.9261308344769126e-05, 4.433517751029221e-05, 4.9261308344769126e-05, 9.852261668953825e-05]
Epoch 1/5: avg loss=1.0325
Current LRs: [0.0009511508380951625, 4.7557541904758126e-05, 4.280178771428231e-05, 4.7557541904758126e-05, 9.511508380951625e-05]
Current LRs: [0.0009235541937120755, 4.6177709685603776e-05, 4.15599387170434e-05, 4.6177709685603776e-05, 9.235541937120755e-05]
Current LRs: [0.0008904766522903719, 4.452383

## 保存：投影层 + 已解冻顶层 + 优化器
保存 BERT 的最后2层 + pooler、ResNet50 的 layer4、投影层、优化器。


In [10]:
def save_unfreeze_checkpoint(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                             optimizer: torch.optim.Optimizer, save_path: str, last_n_layers: int):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    ckpt = {
        'projection_dim': model.fusion.projection_dim,
        'last_n_layers': last_n_layers,
        'fusion': {
            'text_projector': model.fusion.text_projector.state_dict(),
            'image_projector': model.fusion.image_projector.state_dict(),
        },
        'text_unfrozen': {},
        'image_unfrozen': {},
        'optimizer': optimizer.state_dict(),
    }
    enc = text_extractor.model.encoder
    total_layers = len(enc.layer)
    start_idx = max(0, total_layers - last_n_layers)
    for i in range(start_idx, total_layers):
        ckpt['text_unfrozen'][f'encoder_layer_{i}'] = enc.layer[i].state_dict()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        ckpt['text_unfrozen']['pooler'] = text_extractor.model.pooler.state_dict()
    if hasattr(image_extractor.model, 'layer4'):
        ckpt['image_unfrozen']['layer4'] = image_extractor.model.layer4.state_dict()
    torch.save(ckpt, save_path)
    print(f"Checkpoint saved to: {save_path}")

# 保存一次
save_unfreeze_checkpoint(model, text_extractor, image_extractor, optim, save_path, last_n_layers)


Checkpoint saved to: /mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights/step_3_1_1.pth


## 加载：恢复解冻顶层与投影层权重，继续训练
加载后会自动再次执行顶层解冻，并恢复优化器状态（如提供）。


In [11]:
def load_unfreeze_checkpoint(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                             optimizer: torch.optim.Optimizer, load_path: str):
    ckpt = torch.load(load_path, map_location='cpu')
    model.fusion.text_projector.load_state_dict(ckpt['fusion']['text_projector'])
    model.fusion.image_projector.load_state_dict(ckpt['fusion']['image_projector'])
    ln = ckpt.get('last_n_layers', 2)
    unfreeze_text_top_layers(text_extractor, last_n_layers=ln)
    unfreeze_image_top_block(image_extractor, unfreeze_layer4=True)
    enc = text_extractor.model.encoder
    for k, v in ckpt['text_unfrozen'].items():
        if k.startswith('encoder_layer_'):
            idx = int(k.split('_')[-1])
            if 0 <= idx < len(enc.layer):
                enc.layer[idx].load_state_dict(v)
    if 'pooler' in ckpt['text_unfrozen'] and hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_extractor.model.pooler.load_state_dict(ckpt['text_unfrozen']['pooler'])
    if 'layer4' in ckpt['image_unfrozen'] and hasattr(image_extractor.model, 'layer4'):
        image_extractor.model.layer4.load_state_dict(ckpt['image_unfrozen']['layer4'])
    if optimizer is not None and 'optimizer' in ckpt:
        optimizer.load_state_dict(ckpt['optimizer'])
    print(f"Checkpoint loaded from: {load_path}")

# 测试加载
# load_unfreeze_checkpoint(model, text_extractor, image_extractor, optim, save_path)


## 验证评估：Recall@1/5/10 与 MeanRecall
- 基于验证集构建图像索引
- 对每条查询计算相似度并统计召回（优先使用 FAISS；不可用则 Torch 回退）


In [12]:
valid_imgs = loader.create_img_id_to_image_dict(
    split='valid', 
    max_samples=valid_imgs_max_samples,
    # batch_size=3000,
    # max_batches=10
)

valid_queries = []
if 'item_ids' in valid_df.columns:
    for _, row in valid_df.iterrows():
        q = row.get('query_text', None)
        ids = [str(i) for i in row.get('item_ids', [])] if isinstance(row.get('item_ids', []), list) else []
        if q and ids:
            valid_queries.append((q, ids))
print(f'Usable valid queries: {len(valid_queries)}')

image_index = model.build_image_index(valid_imgs, batch_size=32)
all_image_ids = list(image_index.keys())
all_image_feats = torch.stack([image_index[i] for i in all_image_ids]) if all_image_ids else torch.empty((0, 512))
faiss_index = None
if HAS_FAISS and all_image_feats.size(0) > 0:
    d = all_image_feats.size(1)
    faiss_index = faiss.IndexFlatIP(d)
    feats_np = all_image_feats.detach().cpu().numpy().astype('float32')
    faiss_index.add(feats_np)

all_image_feats = all_image_feats.to(device)

def compute_recall_at_k(k_values, queries):
    recalls = {k: 0 for k in k_values}
    total = 0
    for q_text, gt_ids in tqdm(queries, desc='Evaluate'):
        if all_image_feats.size(0) == 0:
            continue
        q_feat = model.extract_and_fuse_text_features([q_text])
        if faiss_index is not None:
            q_np = q_feat.detach().cpu().numpy().astype('float32')
            _, I = faiss_index.search(q_np, max(k_values))
            top_idx = I[0].tolist()
            top_ids = [all_image_ids[i] for i in top_idx]
        else:
            sims = model.sim.calculate_similarity(q_feat, all_image_feats)
            _, top_idx = torch.topk(sims[0], k=max(k_values))
            top_ids = [all_image_ids[i] for i in top_idx.tolist()]
        total += 1
        for k in k_values:
            if any(g in set(top_ids[:k]) for g in gt_ids):
                recalls[k] += 1
    return {k: (recalls[k] / total if total > 0 else 0.0) for k in k_values}, total

rec, total_q = compute_recall_at_k([1,5,10], valid_queries)
mean_recall = (rec.get(1,0)+rec.get(5,0)+rec.get(10,0))/3 if total_q>0 else 0.0
print(f'Recall@1={rec.get(1,0):.4f}, Recall@5={rec.get(5,0):.4f}, Recall@10={rec.get(10,0):.4f}, MeanRecall={mean_recall:.4f} (N={total_q})')


2025-11-09 17:12:57,142 - INFO - 批量加载valid图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_imgs.tsv
实际加载valid图片数据: 100%|████████████████████████████████████████████████████| 29806/29806 [00:14<00:00, 2007.07it/s]
2025-11-09 17:13:14,562 - INFO - 成功创建valid图片映射字典，共29806张图片


Usable valid queries: 5008


Evaluate: 100%|████████████████████████████████████████████████████████████████████| 5008/5008 [00:33<00:00, 149.39it/s]

Recall@1=0.0587, Recall@5=0.1979, Recall@10=0.3071, MeanRecall=0.1879 (N=5008)





In [13]:
with open("step_3_1-1_乾-基线.finishflag", "w") as f:
    f.write("finish")

import IPython
def kill_current_kernel():
    '''杀死当前的kernel释放内存空间。'''
    IPython.Application.instance().kernel.do_shutdown(True)
kill_current_kernel()