In [None]:
import os
from typing import List, Dict
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.document_loaders import TextLoader
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from huggingface_hub import hf_hub_download
from vllm import LLM, SamplingParams
from langchain_community.llms import VLLM
from datasets import Dataset
from ragas.metrics import faithfulness, answer_relevancy, context_precision, context_utilization, context_recall, answer_correctness 
from ragas import evaluate
import pandas as pd
from tqdm import tqdm

from langchain_community.llms import LlamaCpp
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler

import warnings
warnings.filterwarnings("ignore")

In [None]:
config = {'model': 'llama3.1-8b-quant4',
          'repo_id': 'lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF',
          'filename': 'Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf',
          'tokenizer': 'hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4',
          'embed_model': 'e5l', # краткое название, используется в именах
          'embed_model_full': "intfloat/multilingual-e5-large", # huggingface embed model
          'vectorstore_name': 'MILVUS', # база данных MILVUS / FAISS
          'chunk_size': 512,
          'chunk_overlap': 128,
          'llm_framework': 'VLLM' # VLLM, LLamaCpp
          'retriever': None,
          'chain_type': 'stuff',
}

class CustomRAGPipeline:
    def __init__(self, 
                 documents_path: str,
                 config: dict,
                 recalc_embedding: bool = False,
                 ):
        
        self.config = config
        self.documents_path = documents_path
        self.embedding_model = self.config['embedding_model']
        
        self.vectorstore = None
        self.qa_chain = None

        self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model)

        self.vectorstore_path = '_'.join([self.config['embed_model'], 
                                          self.config['vectorstore_name'], 
                                          str(self.config['chunk_size']), 
                                          str(self.config['chunk_overlap'])])

        if not recalc_embedding:
            if os.path.exists(self.vectorstore_path) and self.config['vectorstore_name'] == 'FAISS':
                self.vectorstore = FAISS.load_local(self.vectorstore_path, self.embeddings, allow_dangerous_deserialization=True)
            elif recalc_embedding and os.path.isfile(f"{self.vectorstore_path}.db") and self.config['vectorstore_name'] == 'MILVUS':
                self.vectorstore = Milvus(
                    self.embeddings,
                    connection_args={"uri": f"./{self.vectorstore_path}.db"},
                    collection_name="RAG",
                )

        if self.config['llm_framework'] == 'VLLM':
            self.llm = self.load_vllm_model()
        elif self.config['llm_framework'] == 'LLamaCpp':
            self.llm = self.load_llama_cpp_model()
            
            
    def load_vllm_model(self):
        # Load the vLLM model from HuggingFace Hub
        repo_id = self.config['repo_id']
        filename = self.config['filename']
        tokenizer = self.config['tokenizer']
        model_path = hf_hub_download(repo_id, filename=filename)
        
        # Initialize vLLM with the downloaded model
        vllm_llm = VLLM(model=model_path,
                        vllm_kwargs={"quantization": "awq", 
                                     'max_model_len': 13000,
                                     'gpu_memory_utilization': 0.75},
                        temperature=0.75,
                        )
        
        return vllm_llm


    def load_llama_cpp_model(self):
        repo_id = self.config['repo_id']
        filename = self.config['filename']
        model_path = hf_hub_download(repo_id, filename=filename)
        
        # Инициализация модели LlamaCpp
        llama_cpp_llm = LlamaCpp(model_path=model_path,
                                temperature=0.8,
                                top_p=0.95,
                                top_k=30,
                                max_tokens=64,
                                n_ctx=13000,
                                n_parts=-1,
                                n_gpu_layers=64,
                                n_threads=8,
                                frequency_penalty=1.1,
                                verbose=True,
                                stop=["<|eot_id|>"],  # Остановка на токене EOS
                                )
        
        return llama_cpp_llm

    def load_and_process_documents(self):
        if not self.vectorstore:
            # Load documents from the specified path
            loader = TextLoader(self.documents_path)
            documents = loader.load()
            
            # Split the documents into chunks
            text_splitter = CharacterTextSplitter(
                        separator=" ",
                        chunk_size=self.config['chunk_size'],
                        chunk_overlap=self.config['chunk_overlap'],
                        length_function=len,
                        is_separator_regex=False,
                    )
            texts = text_splitter.split_documents(documents)
            
            if self.config['vectorstore_name'] == 'FAISS':
                # Create a FAISS vector store from the documents
                self.vectorstore = FAISS.from_documents(texts, self.embeddings)
                self.vectorstore.save_local(self.vectorstore_path)
            elif self.config['vectorstore_name'] == 'MILVUS':
                Milvus.from_documents(
                    texts,
                    self.embeddings,
                    collection_name="RAG",
                    connection_args={"uri": f"./{self.vectorstore_path}.db"})
                
    def setup_qa_chain(self, custom_prompt: str = None):
        retriever = self.vectorstore.as_retriever()
        
        prompt_template = PromptTemplate(
            input_variables=["context", "question"],
            template=custom_prompt
        )
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type=self.config['chain_type'],
            retriever=retriever,
            return_source_documents=True,
            chain_type_kwargs={"prompt": prompt_template}
        )
    
    def query(self, question: str) -> Dict:
        if not self.qa_chain:
            raise ValueError("QA chain not set up. Call setup_qa_chain() first.")
        
        # Run the QA chain with the provided question
        return self.qa_chain({"query": question})
                    
                    
# Usage example
if __name__ == "__main__":
    # Initialize the pipeline                   
    rag_pipeline = CustomRAGPipeline(documents_path="hmao_npa.txt", config=config)
    
    # Load and process documents
    rag_pipeline.load_and_process_documents()