In [7]:
from typing import List, Optional
import faiss
import numpy as np
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
from torch import Tensor
import torch.nn.functional as F
import os
import ijson

In [None]:
PATH = "arxiv-metadata-s.json"
MODEL = "Qwen/Qwen3-Embedding-0.6B"

tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)

lens = []
with open(PATH, "r", encoding="utf-8") as f:
    for obj in ijson.items(f, "item"):
        title = (obj.get("title") or "").strip()
        abstract = (obj.get("abstract") or "").strip()
        text = (title + "\n" + abstract).strip()
        ids = tokenizer(text, add_special_tokens=True, truncation=False)["input_ids"]
        lens.append(len(ids))

arr = np.array(lens, dtype=np.int32)
print("count:", arr.size)
print("mean tokens:", float(arr.mean()))
for p in [50, 90, 95, 99, 99.5, 99.9]:
    print(f"p{p}:", float(np.percentile(arr, p)))
print("max:", int(arr.max()))


KeyboardInterrupt: 

In [None]:
class RAG:

    def __init__(
        self,
        embedder_name: str = "Qwen/Qwen3-Embedding-0.6B",
        reranker_name: str = "Qwen/Qwen3-Reranker-0.6B",
        chunk_size: int = 500,
        chunk_overlap: int = 125,
        device: Optional[str] = None,
    ):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.emb_tokenizer = AutoTokenizer.from_pretrained(embedder_name)
        self.embedder = AutoModel.from_pretrained(embedder_name).to(self.device)
        self.embedder.eval()

        self.rr_tokenizer = AutoTokenizer.from_pretrained(reranker_name, padding_side='left')
        self.reranker = AutoModelForCausalLM.from_pretrained(reranker_name).to(self.device)
        self.reranker.eval()

        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap,)
        self.index = None
        self.doc_store = []

        self.max_length = 1024
        self.token_false_id = self.rr_tokenizer.convert_tokens_to_ids("no")
        self.token_true_id = self.rr_tokenizer.convert_tokens_to_ids("yes")
        prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
        suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self.prefix_tokens = self.rr_tokenizer.encode(prefix, add_special_tokens=False)
        self.suffix_tokens = self.rr_tokenizer.encode(suffix, add_special_tokens=False)

    def _generate_embeddings(self, texts: List[str]) -> np.ndarray:
        inputs = self.emb_tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=self.max_length,
        ).to(self.device)

        with torch.no_grad():
            outputs = self.embedder(**inputs)
        embeddings = self.last_token_pool(outputs.last_hidden_state,
                                          inputs.attention_mask).cpu()
        return F.normalize(embeddings, p=2, dim=1).numpy()

    @staticmethod
    def last_token_pool(last_hidden_states: Tensor,
                        attention_mask: Tensor) -> Tensor:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_states[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_states.shape[0]
            return last_hidden_states[
                torch.arange(batch_size, device=last_hidden_states.device),
                sequence_lengths] 

    def load_and_process_arxiv_json(self, file_path: str, split: bool = False) -> List[Document]:
        ext = os.path.splitext(file_path)[1].lower()
        if ext != ".json":
            raise ValueError(f"Expected .json file, got: {ext}")

        docs: List[Document] = []
        with open(file_path, "r", encoding="utf-8") as f:
            for obj in ijson.items(f, "item"):
                arxiv_id = obj.get("id")
                title = (obj.get("title") or "").strip()
                abstract = (obj.get("abstract") or "").strip()
                text = (title + "\n" + abstract).strip()
                meta = {
                    "id": arxiv_id,
                    "title": title,
                    "categories": obj.get("categories"),
                    "doi": obj.get("doi"),
                    "journal_ref": obj.get("journal-ref"),
                    "update_date": obj.get("update_date"),
                }

                docs.append(Document(page_content=text, metadata=meta))

        return self.text_splitter.split_documents(docs) if split else docs

    def build_index(self, file_path: str, batch_size: int = 64) -> None:
        all_docs = self.load_and_process_arxiv_json(file_path, split=False)
        self.doc_store = all_docs
        embs = []
        for i in range(0, len(all_docs), batch_size):
            batch_texts = [d.page_content for d in all_docs[i:i + batch_size]]
            embs.append(self._generate_embeddings(batch_texts))
        embeddings = np.concatenate(embs, axis=0).astype("float32")
        self.index = faiss.IndexFlatIP(embeddings.shape[1])
        self.index.add(embeddings)

    @staticmethod
    def get_detailed_instruct(task_description: str, query: str):
        return f'Instruct: {task_description}\nQuery:{query}'

    @staticmethod
    def format_reranker_instruction(query, doc, instruction=None):
        if instruction is None:
            instruction = 'Given a web search query, retrieve relevant passages that answer the query'
        output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
            instruction=instruction, query=query, doc=doc)
        return output

    def process_inputs(self, pairs):
        """Обработка данных для реранкера"""
        inputs = self.rr_tokenizer(pairs,
                                   padding=False,
                                   truncation='longest_first',
                                   return_attention_mask=False,
                                   max_length=self.max_length -
                                   len(self.prefix_tokens) -
                                   len(self.suffix_tokens))
        for i, ele in enumerate(inputs['input_ids']):
            inputs['input_ids'][
                i] = self.prefix_tokens + ele + self.suffix_tokens
        inputs = self.rr_tokenizer.pad(inputs,
                                       padding=True,
                                       return_tensors="pt",
                                       max_length=self.max_length)

        # переносим тензоры на девайс ранжирующей модели
        for key in inputs:
            inputs[key] = inputs[key].to(self.device)
        return inputs

    def search(self,
               query: str,
               k: int = 5,
               task: str = None):
        if self.index is None:
            raise ValueError("Index not initialized")

        if task is None:
            task = 'Given a web search query, retrieve relevant passages that answer the query'

        query_embedding = self._generate_embeddings([query])
        distances, indices = self.index.search(query_embedding, k)
        return distances, indices         

    @torch.no_grad()
    def compute_logits(self, inputs):
        batch_scores = self.reranker(**inputs).logits[:, -1, :]
        true_vector = batch_scores[:, self.token_true_id]
        false_vector = batch_scores[:, self.token_false_id]
        batch_scores = torch.stack([false_vector, true_vector], dim=1)
        batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
        scores = batch_scores[:, 1].exp().tolist()
        return scores

    def rerank(self, query: str, documents: List[str], batch_size=4):
        pairs = []
        for d in documents:
            pairs.append(self.format_reranker_instruction(query, d))

        scores = []
        for i in range(0, len(pairs), batch_size):
            inputs = self.process_inputs(pairs[i:i + batch_size])
            sc = self.compute_logits(inputs)
            scores.extend(sc)
        return scores            

In [9]:
q = "Keldysh formalism Andreev current heavy fermions"

k = 5
rag = RAG(device="cuda")
rag.build_index("./arxiv-metadata-s.json")

D, I = rag.search(q, k=k)
candidates = [rag.doc_store[i].page_content for i in I[0]]

for c in candidates:
    print(c[:800])  # чтобы не печатать всё
    print("-#" * 20)
    print()

Loading weights: 100%|██████████| 310/310 [00:00<00:00, 1517.06it/s, Materializing param=norm.weight]                              
Loading weights: 100%|██████████| 310/310 [00:00<00:00, 1338.27it/s, Materializing param=model.norm.weight]                              


TypeError: Got unsupported ScalarType BFloat16