In [3]:
from typing import List, Tuple, Union
from sentence_transformers import SentenceTransformer, util
import torch

class Encoder:
    """
    Encoder class for generating embeddings from textual data using a SentenceTransformer model.
    """
    def __init__(self, model_name: str = 'cointegrated/rubert-tiny2', use_gpu: bool = False):
        """
        Initializes the Encoder with the given model name and device configuration.
        """
        if not model_name:
            raise ValueError("Model name cannot be empty.")

        try:
            self.device = 'cuda' if (use_gpu and torch.cuda.is_available()) else 'cpu'
            self.model = SentenceTransformer(model_name, device=self.device)
        except Exception as e:
            raise RuntimeError(f"Failed to load model '{model_name}': {str(e)}")

    def encode(self, data: Union[List[str], str]) -> torch.Tensor:
        """
        Encodes text(s) into embeddings.
        """
        try:
            if isinstance(data, str):
                data = [data]
            embeddings = self.model.encode(data, convert_to_tensor=True, device=self.device)
            return embeddings
        except Exception as e:
            raise RuntimeError(f"Encoding failed: {str(e)}")

class RAG:
    """
    Retrieval-Augmented Generation (RAG) class.
    """
    def __init__(self, encoder: Encoder):
        """
        Initializes the RAG class with the given encoder.
        """
        if not isinstance(encoder, Encoder):
            raise ValueError("The encoder must be an instance of Encoder.")
        self.encoder = encoder
        self.documents = []
        self.doc_embeddings = None

    def fit(self, documents: List[str]):
        """
        Encodes and stores document embeddings.
        """
        if not documents:
            raise ValueError("Document list cannot be empty.")
        self.documents = documents
        try:
            self.doc_embeddings = self.encoder.encode(documents)
        except Exception as e:
            raise RuntimeError(f"Document encoding failed: {str(e)}")

    def retrieve(self, query: str, retrieval_limit: int = 5, similarity_threshold: float = 0.5) -> Tuple[List[int], List[str]]:
        """
        Retrieves top-k most relevant documents for the query.
        """
        if self.doc_embeddings is None:
            raise ValueError("You must call fit() before retrieve().")
        if not (1 <= retrieval_limit <= 10):
            raise ValueError("retrieval_limit must be between 1 and 10.")
        if retrieval_limit > len(self.documents):
            raise ValueError("retrieval_limit cannot exceed number of documents.")
        if not (0.0 <= similarity_threshold <= 1.0):
            raise ValueError("similarity_threshold must be between 0 and 1.")

        try:
            #Кодирование запроса в вектор
            query_embedding = self.encoder.encode(query)
            #Расчёт косинусной схожести - как logit
            cosine_scores = util.cos_sim(query_embedding, self.doc_embeddings)[0]  # shape: (num_docs)
            #Выбор топ-N документов
            top_results = torch.topk(cosine_scores, k=retrieval_limit)

            relevant_indices = top_results.indices.tolist()
            relevant_scores = top_results.values.tolist()

            filtered_indices = [
                idx for idx, score in zip(relevant_indices, relevant_scores)
                if score >= similarity_threshold
            ]

            retrieved_docs = [self.documents[idx] for idx in filtered_indices]
            return filtered_indices, retrieved_docs
        except Exception as e:
            raise RuntimeError(f"Retrieval failed: {str(e)}")

    def _create_prompt_template(self, query: str, retrieved_docs: List[str]) -> str:
        """
        Creates a prompt template for generation.
        """

        prompt = "Instructions: Based on the relevant documents, generate a comprehensive response to the user's query.\n\n"
        prompt += "Relevant Documents:\n"
        for i, doc in enumerate(retrieved_docs):
            prompt += f"Document {i+1}: {doc}\n"
        prompt += f"\nUser Query: {query}\n"
        prompt += "Answer:"
        return prompt

    def _generate(self, query: str, retrieved_docs: List[str]) -> str:
        """
        Placeholder for text generation logic.
        """
        prompt = self._create_prompt_template(query, retrieved_docs)

        generated_response = f"(Simulated Response based on documents and query: '{query}')"
        return generated_response

    def run(self, query: str) -> str:
        """
        Runs full RAG pipeline.
        """
        _, retrieved_docs = self.retrieve(query)
        return self._generate(query, retrieved_docs)

In [5]:
documents = [
    "Machine learning is a method of data analysis that automates analytical model building.",
    "Artificial intelligence is intelligence demonstrated by machines, in contrast to the natural intelligence displayed by humans.",
    "Natural language processing is a subfield of linguistics, computer science, and artificial intelligence concerned with the interactions between computers and human language.",
    "Deep learning is a class of machine learning algorithms that uses multiple layers to progressively extract higher-level features from the raw input."
]

encoder = Encoder()
rag = RAG(encoder)
rag.fit(documents)

query = "Tell me about deep learning."
result_indices, result_documents  = rag.retrieve(query, retrieval_limit=2, similarity_threshold=0.6)

print(f'Result indices: {result_indices}')
print(f'Result documents: {result_documents}')

# >> Output:
# >> Result indices: [3, 0]
# >> Result documents:
# >> >> 'Deep learning is a class of machine learning algorithms that uses multiple layers to progressively extract higher-level features from the raw input.',
# >> >> 'Machine learning is a method of data analysis that automates analytical model building.'


Result indices: [3, 0]
Result documents: ['Deep learning is a class of machine learning algorithms that uses multiple layers to progressively extract higher-level features from the raw input.', 'Machine learning is a method of data analysis that automates analytical model building.']
