<center>
    <p style="text-align:center">
        <img alt="phoenix logo" src="https://storage.googleapis.com/arize-assets/phoenix/assets/phoenix-logo-light.svg" width="200"/>
        <br>
        <a href="https://docs.arize.com/phoenix/">Docs</a>
        |
        <a href="https://github.com/Arize-ai/phoenix">GitHub</a>
        |
        <a href="https://join.slack.com/t/arize-ai/shared_invite/zt-1px8dcmlf-fmThhDFD_V_48oU7ALan4Q">Community</a>
    </p>
</center>
<h1 align="center">Relevance Classification Evals</h1>

Arize provides tooling to evaluate LLM applications, including tools to determine the relevance or irrelevance of documents retrieved by retrieval-augmented generation (RAG) applications. This relevance is then used to measure the quality of each retrieval using ranking metrics such as precision@k. In order to determine whether each retrieved document is relevant or irrelevant to the corresponding query, our approach is straightforward: ask an LLM.

The purpose of this notebook is:

- to evaluate the performance of an LLM-assisted approach to relevance classification against information retrieval datasets with ground-truth relevance labels,
- to provide an experimental framework for users to iterate and improve on the default classification template.

## Install Dependencies and Import Libraries

In [None]:
!pip install -qq "arize-phoenix[experimental]==0.0.33rc6" ipython matplotlib openai pycm scikit-learn

In [None]:
import json
import os
from getpass import getpass
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import openai
from phoenix.experimental.evals import (
    PromptTemplate,
    download_benchmark_dataset,
)
from pycm import ConfusionMatrix
from sklearn.metrics import classification_report
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)
from tqdm import tqdm

## Download Benchmark Dataset

We'll evaluate the evaluation system consisting of an LLM model and settings in addition to an evaluation prompt template against benchmark datasets of queries and retrieved documents with ground-truth relevance labels. Currently supported datasets include:

- "wiki_qa-train"
- "ms_marco-v1.1-train"

In [None]:
df = download_benchmark_dataset(
    task="binary-relevance-classification", dataset_name="wiki_qa-test"
).sample(n=500, random_state=42)
df.head()

In [None]:
if not (openai_api_key := os.getenv("OPENAI_API_KEY")):
    openai_api_key = getpass("🔑 Enter your OpenAI API key: ")
openai.api_key = openai_api_key
os.environ["OPENAI_API_KEY"] = openai_api_key

Instantiate the LLM and set parameters.

## Run Relevance Classifications

Run relevance classifications against a subset of the data.

In [None]:
df = df.rename(
    columns={
        "query_text": "query",
        "document_text": "reference",
    },
)

In [None]:
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def openai_functions_classify(
    record: Dict[str, Any],
    prompt_template: PromptTemplate,
    classes: List[str],
    model_name: str,
    function_name: str,
    function_description: str,
    argument_name: str,
    argument_description: str,
    *,
    system_message: Optional[str] = None,
    require_explanation: bool = False,
) -> Tuple[Optional[str], Optional[str]]:
    if not all(variable_name in record for variable_name in prompt_template.variables):
        raise ValueError("All prompt template variables must be present as keys in record.")

    user_message_content = prompt_template.format(
        {variable_name: record[variable_name] for variable_name in prompt_template.variables}
    )
    messages = [{"role": "user", "content": user_message_content}]
    if system_message:
        messages.insert(0, {"role": "system", "content": system_message})
    argument_data = {
        argument_name: {
            "type": "string",
            "description": argument_description,
            "enum": classes,
        },
    }
    if require_explanation:
        argument_data["explanation"] = {
            "type": "string",
            "description": "A brief explanation of your reasoning for your answer.",
        }
    functions = [
        {
            "name": function_name,
            "description": function_description,
            "parameters": {
                "type": "object",
                "properties": argument_data,
                "required": [argument_name],
            },
        }
    ]
    response = openai.ChatCompletion.create(
        model=model_name,
        messages=messages,
        functions=functions,
        function_call={"name": function_name},
    )
    try:
        response_message = response["choices"][0]["message"]
        assert response_message["function_call"]["name"] == function_name
        function_arguments = json.loads(response_message["function_call"]["arguments"])
        return function_arguments[argument_name], function_arguments.get("explanation")
    except Exception:
        pass

    return None, None

In [None]:
prompt_template_string = """You are comparing a reference text to a question and trying to determine if the reference text contains information relevant to answering the question. Here is the data:
    [BEGIN DATA]
    ************
    [Question]: {query}
    ************
    [Reference text]: {reference}
    [END DATA]

Compare the question above to the reference text. You must determine whether the reference text contains information that can answer the question. Please focus on whether the very specific question can be answered by the information in the reference text."""
prompt_template = PromptTemplate(prompt_template_string)

model_name = "gpt-4"

relevance_classifications = []
explanations = []
for record in tqdm(df.to_dict(orient="records")):
    relevance_classification, explanation = openai_functions_classify(
        record=record,
        prompt_template=prompt_template,
        classes=["relevant", "irrelevant"],
        model_name=model_name,
        function_name="relevance",
        function_description="A function to record whether a reference text is relevant to a question.",
        argument_name="relevant",
        argument_description="A string indicating whether the reference text is relevant to the question.",
        require_explanation=True,
    )
    relevance_classifications.append(relevance_classification)
    explanations.append(explanation)

## Evaluate Classifications

Evaluate the predictions against human-labeled ground-truth relevance labels.

In [None]:
true_labels = df["relevant"].map({True: "relevant", False: "irrelevant"}).tolist()
predicted_labels = relevance_classifications
classes = ["relevant", "irrelevant"]

print(classification_report(true_labels, predicted_labels, labels=classes))
confusion_matrix = ConfusionMatrix(
    actual_vector=true_labels, predict_vector=predicted_labels, classes=classes
)
confusion_matrix.plot(
    cmap=plt.colormaps["Blues"],
    number_label=True,
    normalized=True,
)