In [42]:
import pandas as pd

from tfidf import TfidfResearch
from bm25 import BM25Research
from faiss_index import FaissResearch
from qdrant_index import QdrantResearch

from sqlalchemy import text
from utils import engine
from tqdm.autonotebook import tqdm

from eval import evaluate
import numpy as np
from utils import clean_text, STOP_RU

In [None]:
def load_queries_gt(csv_path: str) -> tuple[dict[str, list[str]], list[str]]:
    df = pd.read_csv(csv_path)
    q_gt: dict[str, list[str]] = {}
    for _, row in df.iterrows():
        q = row["query"]
        art_id  = str(row["article_id"]) 
        q_gt.setdefault(q, []).append(str(art_id))

    all_ids = df["article_id"].astype(str).unique().tolist()
    return q_gt, all_ids

def queries_by_diff(df: pd.DataFrame) -> dict[str, dict[str, list[str]]]:
    groups = {}
    for diff in ("easy", "medium", "hard"):
        sub = df[df["difficulty"] == diff]
        qgt = {q: [str(doc)] for q, doc in zip(sub["query"], sub["article_id"])}
        groups[diff] = qgt
    # общий набор
    groups["all"] = {q: [str(doc)] for q, doc in zip(df["query"], df["article_id"])}
    return groups

In [49]:
df_synth = load_queries_gt("synthetic_queries.csv")
queries_gt, article_ids = df_synth
q_groups = queries_by_diff(df_synth)
print(f"Total requests: {len(queries_gt)}, Documents in corpus: {len(article_ids)}")

TypeError: tuple indices must be integers or slices, not str

In [50]:
sql = text("""
    SELECT id::text,
           COALESCE(title, '')   AS title,
           COALESCE(anons, '')   AS anons,
           COALESCE(body, '')    AS body
      FROM public.tmp_news
     WHERE id = ANY(:ids);
""")

df_docs = pd.read_sql(sql, engine, params={"ids": article_ids})


df_docs["full_text"] = (
    df_docs["title"] + ". " +
    df_docs["anons"] + ". " +
    df_docs["body"]
)
texts = df_docs["full_text"].tolist()
ids   = df_docs["id"].tolist()

print("News example:", texts[0][:150], "…")

News example: Официальной позиции Русской православной церкви по вопросу о подлинности царских останков пока нет. Об этом заявил Патриарх Московский и всея Руси Кир …


Base Experiments

In [52]:
tfidf_backend = TfidfResearch(max_features=30_000)
tfidf_backend.index(texts, ids)

metrics_by_level = {
    lvl: evaluate(tfidf_backend, qgt, top_k=10)
    for lvl, qgt in q_groups.items()
}

metrics_by_level

{'easy': {'Precision@10': 0.09795918367346937,
  'Recall@10': 0.9795918367346939,
  'MRR': 0.9642857142857143},
 'medium': {'Precision@10': 0.09595959595959594,
  'Recall@10': 0.9595959595959596,
  'MRR': 0.9287878787878788},
 'hard': {'Precision@10': 0.09099999999999998,
  'Recall@10': 0.91,
  'MRR': 0.8089285714285714},
 'all': {'Precision@10': 0.09494949494949494,
  'Recall@10': 0.9494949494949495,
  'MRR': 0.9001443001443001}}

In [None]:
bm25_backend = BM25Research()
bm25_backend.index(texts, ids)

metrics_by_level = {
    lvl: evaluate(bm25_backend, qgt, top_k=10)
    for lvl, qgt in q_groups.items()
}

metrics_by_level

{'easy': {'Precision@10': 0.09795918367346937,
  'Recall@10': 0.9795918367346939,
  'MRR': 0.9693877551020408},
 'medium': {'Precision@10': 0.09595959595959594,
  'Recall@10': 0.9595959595959596,
  'MRR': 0.9301346801346803},
 'hard': {'Precision@10': 0.09099999999999998,
  'Recall@10': 0.91,
  'MRR': 0.8142619047619047},
 'all': {'Precision@10': 0.09494949494949494,
  'Recall@10': 0.9494949494949495,
  'MRR': 0.9040724707391373}}

In [54]:
faiss_frida = FaissResearch(model_name="sergeyzh/BERTA", embed_dim=768)
faiss_frida.index(texts, ids)

metrics_by_level = {
    lvl: evaluate(faiss_frida, qgt, top_k=10)
    for lvl, qgt in q_groups.items()
}

metrics_by_level

{'easy': {'Precision@10': 0.09897959183673469,
  'Recall@10': 0.9897959183673469,
  'MRR': 0.9413265306122449},
 'medium': {'Precision@10': 0.09898989898989898,
  'Recall@10': 0.98989898989899,
  'MRR': 0.957912457912458},
 'hard': {'Precision@10': 0.09899999999999999,
  'Recall@10': 0.99,
  'MRR': 0.9633333333333333},
 'all': {'Precision@10': 0.09898989898989899,
  'Recall@10': 0.98989898989899,
  'MRR': 0.9542648709315374}}

In [55]:
qdrant_backend = QdrantResearch(
    collection_name="news_research_synth",
    model_name="sergeyzh/BERTA",
    embed_dim=768,
)
qdrant_backend.index(texts, ids)

metrics_by_level = {
    lvl: evaluate(qdrant_backend, qgt, top_k=10)
    for lvl, qgt in q_groups.items()
}

metrics_by_level

{'easy': {'Precision@10': 0.09897959183673469,
  'Recall@10': 0.9897959183673469,
  'MRR': 0.9447278911564627},
 'medium': {'Precision@10': 0.09898989898989898,
  'Recall@10': 0.98989898989899,
  'MRR': 0.9562289562289563},
 'hard': {'Precision@10': 0.09899999999999999,
  'Recall@10': 0.99,
  'MRR': 0.9683333333333333},
 'all': {'Precision@10': 0.09898989898989899,
  'Recall@10': 0.98989898989899,
  'MRR': 0.9565095398428731}}

*Additional experiments*

Lemmatise

In [29]:
import inspect
if not hasattr(inspect, "getargspec"):
    inspect.getargspec = inspect.getfullargspec      # <-- добавляем

# теперь можно спокойно импортировать MorphAnalyzer
from pymorphy2 import MorphAnalyzer
import nltk, re
nltk.download("punkt", quiet=True)

morph = MorphAnalyzer()

def lemmatize_text(txt: str) -> str:
    tokens = [tok for tok in nltk.word_tokenize(txt.lower()) if tok.isalpha()]
    lemmas = [
        morph.parse(tok)[0].normal_form
        for tok in tokens
        if len(tok) > 2
    ]
    return " ".join(lemmas)

# Применяем лемматизацию ко всем текстам (может быть долго, идёт прогресс-баром)
texts_lemma = [lemmatize_text(t) for t in tqdm(texts, desc="pymorphy lemmatize")]
queries_gt_lemma: dict[str, list[str]] = {}

for q, ids_list in tqdm(queries_gt.items(), desc="lemmatize queries"):
    lem_q = lemmatize_text(q)                  # ← ваша функция из шага 5.1
    # объединяем id, если лемматизированный ключ уже встречался
    queries_gt_lemma.setdefault(lem_q, []).extend(ids_list)

# (необязательно) уберём дубликаты id в каждом списке
for k in queries_gt_lemma:
    queries_gt_lemma[k] = list(set(queries_gt_lemma[k]))

pymorphy lemmatize: 100%|██████████| 100/100 [00:01<00:00, 53.15it/s]
lemmatize queries: 100%|██████████| 297/297 [00:00<00:00, 1296.42it/s]


In [None]:
tfidf_lemma = TfidfResearch(max_features=30_000)
tfidf_lemma.index(texts_lemma, ids)
metrics_tfidf_lemma = evaluate(tfidf_lemma, queries_gt_lemma, top_k=10)
print("TF-IDF + Pymorphy:", metrics_tfidf_lemma)

TF-IDF + Natasha-лемматизация: {'Precision@10': 0.09797979797979796, 'Recall@10': 0.9696969696969697, 'MRR': 0.9414141414141415}


In [None]:
bm25_backend = BM25Research()
bm25_backend.index(texts_lemma, ids)
metrics_bm25_lemma = evaluate(bm25_backend, queries_gt_lemma, top_k=10)
print("BM-25 + Pymorphy:", metrics_bm25_lemma)

BM-25 + Natasha-лемматизация: {'Precision@10': 0.09797979797979796, 'Recall@10': 0.9696969696969697, 'MRR': 0.9486531986531986}


In [33]:
faiss_frida = FaissResearch(model_name="sergeyzh/BERTA", embed_dim=768)
faiss_frida.index(texts_lemma, ids)
metrics_bm25_lemma = evaluate(faiss_frida, queries_gt_lemma, top_k=10)
print("FAISS + Pymorphy:", metrics_bm25_lemma)

FAISS + Pymorphy: {'Precision@10': 0.09999999999999999, 'Recall@10': 0.98989898989899, 'MRR': 0.9562289562289562}


Spellchecker

In [35]:
from spellchecker import SpellChecker
spell = SpellChecker(language="ru")

def spell_correct(txt: str) -> str:
    toks  = txt.split()
    fixed = [spell.correction(t) or t for t in toks]
    return " ".join(fixed)


queries_gt_spell = {spell_correct(q): gts for q, gts in tqdm(queries_gt.items())}

bm25_spell = BM25Research()
bm25_spell.index(texts, ids)
metrics_bm25_spell = evaluate(bm25_spell, queries_gt_spell, top_k=10)
print("BM25 + spell-коррекция:", metrics_bm25_spell)


  0%|          | 0/297 [00:00<?, ?it/s]

100%|██████████| 297/297 [04:13<00:00,  1.17it/s]

BM25 + spell-коррекция: {'Precision@10': 0.09528619528619528, 'Recall@10': 0.9427609427609428, 'MRR': 0.8620490620490622}





Hybrid search

In [39]:
faiss_full = FaissResearch(model_name="sergeyzh/BERTA")
faiss_full.index(texts, ids)                 # ← долго, но один раз!
doc_embs      = faiss_full.embeddings        # ndarray shape (N, 768)
id2pos: dict  = {doc_id: i for i, doc_id in enumerate(ids)}


In [None]:
# 2. Быстрая гибрид-функция (без пересоздания индекса)
# ──────────────────────────────────────────────────────────────
def hybrid_search_fast(query: str, N: int = 100, top_k: int = 10):
    # 2.1 топ-N TF-IDF (как раньше)
    top_tfidf = tfidf_lemma.search(query, top_k=N)
    sub_idx   = np.array([id2pos[d] for d, _ in top_tfidf], dtype=np.int64)

    # 2.2 эмбеддинг запроса (одна API-вызов)
    q_vec = faiss_full._get_embeddings([query])      # shape (1, 768)

    # 2.3 косинус-скоры: q_vec (1×768) · doc_embs[sub_idx].T (768×N)
    scores = np.squeeze(q_vec @ doc_embs[sub_idx].T)

    # 2.4 сортируем
    order  = np.argsort(scores)[::-1][:top_k]
    result = [(ids[sub_idx[i]], float(scores[i])) for i in order]
    return result


class HybridFastBackend:
    def __init__(self, N=100, top_k=10):
        self.N = N
        self.top_k = top_k
    def index(self, *args, **kwargs):
        pass                               # ничего не нужно — всё уже посчитано
    def search(self, q, top_k=10):
        return hybrid_search_fast(q, N=self.N, top_k=top_k)

In [44]:

hybrid_fast = HybridFastBackend(N=100)
metrics_hybrid_fast = evaluate(hybrid_fast, queries_gt, top_k=10)
print("Hybrid-FAST:", metrics_hybrid_fast)

Hybrid-FAST: {'Precision@10': 0.09999999999999999, 'Recall@10': 0.98989898989899, 'MRR': 0.9598765432098765}


Differentiated evaluation

In [None]:
df_synth = pd.read_csv("synthetic_queries.csv")



q_groups = queries_by_diff(df_synth)

# ──────────────────────────────────────────────────────────────
# 2. Backend (любой). Возьмём оптимизированный гибрид
# ──────────────────────────────────────────────────────────────
hybrid_fast = HybridFastBackend(N=50)

# ──────────────────────────────────────────────────────────────
# 3. Считаем метрики по каждому уровню сложности
# ──────────────────────────────────────────────────────────────
metrics_by_level = {
    lvl: evaluate(hybrid_fast, qgt, top_k=10)
    for lvl, qgt in q_groups.items()
}

ValueError: Unknown format code 'f' for object of type 'str'

<pandas.io.formats.style.Styler at 0x744529cce990>

In [46]:
metrics_by_level

{'easy': {'Precision@10': 0.09081632653061222,
  'Recall@10': 0.9081632653061225,
  'MRR': 0.8928571428571429},
 'medium': {'Precision@10': 0.08989898989898988,
  'Recall@10': 0.898989898989899,
  'MRR': 0.8888888888888888},
 'hard': {'Precision@10': 0.086, 'Recall@10': 0.86, 'MRR': 0.85},
 'all': {'Precision@10': 0.08888888888888888,
  'Recall@10': 0.8888888888888888,
  'MRR': 0.877104377104377}}