In [19]:
# !pip install rank_bm25


In [20]:
from sentence_transformers import SentenceTransformer
import torch
from tqdm import tqdm
import uuid
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import re
import json
from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize
import numpy as np

In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"
db_name = 'CASML - Generative AI Hackathon'
LLM_model_id = "Qwen/Qwen1.5-1.8B-Chat"
cross_encoder_id = "BAAI/bge-reranker-v2-m3"
embedding_model_id = "intfloat/multilingual-e5-large"

chunk_overlap = 300
chunk_size = 1000

In [22]:
# import pdfplumber
# texts = {}

# with pdfplumber.open("data/book.pdf") as pdf:
#     text = ""
#     for i, page in enumerate(pdf.pages):
#         if i < 18 or i > 642:
#             continue
#         text = page.extract_text()
#         texts[i - 18] = text

In [23]:
# with open("data/texts.json", "w", encoding="utf-8") as f:
#     json.dump(texts, f, ensure_ascii=False, indent=2)

In [24]:
with open("data/texts.json", "r", encoding="utf-8") as f:
    texts = json.load(f)

**ОЧИСТКА ТЕКСТА**

In [25]:
import re

def remove_headers_footers(text, header_patterns=None, footer_patterns=None):
    if header_patterns is None:
        header_patterns = [r'^.*Header.*$']
    if footer_patterns is None:
        footer_patterns = [r'^.*Footer.*$']

    for pattern in header_patterns + footer_patterns:
        text = re.sub(pattern, '', text, flags=re.MULTILINE)

    return text.strip()

def remove_special_characters(text, special_chars=None):
    if special_chars is None:
        special_chars = r'[^A-Za-z0-9\s\.,;:\'\"\?\!\-]'

    text = re.sub(special_chars, '', text)
    return text.strip()

def remove_repeated_substrings(text, pattern=r'\.{2,}'):
    text = re.sub(pattern, '.', text)
    return text.strip()

def remove_extra_spaces(text):
    text = re.sub(r'\n\s*\n', '\n\n', text)
    text = re.sub(r'\s+', ' ', text)

    return text.strip()

def preprocess_text(text, metadata=None, language=None):
    # optional: detect language, skip if not target
    # remove headers/footers
    text = remove_headers_footers(text)
    # fix unicode quirks
    # text = ftfy.fix_text(text)
    # remove HTML tags
    text = re.sub(r'<[^>]+>', ' ', text)
    # remove URLs/emails
    text = re.sub(r'https?://\S+|www\.\S+', '', text)
    text = re.sub(r'\S+@\S+', '', text)
    # normalize whitespace
    text = remove_extra_spaces(text)
    # remove special characters
    text = remove_special_characters(text)
    # lowercase (optional)
    text = text.lower()
    # lemmatize (optional) — using spaCy or any other
    # chunking can happen here or after
    return text.strip()

In [26]:
for page_num in texts:
    texts[page_num] = texts[page_num].replace("\n", " ")
    texts[page_num] = preprocess_text(texts[page_num])

**ЧАНКИНГ**

In [27]:
from langchain_text_splitters import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)

In [None]:
#Базовый чанкинг

text_chunks = [{i: text_splitter.split_text(texts[str(i)])} for i in tqdm(range(len(texts)))]

text_chunks_numbered = []

for chunk_dict in text_chunks:
    key, values = next(iter(chunk_dict.items()))
    for chunk in values:
        text_chunks_numbered.append((key, chunk))

# Добавляем глобальный индекс
global_indexed = [(idx, key, chunk) for idx, (key, chunk) in enumerate(text_chunks_numbered)]

global_chunk_ids, page_numbers, text_chunks = zip(*global_indexed)

In [None]:
# # Не всегда лучше 

# from nltk.tokenize import sent_tokenize


# semantic_chunks = {}

# for i in tqdm(range(len(texts))):
#     text = texts[str(i)]
#     sentences = sent_tokenize(text)
#     joined_sentences = []
#     buffer = ""

#     # объединяем предложения в блоки <= chunk_size
#     for sent in sentences:
#         if len(buffer) + len(sent) + 1 <= chunk_size:
#             buffer += " " + sent
#         else:
#             joined_sentences.append(buffer.strip())
#             buffer = sent
#     if buffer:
#         joined_sentences.append(buffer.strip())

#     # если отдельный блок всё ещё длинный — применяем RecursiveCharacterTextSplitter
#     refined_chunks = []
#     for chunk in joined_sentences:
#         if len(chunk) > chunk_size:
#             refined_chunks.extend(text_splitter.split_text(chunk))
#         else:
#             refined_chunks.append(chunk)

#     semantic_chunks[i] = refined_chunks

# # выравниваем в общий список
# text_chunks_numbered = [(i, chunk) for i, chunks in semantic_chunks.items() for chunk in chunks]

# # глобальная нумерация чанков
# global_indexed = [(idx, key, chunk) for idx, (key, chunk) in enumerate(text_chunks_numbered)]
# global_chunk_ids, page_numbers, text_chunks = zip(*global_indexed)


100%|██████████| 625/625 [00:00<00:00, 2782.45it/s]


In [30]:
from sentence_transformers import SentenceTransformer

embedding_model = SentenceTransformer(embedding_model_id, model_kwargs={'dtype': torch.float16})

In [31]:
vectors = embedding_model.encode(text_chunks, batch_size=32, device=device, normalize_embeddings=True, show_progress_bar=True).tolist()

Batches:   0%|          | 0/76 [00:00<?, ?it/s]

In [32]:
vectors_np = np.asarray(vectors, dtype=np.float32)

In [33]:
from qdrant_client import QdrantClient, models

client = QdrantClient(":memory:")

client.create_collection(
    collection_name=db_name,
    on_disk_payload=True,
    vectors_config=models.VectorParams(
        size=1024,
        distance=models.Distance.COSINE,
        on_disk=True
    ),
)

True

In [34]:
for i in tqdm(range(len(vectors))):
    client.upsert(
        collection_name=db_name,
        points=[
            models.PointStruct(
                id=str(uuid.uuid4()),
                vector=vectors[i],
                payload={
                    'text': text_chunks[i],
                    'page': page_numbers[i] + 7,
                    'chunk_index': global_chunk_ids[i]
                }
            )
        ]
    )

100%|██████████| 2431/2431 [00:01<00:00, 1250.67it/s]


In [35]:
import json


with open('data/queries.json') as files:
    queries = json.load(files)

In [36]:
tokenizer = AutoTokenizer.from_pretrained(LLM_model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    LLM_model_id,
    dtype=torch.float16,
    device_map="auto"
)

In [37]:
def generation_pipeline(messages, max_new_tokens=512, do_sample=True, temperature=0.5, top_p=0.9) -> str:
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    gen_ids = outputs[0][inputs["input_ids"].shape[-1]:]
    answer = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
    return answer

In [38]:
rephrase_prompt = "You are a query rewriter specialized in mental health and psychology topics. Your task is to rewrite the user’s question into short, clear search queries for vector retrieval.\n\nSTRICT RULES:\n- Output EXACTLY THREE rewritten query variants.\n- Separate them each on a NEW LINE.\n- DO NOT explain, comment, justify, or describe anything.\n- DO NOT answer the question.\n- Output ONLY the rewritten queries."

In [39]:
def rephrase_query(query_text) -> list[str]:
    messages = [
        {"role": "system", "content": rephrase_prompt},
        {"role": "user", "content": f"Rephrase this question into exactly three short search queries separated by '\n'. Do not explain anything. Question: {query_text}"},
    ]
    output = generation_pipeline(messages, max_new_tokens=512, do_sample=True, temperature=0.3, top_p=0.9)
    parts = output.split("\n")
    result = [p.strip() for p in parts][:3]
    return result

In [40]:
# Генерируем перефразированные запросы

# for i in tqdm(range(len(queries))):
#     print(queries[i]['question'])
#     new_queries = rephrase_query(queries[i]['question'])
#     queries[i]['new_queries'] = new_queries

In [41]:
# with open("data/new_queries.json", "w", encoding="utf-8") as f:
#     json.dump(queries, f, ensure_ascii=False, indent=2)

In [42]:
# del model

# import gc
# gc.collect()

# torch.cuda.empty_cache()

In [43]:
with open('data/new_queries.json') as files:
    queries = json.load(files)

In [44]:
bm25_corpus = [word_tokenize(t.lower()) for t in text_chunks]
bm25 = BM25Okapi(bm25_corpus)

In [45]:
def get_candidates(queries, top_k_db=5, alpha_vec=0.9, alpha_bm25=0.1):
    results = {}

    for query in tqdm(queries):
        qid = query['query_id']
        query_texts = [query['question']] + query['new_queries']

        all_hits = []

        for qt in query_texts:
            query_vec = embedding_model.encode(qt, normalize_embeddings=True, device=device).tolist()
            query_vec_np = np.asarray(query_vec, dtype=np.float32)

            # --- Векторный поиск ---
            hits = client.search(collection_name=db_name, query_vector=query_vec, limit=top_k_db)

            # --- BM25 ---
            tokenized_query = word_tokenize(qt.lower())
            bm25_scores = bm25.get_scores(tokenized_query)
            bm25_scores = np.array(bm25_scores)
            bm25_scores = (bm25_scores - bm25_scores.min()) / (bm25_scores.max() - bm25_scores.min() + 1e-9)
            bm25_norm = np.power(bm25_scores, 0.5)

            # --- Собираем кандидатов ---
            seen_indices = set()
            for hit in hits:
                seen_indices.add(hit.payload['chunk_index'])
                all_hits.append({
                    "chunk_index": hit.payload['chunk_index'],
                    "text": hit.payload['text'],
                    "page": hit.payload['page'],
                    "score_vec": hit.score,
                    "score_bm25": bm25_norm[hit.payload['chunk_index']],
                    "source": "qdrant"
                })

            # Добавляем недостающих кандидатов из BM25-топа
            top_bm25_idx = np.argsort(bm25_scores)[-top_k_db:][::-1]
            for idx in top_bm25_idx:
                if idx not in seen_indices:
                    score_vec_direct = float(np.dot(query_vec_np, vectors_np[idx]))
                    score_vec_direct = min(max(score_vec_direct, 0.0), 1.0)
                    all_hits.append({
                        "chunk_index": idx,
                        "text": text_chunks[idx],
                        "page": page_numbers[idx] + 7,
                        "score_vec": score_vec_direct,
                        "score_bm25": bm25_norm[idx],
                        "source": "bm25"
                    })

        # --- Комбинируем по весам ---
        dedup = {}
        for item in all_hits:
            ci = item["chunk_index"]
            combined_score = alpha_vec * item["score_vec"] + alpha_bm25 * item["score_bm25"]
            if ci not in dedup or combined_score > dedup[ci]["score"]:
                dedup[ci] = {
                    "text": item["text"],
                    "page": item["page"],
                    "score": combined_score,
                    "chunk_index": ci,
                    "source": item["source"],
                    "score_vec": item["score_vec"],
                    "score_bm25": item["score_bm25"]
                }

        final_items = sorted(dedup.values(), key=lambda x: x["score"], reverse=True)[:top_k_db]

        if final_items:
            scores = [i["score"] for i in final_items]
            vec_scores = [i["score_vec"] for i in final_items]
            bm25_scores = [i["score_bm25"] for i in final_items]
            qdrant_count = sum(1 for i in final_items if i["source"] == "qdrant")
            bm25_count = sum(1 for i in final_items if i["source"] == "bm25")

            print(f"[{qid}] total={len(final_items)} | qdrant={qdrant_count}, bm25={bm25_count}")
            print(f"   combined  -> min={min(scores):.4f}, max={max(scores):.4f}")
            print(f"   qdrant(sc)-> min={min(vec_scores):.4f}, max={max(vec_scores):.4f}")
            print(f"   bm25(sc)  -> min={min(bm25_scores):.4f}, max={max(bm25_scores):.4f}")

        results[qid] = final_items

    return results


In [46]:
relevant = get_candidates(queries, top_k_db=100)

  hits = client.search(collection_name=db_name, query_vector=query_vec, limit=top_k_db)
  4%|▍         | 2/50 [00:00<00:09,  5.23it/s]

[1] total=100 | qdrant=78, bm25=22
   combined  -> min=0.7966, max=0.8509
   qdrant(sc)-> min=0.7857, max=0.8422
   bm25(sc)  -> min=0.6102, max=1.0000
[2] total=100 | qdrant=83, bm25=17
   combined  -> min=0.7860, max=0.8614
   qdrant(sc)-> min=0.7690, max=0.8460
   bm25(sc)  -> min=0.5801, max=1.0000


  8%|▊         | 4/50 [00:00<00:06,  7.29it/s]

[3] total=100 | qdrant=84, bm25=16
   combined  -> min=0.7842, max=0.8638
   qdrant(sc)-> min=0.7734, max=0.8579
   bm25(sc)  -> min=0.5785, max=1.0000
[4] total=100 | qdrant=84, bm25=16
   combined  -> min=0.7767, max=0.8769
   qdrant(sc)-> min=0.7738, max=0.8721
   bm25(sc)  -> min=0.3652, max=1.0000


 12%|█▏        | 6/50 [00:00<00:05,  8.05it/s]

[5] total=100 | qdrant=71, bm25=29
   combined  -> min=0.7877, max=0.8303
   qdrant(sc)-> min=0.7776, max=0.8252
   bm25(sc)  -> min=0.6193, max=1.0000
[6] total=100 | qdrant=82, bm25=18
   combined  -> min=0.8005, max=0.8566
   qdrant(sc)-> min=0.7959, max=0.8507
   bm25(sc)  -> min=0.6633, max=1.0000


 16%|█▌        | 8/50 [00:01<00:05,  8.20it/s]

[7] total=100 | qdrant=80, bm25=20
   combined  -> min=0.7966, max=0.8373
   qdrant(sc)-> min=0.7750, max=0.8366
   bm25(sc)  -> min=0.6756, max=1.0000
[8] total=100 | qdrant=84, bm25=16
   combined  -> min=0.7767, max=0.8780
   qdrant(sc)-> min=0.7756, max=0.8668
   bm25(sc)  -> min=0.5524, max=1.0000


 20%|██        | 10/50 [00:01<00:04,  8.41it/s]

[9] total=100 | qdrant=80, bm25=20
   combined  -> min=0.7972, max=0.8705
   qdrant(sc)-> min=0.7828, max=0.8716
   bm25(sc)  -> min=0.6329, max=1.0000
[10] total=100 | qdrant=83, bm25=17
   combined  -> min=0.7784, max=0.8809
   qdrant(sc)-> min=0.7697, max=0.8676
   bm25(sc)  -> min=0.5647, max=1.0000


 24%|██▍       | 12/50 [00:01<00:04,  8.49it/s]

[11] total=100 | qdrant=93, bm25=7
   combined  -> min=0.7948, max=0.8423
   qdrant(sc)-> min=0.7831, max=0.8380
   bm25(sc)  -> min=0.6683, max=1.0000
[12] total=100 | qdrant=89, bm25=11
   combined  -> min=0.7306, max=0.8461
   qdrant(sc)-> min=0.7314, max=0.8305
   bm25(sc)  -> min=0.4867, max=1.0000


 28%|██▊       | 14/50 [00:01<00:04,  8.62it/s]

[13] total=100 | qdrant=67, bm25=33
   combined  -> min=0.7778, max=0.8560
   qdrant(sc)-> min=0.7627, max=0.8571
   bm25(sc)  -> min=0.5843, max=1.0000
[14] total=100 | qdrant=82, bm25=18
   combined  -> min=0.7763, max=0.8364
   qdrant(sc)-> min=0.7755, max=0.8333
   bm25(sc)  -> min=0.4637, max=1.0000


 32%|███▏      | 16/50 [00:02<00:03,  8.83it/s]

[15] total=100 | qdrant=78, bm25=22
   combined  -> min=0.7892, max=0.8605
   qdrant(sc)-> min=0.7709, max=0.8536
   bm25(sc)  -> min=0.6481, max=1.0000
[16] total=100 | qdrant=82, bm25=18
   combined  -> min=0.7769, max=0.8725
   qdrant(sc)-> min=0.7767, max=0.8722
   bm25(sc)  -> min=0.3714, max=1.0000


 36%|███▌      | 18/50 [00:02<00:03,  8.54it/s]

[17] total=100 | qdrant=77, bm25=23
   combined  -> min=0.7670, max=0.8490
   qdrant(sc)-> min=0.7571, max=0.8528
   bm25(sc)  -> min=0.4844, max=1.0000
[18] total=100 | qdrant=87, bm25=13
   combined  -> min=0.7677, max=0.8924
   qdrant(sc)-> min=0.7663, max=0.8891
   bm25(sc)  -> min=0.4063, max=1.0000


 40%|████      | 20/50 [00:02<00:03,  8.70it/s]

[19] total=100 | qdrant=84, bm25=16
   combined  -> min=0.7658, max=0.8610
   qdrant(sc)-> min=0.7538, max=0.8610
   bm25(sc)  -> min=0.5151, max=1.0000
[20] total=100 | qdrant=81, bm25=19
   combined  -> min=0.7924, max=0.8767
   qdrant(sc)-> min=0.7807, max=0.8759
   bm25(sc)  -> min=0.6163, max=1.0000


 44%|████▍     | 22/50 [00:02<00:03,  8.65it/s]

[21] total=100 | qdrant=87, bm25=13
   combined  -> min=0.7870, max=0.8594
   qdrant(sc)-> min=0.7790, max=0.8511
   bm25(sc)  -> min=0.5849, max=1.0000
[22] total=100 | qdrant=75, bm25=25
   combined  -> min=0.7922, max=0.8516
   qdrant(sc)-> min=0.7792, max=0.8508
   bm25(sc)  -> min=0.5959, max=1.0000


 48%|████▊     | 24/50 [00:02<00:02,  8.84it/s]

[23] total=100 | qdrant=69, bm25=31
   combined  -> min=0.7695, max=0.8538
   qdrant(sc)-> min=0.7568, max=0.8456
   bm25(sc)  -> min=0.4558, max=1.0000
[24] total=100 | qdrant=77, bm25=23
   combined  -> min=0.7814, max=0.8551
   qdrant(sc)-> min=0.7746, max=0.8436
   bm25(sc)  -> min=0.4392, max=1.0000


 52%|█████▏    | 26/50 [00:03<00:02,  8.52it/s]

[25] total=100 | qdrant=72, bm25=28
   combined  -> min=0.7863, max=0.8655
   qdrant(sc)-> min=0.7902, max=0.8556
   bm25(sc)  -> min=0.4873, max=1.0000
[26] total=100 | qdrant=84, bm25=16
   combined  -> min=0.7852, max=0.8571
   qdrant(sc)-> min=0.7816, max=0.8640
   bm25(sc)  -> min=0.5049, max=1.0000


 56%|█████▌    | 28/50 [00:03<00:02,  8.67it/s]

[27] total=100 | qdrant=63, bm25=37
   combined  -> min=0.7623, max=0.8660
   qdrant(sc)-> min=0.7691, max=0.8512
   bm25(sc)  -> min=0.0000, max=1.0000
[28] total=100 | qdrant=81, bm25=19
   combined  -> min=0.7997, max=0.8604
   qdrant(sc)-> min=0.7814, max=0.8533
   bm25(sc)  -> min=0.5444, max=1.0000


 60%|██████    | 30/50 [00:03<00:02,  8.59it/s]

[29] total=100 | qdrant=88, bm25=12
   combined  -> min=0.7781, max=0.8650
   qdrant(sc)-> min=0.7764, max=0.8631
   bm25(sc)  -> min=0.5031, max=1.0000
[30] total=100 | qdrant=86, bm25=14
   combined  -> min=0.7953, max=0.8702
   qdrant(sc)-> min=0.7869, max=0.8618
   bm25(sc)  -> min=0.5168, max=1.0000


 64%|██████▍   | 32/50 [00:03<00:02,  8.40it/s]

[31] total=100 | qdrant=82, bm25=18
   combined  -> min=0.7850, max=0.8640
   qdrant(sc)-> min=0.7668, max=0.8604
   bm25(sc)  -> min=0.5793, max=1.0000
[32] total=100 | qdrant=85, bm25=15
   combined  -> min=0.7974, max=0.8463
   qdrant(sc)-> min=0.7822, max=0.8375
   bm25(sc)  -> min=0.6297, max=1.0000


 68%|██████▊   | 34/50 [00:04<00:01,  8.49it/s]

[33] total=100 | qdrant=82, bm25=18
   combined  -> min=0.7840, max=0.8696
   qdrant(sc)-> min=0.7798, max=0.8551
   bm25(sc)  -> min=0.5988, max=1.0000
[34] total=100 | qdrant=86, bm25=14
   combined  -> min=0.7766, max=0.8516
   qdrant(sc)-> min=0.7653, max=0.8351
   bm25(sc)  -> min=0.5680, max=1.0000


 72%|███████▏  | 36/50 [00:04<00:01,  8.45it/s]

[35] total=100 | qdrant=67, bm25=33
   combined  -> min=0.7819, max=0.8304
   qdrant(sc)-> min=0.7590, max=0.8327
   bm25(sc)  -> min=0.5987, max=1.0000
[36] total=100 | qdrant=78, bm25=22
   combined  -> min=0.7998, max=0.8413
   qdrant(sc)-> min=0.7873, max=0.8327
   bm25(sc)  -> min=0.6481, max=1.0000


 76%|███████▌  | 38/50 [00:04<00:01,  8.38it/s]

[37] total=100 | qdrant=87, bm25=13
   combined  -> min=0.8034, max=0.8557
   qdrant(sc)-> min=0.7905, max=0.8510
   bm25(sc)  -> min=0.5151, max=1.0000
[38] total=100 | qdrant=89, bm25=11
   combined  -> min=0.8041, max=0.8740
   qdrant(sc)-> min=0.7959, max=0.8654
   bm25(sc)  -> min=0.5360, max=1.0000


 80%|████████  | 40/50 [00:04<00:01,  8.31it/s]

[39] total=100 | qdrant=82, bm25=18
   combined  -> min=0.8007, max=0.8858
   qdrant(sc)-> min=0.7874, max=0.8753
   bm25(sc)  -> min=0.6242, max=1.0000
[40] total=100 | qdrant=75, bm25=25
   combined  -> min=0.7930, max=0.8561
   qdrant(sc)-> min=0.7799, max=0.8401
   bm25(sc)  -> min=0.5913, max=1.0000


 84%|████████▍ | 42/50 [00:05<00:00,  8.38it/s]

[41] total=100 | qdrant=70, bm25=30
   combined  -> min=0.8007, max=0.8672
   qdrant(sc)-> min=0.7808, max=0.8525
   bm25(sc)  -> min=0.4806, max=1.0000
[42] total=100 | qdrant=86, bm25=14
   combined  -> min=0.8008, max=0.8742
   qdrant(sc)-> min=0.7821, max=0.8651
   bm25(sc)  -> min=0.5988, max=1.0000


 88%|████████▊ | 44/50 [00:05<00:00,  8.40it/s]

[43] total=100 | qdrant=88, bm25=12
   combined  -> min=0.7741, max=0.8566
   qdrant(sc)-> min=0.7683, max=0.8408
   bm25(sc)  -> min=0.4954, max=1.0000
[44] total=100 | qdrant=79, bm25=21
   combined  -> min=0.7803, max=0.8558
   qdrant(sc)-> min=0.7686, max=0.8564
   bm25(sc)  -> min=0.5467, max=1.0000


 92%|█████████▏| 46/50 [00:05<00:00,  8.23it/s]

[45] total=100 | qdrant=79, bm25=21
   combined  -> min=0.7860, max=0.8238
   qdrant(sc)-> min=0.7778, max=0.8302
   bm25(sc)  -> min=0.5431, max=1.0000
[46] total=100 | qdrant=72, bm25=28
   combined  -> min=0.7824, max=0.8554
   qdrant(sc)-> min=0.7670, max=0.8483
   bm25(sc)  -> min=0.5411, max=1.0000


 96%|█████████▌| 48/50 [00:05<00:00,  8.36it/s]

[47] total=100 | qdrant=76, bm25=24
   combined  -> min=0.7997, max=0.8403
   qdrant(sc)-> min=0.7895, max=0.8277
   bm25(sc)  -> min=0.5898, max=1.0000
[48] total=100 | qdrant=88, bm25=12
   combined  -> min=0.7976, max=0.8506
   qdrant(sc)-> min=0.7861, max=0.8456
   bm25(sc)  -> min=0.6306, max=1.0000


100%|██████████| 50/50 [00:06<00:00,  8.27it/s]

[49] total=100 | qdrant=82, bm25=18
   combined  -> min=0.7995, max=0.8303
   qdrant(sc)-> min=0.7834, max=0.8401
   bm25(sc)  -> min=0.6795, max=1.0000
[50] total=100 | qdrant=88, bm25=12
   combined  -> min=0.7863, max=0.8791
   qdrant(sc)-> min=0.7829, max=0.8668
   bm25(sc)  -> min=0.4536, max=1.0000





In [47]:
def filter_by_db_score(relevant, score_threshold=0.0):
    filtered = {}

    for qid, items in relevant.items():
        selected = [item for item in items if item["score"] >= score_threshold]
        filtered[qid] = selected

    return filtered

# relevant = filter_by_db_score(relevant, score_threshold=0.8)


*Список моделей для Cross-Encoder*


-BAAI/bge-reranker-v2-m3

-cross‑encoder/ms‑marco‑MiniLM‑L6‑v2


In [48]:
from sentence_transformers import CrossEncoder

cross_encoder = CrossEncoder(cross_encoder_id, device=device)

In [49]:
# Реранк на основе только исходного вопроса

# def rerank_results(queries, relevant, top_k_rerank=3, rerank_threshold=None):
#     reranked = {}

#     for query in tqdm(queries):
#         qid = query['query_id']
#         q_text = query['question']

#         candidates = relevant[qid]
#         texts = [c["text"] for c in candidates]

#         if len(texts) == 0:
#             reranked[qid] = []
#             continue

#         pairs = [(q_text, t) for t in texts]
#         scores = cross_encoder.predict(pairs)  # по аналогии с твоим примером

#         scored = [
#             {**c, "rerank_score": s}
#             for c, s in zip(candidates, scores)
#         ]

#         scored = sorted(scored, key=lambda x: x["rerank_score"], reverse=True)
#         scored = scored[:top_k_rerank]

#         if rerank_threshold is not None:
#             scored = [s for s in scored if s["rerank_score"] >= rerank_threshold]

#         reranked[qid] = scored

#     return reranked


In [50]:
def rerank_results(queries, relevant, top_k_rerank=3, rerank_threshold=None):
    reranked = {}

    for query in tqdm(queries):
        qid = query["query_id"]
        query_texts = [query["question"]] + query.get("new_queries", [])
        candidates = relevant[qid]
        texts = [c["text"] for c in candidates]

        if not texts:
            reranked[qid] = []
            continue

        # Считаем средний скор по всем перефразам
        total_scores = np.zeros(len(texts))
        for q_text in query_texts:
            pairs = [(q_text, t) for t in texts]
            scores = cross_encoder.predict(pairs)
            total_scores += np.array(scores)

        avg_scores = total_scores / len(query_texts)

        scored = [
            {**c, "rerank_score": s}
            for c, s in zip(candidates, avg_scores)
        ]

        scored = sorted(scored, key=lambda x: x["rerank_score"], reverse=True)
        scored = scored[:top_k_rerank]

        if rerank_threshold is not None:
            scored = [s for s in scored if s["rerank_score"] >= rerank_threshold]

        reranked[qid] = scored

    return reranked


In [51]:
relevant_reranked = rerank_results(
    queries,
    relevant,
    top_k_rerank=6,         # регулируй
    rerank_threshold=None   # или поставь, например, 0.1
)

100%|██████████| 50/50 [02:50<00:00,  3.41s/it]


In [54]:
queries[0]

{'query_id': '1',
 'question': 'What is the scientific method in psychology?',
 'new_queries': ['What does the scientific method entail in psychology?',
  'How can I understand the scientific method used in psychology research?',
  'What are the key steps of the scientific method applied in psychological studies?']}

In [55]:
relevant_reranked["1"]

[{'text': 'thus, psychological science is empirical, based on measurable data. in general, science deals only with matter and energy, that is, those things that can be measured, and it cannot arrive at knowledge about values and morality. this is one reason why our scientific understanding of the mind is so limited, since thoughts, at least as we experience them, are neither matter nor energy. the scientific method is also a form of empiricism. anempirical methodfor acquiring knowledge is one based on observation, including experimentation, rather than a method based only on forms of logical argument or previous authorities. it was not until the late 1800s that psychology became accepted as its own academic discipline. before this time, the workings of the mind were considered under the auspices of philosophy. given that any behavior is, at its roots, biological, some areas of psychology take on aspects of a natural science like biology.',
  'page': 8,
  'score': 0.83656948781172,
  'c

In [59]:
# Очищаем память, чтобы LLM было вкусно

# del cross_encoder
# del embedding_model

import gc
gc.collect()


torch.cuda.empty_cache()

In [60]:
tokenizer = AutoTokenizer.from_pretrained(LLM_model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    LLM_model_id,
    dtype=torch.float16,
    device_map="auto"
)


In [None]:
# TODO почитать про промпты

system_prompt = (
    "You are an expert in psychology.\n\n"
    "Your task is to answer the user’s question using strictly and only the information "
    "from the retrieved context chunks provided below.\n\n"
    "Each chunk contains a label indicating which section of the original source it came from. "
    "When making each factual claim, explicitly reference the section label of the chunk it came from. "
    "If the answer cannot be fully supported by the provided chunks, respond: \"Not enough information to answer the question.\"\n\n"
    "Rules:\n"
    "1. Do not add external knowledge, assumptions, or general reasoning not grounded in the retrieved chunks.\n"
    "2. Do not merge or generalize facts unless they are directly stated in the chunks.\n"
    "3. Each factual statement must include a source label in parentheses, e.g. (section: ...).\n"
    "4. If multiple chunks contradict each other, do not resolve the contradiction; instead, state that there is a contradiction.\n\n"
    "Output format:\n"
    "- First, provide the final answer based solely on the evidence.\n"
    "- Then, list each used factual statement with its corresponding section reference.\n\n"
    "Now wait for the user’s question and the retrieved context chunks."
)

In [103]:
def llm_answer(query, context):

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "system", "content": f"Context documents:\n{context}"},
        {"role": "user", "content": query}
    ]

    answer = generation_pipeline(messages, temperature=0.7, max_new_tokens=1024)
    return answer


In [None]:
# import requests

# def llm_answer(query, context):
#     system_msg = (
#         "You are an expert in psychology. Using only the provided retrieved documents, answer the following question. Do not add any external knowledge."
#     )

#     messages = [
#         {"role": "system", "content": system_msg},
#         {"role": "system", "content": f"Context documents:\n{context}"},
#         {"role": "user", "content": query}
#     ]

#     prompt = tokenizer.apply_chat_template(
#         messages,
#         tokenize=False,
#         add_generation_prompt=True
#     )
#     prompt = "Привет, как дела?"
#     payload = {
#         "prompt": prompt,
#         "max_new_tokens": 512,
#         "do_sample": True,
#         "temperature": 0.9,
#         "top_p": 0.9
#     }

#     r = requests.post("https://shortly-pleasant-democrats-discussion.trycloudflare.com/generate", json=payload)
#     r.raise_for_status()

#     # Модель возвращает полный текст, нужно вырезать только продолжение
#     response_text = r.json()["response"]

#     print(response_text)
#     # Если нужно вырезать ввод с контекстом, делаем так:
#     if response_text.startswith(prompt):
#         response_text = response_text[len(prompt):].strip()

#     return response_text


In [None]:
# def reorder_chunks(chunks):
#     n = len(chunks)
#     if n <= 3:
#         return chunks  # просто как есть

#     first_three = chunks[:3]
#     # далее позиции считаем относительно chunks[3:]
#     odds = chunks[3::2]   # 4-й, 6-й, 8-й...
#     evens = chunks[4::2]  # 5-й, 7-й, 9-й...
#     evens = evens[::-1]   # развернуть

#     return first_three + odds + evens

# # пример формирования вывода
# def format_chunks(chunks):
#     return "\n".join(f"{i}. {c}" for i, c in enumerate(reorder_chunks(chunks), start=1))


In [104]:
from toc import toc

def find_section_for_page(toc, page):
    all_sections = []

    for chapter_data in toc.values():
        chapter_title = chapter_data["title"]
        for section_title, start in chapter_data["sections"].items():
            all_sections.append((chapter_title, section_title, start))

    # сортируем по старту секции
    all_sections.sort(key=lambda x: x[2])

    # ищем последнюю секцию, начавшуюся не позже страницы
    candidates = [(ch, sec, start) for ch, sec, start in all_sections if start <= page]
    if not candidates:
        return None

    ch, sec, _ = candidates[-1]
    return f"{ch}/{sec}"

In [105]:
def reorder_pairs(pairs):
    # pairs: список кортежей (chunk_text, page)
    n = len(pairs)
    if n <= 3:
        return pairs

    first_three = pairs[:3]
    odds = pairs[3::2]   # 4-й, 6-й, 8-й...
    evens = pairs[4::2]  # 5-й, 7-й, 9-й...
    evens = evens[::-1]

    return first_three + odds + evens

def format_pairs(pairs):
    lines = []
    for i, (chunk, page) in enumerate(pairs, start=1):
        sec = find_section_for_page(toc, page)
        if sec is None:
            sec = "Unknown"
        lines.append(f"{i}. {sec} - {chunk}")
    return "\n".join(lines)

In [106]:
outputs = []
result_pages = []

for query in tqdm(queries):
    qid = query['query_id']
    q_text = query['question']
    data = relevant_reranked[qid]

    # собираем пары (текст, страница)
    pairs = [(item["text"], item["page"]) for item in data]

    # переставляем
    reordered = reorder_pairs(pairs)

    # делаем контекст с секциями
    context = format_pairs(reordered)
    answer = llm_answer(q_text, context)

    outputs.append(answer)
    result_pages.append([p for _, p in reordered])


100%|██████████| 50/50 [10:46<00:00, 12.94s/it]


In [111]:
def merge_sequential(items, overlap: int) -> str:
    if not items:
        return ""
    items = sorted(items, key=lambda x: x["chunk_index"])

    merged_groups = []
    cur_text = items[0]["text"]
    prev_idx = items[0]["chunk_index"]

    for it in items[1:]:
        idx, t = it["chunk_index"], it["text"]
        if idx == prev_idx + 1:
            # учёт overlap: убираем дублирующийся префикс t
            max_k = min(overlap, len(cur_text), len(t))
            cut = 0
            for k in range(max_k, 0, -1):
                if cur_text[-k:] == t[:k]:
                    cut = k
                    break
            cur_text += t[cut:]
        else:
            merged_groups.append(cur_text)
            cur_text = t
        prev_idx = idx

    merged_groups.append(cur_text)
    return "\n".join(merged_groups)

In [112]:
ids = [q['query_id'] for q in queries]
references = [{"sections": [], "pages": list(set(pages))} for pages in result_pages]

for ref in references:
    for page in ref["pages"]:
        section = find_section_for_page(toc, page)
        if section:
            section = section.lower().replace(" ", "_")
            ref["sections"].append(section)

context_text = [
    merge_sequential(relevant_reranked[qid], chunk_overlap)
    for qid in ids
]

In [113]:
submission_df = pd.DataFrame({
    "ID": ids,
    "context": context_text,
    "answer": outputs,
    "references": references
})

In [114]:
submission_df.to_csv("data/submission10.csv", index=False)