In [None]:
import os
from typing import List, Dict
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.document_loaders import TextLoader
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from huggingface_hub import hf_hub_download
from vllm import LLM, SamplingParams
from langchain_community.llms import VLLM
from datasets import Dataset
from ragas.metrics import faithfulness, answer_relevancy, context_precision, context_utilization, context_recall, answer_correctness 
from ragas import evaluate
import pandas as pd
from tqdm import tqdm

from langchain_community.llms import LlamaCpp
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler

import warnings
warnings.filterwarnings("ignore")

In [None]:
config = {
    'model_name': 'llama3.1-8b-q4',  # llama3.1-8b-q4 / gemma-2-9b-it-simpo-q4 / tlite-q4
    'embed_model_name_short': 'e5l', # e5l (multilingual-e5-large) /
    'chunk_size': 512,
    'chunk_overlap': 128,
    'llm_framework': 'VLLM', # VLLM, LLamaCpp, Ollama
    'vectorstore_name': 'MILVUS', # база данных MILVUS / FAISS
    'retriever_type': None,
    'reranker_type': None,
    'chain_type': 'stuff',
}

llama_config = {
    'repo_id': 'lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF',
    'filename': 'Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf',
    'tokenizer': 'hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4'
    }

gemma_config = {
    'repo_id': "mannix/gemma2-9b-simpo",
    'llm_framework': 'Ollama'
    }

tlite_config = {
    'repo_id': 'mradermacher/saiga_tlite_8b-GGUF',
    'filename': 'saiga_tlite_8b.Q4_K_M.gguf',
    'tokenizer': 'IlyaGusev/saiga_tlite_8b'
    }

def update_config_with_model(config, llama_config, gemma_config, tlite_config):
    if config['model_name'] == 'llama3.1-8b-q4':
        config.update(llama_config)
    elif config['model_name'] == 'gemma-2-9b-it-simpo-q4':
        config.update(gemma_config)
    elif config['model_name'] == 'tlite-q4':
        config.update(tlite_config)
    else:
        ValueError('Incorrect model_name: choose from llama3.1-8b-q4, gemma-2-9b-it-simpo-q4, or tlite-q4')
    
    if config['embed_model_name_short'] == 'e5l':
        config['embedding_model'] = "intfloat/multilingual-e5-large"


update_config_with_model(config, llama_config, gemma_config, tlite_config)

class CustomRAGPipeline:
    def __init__(self, 
                 documents_path: str,
                 config: dict,
                 recalc_embedding: bool = False,
                 ):
        
        self.config = config
        self.documents_path = documents_path
        self.embedding_model = self.config['embedding_model']
        
        self.vectorstore = None
        self.qa_chain = None

        self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model)

        self.vectorstore_path = '_'.join([self.config['embed_model'], 
                                          self.config['vectorstore_name'], 
                                          str(self.config['chunk_size']), 
                                          str(self.config['chunk_overlap'])])

        if not recalc_embedding:
            if os.path.exists(self.vectorstore_path) and self.config['vectorstore_name'] == 'FAISS':
                self.vectorstore = FAISS.load_local(self.vectorstore_path, self.embeddings, allow_dangerous_deserialization=True)
            elif os.path.isfile(f"{self.vectorstore_path}.db") and self.config['vectorstore_name'] == 'MILVUS':
                self.vectorstore = Milvus(
                    self.embeddings,
                    connection_args={"uri": f"./{self.vectorstore_path}.db"},
                    collection_name="RAG",
                )

        if self.config['llm_framework'] == 'VLLM':
            self.llm = self.load_vllm_model()
        elif self.config['llm_framework'] == 'LLamaCpp':
            self.llm = self.load_llama_cpp_model()
        elif self.config['llm_framework'] == 'Ollama':
            self.llm = self.load_ollama_model()
            
            
    def load_vllm_model(self):
        # Load the vLLM model from HuggingFace Hub
        repo_id = self.config['repo_id']
        filename = self.config['filename']
        tokenizer = self.config['tokenizer']
        model_path = hf_hub_download(repo_id, filename=filename)
        
        # Initialize vLLM with the downloaded model
        vllm_llm = VLLM(model=model_path,
                        vllm_kwargs={"quantization": "awq", 
                                     'max_model_len': 13000,
                                     'gpu_memory_utilization': 0.75},
                        temperature=0.75,
                        stop=["<|eot_id|>"]
                        )
        
        return vllm_llm


    def load_llama_cpp_model(self):
        repo_id = self.config['repo_id']
        filename = self.config['filename']
        model_path = hf_hub_download(repo_id, filename=filename)
        
        # Инициализация модели LlamaCpp
        llama_cpp_llm = LlamaCpp(model_path=model_path,
                                temperature=0.8,
                                top_p=0.95,
                                top_k=30,
                                max_tokens=64,
                                n_ctx=13000,
                                n_parts=-1,
                                n_gpu_layers=64,
                                n_threads=8,
                                frequency_penalty=1.1,
                                verbose=True,
                                stop=["<|eot_id|>"],  # Остановка на токене EOS
                                )
        
        return llama_cpp_llm

    def load_ollama_model(self):
        return OllamaLLM(model = self.config['repo_id'], 
                         temperature=0.8,
                         top_p=0.95,
                         top_k=30,
                         max_tokens=512,
                         stop=["<|eot_id|>"])
    
    def load_and_process_documents(self):
        if not self.vectorstore:
            # Load documents from the specified path
            loader = TextLoader(self.documents_path)
            documents = loader.load()
            
            # Split the documents into chunks
            text_splitter = CharacterTextSplitter(
                        separator=" ",
                        chunk_size=self.config['chunk_size'],
                        chunk_overlap=self.config['chunk_overlap'],
                        length_function=len,
                        is_separator_regex=False,
                    )
            texts = text_splitter.split_documents(documents)
            
            if self.config['vectorstore_name'] == 'FAISS':
                # Create a FAISS vector store from the documents
                self.vectorstore = FAISS.from_documents(texts, self.embeddings)
                self.vectorstore.save_local(self.vectorstore_path)
            elif self.config['vectorstore_name'] == 'MILVUS':
                Milvus.from_documents(
                    texts,
                    self.embeddings,
                    collection_name="RAG",
                    connection_args={"uri": f"./{self.vectorstore_path}.db"})
                
    def setup_qa_chain(self, custom_prompt: str = None):
        retriever = self.vectorstore.as_retriever()
        
        prompt_template = PromptTemplate(
            input_variables=["context", "question"],
            template=custom_prompt
        )
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type=self.config['chain_type'],
            retriever=retriever,
            return_source_documents=True,
            chain_type_kwargs={"prompt": prompt_template}
        )
    
    def query(self, question: str) -> Dict:
        if not self.qa_chain:
            raise ValueError("QA chain not set up. Call setup_qa_chain() first.")
        
        # Run the QA chain with the provided question
        return self.qa_chain({"query": question})
                    
                    
# Usage example
if __name__ == "__main__":
    # Initialize the pipeline                   
    rag_pipeline = CustomRAGPipeline(documents_path="hmao_npa.txt", config=config)
    
    # Load and process documents
    rag_pipeline.load_and_process_documents()

In [None]:
system_prompt = '''Use the following pieces of context to answer the question at the end. 
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Think step by step. Give full answer. Answer only in Russian. If context doesnt match the answer, say that you do not know the answer.
{context}'''
user_prompt = '''Question: {question}
Answer:'''

custom_prompt = f"""
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
{system_prompt}
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
{user_prompt}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>"""

rag_pipeline.setup_qa_chain(custom_prompt)

if __name__ == "__main__":
    # result = rag_pipeline.query("Какой герб изображен на бланках и штампах Комитета по средствам массовой информации и полиграфии Ханты-Мансийского автономного округа?")
    # print(result['result'])
    # result = rag_pipeline.query("Что такое должностной оклад и как он рассчитывается?")
    # result = rag_pipeline.query("Какие мероприятия проводит Департамент охраны окружающей среды и экологической безопасности автономного округа в 2010 году?")
    result = rag_pipeline.query('Когда юридические лица и ИП должны сообщать об аварийных выбросах?')
    print(result['result'])

In [None]:
# для создания датасета для рагаса, потом удалить

def create_ragas_dataset(rag_pipeline, eval_dataset):
    rag_dataset = []
    for index, row in tqdm(eval_dataset.iterrows()):
        answer = rag_pipeline.query(row["question"])
        rag_dataset.append(
            {"question" : row["question"],
             "answer" : answer["result"],
             "contexts" : [context.page_content for context in answer["source_documents"]],
             "ground_truth" : row["ground_truth"]
             }
        )
    rag_df = pd.DataFrame(rag_dataset)
    rag_eval_dataset = Dataset.from_pandas(rag_df)
    return rag_eval_dataset

eval_dataset = pd.read_excel('v2_ragas_npa_dataset_firstPart.xlsx')
eval_dataset = eval_dataset.groupby('evolution_type', group_keys=False).apply(lambda x: x.sample(25, random_state=42)).copy()
eval_df = create_ragas_dataset(rag_pipeline, eval_dataset)

In [None]:
eval_df.save_to_disk('eval_df_baseline_t_lite.hf')