Skip to content

Commit

Permalink
Merge pull request #915 from arc53/feat/retrievers-class
Browse files Browse the repository at this point in the history
Update application files and fix LLM models, create new retriever class
  • Loading branch information
dartpain committed Apr 9, 2024
2 parents 7d2b8cb + 8d7a134 commit a37b922
Show file tree
Hide file tree
Showing 21 changed files with 475 additions and 267 deletions.
266 changes: 88 additions & 178 deletions application/api/answer/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

from pymongo import MongoClient
from bson.objectid import ObjectId
from transformers import GPT2TokenizerFast



from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.error import bad_request


Expand Down Expand Up @@ -62,9 +61,6 @@ async def async_generate(chain, question, chat_history):
return result


def count_tokens(string):
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
return len(tokenizer(string)['input_ids'])


def run_async_chain(chain, question, chat_history):
Expand Down Expand Up @@ -104,61 +100,11 @@ def get_vectorstore(data):
def is_azure_configured():
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME


def complete_stream(question, docsearch, chat_history, prompt_id, conversation_id, chunks=2):
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
if prompt_id == 'default':
prompt = chat_combine_template
elif prompt_id == 'creative':
prompt = chat_combine_creative
elif prompt_id == 'strict':
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]

if chunks == 0:
docs = []
else:
docs = docsearch.search(question, k=chunks)
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
# join all page_content together with a newline
docs_together = "\n".join([doc.page_content for doc in docs])
p_chat_combine = prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
source_log_docs = []
for doc in docs:
if doc.metadata:
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
else:
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})

if len(chat_history) > 1:
tokens_current_history = 0
# count tokens in history
chat_history.reverse()
for i in chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append({"role": "user", "content": question})

response_full = ""
completion = llm.gen_stream(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
messages=messages_combine)
for line in completion:
data = json.dumps({"answer": str(line)})
response_full += str(line)
yield f"data: {data}\n\n"

# save conversation to database
if conversation_id is not None:
def save_conversation(conversation_id, question, response, source_log_docs, llm):
if conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{"$push": {"queries": {"prompt": question, "response": response_full, "sources": source_log_docs}}},
{"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}},
)

else:
Expand All @@ -168,19 +114,50 @@ def complete_stream(question, docsearch, chat_history, prompt_id, conversation_i
"words, respond ONLY with the summary, use the same "
"language as the system \n\nUser: " + question + "\n\n" +
"AI: " +
response_full},
response},
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system"}]

completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
completion = llm.gen(model=gpt_model,
messages=messages_summary, max_tokens=30)
conversation_id = conversations_collection.insert_one(
{"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [{"prompt": question, "response": response_full, "sources": source_log_docs}]}
"queries": [{"prompt": question, "response": response, "sources": source_log_docs}]}
).inserted_id
return conversation_id

def get_prompt(prompt_id):
if prompt_id == 'default':
prompt = chat_combine_template
elif prompt_id == 'creative':
prompt = chat_combine_creative
elif prompt_id == 'strict':
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt


def complete_stream(question, retriever, conversation_id):


response_full = ""
source_log_docs = []
answer = retriever.gen()
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
data = json.dumps(line)
yield f"data: {data}\n\n"
elif "source" in line:
source_log_docs.append(line["source"])


llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
conversation_id = save_conversation(conversation_id, question, response_full, source_log_docs, llm)

# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
Expand Down Expand Up @@ -213,25 +190,31 @@ def stream():
chunks = int(data["chunks"])
else:
chunks = 2

prompt = get_prompt(prompt_id)

# check if active_docs is set

if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
source = {"active_docs": data_key["source"]}
elif "active_docs" in data:
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
source = {"active_docs": data["active_docs"]}
else:
vectorstore = ""
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
source = {}

if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
retriever_name = source['active_docs']

retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
)

return Response(
complete_stream(question, docsearch,
chat_history=history,
prompt_id=prompt_id,
conversation_id=conversation_id,
chunks=chunks), mimetype="text/event-stream"
)
complete_stream(question=question, retriever=retriever,
conversation_id=conversation_id), mimetype="text/event-stream")


@answer.route("/api/answer", methods=["POST"])
Expand All @@ -255,110 +238,40 @@ def api_answer():
chunks = int(data["chunks"])
else:
chunks = 2

if prompt_id == 'default':
prompt = chat_combine_template
elif prompt_id == 'creative':
prompt = chat_combine_creative
elif prompt_id == 'strict':
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]

prompt = get_prompt(prompt_id)

# use try and except to check for exception
try:
# check if the vectorstore is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = get_vectorstore({"active_docs": data_key["source"]})
source = {"active_docs": data_key["source"]}
else:
vectorstore = get_vectorstore(data)
# loading the index and the store and the prompt template
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
source = {data}

if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
retriever_name = source['active_docs']

retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
source=source, chat_history=history, prompt=prompt, chunks=chunks, gpt_model=gpt_model
)
source_log_docs = []
response_full = ""
for line in retriever.gen():
if "source" in line:
source_log_docs.append(line["source"])
elif "answer" in line:
response_full += line["answer"]

llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)


result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = save_conversation(conversation_id, question, response_full, source_log_docs, llm)


if chunks == 0:
docs = []
else:
docs = docsearch.search(question, k=chunks)
# join all page_content together with a newline
docs_together = "\n".join([doc.page_content for doc in docs])
p_chat_combine = prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
source_log_docs = []
for doc in docs:
if doc.metadata:
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
else:
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
# join all page_content together with a newline


if len(history) > 1:
tokens_current_history = 0
# count tokens in history
history.reverse()
for i in history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append({"role": "user", "content": question})


completion = llm.gen(model=gpt_model, engine=settings.AZURE_DEPLOYMENT_NAME,
messages=messages_combine)


result = {"answer": completion, "sources": source_log_docs}
logger.debug(result)

# generate conversationId
if conversation_id is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{"$push": {"queries": {"prompt": question,
"response": result["answer"], "sources": result['sources']}}},
)

else:
# create new conversation
# generate summary
messages_summary = [
{"role": "assistant", "content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the system \n\n"
"User: " + question + "\n\n" + "AI: " + result["answer"]},
{"role": "user", "content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the system"}
]

completion = llm.gen(
model=gpt_model,
engine=settings.AZURE_DEPLOYMENT_NAME,
messages=messages_summary,
max_tokens=30
)
conversation_id = conversations_collection.insert_one(
{"user": "local",
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [{"prompt": question, "response": result["answer"], "sources": source_log_docs}]}
).inserted_id

result["conversation_id"] = str(conversation_id)

# mock result
# result = {
# "answer": "The answer is 42",
# "sources": ["https://en.wikipedia.org/wiki/42_(number)", "https://en.wikipedia.org/wiki/42_(number)"]
# }
return result
except Exception as e:
# print whole traceback
Expand All @@ -375,27 +288,24 @@ def api_search():

if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
vectorstore = data_key["source"]
source = {"active_docs": data_key["source"]}
elif "active_docs" in data:
vectorstore = get_vectorstore({"active_docs": data["active_docs"]})
source = {"active_docs": data["active_docs"]}
else:
vectorstore = ""
source = {}
if 'chunks' in data:
chunks = int(data["chunks"])
else:
chunks = 2
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, vectorstore, settings.EMBEDDINGS_KEY)
if chunks == 0:
docs = []

if source["active_docs"].split("/")[0] == "default" or source["active_docs"].split("/")[0] == "local":
retriever_name = "classic"
else:
docs = docsearch.search(question, k=chunks)
retriever_name = source['active_docs']

source_log_docs = []
for doc in docs:
if doc.metadata:
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
else:
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
#yield f"data:{data}\n\n"
return source_log_docs
retriever = RetrieverCreator.create_retriever(retriever_name, question=question,
source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model
)
docs = retriever.search()
return docs

Loading

0 comments on commit a37b922

Please sign in to comment.