In [None]:
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import pickle
import json
import numpy as np
from numpy.linalg import norm
import pandas as pd
from joblib import delayed, Parallel
import warnings
from sklearn.metrics.pairwise import cosine_similarity
import os
warnings.filterwarnings('ignore')
np.random.seed(1337)

model = SentenceTransformer('e5_large/', device='cuda:1')  # or 'e5-large-v2'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [None]:
def cosin_sim(a, b):
    return (a * b).sum(-1) / (norm(a, axis=-1) * norm(b, axis=-1))

In [None]:
queries = pickle.load(open('data/generated_data/q_res_single.pkl', 'rb'))
passages = json.load(open('data/generated_data/passages.json', 'r', encoding='utf-8'))

In [None]:
passages_texts = [passage['page_content'] for passage in passages]

In [None]:
positive_dataset = []
for questions, passage_text in zip(queries, passages_texts):
    if questions:
        for q in questions:
            positive_dataset.append((q, passage_text))

len(positive_dataset)

In [None]:
### Фтльтруем то, что забраковала LLM
e_res = pickle.load(open('data/generated_data/e_res_single.pkl', 'rb'))

positive_dataset = [positive_dataset[i] for i in range(len(positive_dataset)) if e_res[i]]
len(positive_dataset)

In [None]:
### Индексируем вопросы и чанки 
query_to_text = {query: text for query, text in positive_dataset}
queries = list(query_to_text.keys())
queries_embs = model.encode(queries, show_progress_bar=True)
query_to_emb = {query: emb for query, emb in zip(queries, queries_embs)}
doc_embs = model.encode(passages_texts, show_progress_bar=True)
text_to_emb = {text: emb for text, emb in zip(passages_texts, doc_embs)}

In [None]:
### Фильтруем плохие позитивы
print('queries before clean:', len(queries))
threshold = 0.83
for q in queries:
    q_emb = query_to_emb[q]
    doc_emb = text_to_emb[query_to_text[q]]
    if cosin_sim(q_emb, doc_emb) < threshold:
        query_to_text.pop(q)
        query_to_emb.pop(q)
queries = list(query_to_text.keys())
queries_embs = [query_to_emb[query] for query in queries]
print('queries after clean:', len(query_to_text))

In [None]:
similarities = cosine_similarity(queries_embs, doc_embs)

### Negatives mining

In [None]:
MARGINE = 0.03

In [None]:
def get_potentials(query_emb, doc_emb):
    """Для вопроса и его голд пассажа находим индексы всех потенциальных негативов"""
    all_similarities = cosin_sim(query_emb, doc_embs)
    positive_similarity = cosin_sim(query_emb, doc_emb)
    res = []
    for i in range(len(doc_embs)):
        # if all_similarities[i] + MARGINE < positive_similarity:
            res.append(i)

    return res


In [None]:
potentials = Parallel(n_jobs=8)(delayed(get_potentials)(query_to_emb[query], text_to_emb[query_to_text[query]]) for query in tqdm(queries))

In [None]:
query_to_potentials = {query: potential for query, potential in zip(queries, potentials)}

In [None]:
### Собираем soft negatives
num_negatives = 1
soft_negatives = []
for i, query in enumerate(tqdm(queries)):
    potentials = query_to_potentials[query]
    negative_indexes = np.random.choice(potentials, num_negatives)
    for index in negative_indexes:
        triplet = (query, query_to_text[query], passages_texts[index])
        soft_negatives.append(triplet)

In [None]:
### Собираем hard negatives
num_negatives = 1
hard_negatives_threshold = 0.75
hard_negatives = []
for i, query in enumerate(tqdm(queries)):
    potentials = query_to_potentials[query]
    potentials = [potential for potential in potentials if similarities[i, potential] > hard_negatives_threshold]
    try:
        negative_indexes = np.random.choice(potentials, num_negatives)
        for index in negative_indexes:
            triplet = (query, query_to_text[query], passages_texts[index])
            hard_negatives.append(triplet)
    except:
        1

In [None]:
len(soft_negatives), len(hard_negatives)

In [None]:
pickle.dump(soft_negatives, open('data/soft_negatives_single_83_0_1_75.pkl', 'wb'))
pickle.dump(hard_negatives, open('data/hard_negatives_single_83_0_1_75.pkl', 'wb'))