In [15]:
from langchain.prompts import PromptTemplate
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import RetrievalQA
import chainlit as cl

In [16]:
db_faiss_path = "vectorstores/db_faiss"


In [17]:
custom_propmt_template = """
Use the following pieces of information to answet the user's question.
If you don't know the answer, please just say that you don't know the answer, don't try to make up an answer.

Context: {}
Questions: {question}

Only returns the helpful answer below and nothing else.
Helpful answer:
"""

In [18]:
def set_custom_propmt():
    """
    Propmt template for QA retrieval for each vector stores
    """
    prompt = PromptTemplate(template=custom_propmt_template, imput_variables=['context',
                                                                              'question'])
    return prompt

In [20]:
def load_llm():
    llm = CTransformers(
        model ="llama-2-7b-chat.ggmlv3.q4_1.bin",
        model_type = "llama",
        max_new_tokens = 512,
        temperature = 0.5
    )
    return llm

In [21]:
def retrieval_qa_chain(llm,propmt,db):
    qa_chain = RetrievalQA.from_chain_type(
        llm = llm,
        chain_type = "stuff",
        retriever = db.as_retriever(search_kwargs ={'k':2}),
        return_source_documents = True,
        chain_type_kwargs = {'propmt':prompt}

    )
    return qa_chain

In [22]:
def qa_bot():
    embeddings = HuggingFaceBgeEmbeddings(model_name='sentence-transformers/all-MINILM-L6-v2',
                                          model_kwargs={'device':'cpu'})
    db = FAISS.load_local(db_faiss_path, embeddings)
    llm = load_llm()
    qa_prompt = set_custom_propmt()
    qa = retrieval_qa_chain(llm, qa_prompt, db)

    return qa
    

In [23]:
def final_result(query):
    qa_result = qa_bot()
    response = qa_result({'query':query})
    return response

Chainlit code

In [24]:
@cl.on_chat_start
async def start():
    chain = qa_bot()
    msg = cl.Message(content= "Starting the bot ....")
    await msg.send()
    msg.content = "Hi, welcome to the bot. Type your query."
    await msg.update()
    cl.user_session.set("chain", chain)


@cl.on_message
async def main(message):
    chain = cl.user_session.get("chain")
    cb = cl.AsyncLangchainCallbackHandler(
        stream_final_answer= True, answer_prefix_tokens=
        ["FINAL", "ANSWER"] )
    cb.answer_reached = True
    res = await chain.acall(message, callbacks = [cb])
    answer = res["result"]
    sources = res["source_documents"]

    if sources:
        answer += f"\nSources" + str(sources)

    else:
        answer += f"\nNo Sources found"

    await cl.Message(content = answer).send()