In [1]:
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.document_loaders import DirectoryLoader, PythonLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain.chains.combine_documents import create_stuff_documents_chain

load_dotenv()

True

In [2]:
# load doc
loader = DirectoryLoader(
    '../../data/raw/openja/lightfm/tests', 
    glob="**/*.py", 
    show_progress=True, 
    #use_multithreading=True,
    loader_cls=PythonLoader
)
docs = loader.load()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(docs)

vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings())
retriever = vectorstore.as_retriever(k=4)
docs = retriever.invoke("How many test functions are there?")

# define prompt and chat
prompt = ChatPromptTemplate.from_messages([
    ("system", "Analyze the test functions from the codes below:\n\n{context}"),
    MessagesPlaceholder(variable_name="messages")
])
chat = ChatOpenAI(model='gpt-4')

# combine prompt, chat and doc
docs_chain = create_stuff_documents_chain(chat, prompt)

for chunk in docs_chain.stream({
    "context": docs,
    "messages": [
        HumanMessage(content="How many test functions are there? Can you list them all?")
    ],
}):
    print(chunk, end="", flush=True)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 5318.50it/s]


There are three test functions in the provided code. They are:

1. test_basic_fetching_stackexchange
2. test_bpr_precision
3. test_bpr_precision_multithreaded