In [None]:
%load_ext dotenv
%dotenv .env

In [None]:
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts.prompt import PromptTemplate

import os

In [None]:
persist_directory = os.environ["CHROMADB_FOLDER"]

embedding = OpenAIEmbeddings()
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)

In [None]:
model = os.environ["OPENAI_MODEL"]

llm = ChatOpenAI(model_name=model, temperature=0.0, max_tokens=256)
qa = ConversationalRetrievalChain.from_llm(
    llm=llm,
    chain_type="stuff",
    retriever=vectordb.as_retriever(),
    return_source_documents=True,
)

In [None]:
chat_history = []


def get_sources(documents):
    return set([doc.metadata.get("sources") for doc in documents])


def ask(query: str):
    result = qa({"question": query, "chat_history": chat_history})
    answer = result["answer"]
    sources = {
        " ".join(
            list(set([doc.metadata["source"] for doc in result["source_documents"]]))
        )
    }

    chat_history.append((query, answer))
    print(f"Question: {query}")
    print(f"Answer:   {answer}")
    print(f"Sources:  {sources}")

In [None]:
ask(query="What is multi-head attention?")