In [None]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_together.embeddings import TogetherEmbeddings

from langchain.schema import format_document
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage, get_buffer_string
from langchain_core.runnables import RunnableParallel

from langchain.prompts.prompt import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain.memory import ChatMessageHistory

from langchain.retrievers import MultiQueryRetriever

from helpers import get_vectorstore, save_response_to_markdown_file, read_sample

In [None]:
from models import get_together_nous_mix, get_together_quen
ACTIVE_LLM = get_together_nous_mix()
def get_retriever(filename, context_size = "8k"):
    model = f"togethercomputer/m2-bert-80M-{context_size}-retrieval"
    embedder = TogetherEmbeddings(model=model)
    local_vector_path = f"{filename[:-4]}-embeddings"

    vectorstore = get_vectorstore(embedder = embedder, local_vector_path = local_vector_path)
    retriever = vectorstore.as_retriever()
    return retriever

In [None]:
PDF_FILENAME = "yang.pdf"
from os.path import exists
assert exists(f"vector-dbs/{PDF_FILENAME[:-4]}-embeddings"), "Embeddings not found. Run DocQuery first."
retriever = get_retriever(PDF_FILENAME, context_size="8k")

In [None]:
# BACKGROUND_OG = "Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language."
standalone_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.

Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""

CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(standalone_template)

rag_template = """Answer the question based only on the following context:
{context}

Question: {question}
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(rag_template)
ANSWER_PROMPT.messages.insert(0, 
   SystemMessage(
       content="You are a precise, autoregressive question-answering system."
   )
  )
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

def _combine_documents(
    docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
    doc_strings = [format_document(doc, document_prompt) for doc in docs]
    return document_separator.join(doc_strings)

In [None]:
#"""# _inputs = RunnableParallel(
#     standalone_question=RunnablePassthrough.assign(
#         chat_history=lambda x: get_buffer_string(x["chat_history"])
#     )
#     | CONDENSE_QUESTION_PROMPT
#     | ACTIVE_LLM
#     | StrOutputParser(),
# )
# _context = {
#     "context": itemgetter("standalone_question") | retriever | _combine_documents,
#     "question": lambda x: x["standalone_question"],
# }
# conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ACTIVE_LLM
# def get_conversational_qa_response(question: str, chat_history: list = None):
#     conversational_chat_history = chat_history
#     if conversational_chat_history is None:
#         conversational_chat_history = ChatMessageHistory()
#     response = conversational_qa_chain.invoke(
#         {
#             "question": question,
#             "chat_history": conversational_chat_history,
#         }
#     )
#     return response

# history = ChatMessageHistory()
# history.add_user_message("Only answer Yes or No. Does this paper discuss chromosomes?")
# history.add_ai_message("No")

# query = "Tell me more the main ideas it does discuss."
# response = get_conversational_qa_response(question = query, chat_history = history.messages)"""

In [None]:
# Set memory
memory = ConversationBufferMemory(
    return_messages=True, input_key="question", output_key="answer"
)
chat_context = {"question": "Only answer Yes or No. Does this paper discuss chromosomes?"}, {"answer": "Yes"}
memory.save_context(*chat_context)

In [None]:
memory.buffer_as_str

In [None]:
memory = ConversationBufferMemory(
    return_messages=True, input_key="question", output_key="answer"
)
loaded_memory_reg = RunnablePassthrough.assign(
    chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
)
regular_chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | loaded_memory_reg
    | ANSWER_PROMPT
    | ACTIVE_LLM
    | StrOutputParser()
)
sample_prompt = read_sample()
response = regular_chain.invoke(sample_prompt)

In [None]:
inputs = {"question": sample_prompt}
memory.save_context(inputs, {"answer": response})

In [None]:
# First we add a step to load memory
# This adds a "memory" key to the input object
loaded_memory = RunnablePassthrough.assign(
    chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
)
# Now we calculate the standalone question
standalone_question = {
    "standalone_question": {
        "question": lambda x: x["question"],
        "chat_history": lambda x: get_buffer_string(x["chat_history"]),
    }
    | CONDENSE_QUESTION_PROMPT
    | ACTIVE_LLM
    | StrOutputParser(),
}
# Now we retrieve the documents
retrieved_documents = {
    "docs": itemgetter("standalone_question") | retriever,
    "question": lambda x: x["standalone_question"],
}
# Now we construct the inputs for the final prompt
final_inputs = {
    "context": lambda x: _combine_documents(x["docs"]),
    "question": itemgetter("question"),
}
# And finally, we do the part that returns the answers
answer = {
    "answer": final_inputs | ANSWER_PROMPT | ACTIVE_LLM,
    "docs": itemgetter("docs"),
}
# And now we put it all together!
final_chain = loaded_memory | standalone_question | retrieved_documents | answer

In [None]:
# BACKUP 
# QUESTION = f"""Given the following instructions, help me create specific steps to test and adjust the retrieval threshold.
# {read_sample()}
# """
# inputs = {"question": QUESTION}
# result = final_chain.invoke(inputs)
# memory.save_context(inputs, {"answer": result["answer"].content})

In [None]:
# QUESTION = f"""Given the following instructions, help me create specific steps to test and adjust the retrieval threshold.
# {read_sample()}
# """
QUESTION = read_sample()
inputs = {"question": QUESTION}
result = final_chain.invoke(inputs)
memory.save_context(inputs, {"answer": response.content})

In [None]:
# save_response_to_markdown_file(response.content, "response.md")
save_response_to_markdown_file(response, "response.md")

In [None]:
def print_history_from_memory(chat_memory):
    msgs = chat_memory.buffer_as_str.split("\n")
    count = 1
    text = ""
    for msg in msgs:
        text += msg + "\n"
        count += 1
    print(text)

print_history_from_memory(memory)

In [None]:
# NEW_QUESTION = f"""Given the following instructions, help me create specific steps to test and adjust the retrieval threshold.
# {read_sample()}
# """
NEW_QUESTION = read_sample()
inputs = {"question": NEW_QUESTION}
result = final_chain.invoke(inputs)
memory.save_context(inputs, {"answer": result["answer"].content})