# ResNet50 与 BERT 预训练模型的域内微调与保存

目的：先用电商数据中的所有图片和文本，分别对预训练好的 ResNet50 与 BERT 进行域内微调（自监督/MLM），然后将微调后的权重保存，供后续多模态检索任务使用。

- 图像侧：使用自监督旋转预测任务（0/90/180/270）对 ResNet50 的顶层进行轻量微调。
- 文本侧：使用 Masked Language Modeling (MLM) 在电商查询文本上微调 BERT。
- 结果：保存两个模型的微调后权重，后续检索任务可直接加载。


In [1]:
import os
import math
import json
from typing import List

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader as TorchDataLoader
from torch.optim import Adam, AdamW

import timm
from torchvision import transforms
from torchvision.transforms import functional as TF
from PIL import Image
from tqdm import tqdm
from torch.amp import autocast, GradScaler

from transformers import BertTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_cosine_schedule_with_warmup

# 环境与缓存设置（与项目保持一致，可按需修改为本地镜像/缓存）
cache_dir = "/mnt/d/HuggingFaceModels/"
os.environ['TORCH_HOME'] = cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['all_proxy'] = 'socks5://127.0.0.1:7890'
os.environ["WANDB_DISABLED"] = "true"
os.environ['CURL_CA_BUNDLE'] = ""
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"

# 导入项目数据加载器
import sys
sys.path.append(os.path.abspath(os.path.join('.', 'Multimodal_Retrieval', 'plan1_1')))
from data_loader import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# 保存路径
save_dir = '/mnt/d/forCoding_data/Tianchi_MUGE/trained_models/weights'
os.makedirs(save_dir, exist_ok=True)
resnet_save_path = os.path.join(save_dir, 'ft_resnet50_rotation_backbone.pth')
bert_save_path = os.path.join(save_dir, 'ft_bert_base_chinese_mlm.pth')


  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
import time
while True:
    if os.path.exists("step_2_3-3_屯-mean_pooling-v2_cp1.finishflag"):
        break
    time.sleep(5)

In [3]:
# 初始化数据加载器
loader = DataLoader()
## batch_size=1000, image_size=(224, 224)

2025-11-09 21:20:37,327 - INFO - 初始化数据加载器，数据目录: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval


## 图像侧：ResNet50 旋转预测自监督微调
- 使用 timm 的 ResNet50 特征提取模式（`num_classes=0`），只微调 `layer4` 顶层。
- 在每张图像上构造 4 个旋转版本，对应标签 0/1/2/3。
- 在分类头上计算交叉熵，优化 `layer4` 与分类头。


In [4]:
from safetensors.torch import load_file

class RotationHead(nn.Module):
    def __init__(self, in_dim=2048, num_classes=4):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)
    def forward(self, x):
        return self.fc(x)

def build_resnet50_backbone(device, cache_dir=None, weights_path=None):
    model = timm.create_model('resnet50', pretrained=False, num_classes=0, cache_dir=cache_dir).to(device)
    if weights_path is not None:
        if weights_path.endswith('.safetensors'):
            state_dict = load_file(weights_path)
        else:
            state_dict = torch.load(weights_path, map_location='cpu')
        model.load_state_dict(state_dict, strict=False)

    
    # 仅微调 layer4 顶层参数
    for p in model.parameters():
        p.requires_grad = False
    if hasattr(model, 'layer4'):
        for p in model.layer4.parameters():
            p.requires_grad = True
    return model

class RotationBatchDataset(Dataset):
    def __init__(self, items, transform, angles=(0,90,180,270)):
        self.transform = transform
        self.angles = angles
        self.samples = []
        for it in items:
            img = it.get('image', None)
            if img is None:
                continue
            for li, ang in enumerate(self.angles):
                self.samples.append((img, li, ang))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img, li, ang = self.samples[idx]
        rotated = TF.rotate(img, angle=ang)
        return self.transform(rotated), li

def train_resnet_rotation(loader: DataLoader, device, epochs=1, img_batch_size=800, inner_bs=24, lr=1e-4, max_batches=2, use_amp=True, warmup_ratio=0.05):
    model = build_resnet50_backbone(device=device, cache_dir=cache_dir, 
                                   weights_path="/mnt/d/HuggingFaceModels/models--timm--resnet50.a1_in1k/snapshots/767268603ca0cb0bfe326fa87277f19c419566ef/model.safetensors"
                                   )
    head = RotationHead(in_dim=2048, num_classes=4).to(device)
    params = list(model.layer4.parameters()) + list(head.parameters()) if hasattr(model, 'layer4') else list(head.parameters())
    optim = Adam(params, lr=lr)
    # 预估调度步数：每张原图生成4个旋转样本，inner_bs为步长
    est_imgs = (img_batch_size if max_batches is None else img_batch_size * max_batches)
    total_train_samples = est_imgs * 4
    est_total_steps = max(1, total_train_samples // inner_bs)
    num_warmup_steps = max(1, int(est_total_steps * warmup_ratio))
    scheduler = get_cosine_schedule_with_warmup(optim, num_warmup_steps=num_warmup_steps, num_training_steps=est_total_steps)
    scaler = GradScaler(enabled=(use_amp and device.type=='cuda'))
    criterion = nn.CrossEntropyLoss()
    base_tf = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    angles = [0, 90, 180, 270]

    model.train()
    outer = loader.load_images_batch(split='train', batch_size=img_batch_size, max_batches=max_batches)
    for batch_idx, img_batch in enumerate(outer):
    # for batch_idx, img_batch in enumerate(tqdm(outer, desc="ResNet rotation: image batches", total=max_batches if max_batches else None)):
        ds = RotationBatchDataset(img_batch, base_tf, angles)
        if len(ds) == 0:
            continue
        dl = TorchDataLoader(ds, batch_size=inner_bs, shuffle=True, num_workers=2, pin_memory=(device.type=='cuda'))
        running_loss = 0.0
        steps = 0
        for xb, yb in dl: # tqdm(dl, desc=f"Batch {batch_idx+1}"):
            xb = xb.to(device, non_blocking=True)
            yb = torch.tensor(yb, dtype=torch.long, device=device)
            with autocast(
                device_type = device.type, 
                enabled=(use_amp and device.type=='cuda')
            ):
                feats = model(xb)
                logits = head(feats)
                loss = criterion(logits, yb)
            optim.zero_grad(set_to_none=True)
            if use_amp and device.type=='cuda':
                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
                scaler.step(optim)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
                optim.step()
            scheduler.step()
            running_loss += float(loss.detach().item())
            steps += 1
        if steps > 0:
            current_lr = optim.param_groups[0]['lr']
            print(f"Rotation Batch {batch_idx+1}: avg loss={running_loss/steps:.4f}, lr={current_lr:.6f}")
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    return model


In [None]:
# 1) 图像侧：ResNet50 旋转预测自监督微调
resnet_model = train_resnet_rotation(
    loader, device=device, img_batch_size=1000, max_batches=200, inner_bs=24, epochs=2, lr=1e-4, use_amp=True
)
torch.save(resnet_model.state_dict(), resnet_save_path)
print(f'ResNet50 backbone fine-tuned and saved to: {resnet_save_path}')

2025-11-09 21:20:37,968 - INFO - 批量加载train图片数据: /mnt/d/forCoding_data/Tianchi_MUGE/originalData/Multimodal_Retrieval/MR_train_imgs.tsv
  yb = torch.tensor(yb, dtype=torch.long, device=device)
实际加载train图片数据:   1%|▍                                                    | 1100/129380 [00:05<15:07, 141.34it/s]

Rotation Batch 1: avg loss=1.3862, lr=0.000010


实际加载train图片数据:   2%|▉                                                    | 2157/129380 [00:08<10:24, 203.79it/s]

Rotation Batch 2: avg loss=1.3856, lr=0.000020


实际加载train图片数据:   2%|█▎                                                   | 3067/129380 [00:10<08:49, 238.53it/s]

Rotation Batch 3: avg loss=1.3804, lr=0.000030


实际加载train图片数据:   3%|█▋                                                   | 4034/129380 [00:13<08:18, 251.65it/s]

Rotation Batch 4: avg loss=1.3767, lr=0.000040


实际加载train图片数据:   4%|██▏                                                  | 5207/129380 [00:16<07:57, 260.30it/s]

Rotation Batch 5: avg loss=1.3707, lr=0.000050


实际加载train图片数据:   5%|██▌                                                  | 6141/129380 [00:18<07:48, 263.08it/s]

Rotation Batch 6: avg loss=1.3635, lr=0.000060


实际加载train图片数据:   6%|██▉                                                  | 7290/129380 [00:21<06:10, 329.31it/s]

Rotation Batch 7: avg loss=1.3546, lr=0.000070


实际加载train图片数据:   6%|███▎                                                 | 8134/129380 [00:24<08:01, 252.01it/s]

Rotation Batch 8: avg loss=1.3461, lr=0.000080


实际加载train图片数据:   7%|███▋                                                 | 9022/129380 [00:27<08:12, 244.46it/s]

Rotation Batch 9: avg loss=1.3317, lr=0.000090


实际加载train图片数据:   8%|████                                                | 10174/129380 [00:30<07:38, 260.12it/s]

Rotation Batch 10: avg loss=1.3134, lr=0.000100


实际加载train图片数据:   9%|████▍                                               | 11069/129380 [00:32<08:06, 243.24it/s]

Rotation Batch 11: avg loss=1.2924, lr=0.000100


实际加载train图片数据:   9%|████▊                                               | 12000/129380 [00:36<09:20, 209.59it/s]

Rotation Batch 12: avg loss=1.2693, lr=0.000100


实际加载train图片数据:  10%|█████▎                                              | 13227/129380 [00:39<07:32, 256.94it/s]

Rotation Batch 13: avg loss=1.2572, lr=0.000100


实际加载train图片数据:  11%|█████▋                                              | 14068/129380 [00:42<08:06, 236.79it/s]

Rotation Batch 14: avg loss=1.2303, lr=0.000100


实际加载train图片数据:  12%|██████▏                                             | 15256/129380 [00:44<06:56, 274.07it/s]

Rotation Batch 15: avg loss=1.2189, lr=0.000100


实际加载train图片数据:  12%|██████▍                                             | 16128/129380 [00:47<07:48, 241.51it/s]

Rotation Batch 16: avg loss=1.2044, lr=0.000100


实际加载train图片数据:  13%|██████▉                                             | 17335/129380 [00:50<05:22, 347.78it/s]

Rotation Batch 17: avg loss=1.1767, lr=0.000100


实际加载train图片数据:  14%|███████▎                                            | 18305/129380 [00:53<05:08, 360.03it/s]

Rotation Batch 18: avg loss=1.1561, lr=0.000100


实际加载train图片数据:  15%|███████▋                                            | 19193/129380 [00:55<06:50, 268.44it/s]

Rotation Batch 19: avg loss=1.1166, lr=0.000099


实际加载train图片数据:  16%|████████                                            | 20106/129380 [00:58<07:13, 251.88it/s]

Rotation Batch 20: avg loss=1.1220, lr=0.000099


实际加载train图片数据:  16%|████████▍                                           | 21022/129380 [01:01<07:22, 244.84it/s]

Rotation Batch 21: avg loss=1.0865, lr=0.000099


实际加载train图片数据:  17%|████████▉                                           | 22248/129380 [01:04<06:38, 268.51it/s]

Rotation Batch 22: avg loss=1.0797, lr=0.000099


实际加载train图片数据:  18%|█████████▎                                          | 23155/129380 [01:07<07:09, 247.32it/s]

Rotation Batch 23: avg loss=1.0649, lr=0.000099


实际加载train图片数据:  19%|█████████▋                                          | 24009/129380 [01:10<08:48, 199.34it/s]

Rotation Batch 24: avg loss=1.0496, lr=0.000099


实际加载train图片数据:  19%|██████████                                          | 25162/129380 [01:13<07:16, 238.68it/s]

Rotation Batch 25: avg loss=1.0457, lr=0.000098


实际加载train图片数据:  20%|██████████▍                                         | 26023/129380 [01:16<07:18, 235.73it/s]

Rotation Batch 26: avg loss=1.0220, lr=0.000098


实际加载train图片数据:  21%|██████████▉                                         | 27306/129380 [01:19<04:49, 352.21it/s]

Rotation Batch 27: avg loss=1.0132, lr=0.000098


实际加载train图片数据:  22%|███████████▎                                        | 28188/129380 [01:21<06:23, 264.05it/s]

Rotation Batch 28: avg loss=1.0128, lr=0.000098


实际加载train图片数据:  23%|███████████▋                                        | 29125/129380 [01:24<06:32, 255.38it/s]

Rotation Batch 29: avg loss=1.0078, lr=0.000098


实际加载train图片数据:  23%|████████████                                        | 30039/129380 [01:27<06:43, 246.47it/s]

Rotation Batch 30: avg loss=0.9791, lr=0.000097


实际加载train图片数据:  24%|████████████▌                                       | 31181/129380 [01:30<06:25, 254.68it/s]

Rotation Batch 31: avg loss=1.0152, lr=0.000097


实际加载train图片数据:  25%|████████████▉                                       | 32092/129380 [01:32<06:35, 246.17it/s]

Rotation Batch 32: avg loss=0.9792, lr=0.000097


实际加载train图片数据:  26%|█████████████▎                                      | 33010/129380 [01:35<06:34, 244.49it/s]

Rotation Batch 33: avg loss=1.0033, lr=0.000096


实际加载train图片数据:  26%|█████████████▊                                      | 34219/129380 [01:38<06:17, 252.15it/s]

Rotation Batch 34: avg loss=0.9852, lr=0.000096


实际加载train图片数据:  27%|██████████████                                      | 35041/129380 [01:41<07:05, 221.76it/s]

Rotation Batch 35: avg loss=0.9520, lr=0.000096


实际加载train图片数据:  28%|██████████████▌                                     | 36319/129380 [01:44<04:46, 324.32it/s]

Rotation Batch 36: avg loss=0.9721, lr=0.000095


实际加载train图片数据:  29%|██████████████▉                                     | 37259/129380 [01:47<06:02, 254.07it/s]

Rotation Batch 37: avg loss=0.9594, lr=0.000095


实际加载train图片数据:  29%|███████████████▎                                    | 38134/129380 [01:50<06:36, 230.12it/s]

Rotation Batch 38: avg loss=0.9642, lr=0.000095


实际加载train图片数据:  30%|███████████████▋                                    | 39024/129380 [01:53<06:37, 227.11it/s]

Rotation Batch 39: avg loss=0.9460, lr=0.000094


实际加载train图片数据:  31%|████████████████▏                                   | 40175/129380 [01:56<06:03, 245.31it/s]

Rotation Batch 40: avg loss=0.9532, lr=0.000094


实际加载train图片数据:  32%|████████████████▍                                   | 41024/129380 [02:00<08:08, 180.86it/s]

Rotation Batch 41: avg loss=0.9264, lr=0.000094


实际加载train图片数据:  33%|████████████████▉                                   | 42143/129380 [02:03<06:41, 217.41it/s]

Rotation Batch 42: avg loss=0.9327, lr=0.000093


实际加载train图片数据:  33%|█████████████████▍                                  | 43297/129380 [02:06<04:39, 308.09it/s]

Rotation Batch 43: avg loss=0.9490, lr=0.000093


实际加载train图片数据:  34%|█████████████████▊                                  | 44192/129380 [02:09<05:43, 248.36it/s]

Rotation Batch 44: avg loss=0.9277, lr=0.000092


实际加载train图片数据:  35%|██████████████████▏                                 | 45102/129380 [02:12<06:00, 233.94it/s]

Rotation Batch 45: avg loss=0.9369, lr=0.000092


实际加载train图片数据:  36%|██████████████████▍                                 | 46000/129380 [02:15<06:01, 230.71it/s]

Rotation Batch 46: avg loss=0.9118, lr=0.000091


实际加载train图片数据:  36%|██████████████████▉                                 | 47223/129380 [02:18<05:21, 255.87it/s]

Rotation Batch 47: avg loss=0.9196, lr=0.000091


实际加载train图片数据:  37%|███████████████████▎                                | 48170/129380 [02:21<05:33, 243.77it/s]

Rotation Batch 48: avg loss=0.9096, lr=0.000090


实际加载train图片数据:  38%|███████████████████▋                                | 49046/129380 [02:24<05:54, 226.38it/s]

Rotation Batch 49: avg loss=0.9374, lr=0.000090


实际加载train图片数据:  39%|████████████████████▏                               | 50318/129380 [02:27<03:59, 330.19it/s]

Rotation Batch 50: avg loss=0.9327, lr=0.000089


实际加载train图片数据:  40%|████████████████████▌                               | 51239/129380 [02:30<05:00, 259.79it/s]

Rotation Batch 51: avg loss=0.9077, lr=0.000089


实际加载train图片数据:  40%|████████████████████▉                               | 52000/129380 [02:33<06:51, 188.09it/s]

Rotation Batch 52: avg loss=0.9389, lr=0.000088


实际加载train图片数据:  41%|█████████████████████▍                              | 53347/129380 [02:36<03:47, 333.58it/s]

Rotation Batch 53: avg loss=0.9043, lr=0.000088


实际加载train图片数据:  42%|█████████████████████▊                              | 54188/129380 [02:39<04:53, 256.59it/s]

Rotation Batch 54: avg loss=0.9050, lr=0.000087


实际加载train图片数据:  43%|██████████████████████▏                             | 55130/129380 [02:42<04:55, 251.19it/s]

Rotation Batch 55: avg loss=0.8932, lr=0.000087


实际加载train图片数据:  44%|██████████████████████▋                             | 56293/129380 [02:45<03:41, 329.29it/s]

Rotation Batch 56: avg loss=0.8979, lr=0.000086


实际加载train图片数据:  44%|██████████████████████▉                             | 57168/129380 [02:47<04:41, 256.74it/s]

Rotation Batch 57: avg loss=0.8874, lr=0.000086


实际加载train图片数据:  45%|███████████████████████▎                            | 58052/129380 [02:50<04:57, 240.05it/s]

Rotation Batch 58: avg loss=0.8961, lr=0.000085


实际加载train图片数据:  46%|███████████████████████▊                            | 59249/129380 [02:53<04:25, 264.41it/s]

Rotation Batch 59: avg loss=0.8958, lr=0.000084


实际加载train图片数据:  46%|████████████████████████▏                           | 60142/129380 [02:56<04:44, 243.26it/s]

Rotation Batch 60: avg loss=0.8892, lr=0.000084


实际加载train图片数据:  47%|████████████████████████▌                           | 61024/129380 [02:59<04:54, 232.23it/s]

Rotation Batch 61: avg loss=0.8684, lr=0.000083


实际加载train图片数据:  48%|█████████████████████████                           | 62289/129380 [03:02<03:16, 340.75it/s]

Rotation Batch 62: avg loss=0.8960, lr=0.000083


实际加载train图片数据:  49%|█████████████████████████▍                          | 63169/129380 [03:05<05:02, 219.07it/s]

Rotation Batch 63: avg loss=0.8808, lr=0.000082


实际加载train图片数据:  49%|█████████████████████████▋                          | 64038/129380 [03:08<04:48, 226.38it/s]

Rotation Batch 64: avg loss=0.8613, lr=0.000081


实际加载train图片数据:  50%|██████████████████████████▏                         | 65195/129380 [03:11<04:12, 254.03it/s]

Rotation Batch 65: avg loss=0.8888, lr=0.000081


实际加载train图片数据:  51%|██████████████████████████▌                         | 66036/129380 [03:14<04:30, 234.11it/s]

Rotation Batch 66: avg loss=0.8645, lr=0.000080


实际加载train图片数据:  52%|██████████████████████████▉                         | 67169/129380 [03:16<04:07, 250.86it/s]

Rotation Batch 67: avg loss=0.8656, lr=0.000079


实际加载train图片数据:  53%|███████████████████████████▍                        | 68349/129380 [03:19<02:58, 341.87it/s]

Rotation Batch 68: avg loss=0.8902, lr=0.000079


实际加载train图片数据:  53%|███████████████████████████▋                        | 69000/129380 [03:22<03:58, 253.04it/s]

Rotation Batch 69: avg loss=0.9016, lr=0.000078


实际加载train图片数据:  54%|████████████████████████████▏                       | 70234/129380 [03:25<03:35, 274.34it/s]

Rotation Batch 70: avg loss=0.8648, lr=0.000077


实际加载train图片数据:  55%|████████████████████████████▌                       | 71117/129380 [03:28<03:56, 246.38it/s]

Rotation Batch 71: avg loss=0.8965, lr=0.000077


实际加载train图片数据:  56%|█████████████████████████████                       | 72289/129380 [03:31<02:53, 329.34it/s]

Rotation Batch 72: avg loss=0.8901, lr=0.000076


实际加载train图片数据:  57%|█████████████████████████████▍                      | 73133/129380 [03:33<03:42, 252.77it/s]

Rotation Batch 73: avg loss=0.8655, lr=0.000075


实际加载train图片数据:  57%|█████████████████████████████▊                      | 74046/129380 [03:37<04:26, 207.62it/s]

Rotation Batch 74: avg loss=0.8638, lr=0.000074


实际加载train图片数据:  58%|██████████████████████████████▏                     | 75257/129380 [03:40<03:33, 253.18it/s]

Rotation Batch 75: avg loss=0.8889, lr=0.000074


实际加载train图片数据:  59%|██████████████████████████████▋                     | 76211/129380 [03:43<03:33, 249.38it/s]

Rotation Batch 76: avg loss=0.8693, lr=0.000073


实际加载train图片数据:  60%|██████████████████████████████▉                     | 77124/129380 [03:45<03:36, 241.42it/s]

Rotation Batch 77: avg loss=0.8514, lr=0.000072


实际加载train图片数据:  61%|███████████████████████████████▍                    | 78341/129380 [03:48<02:28, 342.79it/s]

Rotation Batch 78: avg loss=0.8527, lr=0.000071


实际加载train图片数据:  61%|███████████████████████████████▊                    | 79252/129380 [03:51<03:07, 267.06it/s]

Rotation Batch 79: avg loss=0.8325, lr=0.000071


实际加载train图片数据:  62%|████████████████████████████████▏                   | 80158/129380 [03:54<03:16, 250.48it/s]

Rotation Batch 80: avg loss=0.8754, lr=0.000070


实际加载train图片数据:  63%|████████████████████████████████▌                   | 81061/129380 [03:57<03:20, 241.04it/s]

Rotation Batch 81: avg loss=0.8649, lr=0.000069


实际加载train图片数据:  64%|█████████████████████████████████                   | 82248/129380 [04:00<03:00, 261.21it/s]

Rotation Batch 82: avg loss=0.8559, lr=0.000068


实际加载train图片数据:  64%|█████████████████████████████████▍                  | 83067/129380 [04:02<03:20, 230.62it/s]

Rotation Batch 83: avg loss=0.8640, lr=0.000068


实际加载train图片数据:  65%|█████████████████████████████████▊                  | 84219/129380 [04:05<02:56, 255.62it/s]

Rotation Batch 84: avg loss=0.8420, lr=0.000067


实际加载train图片数据:  66%|██████████████████████████████████▏                 | 85096/129380 [04:08<03:04, 240.08it/s]

Rotation Batch 85: avg loss=0.8502, lr=0.000066


实际加载train图片数据:  67%|██████████████████████████████████▋                 | 86314/129380 [04:11<02:28, 290.04it/s]

Rotation Batch 86: avg loss=0.8705, lr=0.000065


实际加载train图片数据:  67%|███████████████████████████████████                 | 87220/129380 [04:14<02:43, 257.10it/s]

Rotation Batch 87: avg loss=0.8502, lr=0.000065


实际加载train图片数据:  68%|███████████████████████████████████▍                | 88084/129380 [04:17<02:54, 236.58it/s]

Rotation Batch 88: avg loss=0.8515, lr=0.000064


实际加载train图片数据:  69%|███████████████████████████████████▉                | 89303/129380 [04:20<01:58, 338.66it/s]

Rotation Batch 89: avg loss=0.8529, lr=0.000063


实际加载train图片数据:  70%|████████████████████████████████████▏               | 90168/129380 [04:23<02:29, 261.45it/s]

Rotation Batch 90: avg loss=0.8785, lr=0.000062


实际加载train图片数据:  70%|████████████████████████████████████▌               | 91035/129380 [04:25<02:39, 240.05it/s]

Rotation Batch 91: avg loss=0.8361, lr=0.000061


实际加载train图片数据:  71%|█████████████████████████████████████               | 92200/129380 [04:28<02:22, 260.55it/s]

Rotation Batch 92: avg loss=0.8671, lr=0.000061


实际加载train图片数据:  72%|█████████████████████████████████████▍              | 93068/129380 [04:31<02:31, 240.38it/s]

Rotation Batch 93: avg loss=0.8740, lr=0.000060


实际加载train图片数据:  73%|█████████████████████████████████████▉              | 94294/129380 [04:34<01:43, 338.02it/s]

Rotation Batch 94: avg loss=0.8727, lr=0.000059


实际加载train图片数据:  74%|██████████████████████████████████████▎             | 95197/129380 [04:37<02:07, 267.17it/s]

Rotation Batch 95: avg loss=0.8839, lr=0.000058


实际加载train图片数据:  74%|██████████████████████████████████████▋             | 96137/129380 [04:39<02:10, 254.65it/s]

Rotation Batch 96: avg loss=0.8406, lr=0.000057


实际加载train图片数据:  75%|███████████████████████████████████████             | 97330/129380 [04:42<01:33, 344.38it/s]

Rotation Batch 97: avg loss=0.8351, lr=0.000056


实际加载train图片数据:  76%|███████████████████████████████████████▍            | 98088/129380 [04:46<02:19, 224.32it/s]

Rotation Batch 98: avg loss=0.8302, lr=0.000056


实际加载train图片数据:  77%|███████████████████████████████████████▊            | 99016/129380 [04:48<02:07, 237.74it/s]

Rotation Batch 99: avg loss=0.8743, lr=0.000055


实际加载train图片数据:  77%|███████████████████████████████████████▍           | 100160/129380 [04:51<01:53, 256.32it/s]

Rotation Batch 100: avg loss=0.8353, lr=0.000054


实际加载train图片数据:  78%|███████████████████████████████████████▉           | 101306/129380 [04:54<01:23, 334.93it/s]

Rotation Batch 101: avg loss=0.8067, lr=0.000053


实际加载train图片数据:  79%|████████████████████████████████████████▏          | 101925/129380 [04:54<00:43, 631.07it/s]

## 文本侧：BERT 在电商查询上的 MLM 微调
- 收集 train/valid 查询文本，构造语料。
- 使用 `BertForMaskedLM` 与 `DataCollatorForLanguageModeling` 实现标准 MLM 训练。
- 默认微调整个 BERT；可选仅微调顶层。


In [None]:
class TextDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer: BertTokenizer, max_length=128):
        self.texts = texts
        self.tok = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx]
        enc = self.tok(text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        # squeeze to remove batch dim
        item = {k: v.squeeze(0) for k, v in enc.items()}
        return item

def train_bert_mlm(loader: DataLoader, device, epochs=1, batch_size=64, lr=5e-5, max_samples=None, last_n_layers=None, use_amp=True, warmup_ratio=0.05):
        tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', cache_dir=cache_dir, local_files_only=True)
        model = BertForMaskedLM.from_pretrained('bert-base-chinese', cache_dir=cache_dir, local_files_only=True).to(device)
        # 可选：仅微调顶层
        if last_n_layers is not None:
            for p in model.bert.parameters():
                p.requires_grad = False
            enc = model.bert.encoder
            total_layers = len(enc.layer)
            start_idx = max(0, total_layers - last_n_layers)
            for i in range(start_idx, total_layers):
                for p in enc.layer[i].parameters():
                    p.requires_grad = True
                enc.layer[i].train()
            if hasattr(model.bert, 'pooler') and model.bert.pooler is not None:
                for p in model.bert.pooler.parameters():
                    p.requires_grad = True
                model.bert.pooler.train()
        # 数据准备
        train_df = loader.load_queries(split='train')
        valid_df = loader.load_queries(split='valid')
        train_texts = [t for t in train_df.get('query_text', []) if isinstance(t, str)]
        valid_texts = [t for t in valid_df.get('query_text', []) if isinstance(t, str)]
        if max_samples is not None:
            train_texts = train_texts[:max_samples]
        ds_train = TextDataset(train_texts, tokenizer)
        collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
        dl_train = TorchDataLoader(ds_train, batch_size=batch_size, shuffle=True, collate_fn=collator, num_workers=2, pin_memory=(device.type=='cuda'))
        optim = AdamW([p for p in model.parameters() if p.requires_grad], lr=lr)
        # 预估总步数：按最大样本与批大小粗略估算
        est_train_samples = len(train_texts) if max_samples is None else min(len(train_texts), max_samples)
        est_total_steps = max(1, (est_train_samples // batch_size) * epochs)
        num_warmup_steps = max(1, int(est_total_steps * warmup_ratio))
        scheduler = get_cosine_schedule_with_warmup(optim, num_warmup_steps=num_warmup_steps, num_training_steps=est_total_steps)
        scaler = GradScaler(enabled=(device.type=='cuda'))
        model.train()
        for e in range(epochs):
            step = 0
            running_loss = 0.0
            for batch in tqdm(dl_train, desc=f"MLM epoch {e+1}/{epochs}"):
                batch = {k: v.to(device) for k, v in batch.items()}
                with autocast(
                    device_type = device.type, 
                    enabled=(use_amp and device.type=='cuda')
                ):
                    outputs = model(**batch)
                    loss = outputs.loss
                optim.zero_grad(set_to_none=True)
                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optim)
                scaler.update()
                scheduler.step()
                step += 1
                running_loss += float(loss.detach().item())
            if step > 0:
                current_lr = optim.param_groups[0]['lr']
                print(f"MLM epoch {e+1}: avg loss={running_loss/step:.4f}, lr={current_lr:.6f}")
            # 简单的验证（可选）：这里只做一次前向以确保正常
            if valid_texts:
                sample_texts = valid_texts[:8]
                ds_valid = TextDataset(sample_texts, tokenizer)
                dl_valid = TorchDataLoader(ds_valid, batch_size=8, shuffle=False, collate_fn=collator)
                model.eval()
                with torch.no_grad():
                    for vb in dl_valid:
                        _ = model(**{k: v.to(device) for k, v in vb.items()})
                model.train()
        return model


In [None]:
# 2) 文本侧：BERT MLM 微调
bert_model = train_bert_mlm(
    loader, device=device, epochs=2, batch_size=64, lr=5e-5, max_samples=300000, last_n_layers=None, use_amp = True
)
torch.save(bert_model.state_dict(), bert_save_path)
print(f'BERT MLM fine-tuned and saved to: {bert_save_path}')


## 后续使用建议
- ResNet50：后续检索任务中，使用 `timm.create_model('resnet50', num_classes=0)` 并加载上述 `state_dict`；只取特征作为图像嵌入。
- BERT：后续检索任务中，加载 `BertModel` 并将 `BERT MLM` 的 `state_dict` 映射到 `bert.*` 对应权重（本 notebook 保存的是 `BertForMaskedLM` 的整体权重）。
- 若需更长训练或更强微调：适当增大 `max_batches`（图像）与 `epochs`（文本），并根据显存与时间预算调整 `batch_size`。


In [None]:
with open("ft_rsn50_brt.finishflag", "w") as f:
    f.write("finish")

import IPython
def kill_current_kernel():
    '''杀死当前的kernel释放内存空间。'''
    IPython.Application.instance().kernel.do_shutdown(True)
kill_current_kernel()