In [1]:
from typing import List, Optional
import os
import faiss
import numpy as np
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
from torch import Tensor
import torch.nn.functional as F
import ijson
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
PATH = "arxiv-metadata-s.json"
MODEL = "Qwen/Qwen3-Embedding-0.6B"

tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)

lens = []
with open(PATH, "r", encoding="utf-8") as f:
    for obj in ijson.items(f, "item"):
        title = (obj.get("title") or "").strip()
        abstract = (obj.get("abstract") or "").strip()
        text = (title + "\n" + abstract).strip()
        ids = tokenizer(text, add_special_tokens=True, truncation=False)["input_ids"]
        lens.append(len(ids))

arr = np.array(lens, dtype=np.int32)
print("count:", arr.size)
print("mean tokens:", float(arr.mean()))
for p in [50, 90, 95, 99, 99.5, 99.9]:
    print(f"p{p}:", float(np.percentile(arr, p)))
print("max:", int(arr.max()))


In [5]:
import pickle

class RAG:

    def __init__(
        self,
        embedder_name: str = "Qwen/Qwen3-Embedding-0.6B",
        reranker_name: str = "Qwen/Qwen3-Reranker-0.6B",
        chunk_size: int = 500,
        chunk_overlap: int = 125,
        device: Optional[str] = None,
    ):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.emb_tokenizer = AutoTokenizer.from_pretrained(embedder_name)
        self.embedder = AutoModel.from_pretrained(embedder_name).to(self.device)
        self.embedder.eval()

        self.rr_tokenizer = AutoTokenizer.from_pretrained(reranker_name, padding_side='left')
        self.reranker = AutoModelForCausalLM.from_pretrained(reranker_name).to(self.device)
        self.reranker.eval()

        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap,)
        self.index = None
        self.doc_store = []

        self.max_length = 1024
        self.token_false_id = self.rr_tokenizer.convert_tokens_to_ids("no")
        self.token_true_id = self.rr_tokenizer.convert_tokens_to_ids("yes")
        prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
        suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self.prefix_tokens = self.rr_tokenizer.encode(prefix, add_special_tokens=False)
        self.suffix_tokens = self.rr_tokenizer.encode(suffix, add_special_tokens=False)

    def _generate_embeddings(self, texts: List[str]) -> np.ndarray:
        inputs = self.emb_tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=self.max_length,
        ).to(self.device)

        with torch.no_grad():
            outputs = self.embedder(**inputs)

        embeddings = self.last_token_pool(outputs.last_hidden_state, inputs.attention_mask)
        embeddings = embeddings.float().cpu()  # <-- ключевая строка

        return F.normalize(embeddings, p=2, dim=1).numpy()


    @staticmethod
    def last_token_pool(last_hidden_states: Tensor,
                        attention_mask: Tensor) -> Tensor:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_states[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_states.shape[0]
            return last_hidden_states[
                torch.arange(batch_size, device=last_hidden_states.device),
                sequence_lengths] 

    def load_and_process_arxiv_json(self, file_path: str, split: bool = False) -> List[Document]:
        ext = os.path.splitext(file_path)[1].lower()
        if ext != ".json":
            raise ValueError(f"Expected .json file, got: {ext}")

        docs: List[Document] = []
        with open(file_path, "r", encoding="utf-8") as f:
            for obj in ijson.items(f, "item"):
                arxiv_id = obj.get("id")
                title = (obj.get("title") or "").strip()
                abstract = (obj.get("abstract") or "").strip()
                text = (title + "\n" + abstract).strip()
                meta = {
                    "id": arxiv_id,
                    "title": title,
                    "categories": obj.get("categories"),
                    "doi": obj.get("doi"),
                    "journal_ref": obj.get("journal-ref"),
                    "update_date": obj.get("update_date"),
                }

                docs.append(Document(page_content=text, metadata=meta))

        return self.text_splitter.split_documents(docs) if split else docs

    def build_index(self, file_path: str, batch_size: int = 64) -> None:
        all_docs = self.load_and_process_arxiv_json(file_path, split=False)
        self.doc_store = all_docs
        embs = []
        for i in tqdm(range(0, len(all_docs), batch_size), desc="Embedding corpus", unit="batch"):
            batch_texts = [d.page_content for d in all_docs[i:i + batch_size]]
            embs.append(self._generate_embeddings(batch_texts))
        embeddings = np.concatenate(embs, axis=0).astype("float32")
        self.index = faiss.IndexFlatIP(embeddings.shape[1])
        self.index.add(embeddings)

    @staticmethod
    def get_detailed_instruct(task_description: str, query: str):
        return f'Instruct: {task_description}\nQuery:{query}'

    @staticmethod
    def format_reranker_instruction(query, doc, instruction=None):
        if instruction is None:
            instruction = 'Given a web search query, retrieve relevant passages that answer the query'
        output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
            instruction=instruction, query=query, doc=doc)
        return output

    def process_inputs(self, pairs):
        """Обработка данных для реранкера"""
        inputs = self.rr_tokenizer(pairs,
                                   padding=False,
                                   truncation='longest_first',
                                   return_attention_mask=False,
                                   max_length=self.max_length -
                                   len(self.prefix_tokens) -
                                   len(self.suffix_tokens))
        for i, ele in enumerate(inputs['input_ids']):
            inputs['input_ids'][
                i] = self.prefix_tokens + ele + self.suffix_tokens
        inputs = self.rr_tokenizer.pad(inputs,
                                       padding=True,
                                       return_tensors="pt",
                                       max_length=self.max_length)

        # переносим тензоры на девайс ранжирующей модели
        for key in inputs:
            inputs[key] = inputs[key].to(self.device)
        return inputs

    def search(self,
               query: str,
               k: int = 5,
               task: str = None):
        if self.index is None:
            raise ValueError("Index not initialized")

        if task is None:
            task = 'Given a web search query, retrieve relevant passages that answer the query'

        query_embedding = self._generate_embeddings([query])
        distances, indices = self.index.search(query_embedding, k)
        return distances, indices         

    @torch.no_grad()
    def compute_logits(self, inputs):
        batch_scores = self.reranker(**inputs).logits[:, -1, :]
        true_vector = batch_scores[:, self.token_true_id]
        false_vector = batch_scores[:, self.token_false_id]
        batch_scores = torch.stack([false_vector, true_vector], dim=1)
        batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
        scores = batch_scores[:, 1].exp().tolist()
        return scores

    def rerank(self, query: str, documents: List[str], batch_size=4):
        pairs = []
        for d in documents:
            pairs.append(self.format_reranker_instruction(query, d))

        scores = []
        for i in range(0, len(pairs), batch_size):
            inputs = self.process_inputs(pairs[i:i + batch_size])
            sc = self.compute_logits(inputs)
            scores.extend(sc)
        return scores
    
def save(self, prefix: str) -> None:
    if self.index is None:
        raise ValueError("Index not initialized")
    faiss.write_index(self.index, prefix + ".index")
    with open(prefix + ".docstore.pkl", "wb") as f:
        pickle.dump(self.doc_store, f)

def load(self, prefix: str) -> None:
    self.index = faiss.read_index(prefix + ".index")
    with open(prefix + ".docstore.pkl", "rb") as f:
        self.doc_store = pickle.load(f)    

In [4]:
q = "Keldysh formalism Andreev current heavy fermions"

k = 5
rag = RAG(device="cuda")
rag.build_index("./arxiv-metadata-s.json")

D, I = rag.search(q, k=k)
candidates = [rag.doc_store[i].page_content for i in I[0]]

for c in candidates:
    print(c[:800])
    print("-#" * 20)
    print()

Embedding corpus: 100%|██████████| 1535/1535 [49:12<00:00,  1.92s/batch]


Towards a Microscopic Theory for Metallic Heavy-Fermion Point Contacts
The bias-dependent resistance R(V) of NS-junctions is calculated using the
Keldysh formalism in all orders of the transfer matrix element. We present a
compact and simple formula for the Andreev current, that results from the
coupling of electrons and holes on the normal side via the anomalous Green's
function on the superconducting side. Using simple BCS Nambu-Green's functions
the well known Blonder-Tinkam-Klapwijk theory can be recovered. Incorporating
the energy-dependent quasi-particle lifetime of the heavy fermions strongly
reduces the Andreev-reflection signal.
-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#

Chiral Dynamics and Heavy-Fermion Formalism in Nuclei: I. Exchange Axial
  Currents
Chiral perturbation theory in heavy-fermion formalism is developed for
meson-exchange currents in nuclei and applied to nuclear axial- charge
transitions. Calculation is performed to the next-to-leading order in chiral
expansion

In [15]:
import numpy as np

ntotal = rag.index.ntotal
dim = rag.index.d
bytes_index = ntotal * dim * 4  # float32 = 4 bytes

print("faiss dim:", dim)
print("index vectors:", ntotal)
print("approx faiss vectors size (MB):", bytes_index / (1024**2))


faiss dim: 1024
index vectors: 98213
approx faiss vectors size (MB): 383.64453125


In [7]:
import faiss
import pickle

prefix = "arxiv_rag_qwen3"

# 1) сохранить FAISS индекс
faiss.write_index(rag.index, prefix + ".index")

# 2) сохранить doc_store (список Document)
with open(prefix + ".docstore.pkl", "wb") as f:
    pickle.dump(rag.doc_store, f)

print("Saved:", prefix + ".index", "and", prefix + ".docstore.pkl")


Saved: arxiv_rag_qwen3.index and arxiv_rag_qwen3.docstore.pkl


In [13]:
import time

def profile_query(rag, query: str, retrieve_k: int = 50, rr_batch_size: int = 4):
    if rag.index is None:
        raise ValueError("Index not initialized")

    t0 = time.perf_counter()
    q_emb = rag._generate_embeddings([query])
    t1 = time.perf_counter()

    D, I = rag.index.search(q_emb.astype("float32"), retrieve_k)
    t2 = time.perf_counter()

    idxs = [int(x) for x in I[0]]
    cand_docs = [rag.doc_store[i] for i in idxs]
    cand_texts = [d.page_content for d in cand_docs]
    t3 = time.perf_counter()

    scores = rag.rerank(query, cand_texts, batch_size=rr_batch_size)
    t4 = time.perf_counter()

    ranked = sorted(zip(idxs, scores), key=lambda x: x[1], reverse=True)[:5]
    t5 = time.perf_counter()

    return {
        "query_embed_s": t1 - t0,
        "faiss_search_s": t2 - t1,
        "gather_candidates_s": t3 - t2,
        "rerank_s": t4 - t3,
        "sort_top5_s": t5 - t4,
        "total_s": t5 - t0,
    }

In [14]:
profile = profile_query(rag, "attention mechanism in transformers", retrieve_k=50, rr_batch_size=4)
print(profile)

{'query_embed_s': 0.03334781900048256, 'faiss_search_s': 0.0300979189996724, 'gather_candidates_s': 9.762999980011955e-05, 'rerank_s': 1.567501509000067, 'sort_top5_s': 2.010000025620684e-05, 'total_s': 1.6310649770002783}


In [10]:
import pandas as pd
import numpy as np
import time

def mrr_at_5(rag, test_csv_path: str, retrieve_k: int = 50, rr_batch_size: int = 4, limit: int | None = None):
    df = pd.read_csv(test_csv_path)
    if limit is not None:
        df = df.head(limit)

    mrrs = []
    t_search = 0.0
    t_rerank = 0.0

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Evaluating MRR@5", unit="query"):
        q = row["query"]
        gold_id = row["id"]

        t0 = time.perf_counter()
        _, I = rag.search(q, k=retrieve_k)   # FAISS topK
        t1 = time.perf_counter()

        cand_idxs = [int(x) for x in I[0]]
        cand_docs = [rag.doc_store[i] for i in cand_idxs]
        cand_texts = [d.page_content for d in cand_docs]

        scores = rag.rerank(q, cand_texts, batch_size=rr_batch_size)
        t2 = time.perf_counter()

        # сортируем по score desc и берём топ-5
        ranked = sorted(zip(cand_docs, scores), key=lambda x: x[1], reverse=True)[:5]
        ranked_ids = [d.metadata.get("id") for d, _ in ranked]

        # reciprocal rank
        rr = 0.0
        for rank, rid in enumerate(ranked_ids, start=1):
            if rid == gold_id:
                rr = 1.0 / rank
                break
        mrrs.append(rr)

        t_search += (t1 - t0)
        t_rerank += (t2 - t1)

    return {
        "n": len(df),
        "MRR@5": float(np.mean(mrrs)),
        "avg_faiss_search_s": t_search / len(df),
        "avg_rerank_s": t_rerank / len(df),
        "avg_total_s": (t_search + t_rerank) / len(df),
    }


In [11]:
res = mrr_at_5(rag, "test_sample.csv", retrieve_k=50, rr_batch_size=4, limit=50)
print(res)


Evaluating MRR@5:   0%|          | 0/50 [00:00<?, ?query/s]

Evaluating MRR@5: 100%|██████████| 50/50 [01:31<00:00,  1.83s/query]

{'n': 50, 'MRR@5': 0.9866666666666666, 'avg_faiss_search_s': 0.06016224915994826, 'avg_rerank_s': 1.767476377959938, 'avg_total_s': 1.8276386271198861}





In [None]:
q = input("Введите запрос: ").strip()

retrieve_k = 50   # сколько брать из FAISS до rerank
final_k = 5

# 1) retrieve
_, I = rag.search(q, k=retrieve_k)
idxs = [int(x) for x in I[0]]
docs = [rag.doc_store[i] for i in idxs]
texts = [d.page_content for d in docs]

# 2) rerank
scores = rag.rerank(q, texts, batch_size=4)

# 3) top-k после rerank
ranked = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)[:final_k]

print("\nTOP результаты:")
for rank, (doc, sc) in enumerate(ranked, start=1):
    print(f"\n#{rank}  score={sc:.4f}  id={doc.metadata.get('id')}")
    print(doc.page_content[:1200])  # чтобы не печатать слишком много
