# 步骤 2.3-3：改用 Mean-Pooling 文本特征 + LR Schedule (屯-v3)

在 `step_2_3-3_屯-mean_pooling.ipynb` 的结构上做最小改动：
- 文本使用基于 attention_mask 的 mean-pooling；图像使用 ResNet50 特征
- 保留两层投影（本版为单层线性映射保持最小差异）与 InfoNCE 损失、AMP、流式训练
- 新增：学习率调度器（warmup + cosine），以进一步降低 loss 并提升稳定性


In [1]:
import os
import sys
import math
from typing import List, Dict, Tuple
import torch
import timm
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import torch.optim.lr_scheduler as lr_scheduler

# # 设置环境变量（与基线一致，按需修改为本地镜像/缓存）
# cache_dir = "/mnt/d/HuggingFaceModels/"
# os.environ['TORCH_HOME'] = cache_dir
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
# os.environ['https_proxy'] = 'http://127.0.0.1:7890'
# os.environ['http_proxy'] = 'http://127.0.0.1:7890'
# os.environ['all_proxy'] = 'socks5://127.0.0.1:7890'
# os.environ["WANDB_DISABLED"] = "true"
# os.environ['CURL_CA_BUNDLE'] = ""
# os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

import importlib

# 可选：Faiss 加速检索（安全导入，避免 NumPy 版本问题导致报错中断）
HAS_FAISS = False
faiss = None
try:
    spec = importlib.util.find_spec('faiss')
    if spec is not None:
        faiss = importlib.import_module('faiss')
        HAS_FAISS = True
except BaseException as e:
    print(f'Faiss unavailable: {e.__class__.__name__}')
    HAS_FAISS = False

# 设置环境变量（与基线一致，按需修改为本地镜像/缓存）
cache_dir = "/mnt/d/HuggingFaceModels/"
os.environ['TORCH_HOME'] = cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['all_proxy'] = 'socks5://127.0.0.1:7890'
os.environ["WANDB_DISABLED"] = "true"
os.environ['CURL_CA_BUNDLE'] = ""
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

# 导入数据加载器（使用 plan1_1/data_loader.py）
sys.path.append(os.path.abspath(os.path.join('.', 'Multimodal_Retrieval', 'plan1_1')))
from data_loader import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
save_dir = '/mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights'
save_path = os.path.join(save_dir, 'step_2_3_3_v3.pth')

## 模型与特征模块
文本 mean-pooling 与 ResNet50 图像特征，投影到共享空间。


In [3]:
class TextFeatureExtractor:
    def __init__(self, model_name='bert-base-chinese', device='cpu', cache_dir=None):
        self.device = device
        self.tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True)
        self.model = BertModel.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=True).to(device)
        self.model.eval()
        self.out_dim = 768
        
    def encode_with_grad(self, texts: List[str]) -> torch.Tensor:
        if not texts:
            return torch.empty((0, self.out_dim), dtype=torch.float32, device=self.device)
        if isinstance(texts, str):
            texts = [texts]
        inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128).to(self.device)
        outputs = self.model(**inputs)
        token_embeddings = outputs.last_hidden_state  # [B, L, 768]
        attention_mask = inputs.get('attention_mask', None)
        if attention_mask is None:
            return token_embeddings[:, 0, :]
        mask = attention_mask.unsqueeze(-1).type_as(token_embeddings)  # [B, L, 1]
        summed = (token_embeddings * mask).sum(dim=1)  # [B, 768]
        lengths = mask.sum(dim=1).clamp(min=1)        # [B, 1]
        mean_pooled = summed / lengths
        return mean_pooled

class ImageFeatureExtractor:
    def __init__(self, model_name='resnet50', device='cpu', cache_dir=None):
        self.device = device
        self.model = timm.create_model(
            model_name, pretrained=True, num_classes=0,
            cache_dir=cache_dir
        ).to(device)
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.out_dim = 2048
        
    def encode_with_grad(self, images: List[Image.Image]) -> torch.Tensor:
        if not images:
            return torch.empty((0, self.out_dim), dtype=torch.float32, device=self.device)
        tensors = torch.stack([self.transform(img.convert('RGB')) for img in images]).to(self.device)
        feats = self.model(tensors)
        return feats

class FeatureFusion:
    def __init__(self, fusion_method='projection', projection_dim=512, device=None):
        self.fusion_method = fusion_method
        self.projection_dim = projection_dim
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if fusion_method == 'projection':
            self.text_projector = torch.nn.Linear(768, projection_dim).to(self.device)
            self.image_projector = torch.nn.Linear(2048, projection_dim).to(self.device)
    def fuse_text_features(self, text_features: torch.Tensor) -> torch.Tensor:
        return self.text_projector(text_features) if self.fusion_method == 'projection' else text_features
    def fuse_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
        return self.image_projector(image_features) if self.fusion_method == 'projection' else image_features

class SimilarityCalculator:
    def __init__(self, similarity_type='cosine'):
        self.similarity_type = similarity_type
    def normalize_features(self, features: torch.Tensor) -> torch.Tensor:
        f = torch.nn.functional.normalize(features, p=2, dim=1)
        return f
    def calculate_similarity(self, text_features: torch.Tensor, image_features: torch.Tensor) -> torch.Tensor:
        if self.similarity_type == 'cosine':
            t_n = self.normalize_features(text_features)
            i_n = self.normalize_features(image_features)
            return torch.mm(t_n, i_n.t())
        return torch.mm(text_features, image_features.t())

class CrossModalRetrievalModel:
    def __init__(self, text_extractor, image_extractor, fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=None):
        self.text_extractor = text_extractor
        self.image_extractor = image_extractor
        self.fusion = FeatureFusion(fusion_method, projection_dim, device)
        self.sim = SimilarityCalculator(similarity_type)
        self.normalize_features = normalize_features
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.normalize(x, p=2, dim=1) if self.normalize_features else x
    def extract_and_fuse_text_features(self, texts: List[str]) -> torch.Tensor:
        with torch.no_grad():
            t = self.text_extractor.encode_with_grad(texts)
        return self._norm(self.fusion.fuse_text_features(t))
    def extract_and_fuse_image_features(self, images: List[Image.Image]) -> torch.Tensor:
        with torch.no_grad():
            i = self.image_extractor.encode_with_grad(images)
        return self._norm(self.fusion.fuse_image_features(i))
    def build_image_index(self, images_dict: Dict[str, Image.Image], batch_size: int = 32) -> Dict[str, torch.Tensor]:
        feats = {}
        keys = list(images_dict.keys())
        for s in range(0, len(keys), batch_size):
            batch_ids = keys[s:s+batch_size]
            batch_imgs = [images_dict[k] for k in batch_ids if images_dict[k] is not None]
            valid_ids = [k for k in batch_ids if images_dict[k] is not None]
            if not batch_imgs:
                continue
            bf = self.extract_and_fuse_image_features(batch_imgs)
            for j, img_id in enumerate(valid_ids):
                feats[img_id] = bf[j].detach().cpu()
        return feats

def info_nce_loss(text_feats: torch.Tensor, image_feats: torch.Tensor, temp: float) -> torch.Tensor:
    logits = torch.mm(text_feats, image_feats.t()) / temp
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_t = torch.nn.functional.cross_entropy(logits, labels)
    loss_i = torch.nn.functional.cross_entropy(logits.t(), labels)
    return (loss_t + loss_i) * 0.5


## 顶层解冻与优化器分组
仅解冻顶层，降低学习率，控制训练稳定性；新增调度器构造函数。


In [4]:
def unfreeze_text_top_layers(text_extractor: TextFeatureExtractor, last_n_layers: int = 2):
    for p in text_extractor.model.parameters():
        p.requires_grad = False
    enc = text_extractor.model.encoder
    total_layers = len(enc.layer)
    for i in range(total_layers - last_n_layers, total_layers):
        for p in enc.layer[i].parameters():
            p.requires_grad = True
        enc.layer[i].train()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        for p in text_extractor.model.pooler.parameters():
            p.requires_grad = True
        text_extractor.model.pooler.train()
    text_extractor.model.eval()

def unfreeze_image_top_block(image_extractor: ImageFeatureExtractor, unfreeze_layer4: bool = True):
    for p in image_extractor.model.parameters():
        p.requires_grad = False
    if unfreeze_layer4 and hasattr(image_extractor.model, 'layer4'):
        for p in image_extractor.model.layer4.parameters():
            p.requires_grad = True
        image_extractor.model.layer4.train()
    image_extractor.model.eval()

def build_optimizer(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                   lr_proj: float = 1e-3, lr_text_top: float = 5e-5, lr_img_top: float = 1e-4, weight_decay: float = 1e-4):
    params = []
    params.append({
        'params': list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
        'lr': lr_proj,
        'weight_decay': weight_decay
    })
    text_top_params = []
    enc = text_extractor.model.encoder
    for mod in enc.layer[-2:]:
        text_top_params += list(mod.parameters())
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_top_params += list(text_extractor.model.pooler.parameters())
    params.append({
        'params': [p for p in text_top_params if p.requires_grad],
        'lr': lr_text_top,
        'weight_decay': 0.0
    })
    img_top_params = []
    if hasattr(image_extractor.model, 'layer4'):
        img_top_params += list(image_extractor.model.layer4.parameters())
    params.append({
        'params': [p for p in img_top_params if p.requires_grad],
        'lr': lr_img_top,
        'weight_decay': 0.0
    })
    optimizer = torch.optim.Adam(params)
    return optimizer

def build_scheduler(optimizer: torch.optim.Optimizer, num_steps: int, warmup_ratio: float = 0.1, min_lr_ratio: float = 0.1):
    warmup_steps = max(1, int(num_steps * warmup_ratio))
    def lr_lambda(current_step: int):
        if current_step < warmup_steps:
            return float(current_step + 1) / float(max(1, warmup_steps))
        # Cosine decay from 1.0 -> min_lr_ratio
        progress = float(current_step - warmup_steps) / float(max(1, num_steps - warmup_steps))
        cosine = 0.5 * (1 + math.cos(math.pi * progress))
        return max(min_lr_ratio, cosine)
    return lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)


## 数据加载与训练参数
保持与基线一致的查询数据加载；图片按批次流式以控制显存。


In [5]:
loader = DataLoader()
train_df = loader.load_queries(split='train')
valid_df = loader.load_queries(split='valid')

# # 训练与流式参数（默认较小，确保可运行）
# train_image_batch_size = 500  # 一个大batch的图片数量
# max_train_batches = 1         # 加载多少个大batch
# epochs_per_batch = 1          # 每个大batch的训练轮数
# train_step_batch_size = 32    # 训练时的小 batch size

# 训练与流式参数（按需调整）：实际用
train_image_batch_size = 5000 ## 一个大batch有这么多图片样本。
max_train_batches = 30 ## 总共加载多少个大batch。
epochs_per_batch = 5 ## 每个大batch训练几个epoch。
train_step_batch_size = 32 ## 每个大batch里面训练的时候的小batch_size是多少。


use_amp = True
temperature = 0.07
use_lr_scheduler = True
warmup_ratio = 0.1
min_lr_ratio = 0.1


2025-11-09 14:09:38,657 - INFO - 初始化数据加载器，数据目录: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval
2025-11-09 14:09:38,660 - INFO - 加载train查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_queries.jsonl
加载train查询数据: 248786it [00:03, 79533.17it/s] 
2025-11-09 14:09:41,872 - INFO - 成功加载train查询数据，共248786条
2025-11-09 14:09:41,887 - INFO - 加载valid查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_queries.jsonl
加载valid查询数据: 5008it [00:00, 145910.12it/s]
2025-11-09 14:09:41,927 - INFO - 成功加载valid查询数据，共5008条


## 初始化模型并执行顶层解冻


In [6]:
image_extractor = ImageFeatureExtractor(device=device, cache_dir=cache_dir)
text_extractor = TextFeatureExtractor(device=device, cache_dir=cache_dir)
model = CrossModalRetrievalModel(
    text_extractor, image_extractor, 
    fusion_method='projection', projection_dim=512, similarity_type='cosine', normalize_features=True, device=device
)

unfreeze_text_top_layers(text_extractor, last_n_layers=2)
unfreeze_image_top_block(image_extractor, unfreeze_layer4=True)

optim = build_optimizer(model, text_extractor, image_extractor,
                        lr_proj=1e-3, lr_text_top=5e-5, lr_img_top=1e-4, weight_decay=1e-4)
scaler = GradScaler(enabled=(device.type == 'cuda' and use_amp))
print('Optim groups:', len(optim.param_groups))


2025-11-09 14:09:42,233 - INFO - Loading pretrained weights from Hugging Face hub (timm/resnet50.a1_in1k)
2025-11-09 14:09:42,627 - INFO - [timm/resnet50.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


Optim groups: 3


  scaler = GradScaler(enabled=(device.type == 'cuda' and use_amp))


## 训练循环（流式）
在每个大批次内构建 warmup+cosine 调度器，并在每步优化后进行 scheduler.step()。


In [None]:
def build_batch_pairs(train_df, img_dict: Dict[str, Image.Image]) -> List[Tuple[str, Image.Image, str]]:
    pairs = []
    if 'item_ids' in train_df.columns:
        for _, row in train_df.iterrows():
            q = row.get('query_text', None)
            ids = row.get('item_ids', [])
            if not q or not isinstance(ids, list) or not ids:
                continue
            chosen_img = None
            chosen_id = None
            for iid in ids:
                sid = str(iid)
                if sid in img_dict and img_dict[sid] is not None:
                    chosen_img = img_dict[sid]
                    chosen_id = sid
                    break
            if chosen_img is not None:
                pairs.append((q, chosen_img, chosen_id))
    return pairs

def print_group_lrs(optim: torch.optim.Optimizer):
    lrs = [pg['lr'] for pg in optim.param_groups]
    print(f'Current LRs: {lrs}')

def train_one_batch(pairs: List[Tuple[str, Image.Image, str]], epochs: int, step_bs: int):
    model.fusion.text_projector.train()
    model.fusion.image_projector.train()

    steps_per_epoch = max(1, (len(pairs) + step_bs - 1) // step_bs)
    total_steps = steps_per_epoch * epochs
    scheduler = None
    if use_lr_scheduler:
        scheduler = build_scheduler(optim, num_steps=total_steps, warmup_ratio=warmup_ratio, min_lr_ratio=min_lr_ratio)

    for e in range(epochs):
        running_loss = 0.0
        steps = 0
        print(f'Epoch {e+1}/{epochs} start')
        print_group_lrs(optim)
        for s in range(0, len(pairs), step_bs):
            batch = pairs[s:s+step_bs]
            if not batch:
                continue
            texts = [t for (t, _, _) in batch]
            imgs = [im for (_, im, _) in batch]

            optim.zero_grad()
            temp = temperature
            if use_amp and device.type == 'cuda':
                with autocast(enabled=True):
                    t_feats = text_extractor.encode_with_grad(texts)
                    i_feats = image_extractor.encode_with_grad(imgs)
                    t_proj = model._norm(model.fusion.fuse_text_features(t_feats))
                    i_proj = model._norm(model.fusion.fuse_image_features(i_feats))
                    loss = info_nce_loss(t_proj, i_proj, temp)
                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(
                    list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
                    max_norm=5.0
                )
                scaler.step(optim)
                scaler.update()
            else:
                t_feats = text_extractor.encode_with_grad(texts)
                i_feats = image_extractor.encode_with_grad(imgs)
                t_proj = model._norm(model.fusion.fuse_text_features(t_feats))
                i_proj = model._norm(model.fusion.fuse_image_features(i_feats))
                loss = info_nce_loss(t_proj, i_proj, temp)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    list(model.fusion.text_projector.parameters()) + list(model.fusion.image_projector.parameters()),
                    max_norm=5.0
                )
                optim.step()
            if scheduler is not None:
                scheduler.step()

            running_loss += loss.item()
            steps += 1
            if steps % 100 == 0:
                print_group_lrs(optim)
        print(f"Epoch {e+1}/{epochs}: avg loss={running_loss/max(steps,1):.4f}")

# 流式加载与训练
batch_idx = 0
for image_batch in loader.load_images_batch(split='train', batch_size=train_image_batch_size, max_batches=max_train_batches):
    batch_idx += 1
    img_map = {item['img_id']: item['image'] for item in image_batch}
    pairs = build_batch_pairs(train_df, img_map)
    print(f"Batch {batch_idx}: images={len(img_map)}, usable_pairs={len(pairs)}")
    if not pairs:
        del img_map
        if device.type == 'cuda':
            torch.cuda.empty_cache()
        continue
    train_one_batch(pairs, epochs=epochs_per_batch, step_bs=train_step_batch_size)
    del img_map
    if device.type == 'cuda':
        torch.cuda.empty_cache()


## 保存：投影层 + 已解冻顶层 + 优化器
保存 BERT 的最后2层 + pooler、ResNet50 的 layer4、投影层、优化器。


In [None]:
def save_unfreeze_checkpoint(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                             optimizer: torch.optim.Optimizer, save_path: str, last_n_layers: int):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    ckpt = {
        'projection_dim': model.fusion.projection_dim,
        'last_n_layers': last_n_layers,
        'fusion': {
            'text_projector': model.fusion.text_projector.state_dict(),
            'image_projector': model.fusion.image_projector.state_dict(),
        },
        'text_unfrozen': {},
        'image_unfrozen': {},
        'optimizer': optimizer.state_dict(),
    }
    enc = text_extractor.model.encoder
    total_layers = len(enc.layer)
    start_idx = max(0, total_layers - last_n_layers)
    for i in range(start_idx, total_layers):
        ckpt['text_unfrozen'][f'encoder_layer_{i}'] = enc.layer[i].state_dict()
    if hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        ckpt['text_unfrozen']['pooler'] = text_extractor.model.pooler.state_dict()
    if hasattr(image_extractor.model, 'layer4'):
        ckpt['image_unfrozen']['layer4'] = image_extractor.model.layer4.state_dict()
    torch.save(ckpt, save_path)
    print(f"Checkpoint saved to: {save_path}")

# 保存一次
save_unfreeze_checkpoint(model, text_extractor, image_extractor, optim, save_path, 2)


## 加载：恢复解冻顶层与投影层权重，继续训练
加载后会自动再次执行顶层解冻，并恢复优化器状态（如提供）。


In [8]:
def load_unfreeze_checkpoint(model: CrossModalRetrievalModel, text_extractor: TextFeatureExtractor, image_extractor: ImageFeatureExtractor,
                             optimizer: torch.optim.Optimizer, load_path: str):
    ckpt = torch.load(load_path, map_location='cpu')
    model.fusion.text_projector.load_state_dict(ckpt['fusion']['text_projector'])
    model.fusion.image_projector.load_state_dict(ckpt['fusion']['image_projector'])
    ln = ckpt.get('last_n_layers', 2)
    unfreeze_text_top_layers(text_extractor, last_n_layers=ln)
    unfreeze_image_top_block(image_extractor, unfreeze_layer4=True)
    enc = text_extractor.model.encoder
    for k, v in ckpt['text_unfrozen'].items():
        if k.startswith('encoder_layer_'):
            idx = int(k.split('_')[-1])
            if 0 <= idx < len(enc.layer):
                enc.layer[idx].load_state_dict(v)
    if 'pooler' in ckpt['text_unfrozen'] and hasattr(text_extractor.model, 'pooler') and text_extractor.model.pooler is not None:
        text_extractor.model.pooler.load_state_dict(ckpt['text_unfrozen']['pooler'])
    if 'layer4' in ckpt['image_unfrozen'] and hasattr(image_extractor.model, 'layer4'):
        image_extractor.model.layer4.load_state_dict(ckpt['image_unfrozen']['layer4'])
    if optimizer is not None and 'optimizer' in ckpt:
        optimizer.load_state_dict(ckpt['optimizer'])
    print(f"Checkpoint loaded from: {load_path}")

# 测试加载
# load_unfreeze_checkpoint(model, text_extractor, image_extractor, optim, save_path)


Checkpoint loaded from: /mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights/step_2_3_3_v3.pth


## 验证评估：Recall@1/5/10 与 MeanRecall
- 基于验证集构建图像索引
- 对每条查询计算相似度并统计召回（优先使用 FAISS；不可用则 Torch 回退）


In [9]:
valid_imgs = loader.create_img_id_to_image_dict(
    split='valid', 
    max_samples=50000,
    batch_size=3000,
    max_batches=10
)

valid_queries = []
if 'item_ids' in valid_df.columns:
    for _, row in valid_df.iterrows():
        q = row.get('query_text', None)
        ids = [str(i) for i in row.get('item_ids', [])] if isinstance(row.get('item_ids', []), list) else []
        if q and ids:
            valid_queries.append((q, ids))
print(f'Usable valid queries: {len(valid_queries)}')

2025-11-09 14:10:20,098 - INFO - 批量加载valid图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_imgs.tsv
实际加载valid图片数据: 100%|████████████████████████████████████████████████████| 29806/29806 [00:13<00:00, 2159.72it/s]
2025-11-09 14:10:50,144 - INFO - 成功创建valid图片映射字典，共29806张图片


Usable valid queries: 5008


In [10]:
image_index = model.build_image_index(valid_imgs, batch_size=32)
all_image_ids = list(image_index.keys())
all_image_feats = torch.stack([image_index[i] for i in all_image_ids]) if all_image_ids else torch.empty((0, 512))
faiss_index = None
if HAS_FAISS and all_image_feats.size(0) > 0:
    d = all_image_feats.size(1)
    faiss_index = faiss.IndexFlatIP(d)
    feats_np = all_image_feats.detach().cpu().numpy().astype('float32')
    faiss_index.add(feats_np)

all_image_feats = all_image_feats.to(device)

def compute_recall_at_k(k_values, queries):
    recalls = {k: 0 for k in k_values}
    total = 0
    for q_text, gt_ids in tqdm(queries, desc='Evaluate'):
        if all_image_feats.size(0) == 0:
            continue
        q_feat = model.extract_and_fuse_text_features([q_text])
        if faiss_index is not None:
            q_np = q_feat.detach().cpu().numpy().astype('float32')
            _, I = faiss_index.search(q_np, max(k_values))
            top_idx = I[0].tolist()
            top_ids = [all_image_ids[i] for i in top_idx]
        else:
            sims = model.sim.calculate_similarity(q_feat, all_image_feats)
            _, top_idx = torch.topk(sims[0], k=max(k_values))
            top_ids = [all_image_ids[i] for i in top_idx.tolist()]
        total += 1
        for k in k_values:
            if any(g in set(top_ids[:k]) for g in gt_ids):
                recalls[k] += 1
    return {k: (recalls[k] / total if total > 0 else 0.0) for k in k_values}, total

rec, total_q = compute_recall_at_k([1,5,10], valid_queries)
mean_recall = (rec.get(1,0)+rec.get(5,0)+rec.get(10,0))/3 if total_q>0 else 0.0
print(f'Recall@1={rec.get(1,0):.4f}, Recall@5={rec.get(5,0):.4f}, Recall@10={rec.get(10,0):.4f}, MeanRecall={mean_recall:.4f} (N={total_q})')


Evaluate: 100%|████████████████████████████████████████████████████████████████████| 5008/5008 [00:33<00:00, 147.99it/s]

Recall@1=0.0294, Recall@5=0.1020, Recall@10=0.1745, MeanRecall=0.1020 (N=5008)



