In [6]:
# ============================
# Shopee Inference (Image + Text Fusion)
# ============================

import os, gc, math, random, sys, time
import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
import timm
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModel

# ------------- Config -------------
SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DATA_DIR = '/kaggle/input/shopee-product-matching'
TEST_CSV = os.path.join(DATA_DIR, 'test.csv')
TEST_IMG_DIR = os.path.join(DATA_DIR, 'test_images')

# укажи путь к весам, если загрузил их как отдельный датасет
# например: /kaggle/input/shopee-multimodal-weights/embedding_extractor_mix.pth
CANDIDATE_WEIGHT_PATHS = [
    '/kaggle/input/shopee-octob-inference/embedding_extractor_mix.pth',  # <- пример
    '/kaggle/working/embedding_extractor_mix.pth',                          # если запускаешь в одном ноуте
    '/kaggle/input/shopee-octob-inference/embedding_extractor.pth',
    '/kaggle/working/embedding_extractor.pth',
]

IMAGE_SIZE = 224            # должен совпадать с train
BATCH_IMG = 64
BATCH_TXT = 256
KQ = 100                    # ширина кандидатов перед обрезкой
K_CAP = 50                  # лимит на размер группы (условие соревна)

# ------------- Data -------------
test = pd.read_csv(TEST_CSV)

class ShopeeTestDataset(Dataset):
    def __init__(self, df, img_root, transform):
        self.df = df.reset_index(drop=True)
        self.img_root = img_root
        self.transform = transform
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_root, row['image'])
        image = Image.open(img_path).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return {
            'image': image,
            'posting_id': row['posting_id'],
            'title': str(row['title']) if not pd.isna(row['title']) else ''
        }

tfm_test = v2.Compose([
    v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    v2.ToTensor(),
    v2.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

test_ds = ShopeeTestDataset(test, TEST_IMG_DIR, tfm_test)
test_dl = DataLoader(test_ds, batch_size=BATCH_IMG, shuffle=False,
                     num_workers=4, pin_memory=True, persistent_workers=True)

# ------------- Models (must match training) -------------
# Image branch: eca_nfnet_l1 + Linear(->512)+BN
def build_image_branch(backbone_name='eca_nfnet_l1', emb_dim=512):
    backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0).to(device).eval()
    feat_dim = backbone.num_features
    head = nn.Sequential(nn.Linear(feat_dim, emb_dim, bias=False),
                         nn.BatchNorm1d(emb_dim)).to(device).eval()
    return backbone, head, feat_dim

# Text branch: xlm-roberta-base + mean-pooling + Linear(->512)+BN
#TEXT_MODEL_NAME = 'xlm-roberta-base'
#text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME, use_fast=True)
#from huggingface_hub import snapshot_download
#LOCAL_XLMR_DIR = snapshot_download(repo_id="xlm-roberta-base")
#TEXT_MODEL_NAME='xlm-roberta-base'
import os
from transformers import AutoTokenizer, XLMRobertaConfig, XLMRobertaModel
from torch import nn

# Папки, где ищем сохранённый токенайзер/конфиг из тренировки
LOCAL_XLMR_DIRS = [
    "/kaggle/input/shopee-token/xlmr_tokenizer",  # <-- укажи свой датасет/путь
    "/kaggle/input/xlmr_tokenizer",
    "/kaggle/working/xlmr_tokenizer",
    
]

def try_load_local_xlmr(sd_txt, emb_dim):
    """Возвращает (tokenizer, txt_model, txt_head) или (None,None,None), если не нашли локальные файлы."""
    for d in LOCAL_XLMR_DIRS:
        tok_json = os.path.join(d, "tokenizer.json")
        cfg_json = os.path.join(d, "config.json")
        if os.path.exists(tok_json) and os.path.exists(cfg_json):
            try:
                tokenizer = AutoTokenizer.from_pretrained(d, use_fast=True, local_files_only=True)
                cfg = XLMRobertaConfig.from_pretrained(d, local_files_only=True)
                txt_model = XLMRobertaModel(cfg).to(device).eval()
                txt_model.load_state_dict(sd_txt['txt_backbone'])

                txt_head = nn.Sequential(nn.Linear(cfg.hidden_size, emb_dim, bias=False),
                                         nn.BatchNorm1d(emb_dim)).to(device).eval()
                txt_head.load_state_dict(sd_txt['txt_embed_head'])
                print(f"[TEXT] XLM-R loaded offline from {d}")
                return tokenizer, txt_model, txt_head
            except Exception as e:
                print(f"[TEXT] Failed loading XLM-R from {d}: {e}")
    return None, None, None

def load_first_existing(paths):
    for p in paths:
        if os.path.exists(p):
            return p
    raise FileNotFoundError(
        "Weights not found. Provide a valid path in CANDIDATE_WEIGHT_PATHS "
        "or attach a dataset with embedding_extractor_mix.pth / embedding_extractor.pth"
    )

weights_path = load_first_existing(CANDIDATE_WEIGHT_PATHS)
print("Using weights:", weights_path)

pkg = torch.load(weights_path, map_location='cpu')
# из твоего пакета весов
sd = pkg["state_dict"]
have_text_weights = ("txt_backbone" in sd) and ("txt_embed_head" in sd)
emb_dim = int(pkg.get("emb_dim", 512))
alpha = float(pkg.get("fusion_alpha", 0.70))
tau   = float(pkg.get("fusion_tau",   0.50))

text_tokenizer, txt_model, txt_head = (None, None, None)
if have_text_weights:
    text_tokenizer, txt_model, txt_head = try_load_local_xlmr(sd, emb_dim)
#text_tokenizer = AutoTokenizer.from_pretrained(LOCAL_XLMR_DIR, use_fast=True, local_files_only=True)
#txt_model      = AutoModel.from_pretrained(LOCAL_XLMR_DIR, local_files_only=True).to(device).eval()


@torch.no_grad()
def mean_pooling(last_hidden_state, attention_mask):
    mask = attention_mask.unsqueeze(-1).float()
    summed = (last_hidden_state * mask).sum(dim=1)
    denom = mask.sum(dim=1).clamp(min=1e-6)
    return summed / denom



# ------------- Load Weights -------------

#pkg = torch.load(weights_path, map_location='cpu')

# Detect package type (mix: image+text) or image-only
have_text = False
if 'state_dict' in pkg:
    sd = pkg['state_dict']
    have_text = ('txt_backbone' in sd) and ('txt_embed_head' in sd)
else:
    raise ValueError("Invalid weights package: missing 'state_dict'")

alpha = float(pkg.get('fusion_alpha', 0.70))
tau   = float(pkg.get('fusion_tau',   0.50))
print(f"Fusion params -> alpha={alpha:.2f}, tau={tau:.2f}")
emb_dim = int(pkg.get('emb_dim', 512))

# Build branches and load weights
img_backbone, img_head, _ = build_image_branch(backbone_name=str(pkg.get('backbone_name_img', pkg.get('backbone_name', 'eca_nfnet_l1'))),
                                               emb_dim=emb_dim)
img_backbone.load_state_dict(sd['img_backbone'])
img_head.load_state_dict(sd['img_embed_head'])
img_backbone.eval(); img_head.eval()

#if have_text:
#    txt_model, txt_head, _ = build_text_embeddings(model_name=str(pkg.get('backbone_name_txt', TEXT_MODEL_NAME)),
#                                               emb_dim=emb_dim)
#    txt_model.load_state_dict(sd['txt_backbone'])
#    txt_head.load_state_dict(sd['txt_embed_head'])
#    txt_model.eval(); txt_head.eval()
#    print("Text branch: LOADED")
#else:
#    txt_model = None; txt_head = None
#    print("Text branch: NOT FOUND in weights (will run image-only)")

# ------------- Embedding builders -------------
@torch.no_grad()
def build_image_embeddings(loader):
    embs, ids = [], []
    for b in tqdm(loader, desc="Embed/img"):
        x = b['image'].to(device, non_blocking=True)
        e = img_backbone(x)               # (B, F)
        e = img_head(e)                   # (B, D)
        e = nn.functional.normalize(e, dim=1)
        embs.append(e.cpu())
        ids.extend(b['posting_id'])
    embs = torch.cat(embs, dim=0).numpy().astype('float32')
    embs /= (np.linalg.norm(embs, axis=1, keepdims=True) + 1e-8)
    return embs.astype('float16'), ids

@torch.no_grad()
def build_text_embeddings(df, batch_size=BATCH_TXT, max_len=64):
    titles = df['title'].fillna('').astype(str).tolist()
    out = []
    for i in tqdm(range(0, len(titles), batch_size), desc="Embed/txt"):
        batch = titles[i:i+batch_size]
        tok = text_tokenizer(batch, padding=True, truncation=True,
                             max_length=max_len, return_tensors='pt')
        tok = {k: v.to(device, non_blocking=True) for k, v in tok.items()}
        h = txt_model(**tok).last_hidden_state
        sent = mean_pooling(h, tok['attention_mask'])    # (B, 768)
        e = txt_head(sent)                                # (B, D)
        e = nn.functional.normalize(e, dim=1)
        out.append(e.cpu())
    embs = torch.cat(out, dim=0).numpy().astype('float32')
    embs /= (np.linalg.norm(embs, axis=1, keepdims=True) + 1e-8)
    return embs.astype('float16')

# ------------- Retrieval utils -------------
def topk_chunked_cos(embs_f16: np.ndarray, K: int, qbs: int = 128):
    """
    OOM-safe top-K по косинусу (L2-нормированные эмбеддинги).
    Возвращает (sims, idxs) формы (N, K).
    """
    N, D = embs_f16.shape
    device_t = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    db = torch.from_numpy(embs_f16.astype('float32', copy=False)).to(device_t, non_blocking=True)
    K = min(K, N)
    idxs_list, sims_list = [], []
    for start in tqdm(range(0, N, qbs), desc="TopK (torch-chunk)"):
        q = db[start:start+qbs]
        S = torch.matmul(q, db.T)  # косинус
        vals, ids = torch.topk(S, k=K, dim=1, largest=True, sorted=True)
        idxs_list.append(ids.cpu().numpy().astype('int32'))
        sims_list.append(vals.cpu().numpy().astype('float32'))
        del S, vals, ids
        if device_t.type == 'cuda':
            torch.cuda.empty_cache()
    idxs = np.vstack(idxs_list)
    sims = np.vstack(sims_list)
    del db
    return sims, idxs

def build_preds_fused_mutual(ids, idxs_img, sims_img, idxs_txt, sims_txt,
                             alpha=0.7, tau=0.50, K_cap=50):
    """
    Взаимность проверяется на объединённом кандидатном множестве и по fused-оценке.
    Если текст отсутствует, просто используем image (sims_txt будет None).
    """
    N = len(ids)
    if sims_txt is None or idxs_txt is None:
        # Image-only fallback (вес текста = 0)
        idxs_txt = [set() for _ in range(N)]
        sims_txt = None

    cand_sets = []
    for i in range(N):
        s = set(idxs_img[i])
        if isinstance(idxs_txt, np.ndarray):
            s = s.union(set(idxs_txt[i]))
        elif isinstance(idxs_txt, list):
            s = s.union(set(idxs_txt[i]))  # list of sets (image-only path also ok)
        cand_sets.append(s)

    out = {}
    for i in range(N):
        map_img = {int(j): float(s) for j, s in zip(idxs_img[i], sims_img[i])}
        map_txt = {}
        if isinstance(idxs_txt, np.ndarray) and sims_txt is not None:
            map_txt = {int(j): float(s) for j, s in zip(idxs_txt[i], sims_txt[i])}

        fused = []
        for j in cand_sets[i]:
            si = (map_img.get(j, -1.0) + 1.0) / 2.0
            if map_txt:
                st = (map_txt.get(j, -1.0) + 1.0) / 2.0
            else:
                st = 0.0
            s  = alpha * si + (1.0 - alpha) * st
            # взаимность
            if s >= tau and (i in cand_sets[j]):
                fused.append((j, s))

        fused.sort(key=lambda x: -x[1])
        keep = [ids[j] for j, _ in fused][:K_CAP]
        if ids[i] not in keep:
            keep = [ids[i]] + keep
        out[ids[i]] = set(keep[:K_CAP])
    return out

# ------------- Run inference -------------
# 1) image embeddings
img_embs, test_ids = build_image_embeddings(test_dl)

# 2) text embeddings (if present)
if have_text:
    txt_embs = build_text_embeddings(test, batch_size=BATCH_TXT, max_len=64)
else:
    txt_embs = None

# 3) top-K candidates
sims_img, idxs_img = topk_chunked_cos(img_embs, K=KQ, qbs=128)
if txt_embs is not None:
    sims_txt, idxs_txt = topk_chunked_cos(txt_embs, K=KQ, qbs=256)
else:
    sims_txt, idxs_txt = None, None

# 4) fused mutual
preds_test = build_preds_fused_mutual(test_ids, idxs_img, sims_img, idxs_txt, sims_txt,
                                      alpha=alpha, tau=tau, K_cap=K_CAP)

# 5) submission
matches_col = [' '.join(preds_test[pid]) for pid in test_ids]
sub = pd.DataFrame({'posting_id': test_ids, 'matches': matches_col})
out_path = '/kaggle/working/submission.csv'
sub.to_csv(out_path, index=False)
print("Saved:", out_path)

# sanity: средний размер групп
avg_group_size = float(np.mean([len(v) for v in preds_test.values()]))
print(f"Avg predicted group size: {avg_group_size:.2f}")

# ------------- Cleanup -------------
del img_embs, txt_embs, sims_img, idxs_img, sims_txt, idxs_txt
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()




Using weights: /kaggle/input/shopee-octob-inference/embedding_extractor_mix.pth
[TEXT] XLM-R loaded offline from /kaggle/input/shopee-token/xlmr_tokenizer
Fusion params -> alpha=0.90, tau=0.74


Embed/img: 100%|██████████| 1/1 [00:01<00:00,  1.24s/it]
Embed/txt: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s]
TopK (torch-chunk): 100%|██████████| 1/1 [00:00<00:00, 11.86it/s]
TopK (torch-chunk): 100%|██████████| 1/1 [00:00<00:00, 1376.54it/s]


Saved: /kaggle/working/submission.csv
Avg predicted group size: 1.00
