# CLIP 模型直接微调（对比学习）

目标：使用电商数据中的图片与文本对，直接在 CLIP 上进行对比学习微调，并保存微调后的权重供后续检索使用。
- 使用项目现有 `DataLoader` 逐批加载图片与查询文本，构造图文对。
- 采用 CLIP 的对比损失（图像-文本相互对齐），支持 AMP 降内存与加速。
- 提供资源友好的默认参数，适配 32GB RAM / 16GB VRAM。


In [1]:
import os
import math
import json
from typing import List

import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader as TorchDataLoader
from torchvision import transforms
from PIL import Image

from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

from transformers import CLIPProcessor, CLIPModel, BertTokenizer
from transformers import get_cosine_schedule_with_warmup

# 环境与缓存设置
cache_dir = "/mnt/d/HuggingFaceModels/"
os.environ['TORCH_HOME'] = cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['all_proxy'] = 'socks5://127.0.0.1:7890'
os.environ["WANDB_DISABLED"] = "true"
os.environ['CURL_CA_BUNDLE'] = ""
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

import sys
sys.path.append(os.path.abspath(os.path.join('.', 'Multimodal_Retrieval', 'plan1_1')))
from data_loader import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

save_dir = '/mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights'
os.makedirs(save_dir, exist_ok=True)
clip_save_path = os.path.join(save_dir, 'ft_clip_vit_base_patch32.pth')


  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
import time
while True:
    if os.path.exists("ft_rsn50_brt.finishflag"):
        break
    time.sleep(5)

In [3]:
class CLIPPairDataset(Dataset):
    def __init__(self, items, queries_texts, image_transform=None):
        self.items = items
        self.texts = queries_texts
        self.tf = image_transform
        self.n = min(len(self.items), len(self.texts))
    def __len__(self):
        return self.n
    def __getitem__(self, idx):
        it = self.items[idx]
        img = it.get('image', None)
        text = self.texts[idx]
        if img is None:
            img = Image.new('RGB', (224, 224), color=(128,128,128))
        if self.tf is not None:
            img = self.tf(img)
        return img, text

def build_clip_processor_model():
    processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32', cache_dir=cache_dir, local_files_only=True)
    model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32', cache_dir=cache_dir, local_files_only=True).to(device)
    return processor, model

def collate_fn(batch):
    # 此 collate 不再使用，训练循环里会定义基于 CLIPProcessor 的版本
    imgs = [b[0] for b in batch]
    texts = [b[1] for b in batch]
    return {'pixel_values': torch.stack(imgs, dim=0), 'texts': texts}


In [4]:
def train_clip(loader: DataLoader, device, epochs=1, img_batch_size=800, inner_bs=64, lr=5e-5, max_batches=2):
    # 使用 CLIPProcessor 进行文本编码，移除 BERT tokenizer
    processor, model = build_clip_processor_model()
    scaler = GradScaler(enabled=(device.type=='cuda'))
    optim = AdamW(model.parameters(), lr=lr)
    # 估算总训练步数用于LR调度（近似即可，有助稳定下降）
    total_batches_est = max(1, max_batches or 1)
    steps_per_batch_est = max(1, img_batch_size // inner_bs)
    total_steps_est = total_batches_est * steps_per_batch_est
    warmup_steps = max(1, int(0.2 * total_steps_est))
    scheduler = get_cosine_schedule_with_warmup(optim, num_warmup_steps=warmup_steps, num_training_steps=total_steps_est)
    image_tf = None

    model.train()
    # 仅加载一次训练查询，并在各批次循环中重复利用
    train_df = loader.load_queries(split='train')
    train_texts = [t for t in train_df.get('query_text', []) if isinstance(t, str)]
    if len(train_texts) == 0:
        return model
    text_offset = 0

    outer = loader.load_images_batch(split='train', batch_size=img_batch_size, max_batches=max_batches)
    for batch_idx, img_batch in enumerate(outer):
        # 基于 item_ids 构造当前批次的真实图文正样本对
        # 统一键为字符串，避免类型不一致导致无法匹配
        img_map = {str(it['img_id']): it['image'] for it in img_batch if it.get('image', None) is not None}
        paired_items = []
        paired_texts = []
        if 'item_ids' in train_df.columns:
            for _, row in train_df.iterrows():
                q = row.get('query_text', None)
                ids = row.get('item_ids', [])
                # 兼容字符串形式的列表，如 '[123, 456]' 或 '123,456'
                if isinstance(ids, str):
                    parsed = None
                    try:
                        parsed = json.loads(ids)
                    except Exception:
                        try:
                            parsed = [s.strip() for s in ids.split(',') if s.strip()]
                        except Exception:
                            parsed = None
                    ids = parsed if isinstance(parsed, list) else []
                if not q or not isinstance(ids, list) or not ids:
                    continue
                chosen_img = None
                for iid in ids:
                    sid = str(iid)
                    if sid in img_map:
                        chosen_img = img_map[sid]
                        break
                if chosen_img is not None:
                    paired_items.append({'image': chosen_img})
                    paired_texts.append(q)
        if len(paired_items) == 0:
            continue
        ds = CLIPPairDataset(paired_items, paired_texts, image_tf)
        def collate_clip(batch):
            imgs = [b[0] for b in batch]
            texts = [b[1] for b in batch]
            inputs = processor(text=texts, images=imgs, return_tensors='pt', padding=True, truncation=True)
            return {
                'pixel_values': inputs['pixel_values'],
                'input_ids': inputs['input_ids'],
                'attention_mask': inputs.get('attention_mask', None)
            }
        dl = TorchDataLoader(ds, batch_size=inner_bs, shuffle=True, num_workers=2, pin_memory=(device.type=='cuda'), collate_fn=collate_clip)
        running_loss = 0.0
        steps = 0
        # pbar = tqdm(dl, desc=f'Train batch {batch_idx+1}', leave=False)
        for i, batch in enumerate(dl, 1):
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
            with autocast(enabled=(device.type=='cuda')):
                outputs = model(**batch)
                # 直接使用 logits 计算对比损失
                logits_per_image = outputs.logits_per_image
                logits_per_text = outputs.logits_per_text
                # 跳过过小微批（对比学习至少需要 >1 的批量用于负样本）
                bs = logits_per_image.size(0)
                if bs < 2:
                    continue
                targets = torch.arange(bs, device=device)
                # 标签平滑可缓解偶发的损失尖峰
                ce = nn.CrossEntropyLoss(label_smoothing=0.1)
                loss_i = ce(logits_per_image, targets)
                loss_t = ce(logits_per_text, targets)
                loss = (loss_i + loss_t) / 2.0
            optim.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optim)
            scaler.update()
            scheduler.step()
            # 累计与显示损失
            running_loss += float(loss.detach().item())
            steps += 1
            avg = running_loss / max(1, steps)
            # pbar.set_postfix({'loss': f'{loss.item():.4f}', 'avg': f'{avg:.4f}'})
        current_lr = optim.param_groups[0]['lr']
        print(f'Batch {batch_idx+1}: avg loss={running_loss/max(1, steps):.4f} | lr={current_lr:.6f}')
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    return model


In [5]:
loader = DataLoader()
clip_model = train_clip(loader, device=device, img_batch_size=2000, max_batches=100, inner_bs=16, epochs=3, lr=1e-4)
torch.save(clip_model.state_dict(), clip_save_path)
print(f'CLIP fine-tuned and saved to: {clip_save_path}')


2025-11-09 21:53:09,376 - INFO - 初始化数据加载器，数据目录: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
  scaler = GradScaler(enabled=(device.type=='cuda'))
2025-11-09 21:53:14,020 - INFO - 加载train查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_queries.jsonl
加载train查询数据: 248786it [00:00, 263029.12it/s]
2025-11-09 21:53:15,027 - INFO - 成功加载train查询数据，共248786条
2025-11-09 21:53:15,047 - INFO - 批量加载train图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_imgs.tsv
  with autocast(enabled=(device.type=='cuda')):
实际加载train图片数据:   2%|▊                                                   | 2000/12938

Batch 1: avg loss=2.4462 | lr=0.000010


实际加载train图片数据:   3%|█▌                                                  | 4000/129380 [00:44<1:07:32, 30.94it/s]

Batch 2: avg loss=1.9444 | lr=0.000020


实际加载train图片数据:   5%|██▌                                                   | 6000/129380 [01:06<57:54, 35.51it/s]

Batch 3: avg loss=1.9106 | lr=0.000029


实际加载train图片数据:   6%|███▎                                                  | 8000/129380 [01:29<57:22, 35.26it/s]

Batch 4: avg loss=1.9356 | lr=0.000039


实际加载train图片数据:   8%|████                                                 | 10000/129380 [01:49<52:36, 37.82it/s]

Batch 5: avg loss=2.0269 | lr=0.000049


实际加载train图片数据:   9%|████▉                                                | 12000/129380 [02:10<54:01, 36.21it/s]

Batch 6: avg loss=2.1035 | lr=0.000059


实际加载train图片数据:  11%|█████▋                                               | 14000/129380 [02:30<50:30, 38.07it/s]

Batch 7: avg loss=2.3169 | lr=0.000069


实际加载train图片数据:  12%|██████▌                                              | 16000/129380 [02:50<46:48, 40.38it/s]

Batch 8: avg loss=2.3365 | lr=0.000078


实际加载train图片数据:  14%|███████▎                                             | 18000/129380 [03:12<48:54, 37.96it/s]

Batch 9: avg loss=2.3229 | lr=0.000088


实际加载train图片数据:  15%|████████▏                                            | 20000/129380 [03:34<40:43, 44.76it/s]

Batch 10: avg loss=2.1658 | lr=0.000097


实际加载train图片数据:  17%|█████████                                            | 22000/129380 [03:54<47:22, 37.78it/s]

Batch 11: avg loss=2.1579 | lr=0.000100


实际加载train图片数据:  19%|█████████▊                                           | 24000/129380 [04:14<44:30, 39.46it/s]

Batch 12: avg loss=2.1948 | lr=0.000100


实际加载train图片数据:  20%|██████████▋                                          | 26000/129380 [04:36<48:26, 35.57it/s]

Batch 13: avg loss=2.2168 | lr=0.000099


实际加载train图片数据:  22%|███████████▍                                         | 28000/129380 [04:55<41:59, 40.23it/s]

Batch 14: avg loss=2.2234 | lr=0.000098


实际加载train图片数据:  23%|████████████▎                                        | 30000/129380 [05:19<50:18, 32.92it/s]

Batch 15: avg loss=2.2096 | lr=0.000097


实际加载train图片数据:  25%|█████████████                                        | 32000/129380 [05:42<47:33, 34.12it/s]

Batch 16: avg loss=2.1719 | lr=0.000095


实际加载train图片数据:  26%|█████████████▉                                       | 34000/129380 [06:04<47:03, 33.79it/s]

Batch 17: avg loss=2.1394 | lr=0.000094


实际加载train图片数据:  28%|██████████████▋                                      | 36000/129380 [06:27<46:12, 33.68it/s]

Batch 18: avg loss=2.0526 | lr=0.000092


实际加载train图片数据:  29%|███████████████▌                                     | 38000/129380 [06:47<40:00, 38.06it/s]

Batch 19: avg loss=2.0629 | lr=0.000090


实际加载train图片数据:  31%|████████████████▍                                    | 40000/129380 [07:08<40:12, 37.04it/s]

Batch 20: avg loss=2.0088 | lr=0.000087


实际加载train图片数据:  32%|█████████████████▏                                   | 42000/129380 [07:30<40:21, 36.08it/s]

Batch 21: avg loss=1.9842 | lr=0.000085


实际加载train图片数据:  34%|██████████████████                                   | 44000/129380 [07:50<38:39, 36.81it/s]

Batch 22: avg loss=1.9719 | lr=0.000082


实际加载train图片数据:  36%|██████████████████▊                                  | 46000/129380 [08:11<31:23, 44.26it/s]

Batch 23: avg loss=1.9370 | lr=0.000079


实际加载train图片数据:  37%|███████████████████▋                                 | 48000/129380 [08:33<38:20, 35.37it/s]

Batch 24: avg loss=1.9204 | lr=0.000075


实际加载train图片数据:  39%|████████████████████▍                                | 50000/129380 [08:53<33:32, 39.45it/s]

Batch 25: avg loss=1.9596 | lr=0.000072


实际加载train图片数据:  40%|█████████████████████▎                               | 52000/129380 [09:14<36:09, 35.66it/s]

Batch 26: avg loss=1.8317 | lr=0.000068


实际加载train图片数据:  42%|██████████████████████                               | 54000/129380 [09:35<33:32, 37.46it/s]

Batch 27: avg loss=1.8901 | lr=0.000065


实际加载train图片数据:  43%|██████████████████████▉                              | 56000/129380 [09:56<32:48, 37.28it/s]

Batch 28: avg loss=1.9172 | lr=0.000061


实际加载train图片数据:  45%|███████████████████████▊                             | 58000/129380 [10:16<30:40, 38.78it/s]

Batch 29: avg loss=1.8597 | lr=0.000058


实际加载train图片数据:  46%|████████████████████████▌                            | 60000/129380 [10:35<28:59, 39.87it/s]

Batch 30: avg loss=1.8542 | lr=0.000054


实际加载train图片数据:  48%|█████████████████████████▍                           | 62000/129380 [10:56<31:03, 36.16it/s]

Batch 31: avg loss=1.8326 | lr=0.000050


实际加载train图片数据:  49%|██████████████████████████▏                          | 64000/129380 [11:20<34:02, 32.02it/s]

Batch 32: avg loss=1.8425 | lr=0.000046


实际加载train图片数据:  51%|███████████████████████████                          | 66000/129380 [11:42<30:41, 34.42it/s]

Batch 33: avg loss=1.7960 | lr=0.000042


实际加载train图片数据:  53%|███████████████████████████▊                         | 68000/129380 [12:05<29:48, 34.31it/s]

Batch 34: avg loss=1.7470 | lr=0.000039


实际加载train图片数据:  54%|████████████████████████████▋                        | 70000/129380 [12:29<31:24, 31.50it/s]

Batch 35: avg loss=1.7178 | lr=0.000035


实际加载train图片数据:  56%|█████████████████████████████▍                       | 72000/129380 [12:50<23:46, 40.23it/s]

Batch 36: avg loss=1.7270 | lr=0.000031


实际加载train图片数据:  57%|██████████████████████████████▎                      | 74000/129380 [13:11<26:27, 34.88it/s]

Batch 37: avg loss=1.7048 | lr=0.000028


实际加载train图片数据:  59%|███████████████████████████████▏                     | 76000/129380 [13:34<25:56, 34.30it/s]

Batch 38: avg loss=1.6888 | lr=0.000024


实际加载train图片数据:  60%|███████████████████████████████▉                     | 78000/129380 [13:53<22:18, 38.38it/s]

Batch 39: avg loss=1.7002 | lr=0.000021


实际加载train图片数据:  62%|████████████████████████████████▊                    | 80000/129380 [14:14<21:31, 38.23it/s]

Batch 40: avg loss=1.6862 | lr=0.000018


实际加载train图片数据:  63%|█████████████████████████████████▌                   | 82000/129380 [14:34<16:54, 46.69it/s]

Batch 41: avg loss=1.6354 | lr=0.000015


实际加载train图片数据:  65%|██████████████████████████████████▍                  | 84000/129380 [14:56<20:31, 36.85it/s]

Batch 42: avg loss=1.6409 | lr=0.000013


实际加载train图片数据:  66%|███████████████████████████████████▏                 | 86000/129380 [15:18<21:50, 33.11it/s]

Batch 43: avg loss=1.6673 | lr=0.000010


实际加载train图片数据:  68%|████████████████████████████████████                 | 88000/129380 [15:38<20:09, 34.20it/s]

Batch 44: avg loss=1.6074 | lr=0.000008


实际加载train图片数据:  69%|████████████████████████████████████                | 89879/129380 [15:39<01:53, 348.46it/s]

Batch 45: avg loss=1.7069 | lr=0.000006


实际加载train图片数据:  71%|█████████████████████████████████████▋               | 92000/129380 [16:19<14:00, 44.45it/s]

Batch 46: avg loss=1.7099 | lr=0.000004


实际加载train图片数据:  73%|██████████████████████████████████████▌              | 94149/129380 [16:38<11:58, 49.03it/s]

Batch 47: avg loss=1.7001 | lr=0.000003


实际加载train图片数据:  74%|███████████████████████████████████████▎             | 96000/129380 [16:59<16:38, 33.42it/s]

Batch 48: avg loss=1.6542 | lr=0.000002


实际加载train图片数据:  76%|████████████████████████████████████████▏            | 98000/129380 [17:24<17:26, 29.98it/s]

Batch 49: avg loss=1.6896 | lr=0.000001


实际加载train图片数据:  77%|████████████████████████████████████████▏           | 100000/129380 [17:45<13:58, 35.03it/s]

Batch 50: avg loss=1.6758 | lr=0.000000


实际加载train图片数据:  79%|████████████████████████████████████████▉           | 102000/129380 [18:08<13:39, 33.39it/s]

Batch 51: avg loss=1.7446 | lr=0.000000


实际加载train图片数据:  80%|█████████████████████████████████████████▊          | 104000/129380 [18:32<10:15, 41.20it/s]

Batch 52: avg loss=1.7369 | lr=0.000000


实际加载train图片数据:  82%|██████████████████████████████████████████▌         | 106000/129380 [18:52<10:13, 38.09it/s]

Batch 53: avg loss=1.7845 | lr=0.000000


实际加载train图片数据:  83%|███████████████████████████████████████████▍        | 108000/129380 [19:14<10:26, 34.12it/s]

Batch 54: avg loss=1.7020 | lr=0.000001


实际加载train图片数据:  85%|████████████████████████████████████████████▏       | 110000/129380 [19:36<09:34, 33.74it/s]

Batch 55: avg loss=1.7022 | lr=0.000002


实际加载train图片数据:  87%|█████████████████████████████████████████████       | 112000/129380 [19:56<07:57, 36.38it/s]

Batch 56: avg loss=1.6692 | lr=0.000003


实际加载train图片数据:  88%|█████████████████████████████████████████████▊      | 114000/129380 [20:18<07:10, 35.77it/s]

Batch 57: avg loss=1.6478 | lr=0.000004


实际加载train图片数据:  90%|██████████████████████████████████████████████▌     | 116000/129380 [20:38<06:01, 36.98it/s]

Batch 58: avg loss=1.6099 | lr=0.000006


实际加载train图片数据:  91%|███████████████████████████████████████████████▍    | 118000/129380 [20:58<04:46, 39.73it/s]

Batch 59: avg loss=1.6193 | lr=0.000008


实际加载train图片数据:  93%|████████████████████████████████████████████████▏   | 120000/129380 [21:20<04:27, 35.08it/s]

Batch 60: avg loss=1.6082 | lr=0.000010


实际加载train图片数据:  94%|█████████████████████████████████████████████████   | 122000/129380 [21:41<03:28, 35.42it/s]

Batch 61: avg loss=1.5765 | lr=0.000012


实际加载train图片数据:  96%|█████████████████████████████████████████████████▊  | 124000/129380 [22:01<02:24, 37.35it/s]

Batch 62: avg loss=1.6109 | lr=0.000015


实际加载train图片数据:  97%|██████████████████████████████████████████████████▋ | 126000/129380 [22:23<01:32, 36.69it/s]

Batch 63: avg loss=1.5576 | lr=0.000018


实际加载train图片数据:  99%|███████████████████████████████████████████████████▍| 128000/129380 [22:43<00:35, 38.52it/s]

Batch 64: avg loss=1.5802 | lr=0.000021


实际加载train图片数据: 100%|████████████████████████████████████████████████████| 129380/129380 [22:43<00:00, 94.88it/s]


Batch 65: avg loss=1.5692 | lr=0.000023
CLIP fine-tuned and saved to: /mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights/ft_clip_vit_base_patch32.pth


## 验证评估：Recall@1/5/10 与 MeanRecall
- 基于验证集构建图像索引
- 对每条查询计算相似度并统计召回（优先使用 FAISS；不可用则 Torch 回退）


In [6]:
import importlib, numpy as np
HAS_FAISS = False
faiss = None
try:
    spec = importlib.util.find_spec('faiss')
    if spec is not None:
        faiss = importlib.import_module('faiss')
        HAS_FAISS = True
except BaseException as e:
    print(f'Faiss unavailable: {e.__class__.__name__}')
    HAS_FAISS = False

# 加载验证集查询
valid_df = loader.load_queries(split='valid')
valid_queries = []
if 'item_ids' in valid_df.columns:
    for _, row in valid_df.iterrows():
        q = row.get('query_text', None)
        ids = [str(i) for i in row.get('item_ids', [])] if isinstance(row.get('item_ids', []), list) else []
        if q and ids:
            valid_queries.append((q, ids))
print(f'Usable valid queries: {len(valid_queries)}')

# 构建验证集图像字典（可选限制最大样本数）
valid_imgs_max_samples = None  # 可改为例如 30000 以降低内存
valid_imgs = loader.create_img_id_to_image_dict(split='valid', max_samples=valid_imgs_max_samples)

# 使用 CLIP 计算图像特征并构建索引
processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32', cache_dir=cache_dir, local_files_only=True)
clip_model.eval()
def l2_normalize(x):
    return torch.nn.functional.normalize(x, p=2, dim=-1)

image_index = {}
all_ids = list(valid_imgs.keys())
batch_size = 64
for s in tqdm(range(0, len(all_ids), batch_size), desc='Build image index'):
    batch_ids = all_ids[s:s+batch_size]
    batch_imgs = [valid_imgs[i] for i in batch_ids if valid_imgs[i] is not None]
    valid_batch_ids = [i for i in batch_ids if valid_imgs[i] is not None]
    if not batch_imgs:
        continue
    inputs = processor(images=batch_imgs, return_tensors='pt')
    pixel_values = inputs['pixel_values'].to(device)
    with torch.no_grad():
        feats = clip_model.get_image_features(pixel_values)
        feats = l2_normalize(feats).detach().cpu()
    for j, img_id in enumerate(valid_batch_ids):
        image_index[img_id] = feats[j]

all_image_ids = list(image_index.keys())
all_image_feats = torch.stack([image_index[i] for i in all_image_ids]) if all_image_ids else torch.empty((0, clip_model.visual_projection.out_features if hasattr(clip_model, 'visual_projection') else 512))
faiss_index = None
if HAS_FAISS and all_image_feats.size(0) > 0:
    d = all_image_feats.size(1)
    faiss_index = faiss.IndexFlatIP(d)
    faiss_index.add(all_image_feats.numpy().astype('float32'))

all_image_feats = all_image_feats.to(device)

def compute_recall_at_k(k_values, queries):
    recalls = {k: 0 for k in k_values}
    total = 0
    for q_text, gt_ids in tqdm(queries, desc='Evaluate'):
        if all_image_feats.size(0) == 0:
            continue
        inputs = processor(text=[q_text], return_tensors='pt', padding=True)
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs.get('attention_mask', None)
        if attention_mask is not None:
            attention_mask = attention_mask.to(device)
        with torch.no_grad():
            q_feat = clip_model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
            q_feat = l2_normalize(q_feat)
        if faiss_index is not None:
            q_np = q_feat.detach().cpu().numpy().astype('float32')
            _, I = faiss_index.search(q_np, max(k_values))
            top_idx = I[0].tolist()
            top_ids = [all_image_ids[i] for i in top_idx]
        else:
            sims = (q_feat @ all_image_feats.t()).squeeze(0)
            _, top_idx = torch.topk(sims, k=max(k_values))
            top_ids = [all_image_ids[i] for i in top_idx.tolist()]
        total += 1
        for k in k_values:
            if any(g in set(top_ids[:k]) for g in gt_ids):
                recalls[k] += 1
    return {k: (recalls[k] / max(1, total)) for k in k_values}, total

k_values = [1, 5, 10]
recall_scores, N = compute_recall_at_k(k_values, valid_queries)
mean_recall = sum(recall_scores.values()) / len(k_values)
print(f"Recall@1={recall_scores[1]:.4f}, Recall@5={recall_scores[5]:.4f}, Recall@10={recall_scores[10]:.4f}, MeanRecall={mean_recall:.4f} (N={N})")


2025-11-09 22:16:26,338 - INFO - Loading faiss with AVX2 support.
2025-11-09 22:16:26,408 - INFO - Successfully loaded faiss with AVX2 support.
2025-11-09 22:16:26,413 - INFO - 加载valid查询数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_queries.jsonl
加载valid查询数据: 5008it [00:00, 349031.66it/s]
2025-11-09 22:16:26,433 - INFO - 成功加载valid查询数据，共5008条


Usable valid queries: 5008


2025-11-09 22:16:26,511 - INFO - 批量加载valid图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_valid_imgs.tsv
实际加载valid图片数据: 100%|████████████████████████████████████████████████████| 29806/29806 [00:14<00:00, 2053.30it/s]
2025-11-09 22:16:43,723 - INFO - 成功创建valid图片映射字典，共29806张图片
Build image index: 100%|██████████████████████████████████████████████████████████████| 466/466 [00:42<00:00, 11.08it/s]
Evaluate: 100%|████████████████████████████████████████████████████████████████████| 5008/5008 [00:43<00:00, 115.74it/s]

Recall@1=0.0132, Recall@5=0.0537, Recall@10=0.0931, MeanRecall=0.0533 (N=5008)





In [7]:
with open("ft_clip.finishflag", "w") as f:
    f.write("finish")

import IPython
def kill_current_kernel():
    '''杀死当前的kernel释放内存空间。'''
    IPython.Application.instance().kernel.do_shutdown(True)
kill_current_kernel()