# Generation

In [1]:
import json
import os
import chromadb
from dotenv import load_dotenv, find_dotenv
from flashrank import Ranker
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_cohere import CohereRerank
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.document_compressors.flashrank_rerank import FlashrankRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain


load_dotenv(find_dotenv())

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
DATABASE_PATH = os.getenv('DATABASE_PATH')
COHERE_API_KEY = os.getenv('COHERE_API_KEY')

EMBEDDING_MODEL = 'text-embedding-ada-002'

In [None]:
# loading questions

questions = json.load(open("../../Source/Questions/questions_seed_1.json"))
question = questions[0]["question"]
question

In [15]:
def pretty_output_text(text: str, words_per_line: int = 10) -> str:
    text_parts = text.split('\n')
    pretty_text = ''
    
    for text_part in text_parts:
        words = text_part.split(' ')
        for i, word in enumerate(words):
            pretty_text += word + ' '
            if (i + 1) % words_per_line == 0 and i != len(words) - 1:
                pretty_text += '\n'
        pretty_text += '\n'
    
    return pretty_text


def get_privat_path(source: str) -> str:
    parts = source.split('\\')[8:]
    return os.path.join(*parts)


def pretty_output_llm_answer(chain_output: dict, line_len: int=1_000_000, print_privat_path=True) -> str:
    print(f"Question: {question}")
    print("-" * 150, end='\n\n')
    
    sources = None
    if not print_privat_path:
        sources = " \n".join(f"{doc.metadata['source']} - {doc.metadata['page_number']}" for doc in chain_output["context"])
    else:
        sources = " \n".join(f"{get_privat_path(doc.metadata['source'])}" for doc in chain_output["context"])
    
    full_answer = f"{pretty_output_text(chain_output['answer'], line_len)}\n\nQuellen:\n{sources}"
    
    print(full_answer)

In [27]:
LLM = ChatOpenAI(
    api_key=OPENAI_API_KEY,
    model="gpt-4o-mini",
    temperature=0.0,
)

PROMPT = ChatPromptTemplate([
    ("system", """Du bist ein Assistant einer öffentlichen Behörde und deine Aufgabe ist es, Fragen nur auf Basis des bereitgestellten Kontexts zu beantworten.

- Wenn die Frage anhand des gegebenen Kontexts beantwortet werden kann, beantworte sie unter Einbeziehung relevanter Paragrafen, Gesetze oder Vorschriften, die im Kontext erwähnt werden.
- Wenn die Frage im Kontext nicht eindeutig beantwortet werden kann oder keine ausreichenden Informationen vorliegen, gib an, dass du die Frage nicht beantworten kannst.
- Achte besonders darauf, dass du keine Informationen hinzufügst, die nicht im Kontext enthalten sind.

Am Ende deiner Antwort weise bitte darauf hin, dass du ein ChatBot bist und die Antwort unbedingt von einer qualifizierten Person überprüft werden sollte.

<kontext>
{context}
</kontext>"""),
    ("human", "Frage: {input}")
])


def get_vectorstore(type: str):
    """ Create a vectorstore.
    
    Args:
        type (str): "basic" or "by_title"
    """
    collection_name = None
    if type == "basic":
        collection_name = 'collection_1500'
    elif type == "by_title":
        collection_name = 'collection_1800'
    else:
        raise ValueError(f"Invalid type: {type}")
    
    db_path = os.path.join(DATABASE_PATH, "Unstructured", type, EMBEDDING_MODEL)
    db_client = chromadb.PersistentClient(db_path)
    
    return Chroma(
        collection_name=collection_name,
        client=db_client,
        embedding_function=OpenAIEmbeddings(model=EMBEDDING_MODEL, api_key=OPENAI_API_KEY),
        create_collection_if_not_exists=False
    )


def get_retriever(type: str, k: int, mq: bool=False):
    """ Create a retriever.
    
    Args:
        type (str): "basic" or "by_title"
        k (int): Number of documents to retrieve
        mq (bool): Whether to use MultiQueryRetriever
    """
    vectorstore = get_vectorstore(type)
    
    retriever = vectorstore.as_retriever(
        search_type='similarity',
        search_kwargs={
            'k': k,
        }
    )
    
    if not mq:
        return retriever
    else:
        mq_retriever = MultiQueryRetriever.from_llm(
            retriever=retriever,
            llm=LLM,
        )
        return mq_retriever


def get_reranker(k: int, type: str):
    """ Create a reranker.
    
    Args:
        k (int): Number of documents to rerank
        type (str): "flashrank" or "cohere"
    """
    if type == "flashrank":
        client = Ranker(
            model_name='rank-T5-flan',
            max_length=4096,
        )
        return FlashrankRerank(
            client=client,
            top_n=k,
        )
    elif type == "cohere":
        return CohereRerank(
            top_n=k,
            model='rerank-multilingual-v3.0',
            cohere_api_key=COHERE_API_KEY,
        )
    else:
        raise ValueError(f"Invalid type: {type}")


def process(question: str, k: int, type: str, rerank_k: int=3, rerank_type: str=None, mq: bool=False):
    """ Process a question.
    
    Args:
        question (str): Question to process
        k (int): Number of documents to retrieve
        type (str): "basic" or "by_title"
        rerank_k (int): Number of documents to rerank. Defaults to 3.
        rerank_type (str): "flashrank" or "cohere". Defaults to None.
        mq (bool, optional): Whether to use MultiQueryRetriever. Defaults to False.
    """
    
    retriever = get_retriever(type, k, mq)
    reranker = get_reranker(rerank_k, rerank_type) if rerank_type else None
    
    if rerank_type:
        retriever = ContextualCompressionRetriever(
            base_retriever=retriever,
            base_compressor=reranker,
        )
    retrieval_chain = create_retrieval_chain(
        retriever=retriever,
        combine_docs_chain=create_stuff_documents_chain(
            llm=LLM,
            prompt=PROMPT,
        )
    )
    return retrieval_chain.invoke({"input": question})

In [28]:
answer_dict = process(question, 5, "basic", 2, "cohere", True)

In [None]:
pretty_output_llm_answer(answer_dict, 10, print_privat_path=True)