In [1]:
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.document_loaders import DirectoryLoader, PythonLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain.chains.combine_documents import create_stuff_documents_chain

load_dotenv()

True

In [2]:
# load doc
loader = DirectoryLoader(
    '../../data/raw/openja/lightfm/tests', 
    glob="**/*.py", 
    show_progress=True, 
    #use_multithreading=True,
    loader_cls=PythonLoader
)
docs = loader.load()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(docs)

vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings())
retriever = vectorstore.as_retriever(k=4)
docs = retriever.invoke("How many test functions are there?")

# define prompt and chat
prompt = ChatPromptTemplate.from_messages([
    ("system", "Analyze the test functions from the codes below:\n\n{context}"),
    MessagesPlaceholder(variable_name="messages")
])
chat = ChatOpenAI(model='gpt-4')

# combine prompt, chat and doc
docs_chain = create_stuff_documents_chain(chat, prompt)

for chunk in docs_chain.stream({
    "context": docs,
    "messages": [
        HumanMessage(content="How many test functions are there? Can you list them all?")
    ],
}):
    print(chunk, end="", flush=True)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 5318.50it/s]


There are three test functions in the provided code. They are:

1. test_basic_fetching_stackexchange
2. test_bpr_precision
3. test_bpr_precision_multithreaded

### wrapped into functions

In [3]:
from dotenv import load_dotenv

from langchain_community.document_loaders import PythonLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from langchain.chains.combine_documents import create_stuff_documents_chain

from langchain.memory import ChatMessageHistory
from langchain_core.messages import AIMessage, HumanMessage

load_dotenv()

True

In [4]:
def load_test_file(path):
    loader = PythonLoader(path)
    py = loader.load()
    py_splits = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0).split_documents(py)
    return py_splits

def get_ai_response(message, py_splits, history=None, chain=None):
    if chain is None:
        prompt = ChatPromptTemplate.from_messages([
            ("system", "You are a coder analyzer. Please understand the code and answer the question as accurate as possible. Analyze the test functions from the codes below:\n\n{context}"),
            MessagesPlaceholder(variable_name="messages")
        ])
        chat = ChatOpenAI(model='gpt-4')

        chain = create_stuff_documents_chain(chat, prompt)
        
    if history is None:
        history = ChatMessageHistory()

    history.add_user_message(message)
    resp = chain.invoke({
        "context": py_splits, 
        "messages": history.messages
    })

    history.add_ai_message(resp)

    return resp, history, chain

In [5]:
py_splits = load_test_file('../../data/raw/openja/lightfm/tests/test_evaluation.py')

resp, history, chain = get_ai_response(
    message="How many functions are defined in the code? list them all",
    py_splits=py_splits
)

print(resp)

There are 12 functions defined in the code. They are as follows:

1. _generate_data
2. _precision_at_k
3. _recall_at_k
4. _auc
5. test_precision_at_k
6. test_precision_at_k_with_ties
7. test_recall_at_k
8. test_auc_score
9. test_intersections_check
10. model.predict
11. evaluation.precision_at_k
12. evaluation.recall_at_k


In [6]:
resp, history, _ = get_ai_response(
    message="What is each of the functions doing?",
    py_splits=py_splits,
    history=history,
    chain=chain
)

print(resp)

1. `_generate_data`: This function generates a dataset where every user has interactions in both the training and the test set. It takes number of users, number of items, density, and test fraction as the parameters.

2. `_precision_at_k`: This function calculates precision at 'k' using a given model, ground_truth, 'k' value, and optionally train data, user features and item features. Precision at 'k' is the proportion of recommended items in the top-k set that are relevant.

3. `_recall_at_k`: This function calculates recall at 'k' using a given model, ground_truth, 'k' value, and optionally train data, user features and item features. Recall at 'k' is the proportion of relevant items found in the top-k recommendations.

4. `_auc`: This function calculates Area Under the ROC Curve (AUC) using a given model, ground_truth, and optionally train data, user features and item features.

5. `test_precision_at_k`: This function tests the 'precision_at_k' function by generating data, fitting a

In [7]:
resp, history, _ = get_ai_response(
    message="Which of them are related to ML pipeline test cases?",
    py_splits=py_splits,
    history=history,
    chain=chain
)

print(resp)

The following functions are related to Machine Learning (ML) pipeline test cases:

1. `test_precision_at_k`: This function tests the precision at 'k' function, which is a common evaluation metric in recommendation systems.

2. `test_precision_at_k_with_ties`: This function tests the precision at 'k' function in a special scenario where all predictions are zero, ensuring the metric handles this edge case correctly.

3. `test_recall_at_k`: This function tests the recall at 'k' function, another common evaluation metric in recommendation systems.

4. `test_auc_score`: This function tests the Area Under the ROC Curve (AUC) function, a common evaluation metric for binary classification problems.

5. `test_intersections_check`: This function tests if the evaluation functions correctly handle situations where the training and test sets have common interactions. This is important for ensuring the evaluation metrics are being calculated correctly and the model isn't simply memorizing the traini