In [1]:
!pip install -qU sentence-transformers

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.1/494.1 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h

In [2]:
# ============================================================
# V2.1: EMBEDDING-ONLY INTENT MODEL (ANN HARD-NEGATIVE SAFE)
# ============================================================

from __future__ import annotations

import json
import os
import random
import logging
import sys
import time
from pathlib import Path
from collections import defaultdict
from typing import Dict, List, Iterable, Tuple

import torch
from torch.utils.data import DataLoader

from sentence_transformers import (
    SentenceTransformer,
    InputExample,
    losses,
    util
)
from sentence_transformers.losses import TripletLoss

# ============================================================
# LOGGING (HARD FLUSH + STEP IDS)
# ============================================================

logging.getLogger().handlers.clear()
logger = logging.getLogger("intent-embedder-v2.1")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s | %(message)s"))
logger.addHandler(handler)
logger.propagate = False

_STEP = 0

def log_step(msg: str):
    global _STEP
    _STEP += 1
    logger.info(f"[STEP {_STEP:02d}] {msg}")
    sys.stdout.flush()
    sys.stderr.flush()

def log_kv(**kwargs):
    logger.info(" | ".join(f"{k}={v}" for k, v in kwargs.items()))
    sys.stdout.flush()
    sys.stderr.flush()

def heartbeat(tag: str):
    logger.info(f"[HEARTBEAT] {tag} | t={time.time():.2f}")
    sys.stdout.flush()

# ============================================================
# ENV / TORCH HARDENING (NO LOGIC CHANGE)
# ============================================================

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"

torch.set_num_threads(1)
torch.set_num_interop_threads(1)

# ============================================================
# CONFIG
# ============================================================

BASE_MODEL = "paraphrase-multilingual-mpnet-base-v2"
DATA_JSON_PATH = Path("/kaggle/input/training-data-gold-silver-final/training_data_gold_silver_FINAL.json")
OUTPUT_DIR = Path("/kaggle/working/food_intent_embedder_v2_1")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MAX_SEQ_LEN = 128
BATCH_SIZE = 64

EPOCHS_GOLD = 3
EPOCHS_SILVER = 1
EPOCHS_TRIPLET = 1

LR_BASE = 2e-5
LR_SILVER = 1e-5
LR_TRIPLET = 1e-5

MAX_GOLD_PAIRS_PER_INTENT = 300
MAX_GOLD_SILVER_PAIRS_PER_INTENT = 150
MAX_TRIPLETS_PER_INTENT = 200
ANN_K = 10

random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

log_step("Torch + environment initialized")
log_kv(
    torch_version=torch.__version__,
    cuda_available=torch.cuda.is_available(),
    device=DEVICE,
)

# ============================================================
# LOAD DATA
# ============================================================

log_step("Loading dataset")

with open(DATA_JSON_PATH, "r", encoding="utf-8") as f:
    raw_data: Dict[str, Dict[str, List[str]]] = json.load(f)

gold_texts: Dict[str, List[str]] = {}
silver_texts: Dict[str, List[str]] = {}

for intent, tiers in raw_data.items():
    g = list(set(tiers.get("gold", [])))
    s = list(set(tiers.get("silver", [])))
    if g:
        gold_texts[intent] = g
    if s:
        silver_texts[intent] = s

assert set(silver_texts.keys()).issubset(gold_texts.keys())

intents = sorted(gold_texts.keys())
log_kv(
    intents=len(intents),
    gold_items=sum(len(v) for v in gold_texts.values()),
    silver_items=sum(len(v) for v in silver_texts.values()),
)

# ============================================================
# MODEL INIT
# ============================================================

log_step(f"Initializing model on {DEVICE}")
heartbeat("before_model_init")

model = SentenceTransformer(BASE_MODEL, device=DEVICE)
model.max_seq_length = MAX_SEQ_LEN

heartbeat("after_model_init")

# ============================================================
# DATA GENERATORS
# ============================================================

def gold_pair_generator() -> Iterable[InputExample]:
    for intent, texts in gold_texts.items():
        if len(texts) < 2:
            continue
        n = min(len(texts) * 2, MAX_GOLD_PAIRS_PER_INTENT)
        for _ in range(n):
            a, b = random.sample(texts, 2)
            yield InputExample(texts=[a, b])

def gold_silver_pair_generator() -> Iterable[InputExample]:
    for intent, g_texts in gold_texts.items():
        s_texts = silver_texts.get(intent)
        if not s_texts:
            continue
        n = min(len(g_texts), len(s_texts), MAX_GOLD_SILVER_PAIRS_PER_INTENT)
        for _ in range(n):
            yield InputExample(
                texts=[random.choice(g_texts), random.choice(s_texts)]
            )

# ============================================================
# STAGE 1: GOLD METRIC LEARNING
# ============================================================

log_step("Stage 1: Gold metric learning")

gold_pairs = list(gold_pair_generator())
log_kv(gold_pairs=len(gold_pairs))

gold_loader = DataLoader(
    gold_pairs,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    num_workers=0,
)

mnrl_loss = losses.MultipleNegativesRankingLoss(model)

heartbeat("before_stage1_fit")

model.fit(
    train_objectives=[(gold_loader, mnrl_loss)],
    epochs=EPOCHS_GOLD,
    optimizer_params={"lr": LR_BASE},
    show_progress_bar=True,
)

heartbeat("after_stage1_fit")

# ============================================================
# STAGE 2: GOLD–SILVER SMOOTHING
# ============================================================

if silver_texts:
    log_step("Stage 2: Gold–Silver smoothing")

    gs_pairs = list(gold_silver_pair_generator())
    log_kv(gs_pairs=len(gs_pairs))

    if gs_pairs:
        gs_loader = DataLoader(
            gs_pairs,
            batch_size=BATCH_SIZE,
            shuffle=True,
            drop_last=True,
            num_workers=0,
        )

        heartbeat("before_stage2_fit")

        model.fit(
            train_objectives=[(gs_loader, mnrl_loss)],
            epochs=EPOCHS_SILVER,
            optimizer_params={"lr": LR_SILVER},
            show_progress_bar=True,
        )

        heartbeat("after_stage2_fit")

# ============================================================
# STAGE 3: ANN HARD-NEGATIVE MINING
# ============================================================

log_step("Stage 3: ANN hard-negative mining")

all_texts: List[str] = []
all_labels: List[str] = []

for intent, samples in gold_texts.items():
    for s in samples:
        all_texts.append(s)
        all_labels.append(intent)

log_kv(total_texts=len(all_texts))

heartbeat("before_encode_all")

with torch.no_grad():
    embeddings = model.encode(
        all_texts,
        convert_to_tensor=True,
        normalize_embeddings=True,
        batch_size=128,
    )

heartbeat("after_encode_all")
log_kv(embedding_shape=tuple(embeddings.shape))

heartbeat("before_ann_search")

search_results = util.semantic_search(
    embeddings,
    embeddings,
    top_k=min(len(all_texts), ANN_K + 1),
    score_function=util.dot_score,
)

heartbeat("after_ann_search")
log_kv(ann_rows=len(search_results))

triplets: List[InputExample] = []
per_intent_count = defaultdict(int)

for i, neighbors in enumerate(search_results):
    anchor_text = all_texts[i]
    anchor_intent = all_labels[i]

    positive = None
    negative = None

    for hit in neighbors[1:]:
        j = hit["corpus_id"]
        if all_labels[j] == anchor_intent and positive is None:
            positive = all_texts[j]
        elif all_labels[j] != anchor_intent and negative is None:
            negative = all_texts[j]
        if positive and negative:
            break

    if positive and negative and per_intent_count[anchor_intent] < MAX_TRIPLETS_PER_INTENT:
        triplets.append(InputExample(texts=[anchor_text, positive, negative]))
        per_intent_count[anchor_intent] += 1

log_kv(triplets=len(triplets))

if triplets:
    batch_size = min(32, len(triplets))
    triplet_loader = DataLoader(
        triplets,
        batch_size=batch_size,
        shuffle=True,
        drop_last=(len(triplets) >= batch_size),
        num_workers=0,
    )

    try:
        triplet_loss = TripletLoss(
            model=model,
            distance_metric="cosine",
            margin=0.2,
        )
        log_step("TripletLoss: cosine + margin")
    except Exception as e:
        log_kv(triplet_loss_fallback=str(e))
        cosine_dist = lambda x, y: 1 - torch.nn.functional.cosine_similarity(x, y)
        triplet_loss = TripletLoss(
            model=model,
            distance_metric=cosine_dist,
        )

    heartbeat("before_stage3_fit")

    model.fit(
        train_objectives=[(triplet_loader, triplet_loss)],
        epochs=EPOCHS_TRIPLET,
        optimizer_params={"lr": LR_TRIPLET},
        show_progress_bar=True,
    )

    heartbeat("after_stage3_fit")

# ============================================================
# CENTROID BUILD
# ============================================================

log_step("Building final gold centroids")

centroid_texts = []
centroid_labels = []

for intent, texts in gold_texts.items():
    centroid_texts.extend(texts)
    centroid_labels.extend([intent] * len(texts))

heartbeat("before_centroid_encode")

with torch.no_grad():
    emb = model.encode(
        centroid_texts,
        convert_to_tensor=True,
        normalize_embeddings=True,
        batch_size=128,
    )

heartbeat("after_centroid_encode")

centroids: Dict[str, torch.Tensor] = {}
offset = 0
for intent, texts in gold_texts.items():
    n = len(texts)
    c = emb[offset:offset + n].mean(dim=0)
    centroids[intent] = util.normalize_embeddings(c.unsqueeze(0))[0]
    offset += n

intent_list = sorted(centroids.keys())
centroid_matrix = torch.stack([centroids[i] for i in intent_list])

# ============================================================
# SAVE ARTIFACTS
# ============================================================

log_step("Saving artifacts")

model.save(str(OUTPUT_DIR / "encoder"))

torch.save(
    {
        "intent_list": intent_list,
        "centroids": centroid_matrix.cpu(),
    },
    OUTPUT_DIR / "centroids.pt",
)

with open(OUTPUT_DIR / "meta.json", "w") as f:
    json.dump(
        {
            "base_model": BASE_MODEL,
            "intents": intent_list,
            "training": {
                "gold_epochs": EPOCHS_GOLD,
                "silver_epochs": EPOCHS_SILVER,
                "triplet_epochs": EPOCHS_TRIPLET,
                "losses": [
                    "MultipleNegativesRankingLoss",
                    "TripletLoss (ANN hard negatives)",
                ],
            },
        },
        f,
        indent=2,
    )

log_step("Training complete")

# ============================================================
# INFERENCE (OPEN WORLD)
# ============================================================

@torch.no_grad()
def predict(
    text: str,
    min_similarity: float = 0.45,
    min_margin: float = 0.05,
) -> Tuple[str, float, float]:

    emb = model.encode(
        [text],
        convert_to_tensor=True,
        normalize_embeddings=True,
    )

    sims = util.dot_score(emb, centroid_matrix)[0]
    top2 = torch.topk(sims, k=2)

    best_id = top2.indices[0].item()
    best_sim = top2.values[0].item()
    margin = (top2.values[0] - top2.values[1]).item()

    if best_sim < min_similarity or margin < min_margin:
        return "UNKNOWN", best_sim, margin

    return intent_list[best_id], best_sim, margin


2026-02-02 11:05:44.195141: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770030344.395753      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770030344.451642      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770030344.941018      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770030344.941057      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770030344.941060      55 computation_placer.cc:177] computation placer alr

2026-02-02 11:06:00,694 | INFO | [STEP 01] Torch + environment initialized
2026-02-02 11:06:00,695 | INFO | torch_version=2.8.0+cu126 | cuda_available=True | device=cuda
2026-02-02 11:06:00,697 | INFO | [STEP 02] Loading dataset
2026-02-02 11:06:01,005 | INFO | intents=12 | gold_items=50930 | silver_items=33039
2026-02-02 11:06:01,007 | INFO | [STEP 03] Initializing model on cuda
2026-02-02 11:06:01,008 | INFO | [HEARTBEAT] before_model_init | t=1770030361.01


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/402 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

2026-02-02 11:06:08,145 | INFO | [HEARTBEAT] after_model_init | t=1770030368.14
2026-02-02 11:06:08,146 | INFO | [STEP 04] Stage 1: Gold metric learning
2026-02-02 11:06:08,156 | INFO | gold_pairs=3512
2026-02-02 11:06:08,158 | INFO | [HEARTBEAT] before_stage1_fit | t=1770030368.16


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss


2026-02-02 11:10:06,831 | INFO | [HEARTBEAT] after_stage1_fit | t=1770030606.83
2026-02-02 11:10:06,833 | INFO | [STEP 05] Stage 2: Gold–Silver smoothing
2026-02-02 11:10:06,836 | INFO | gs_pairs=1368
2026-02-02 11:10:06,838 | INFO | [HEARTBEAT] before_stage2_fit | t=1770030606.84


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss


2026-02-02 11:10:34,898 | INFO | [HEARTBEAT] after_stage2_fit | t=1770030634.90
2026-02-02 11:10:34,899 | INFO | [STEP 06] Stage 3: ANN hard-negative mining
2026-02-02 11:10:34,910 | INFO | total_texts=50930
2026-02-02 11:10:34,911 | INFO | [HEARTBEAT] before_encode_all | t=1770030634.91
2026-02-02 11:11:37,750 | INFO | [HEARTBEAT] after_encode_all | t=1770030697.75
2026-02-02 11:11:37,752 | INFO | embedding_shape=(50930, 768)
2026-02-02 11:11:37,753 | INFO | [HEARTBEAT] before_ann_search | t=1770030697.75
2026-02-02 11:11:39,362 | INFO | [HEARTBEAT] after_ann_search | t=1770030699.36
2026-02-02 11:11:39,363 | INFO | ann_rows=50930
2026-02-02 11:11:39,517 | INFO | triplets=2266
2026-02-02 11:11:39,519 | INFO | triplet_loss_fallback=TripletLoss.__init__() got an unexpected keyword argument 'margin'
2026-02-02 11:11:39,520 | INFO | [HEARTBEAT] before_stage3_fit | t=1770030699.52


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss


2026-02-02 11:12:53,225 | INFO | [HEARTBEAT] after_stage3_fit | t=1770030773.23
2026-02-02 11:12:53,227 | INFO | [STEP 07] Building final gold centroids
2026-02-02 11:12:53,229 | INFO | [HEARTBEAT] before_centroid_encode | t=1770030773.23
2026-02-02 11:13:56,103 | INFO | [HEARTBEAT] after_centroid_encode | t=1770030836.10
2026-02-02 11:13:56,107 | INFO | [STEP 08] Saving artifacts
2026-02-02 11:13:58,522 | INFO | [STEP 09] Training complete
