In [None]:
!pip install faiss-gpu-cu12 xformers torch torchvision transformers  sentence_transformers -U

In [None]:
"""
–§–æ—Ä–º–∏—Ä—É–µ—Ç –¥–∞—Ç–∞—Å–µ—Ç triplets (q, pos, [neg‚ÇÅ ‚Ä¶ neg‚Çá])
–¥–ª—è —Å–ø–ª–∏—Ç–æ–≤, –≥–¥–µ `positive` ‚Äî –æ–¥–∏–Ω–æ—á–Ω–∞—è —Å—Ç—Ä–æ–∫–∞.
"""

import random

import datasets
import faiss
import numpy as np
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
from transformers import AutoConfig
from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict
from huggingface_hub import HfApi, create_repo



# ==========================
# Config
# ==========================
CONFIG = {
    "prefix_type":          "search",
    "hard_neg_rank_start":  10,
    "hard_neg_rank_end":    50,
    "hard_neg_count":       5,
    "max_triplets":         100_000,
    "log_every":            500,
    "max_tokens":           1024,
    "sample_size":          5_000,
    "margin_tau":           0.01,

    "dataset_path":         "zloelias/lenta-ru",
    "query_column":         "title",
    "positive_column":      "text",        # <-- –æ–¥–Ω–∞ —Å—Ç—Ä–æ–∫–∞
    "train_split":          "train",

    "hf_token":             "<token>",
    "target_repo":          "Alexator26/lenta-ru-triplets"
}

# ==========================
# Model & tokenizer
# ==========================
tokenizer = AutoTokenizer.from_pretrained("deepvk/USER2-base")
model     = SentenceTransformer("deepvk/USER2-base")


# ========= helpers =========
def count_tokens(text: str) -> int:
    return len(tokenizer.encode(text, truncation=False))



def prompt_name(is_query: bool) -> str: #–≤—Å–µ–≥–¥–∞ —é–∑–∞–ª–æ—Å—å search –ø—Ä–∏ –º–∞–π–Ω–∏–Ω–≥–µ
    if CONFIG["prefix_type"] == "search":
        return "search_query" if is_query else "search_document"
    return "clustering"


def create_triplets(
        split,
        q_field: str,
        p_field: str,
        *,
        max_tokens : int   = 500,
        margin_tau : float = CONFIG['margin_tau']         # pos_sim ‚Äì neg_sim ‚â• œÑ (semi‚Äëhard)
    ):
    """
    –°—Ç—Ä–æ–∏—Ç —Ç—Ä–∏–ø–ª–µ—Ç—ã (query, positive, [neg‚Ä¶]) c semi‚Äëhard –Ω–µ–≥–∞—Ç–∏–≤–∞–º–∏:
        sim(q, neg) < sim(q, pos) ‚Äì œÑ.

    –ò—Å–ø–æ–ª—å–∑—É—é—Ç—Å—è –ø–∞—Ä–∞–º–µ—Ç—Ä—ã –∏–∑ CONFIG:
        ‚Ä¢ hard_neg_rank_start / _end  ‚Äì –¥–∏–∞–ø–∞–∑–æ–Ω FAISS‚Äë—Ä–∞–Ω–≥–æ–≤,
        ‚Ä¢ hard_neg_count              ‚Äì —Å–∫–æ–ª—å–∫–æ –Ω–µ–≥–∞—Ç–∏–≤–æ–≤ –≤ —Ç—Ä–∏–ø–ª–µ—Ç–µ.
    """

    # ------------------------------------------------------------------ #
    # 0. –§–∏–ª—å—Ç—Ä–∞—Ü–∏—è –∏ –ø–µ—Ä–µ–º–µ—à–∏–≤–∞–Ω–∏–µ
    # ------------------------------------------------------------------ #
    examples = [ex for ex in split if ex.get(q_field) and ex.get(p_field)]
    random.shuffle(examples)
    if not examples:
        return []

    # ------------------------------------------------------------------ #
    # 1. –£–Ω–∏–∫–∞–ª–∏–∑–∏—Ä—É–µ–º positive‚Äë–¥–æ–∫—É–º–µ–Ω—Ç—ã
    # ------------------------------------------------------------------ #
    unique_docs, doc2idx, ex2doc_idx = [], {}, {}
    for ex_id, ex in enumerate(tqdm(examples, desc="Collect uniques")):
        doc = ex[p_field]
        if count_tokens(doc) > max_tokens:
            continue

        # –¥–æ–±–∞–≤–ª—è–µ–º –¥–æ–∫—É–º–µ–Ω—Ç —Ç–æ–ª—å–∫–æ –æ–¥–∏–Ω —Ä–∞–∑
        if doc not in doc2idx:
            doc_idx = len(unique_docs)
            doc2idx[doc] = doc_idx
            unique_docs.append(doc)

        ex2doc_idx[ex_id] = doc2idx[doc]

    if not unique_docs:
        return []

    # ------------------------------------------------------------------ #
    # 2. FAISS‚Äë–∏–Ω–¥–µ–∫—Å –ø–æ –¥–æ–∫—É–º–µ–Ω—Ç–∞–º
    # ------------------------------------------------------------------ #
    doc_embs = model.encode(
        unique_docs,
        prompt_name       = prompt_name(is_query=False),
        convert_to_numpy  = True,
        show_progress_bar = True
    ).astype("float32")

    faiss.normalize_L2(doc_embs)                    # cosine == IP
    index = faiss.IndexFlatIP(doc_embs.shape[1])
    index.add(doc_embs)

    # ------------------------------------------------------------------ #
    # 3. –ú–∞–π–Ω–Ω–∏–Ω–≥ —Å semi‚Äëhard –Ω–µ–≥–∞—Ç–∏–≤–∞–º–∏
    # ------------------------------------------------------------------ #
    triplets = []

    for ex_id, ex in enumerate(tqdm(examples, desc="Mining triplets")):
        if len(triplets) >= CONFIG["max_triplets"]:
            break
        if ex_id and ex_id % CONFIG["log_every"] == 0:
            print(f"   {ex_id}/{len(examples)} ‚Üí {len(triplets)} triplets")

        q_text = ex[q_field]
        if count_tokens(q_text) > max_tokens:
            continue

        pos_idx = ex2doc_idx.get(ex_id)
        if pos_idx is None:
            continue
        pos_text = unique_docs[pos_idx]

        # ---------- —ç–º–±–µ–¥–¥–∏–Ω–≥ –∑–∞–ø—Ä–æ—Å–∞ (1) ------------------------------
        q_emb   = model.encode(
            q_text,
            prompt_name       = prompt_name(is_query=True),
            convert_to_numpy  = True,
            show_progress_bar = False
        ).astype("float32")                         # shape: (d,)

        # –ø—Ä–∏–≤–æ–¥–∏–º –∫ shape (1, d) –¥–ª—è FAISS‚Äë–Ω–æ—Ä–º–∞–ª–∏–∑–∞—Ü–∏–∏
        q_emb_2d = q_emb.reshape(1, -1)
        faiss.normalize_L2(q_emb_2d)               # in‚Äëplace
        q_emb    = q_emb_2d[0]                     # —Å–Ω–æ–≤–∞ (d,)

        # ---------- similarity —Å –ø–æ–∑–∏—Ç–∏–≤–æ–º -----------------------------
        pos_sim = float(np.dot(q_emb, doc_embs[pos_idx]))

        # ---------- –∏—â–µ–º —Å–æ—Å–µ–¥–µ–π ---------------------------------------
        _, neigh = index.search(
            q_emb_2d,                              # 2‚ÄëD !
            CONFIG["hard_neg_rank_end"] + 1
        )

        # ---------- semi‚Äëhard –æ—Ç–±–æ—Ä ------------------------------------
        candidates = []
        for i in neigh[0][CONFIG["hard_neg_rank_start"]
                          : CONFIG["hard_neg_rank_end"] + 1]:
            if i == pos_idx:                       # –ø—Ä–æ–ø—É—Å–∫–∞–µ–º –ø–æ–∑–∏—Ç–∏–≤
                continue

            sim_i = float(np.dot(q_emb, doc_embs[i]))
            if sim_i < pos_sim - margin_tau:       # semi‚Äëhard —É—Å–ª–æ–≤–∏–µ
                candidates.append(i)

        random.shuffle(candidates)
        neg_idxs = candidates[:CONFIG["hard_neg_count"]]

        if len(neg_idxs) < CONFIG["hard_neg_count"]:
            continue

        # ---------- —Å–∫–ª–∞–¥—ã–≤–∞–µ–º —Ç—Ä–∏–ø–ª–µ—Ç ---------------------------------
        triplets.append({
            "query"    : q_text,
            "positive" : pos_text,
            "negatives": [unique_docs[i]
                          for i in neg_idxs]
        })

    # ------------------------------------------------------------------ #
    # 4. –§–∏–Ω–∞–ª—å–Ω–∞—è –¥–µ–¥—É–ø–ª–∏–∫–∞—Ü–∏—è
    # ------------------------------------------------------------------ #
    unique_triplets = {
        (t["query"], t["positive"], tuple(t["negatives"])): t
        for t in triplets
    }

    return list(unique_triplets.values())

# =========  quality check: are negatives really hard?  =========
def evaluate_hard_negative_quality(triplets, sample_size=CONFIG['sample_size']):
    """
    Quickly probes the mined triplets to be sure that
    (a) positives are closer to the query than any negative
    (b) negatives are still *close enough* to qualify as ¬´hard¬ª.

    Returns a dict with interpretable metrics that can be dumped
    straight into the dataset card (or just printed to console).
    """
    # 1) ------------------  sampling  ------------------
    sample = random.sample(triplets, k=min(sample_size, len(triplets)))

    q_texts   = [t["query"]               for t in sample]
    pos_texts = [t["positive"]            for t in sample]
    neg_texts = [n for t in sample for n in t["negatives"]]


    # 2) ------------------  embeddings  ----------------
    q_embs   = model.encode(q_texts,
                            prompt_name=prompt_name(True),
                            convert_to_numpy=True, show_progress_bar=True)

    pos_embs = model.encode(pos_texts,
                            prompt_name=prompt_name(False),
                            convert_to_numpy=True, show_progress_bar=True)

    neg_embs = model.encode(neg_texts,
                            prompt_name=prompt_name(False),
                            convert_to_numpy=True, show_progress_bar=True)

    # cosine normalisation
    q_embs   /= np.linalg.norm(q_embs,   axis=1, keepdims=True)
    pos_embs /= np.linalg.norm(pos_embs, axis=1, keepdims=True)
    neg_embs /= np.linalg.norm(neg_embs, axis=1, keepdims=True)

    # 3) ------------------  similarity math -------------
    pos_sims = (q_embs * pos_embs).sum(axis=1)

    neg_sims, margins = [], []
    idx, bad_cnt = 0, 0                                      # "bad" == negative ‚â• positive
    for i, t in tqdm(enumerate(sample), desc='evaluate_hard_negative_quality', total=sample_size):
        n = len(t["negatives"])
        sims = q_embs[i] @ neg_embs[idx: idx + n].T

        hardest = sims.max()
        neg_sims.extend(sims.tolist())

        if hardest >= pos_sims[i]:
            bad_cnt += 1
        margins.append(pos_sims[i] - hardest)

        idx += n

    stats = {
        "sampled_triplets":              len(sample),
        "mean_pos_sim":                 float(np.mean(pos_sims)),
        "mean_neg_sim":                 float(np.mean(neg_sims)),
        "mean_margin_pos_vs_hardest":   float(np.mean(margins)),
        "triplets_with_harder_negative": bad_cnt
    }
    return stats


def push_dataset(dsdict, readme_fragment):
    api = HfApi()

    create_repo(
        repo_id=CONFIG["target_repo"],
        repo_type="dataset",
        private=True,
        token=CONFIG["hf_token"]
    )

    dsdict.push_to_hub(
        repo_id=CONFIG["target_repo"],
        token=CONFIG["hf_token"],
        private=True,
    )

    with open("README.md", "w") as f:
        f.write(readme_fragment)

    api.upload_file(
        path_or_fileobj="README.md",
        path_in_repo="README.md",
        repo_id=CONFIG["target_repo"],
        repo_type="dataset",
        token=CONFIG["hf_token"]
    )


# ========= run & push =========
def main():
    raw_ds  = datasets.load_dataset(CONFIG["dataset_path"])
    split   = raw_ds[CONFIG["train_split"]]

    triplets = create_triplets(
        split,
        CONFIG["query_column"],
        CONFIG["positive_column"],
        max_tokens=CONFIG["max_tokens"]
    )

    if not triplets:
        raise RuntimeError("Triplets list empty!")

    train_ds = Dataset.from_list(triplets)
    dsdict   = DatasetDict({"train": train_ds})
    print(dsdict)

    # ================== call the checker right before push ==================
    stats = evaluate_hard_negative_quality(triplets)

    readme_fragment = f"""
    ### üîé Hard‚Äënegative sanity check
    Randomly inspected {stats['sampled_triplets']:,} triplets with `deepvk/USER2-base`.

    | metric | value |
    | --- | --- |
    | mean cos‚Äësim(query, **positive**) | **{stats['mean_pos_sim']:.4f}** |
    | mean cos‚Äësim(query, negatives)   | {stats['mean_neg_sim']:.4f} |
    | mean margin = pos ‚Äì hardest_neg  | {stats['mean_margin_pos_vs_hardest']:.4f} |
    | bad cases (neg ‚â• pos)            | {stats['triplets_with_harder_negative']}/{stats['sampled_triplets']} |

    Lower margin ‚áí harder negatives.
    Ideally the last line should be 0.
    """

    print(readme_fragment)
    push_dataset(dsdict, readme_fragment)



if __name__ == "__main__":
    main()
