In [1]:
# ================== Shopee: ViT + DeBERTa + Text Decoder (ArcFace) ==================
## ================== Shopee: ViT + DeBERTa + Sub-Center ArcFace (fixed training) ==================
# ================== Shopee: ViT + DeBERTa + Sub-Center ArcFace + Fusion (fixed) ==================
import os, math, random, gc, sys
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# torchvision v2 transforms (с фолбэком на старый API)
try:
    from torchvision.transforms import v2
except Exception:
    from torchvision import transforms as v2

import timm
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModel

# ----------------- Seeds / Determinism -----------------
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ----------------- Config -----------------
DATA_DIR = '/kaggle/input/shopee-product-matching'
TRAIN_CSV = os.path.join(DATA_DIR, 'train.csv')
TEST_CSV  = os.path.join(DATA_DIR, 'test.csv')
TRAIN_IMG_DIR = os.path.join(DATA_DIR, 'train_images')
TEST_IMG_DIR  = os.path.join(DATA_DIR, 'test_images')

IMSIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 4
EPOCHS = 5

# Warmup/AMP/Clip
FREEZE_EPOCHS = 1              # столько эпох backbone заморожен (стабилизируем центры)
AMP_OFF_STEPS = 1000           # первые N шагов без AMP, чтобы «поймать» градиенты
CLIP_GRAD_NORM = None          # либо 10.0; по умолчанию без клипа

# Retrieval / fusion
KQ = 100
MUTUAL_AT_VAL = True
TEXT_MAX_LEN = 64
TEXT_ARC_W_FINAL = 0.20        # вес текстовой ветки в классификационном лоссе (подмешиваем нежно)

# image / text encoders
IMG_BACKBONE = 'vit_base_patch16_224'
TEXT_MODEL_NAME = 'xlm-roberta-large'

# Arc schedules (мягкие!)
S_START, S_END = 16.0, 45.0
M_END = 0.30
SUB_K = 3   # sub-centers per class

# ----------------- Data -----------------
train = pd.read_csv(TRAIN_CSV)
test  = pd.read_csv(TEST_CSV)

# id mapping
label2id = {lg: i for i, lg in enumerate(sorted(train['label_group'].unique()))}
id2label = {i: lg for lg, i in label2id.items()}
train['class_id'] = train['label_group'].map(label2id)

# простая random split (стратификация по всем классам в Shopee небезопасна из-за синглтонов)
train_df, val_df = train_test_split(train, test_size=0.2, random_state=seed, shuffle=True)

# ----------------- Dataset -----------------
class ShopeeDataset(Dataset):
    def __init__(self, df, img_root, transform, train=True):
        self.df = df.reset_index(drop=True)
        self.img_root = img_root
        self.transform = transform
        self.train = train
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # картинки в датасете уже .jpg
        img_path = os.path.join(self.img_root, str(row['image']))
        image = Image.open(img_path).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        sample = {'image': image,
                  'posting_id': row['posting_id'],
                  'title': row.get('title', '')}
        if self.train:
            sample['label'] = torch.tensor(int(row['class_id']), dtype=torch.long)
        return sample

# ----------------- Transforms -----------------
transforms_train = v2.Compose([
    v2.Resize(256, antialias=True),
    v2.RandomResizedCrop(IMSIZE, scale=(0.8, 1.0), antialias=True),
    v2.RandomHorizontalFlip(),
    v2.ToTensor(),
    v2.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

transforms_eval = v2.Compose([
    v2.Resize(256, antialias=True),
    v2.CenterCrop(IMSIZE),
    v2.ToTensor(),
    v2.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

train_ds = ShopeeDataset(train_df, TRAIN_IMG_DIR, transforms_train, train=True)
val_ds   = ShopeeDataset(val_df,   TRAIN_IMG_DIR, transforms_eval,  train=True)
test_ds  = ShopeeDataset(test,     TEST_IMG_DIR,  transforms_eval,  train=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

# ----------------- Models -----------------
# Image backbone (ViT)
img_backbone = timm.create_model(IMG_BACKBONE, pretrained=True, num_classes=0).to(device)
feat_dim = img_backbone.num_features   # ViT-B/16 -> 768

embedding_head = nn.Sequential(
    nn.Linear(feat_dim, 512, bias=False),
    nn.BatchNorm1d(512)
).to(device)

# Text encoder (DeBERTa v3) + projection head
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME, use_fast=True)
text_encoder   = AutoModel.from_pretrained(TEXT_MODEL_NAME).to(device)
for p in text_encoder.parameters():
    p.requires_grad = False
text_hidden = text_encoder.config.hidden_size  # обычно 768

text_head = nn.Sequential(
    nn.Linear(text_hidden, 512, bias=False),
    nn.BatchNorm1d(512)
).to(device)

def mean_pooling(last_hidden_state, attention_mask):
    mask = attention_mask.unsqueeze(-1).float()
    summed = (last_hidden_state * mask).sum(dim=1)
    denom = mask.sum(dim=1).clamp(min=1e-6)
    return summed / denom

# ----------------- Sub-Center ArcFace -----------------
class SubCenterArcFace(nn.Module):
    """
    Sub-Center ArcFace: для каждого класса C у нас K под-центров.
    Веса: (C*K, D). Косинус -> reshape (B, C, K) -> max по K -> (B, C),
    затем ArcFace margin на true-класс.
    """
    def __init__(self, in_features, out_classes, k_sub=3, s=30.0, m=0.5, easy_margin=False):
        super().__init__()
        self.in_features  = in_features
        self.out_classes  = out_classes
        self.k_sub        = k_sub
        self.s, self.m    = s, m
        self.easy_margin  = easy_margin

        self.weight = nn.Parameter(torch.FloatTensor(out_classes * k_sub, in_features))
        nn.init.xavier_uniform_(self.weight)

        self._refresh_trig()

    def _refresh_trig(self):
        self.cos_m = math.cos(self.m)
        self.sin_m = math.sin(self.m)
        self.th    = math.cos(math.pi - self.m)
        self.mm    = math.sin(math.pi - self.m) * self.m

    def set_margin(self, m):
        self.m = float(m)
        self._refresh_trig()

    def forward(self, emb, labels):
        # emb: (B, D), предполагаем L2-нормированные
        W = F.normalize(self.weight)                       # (C*K, D)
        cosine_all = F.linear(emb, W)                      # (B, C*K)
        cosine_all = cosine_all.view(emb.size(0), self.out_classes, self.k_sub)
        cosine, _  = torch.max(cosine_all, dim=2)          # (B, C)

        # ArcFace угол
        sine = torch.sqrt(torch.clamp(1.0 - cosine**2, min=1e-6))
        phi  = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)

        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1,1), 1.0)

        logits = one_hot * phi + (1.0 - one_hot) * cosine
        logits = logits * self.s
        return logits

NUM_CLASSES = train['class_id'].nunique()
arcface = SubCenterArcFace(512, NUM_CLASSES, k_sub=SUB_K, s=S_START, m=0.0, easy_margin=False).to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.05)

# ----------------- Param groups (BN/LayerNorm без WD!) -----------------
def split_params_by_wd(module):
    wd, no_wd = [], []
    for m in module.modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
            for p in m.parameters(recurse=False):
                if p.requires_grad: no_wd.append(p)
        else:
            for p in m.parameters(recurse=False):
                if p.requires_grad:
                    if p.ndim == 1:  # bias/scale
                        no_wd.append(p)
                    else:
                        wd.append(p)
    return wd, no_wd

wd_backbone, no_wd_backbone = split_params_by_wd(img_backbone)
wd_head,     no_wd_head     = split_params_by_wd(embedding_head)
wd_text,     no_wd_text     = split_params_by_wd(text_head)
wd_arc,      no_wd_arc      = split_params_by_wd(arcface)

param_groups = [
    {'params': wd_backbone, 'lr': 2e-4, 'weight_decay': 0.05},
    {'params': no_wd_backbone, 'lr': 2e-4, 'weight_decay': 0.0},
    {'params': wd_head, 'lr': 1e-3, 'weight_decay': 0.05},
    {'params': no_wd_head, 'lr': 1e-3, 'weight_decay': 0.0},
    {'params': wd_text, 'lr': 1e-3, 'weight_decay': 0.05},
    {'params': no_wd_text, 'lr': 1e-3, 'weight_decay': 0.0},
    {'params': wd_arc, 'lr': 1e-3, 'weight_decay': 0.05},
    {'params': no_wd_arc, 'lr': 1e-3, 'weight_decay': 0.0},
]
optimizer = torch.optim.AdamW(param_groups)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))

# ----------------- S/M/Text-weight schedules -----------------
def s_at_epoch(ep, total=EPOCHS):
    if total==1: return S_END
    return S_START + (S_END - S_START) * (ep-1) / (total-1)

def m_at_epoch(ep, total=EPOCHS):
    if total==1: return M_END
    return (M_END * (ep-1)) / (total-1)

def text_weight_at_epoch(ep, w_end=TEXT_ARC_W_FINAL, total=EPOCHS):
    if total==1: return w_end
    return (w_end * (ep-1)) / (total-1)

def set_backbone_trainable(flag: bool):
    for p in img_backbone.parameters():
        p.requires_grad = flag

# ----------------- Train -----------------
global_step = 0
for epoch in range(1, EPOCHS+1):
    # schedules
    arcface.set_margin(m_at_epoch(epoch))
    arcface.s = s_at_epoch(epoch)
    TW = text_weight_at_epoch(epoch)

    # warmup freeze
    set_backbone_trainable(epoch > FREEZE_EPOCHS)
    img_backbone.train(epoch > FREEZE_EPOCHS)
    embedding_head.train(); text_head.train(); arcface.train()

    running = 0.0
    pbar = tqdm(train_loader, desc=f"train {epoch}/{EPOCHS} (s={arcface.s:.1f}, m={arcface.m:.2f}, tw={TW:.2f})", leave=False)
    for batch in pbar:
        optimizer.zero_grad(set_to_none=True)

        use_amp = (device.type=='cuda') and (global_step >= AMP_OFF_STEPS)

        with torch.cuda.amp.autocast(enabled=use_amp):
            X = batch['image'].to(device, non_blocking=True)
            y = batch['label'].to(device, non_blocking=True)

            # image branch
            f_img = img_backbone(X)
            e_img = F.normalize(embedding_head(f_img), dim=1)
            logits_img = arcface(e_img, y)
            loss_img = criterion(logits_img, y)

            # text branch (encoder frozen)
            titles = batch['title']
            tok = text_tokenizer(list(titles), padding=True, truncation=True,
                                 max_length=TEXT_MAX_LEN, return_tensors='pt')
            tok = {k: v.to(device, non_blocking=True) for k,v in tok.items()}
            with torch.no_grad():
                out_txt = text_encoder(**tok)
                pooled  = mean_pooling(out_txt.last_hidden_state, tok['attention_mask'])
            e_txt = F.normalize(text_head(pooled), dim=1)
            logits_txt = arcface(e_txt, y)
            loss_txt = criterion(logits_txt, y)

            loss = loss_img + TW * loss_txt

        if use_amp:
            scaler.scale(loss).backward()
            if CLIP_GRAD_NORM:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(arcface.parameters(), CLIP_GRAD_NORM)
                torch.nn.utils.clip_grad_norm_(embedding_head.parameters(), CLIP_GRAD_NORM)
                if epoch > FREEZE_EPOCHS:
                    torch.nn.utils.clip_grad_norm_(img_backbone.parameters(), CLIP_GRAD_NORM)
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward()
            if CLIP_GRAD_NORM:
                torch.nn.utils.clip_grad_norm_(arcface.parameters(), CLIP_GRAD_NORM)
                torch.nn.utils.clip_grad_norm_(embedding_head.parameters(), CLIP_GRAD_NORM)
                if epoch > FREEZE_EPOCHS:
                    torch.nn.utils.clip_grad_norm_(img_backbone.parameters(), CLIP_GRAD_NORM)
            optimizer.step()

        running += loss.item() * X.size(0)
        global_step += 1
        if global_step % 100 == 0:
            try: gnorm = arcface.weight.grad.norm().item()
            except: gnorm = float('nan')
            pbar.set_postfix(loss=running/((pbar.n+1)*X.size(0)), g_arc=f"{gnorm:.2f}")

    scheduler.step()

# ----------------- Validation (embeddings + fusion) -----------------
@torch.no_grad()
def build_img_embs(loader):
    img_backbone.eval(); embedding_head.eval()
    embs, ids, labels = [], [], []
    for b in tqdm(loader, desc="Embed/val(img)"):
        x = b['image'].to(device, non_blocking=True)
        e = F.normalize(embedding_head(img_backbone(x)), dim=1)
        embs.append(e.cpu())
        ids.extend(b['posting_id'])
        labels.extend(b['label'].cpu().numpy().tolist())
    embs = torch.cat(embs, dim=0).numpy().astype('float32')
    embs = embs / (np.linalg.norm(embs, axis=1, keepdims=True) + 1e-8)
    return embs, ids, np.array(labels)

@torch.no_grad()
def build_text_embs_for_df(df, batch_size=512, max_len=TEXT_MAX_LEN):
    text_encoder.eval(); text_head.eval()
    titles = df['title'].fillna('').astype(str).tolist()
    outs = []
    for i in tqdm(range(0, len(titles), batch_size), desc="Embed/val(txt)"):
        b = titles[i:i+batch_size]
        tok = text_tokenizer(b, padding=True, truncation=True,
                             max_length=max_len, return_tensors='pt')
        tok = {k: v.to(device, non_blocking=True) for k,v in tok.items()}
        out = text_encoder(**tok)
        pooled = mean_pooling(out.last_hidden_state, tok['attention_mask'])
        e = F.normalize(text_head(pooled), dim=1)
        outs.append(e.cpu())
    embs = torch.cat(outs, dim=0).numpy().astype('float32')
    embs = embs / (np.linalg.norm(embs, axis=1, keepdims=True) + 1e-8)
    return embs

def topk_chunked_cos(embs_f32: np.ndarray, K: int, qbs: int = 128):
    N, D = embs_f32.shape
    device_t = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    db = torch.from_numpy(embs_f32.astype('float32', copy=False)).to(device_t, non_blocking=True)
    K = min(K, N)
    idxs_list, sims_list = [], []
    for start in tqdm(range(0, N, qbs), desc="TopK (torch-chunk)"):
        q = db[start:start+qbs]
        S = torch.matmul(q, db.T)                      # косинус (L2-норм)
        vals, ids = torch.topk(S, k=K, dim=1, largest=True, sorted=True)
        idxs_list.append(ids.cpu().numpy().astype('int32'))
        sims_list.append(vals.cpu().numpy().astype('float32'))
        del S, vals, ids
        if device_t.type == 'cuda':
            torch.cuda.empty_cache()
    idxs = np.vstack(idxs_list); sims = np.vstack(sims_list)
    del db
    return sims, idxs

def build_preds_fused_union(ids, idxs_img, sims_img, idxs_txt, sims_txt,
                            tau, alpha=0.7, K_cap=50, mutual=True):
    N = len(ids)
    out = {}
    for i in range(N):
        cand_idx = set(idxs_img[i]).union(set(idxs_txt[i]))
        map_img = {int(j): float(s) for j, s in zip(idxs_img[i], sims_img[i])}
        map_txt = {int(j): float(s) for j, s in zip(idxs_txt[i], sims_txt[i])}
        fused = []
        for j in cand_idx:
            si = (map_img.get(int(j), 0.0) + 1.0) / 2.0
            st = (map_txt.get(int(j), 0.0) + 1.0) / 2.0
            s  = alpha*si + (1.0-alpha)*st
            fused.append((j, s))
        fused.sort(key=lambda x: -x[1])
        keep = []
        for j, s in fused:
            if s < tau: 
                continue
            if (not mutual) or (np.any(idxs_img[j] == i)):
                keep.append(ids[j])
        if ids[i] not in keep:
            keep = [ids[i]] + keep
        out[ids[i]] = set(keep[:50])
    return out

def f1_matches(ids, labels, preds):
    truth={}
    for pid, g in zip(ids, labels):
        truth.setdefault(g, set()).add(pid)
    f1s=[]
    for pid, g in zip(ids, labels):
        T = truth[g]; P = preds[pid]
        inter = len(T & P); denom = len(T)+len(P)
        f1s.append(2*inter/denom if denom>0 else 0.0)
    return float(np.mean(f1s))

# --- build embeddings ---
val_img, val_ids, val_labels = build_img_embs(val_loader)
val_txt = build_text_embs_for_df(val_ds.df, batch_size=512, max_len=TEXT_MAX_LEN)

# --- topK ---
sims_img, idxs_img = topk_chunked_cos(val_img, K=KQ, qbs=128)
sims_txt, idxs_txt = topk_chunked_cos(val_txt, K=KQ, qbs=256)

# --- grid search for (alpha, tau), MUTUAL=True ---
alphas = np.linspace(0.4, 0.9, 6)
taus   = np.linspace(0.20, 0.80, 31)
best_f1, best_tau, best_alpha, best_preds = -1.0, None, None, None
for a in alphas:
    for t in taus:
        preds = build_preds_fused_union(val_ids, idxs_img, sims_img, idxs_txt, sims_txt,
                                        tau=float(t), alpha=float(a), K_cap=50, mutual=MUTUAL_AT_VAL)
        f1 = f1_matches(val_ids, val_labels, preds)
        if f1 > best_f1:
            best_f1, best_tau, best_alpha, best_preds = f1, float(t), float(a), preds
print(f"[VAL FUSION] Best F1={best_f1:.4f} at tau={best_tau:.2f}, alpha={best_alpha:.2f}")
print(f"[VAL] Avg predicted group size: {np.mean([len(v) for v in best_preds.values()]):.2f}")

# ----------------- Save -----------------
SAVE_DIR = '/kaggle/working'
os.makedirs(SAVE_DIR, exist_ok=True)

ckpt = {
    'backbone_name': IMG_BACKBONE,
    'text_model_name': TEXT_MODEL_NAME,
    'feat_dim': int(feat_dim),
    'emb_dim': 512,
    'num_classes': NUM_CLASSES,
    'arcface_type': 'subcenter',
    'arcface_cfg': {'s': float(arcface.s), 'm': float(arcface.m), 'k_sub': int(SUB_K), 'easy_margin': False},
    'state_dict': {
        'backbone': img_backbone.state_dict(),
        'embedding_head': embedding_head.state_dict(),
        'text_head': text_head.state_dict(),
        'arcface': arcface.state_dict(),
    },
    'label2id': label2id,
    'best_tau': float(best_tau),
    'best_alpha': float(best_alpha),
    'mutual_used': bool(MUTUAL_AT_VAL),
    'val_f1_fusion': float(best_f1),
    'epoch': EPOCHS,
}
torch.save(ckpt, os.path.join(SAVE_DIR, 'vit_deberta_subarcface_ckpt.pth'))
print('[SAVE] Full checkpoint ->', os.path.join(SAVE_DIR, 'vit_deberta_subarcface_ckpt.pth'))

embed_pkg = {
    'backbone_name': IMG_BACKBONE,
    'text_model_name': TEXT_MODEL_NAME,
    'feat_dim': int(feat_dim),
    'emb_dim': 512,
    'state_dict': {
        'backbone': img_backbone.state_dict(),
        'embedding_head': embedding_head.state_dict(),
        'text_head': text_head.state_dict(),
    },
    'best_tau': float(best_tau),
    'best_alpha': float(best_alpha),
    'mutual_used': bool(MUTUAL_AT_VAL),
}
torch.save(embed_pkg, os.path.join(SAVE_DIR, 'embedding_extractor_vit_deberta_subarcface.pth'))
print('[SAVE] Embedding extractor ->', os.path.join(SAVE_DIR, 'embedding_extractor_vit_deberta_subarcface.pth'))

# ================== Inference / Submission ==================
@torch.no_grad()
def embed_test_images(ds, batch_size=64):
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)
    img_backbone.eval(); embedding_head.eval()
    embs, ids = [], []
    for b in tqdm(loader, desc="Embed/test(img)"):
        x = b['image'].to(device, non_blocking=True)
        e = F.normalize(embedding_head(img_backbone(x)), dim=1)
        embs.append(e.cpu())
        ids.extend(b['posting_id'])
    embs = torch.cat(embs, dim=0).numpy().astype('float32')
    embs = embs / (np.linalg.norm(embs, axis=1, keepdims=True) + 1e-8)
    return embs, ids

@torch.no_grad__()
def embed_test_text(df, batch_size=512, max_len=TEXT_MAX_LEN):
    text_encoder.eval(); text_head.eval()
    titles = df['title'].fillna('').astype(str).tolist()
    outs=[]
    for i in tqdm(range(0, len(titles), batch_size), desc="Embed/test(txt)"):
        b = titles[i:i+batch_size]
        tok = text_tokenizer(b, padding=True, truncation=True,
                             max_length=max_len, return_tensors='pt')
        tok = {k: v.to(device, non_blocking=True) for k,v in tok.items()}
        out = text_encoder(**tok)
        pooled = mean_pooling(out.last_hidden_state, tok['attention_mask'])
        e = F.normalize(text_head(pooled), dim=1)
        outs.append(e.cpu())
    embs = torch.cat(outs, dim=0).numpy().astype('float32')
    embs = embs / (np.linalg.norm(embs, axis=1, keepdims=True) + 1e-8)
    return embs

def build_faiss_or_torch_topk(embs_f32, K):
    N, D = embs_f32.shape
    try:
        import faiss
        xb = embs_f32.astype('float32', copy=False)
        index = faiss.IndexFlatIP(D)
        index.add(xb)
        def search(q, k):
            return index.search(q.astype('float32', copy=False), k)
        use_faiss = True
    except Exception:
        use_faiss = False
        device_t = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        db = torch.from_numpy(embs_f32.astype('float32', copy=False)).to(device_t, non_blocking=True)
        def search(q, k):
            q_t = torch.from_numpy(q.astype('float32', copy=False)).to(device_t, non_blocking=True)
            S = torch.matmul(q_t, db.T)
            vals, ids = torch.topk(S, k=k, dim=1, largest=True, sorted=True)
            return vals.cpu().numpy().astype('float32'), ids.cpu().numpy().astype('int64')
    return search, use_faiss

def predict_fused(ids, img_embs, txt_embs, alpha, tau, mutual=True, K_cap=50, ksearch=KQ):
    # precompute topK in each modality (self-sim includes self на позиции 0)
    sims_img, idxs_img = topk_chunked_cos(img_embs, K=ksearch, qbs=128)
    sims_txt, idxs_txt = topk_chunked_cos(txt_embs, K=ksearch, qbs=256)
    preds = build_preds_fused_union(ids, idxs_img, sims_img, idxs_txt, sims_txt,
                                    tau=float(tau), alpha=float(alpha), K_cap=K_cap, mutual=mutual)
    # convert to submission format
    posting_ids = []
    matches = []
    for pid in ids:
        posting_ids.append(pid)
        matches.append(" ".join(list(preds[pid])))
    return pd.DataFrame({'posting_id': posting_ids, 'matches': matches})

# --- Build test embeddings ---
test_img_embs, test_ids = embed_test_images(test_ds, batch_size=64)
test_txt_embs = embed_test_text(test_ds.df, batch_size=512, max_len=TEXT_MAX_LEN)

# --- Predict with best_alpha / best_tau found on val ---
sub_df = predict_fused(test_ids, test_img_embs, test_txt_embs,
                       alpha=best_alpha, tau=best_tau, mutual=MUTUAL_AT_VAL, K_cap=50, ksearch=KQ)

OUT_PATH = '/kaggle/working/submission.csv'
sub_df.to_csv(OUT_PATH, index=False)
print("[SAVE] submission ->", OUT_PATH)
display(sub_df.head())






model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

2025-09-28 17:30:16.197481: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759080616.362711      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759080616.413840      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

  scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))
  with torch.cuda.amp.autocast(enabled=use_amp):
train 2/5 (s=23.2, m=0.07, tw=0.05):   0%|          | 0/856 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment varia

[VAL FUSION] Best F1=0.8985 at tau=0.76, alpha=0.80
[VAL] Avg predicted group size: 1.70
[SAVE] Full checkpoint -> /kaggle/working/vit_deberta_subarcface_ckpt.pth
[SAVE] Embedding extractor -> /kaggle/working/embedding_extractor_vit_deberta_subarcface.pth


AttributeError: module 'torch' has no attribute 'no_grad__'