In [None]:
!pip install -qq "arize-phoenix[eval,llama-index]" "openai>=1" pyvis datasets pycm requests

In [None]:
import os
from getpass import getpass

if not (openai_api_key := os.getenv("OPENAI_API_KEY")):
    openai_api_key = getpass("🔑 Enter your OpenAI API key: ")
os.environ["OPENAI_API_KEY"] = openai_api_key

In [None]:
# Standard library imports
import re

# Third-party library imports
from typing import Any, Dict, List, Optional, Set, Tuple, cast

import nest_asyncio
import pandas as pd

# Local module imports
from llama_index.core import SQLDatabase
from llama_index.core.agent import (
    AgentChatResponse,
    AgentRunner,
    QueryPipelineAgentWorker,
    ReActChatFormatter,
    Task,
)
from llama_index.core.agent.react.output_parser import ReActOutputParser
from llama_index.core.agent.react.types import (
    ObservationReasoningStep,
    ResponseReasoningStep,
)
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.core.query_pipeline import (
    AgentFnComponent,
    AgentInputComponent,
    CustomAgentComponent,
    QueryComponent,
    QueryPipeline,
    ToolRunnerComponent,
)
from llama_index.core.tools import BaseTool, QueryEngineTool
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine
from tqdm import tqdm

import phoenix as px

# Apply settings and initializations
pd.set_option("display.max_colwidth", 1000)
nest_asyncio.apply()

In [None]:
import os
import tempfile
import zipfile
from io import BytesIO

import matplotlib.pyplot as plt
import pandas as pd
import requests

In [None]:
from pycm import ConfusionMatrix
from sklearn.metrics import classification_report

from phoenix.evals import (
    SQL_GEN_EVAL_PROMPT_RAILS_MAP,
    SQL_GEN_EVAL_PROMPT_TEMPLATE,
    OpenAIModel,
    llm_classify,
)

temp_dir = tempfile.mkdtemp()
url = "https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip"
with zipfile.ZipFile(BytesIO(requests.get(url).content), "r") as f:
    f.extractall(temp_dir)
engine = create_engine(f"sqlite:///{os.path.join(temp_dir, 'chinook.db')}")
sql_database = SQLDatabase(engine)

In [None]:
(session := px.launch_app()).view()

In [None]:
from openinference.instrumentation.llama_index import LlamaIndexInstrumentor

from phoenix.otel import register

tracer_provider = register(endpoint="http://127.0.0.1:6006/v1/traces")
LlamaIndexInstrumentor().instrument(skip_dep_check=True, tracer_provider=tracer_provider)

In [None]:
sql_query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database,
    tables=["albums", "tracks", "artists"],
    verbose=True,
)
sql_tool = QueryEngineTool.from_defaults(
    query_engine=sql_query_engine,
    name="sql_tool",
    description=("Useful for translating a natural language query into a SQL query"),
)

In [None]:
questions = [
    "What is the name of the artist with ID 5?",
    "List all tracks in the album with ID 3.",
    "How many tracks does the artist named 'Aerosmith' have?",
    "Find the oldest song in the database.",
    "What is the duration of the track with ID 10?",
    "List the names of all albums released in 2020.",
    "How many artists are in the database?",
    "Which artist has the most tracks in the database?",
    "List all tracks in the 'Pop' genre.",
    "What is the average length of tracks in the database?",
    "Find the most recent track.",
    "List the top 5 longest tracks in the database.",
    "Which album has the highest number of tracks?",
    "List all artists who have released more than 3 albums.",
    "What is the shortest track in the database?",
    "Find all albums released by 'The Beatles'.",
    "How many tracks are in the 'Rock' genre?",
    "List the names of all tracks released before 2000.",
    "What is the total duration of the album with ID 7?",
    "Find the artist who released the album 'Thriller'.",
    "List the names of all albums by 'Pink Floyd'.",
    "How many albums have been released between 1990 and 2000?",
    "What genres are covered by the artist 'David Bowie'?",
    "List the top 10 most played tracks.",
    "Which artist has the longest total track duration in the database?",
    "Find all tracks with a duration longer than 5 minutes.",
    "How many tracks does each album contain on average?",
    "List all albums sorted by release date.",
    "Which artist's albums have the highest average ratings?",
    "Find the total duration of all tracks by 'Michael Jackson'.",
    "How many tracks in the database are instrumental?",
    "List the names of all tracks by artists with the name starting with 'J'.",
    "What is the most common genre in the database?",
    "Find the average album length in minutes.",
    "How many artists have only one album in the database?",
    "List all tracks from the album with the most number of tracks.",
    "Which artist has released the most albums?",
    "Find the total number of tracks produced by 'Eminem'.",
    "How many albums in the database have no tracks?",
    "List the name and duration of the longest track in each album.",
    "What is the average number of tracks per album?",
    "Find all albums that have more than 10 tracks.",
    "How many tracks in the database are longer than the average track length?",
    "List the albums released by the artist with the most albums.",
    "Which year has the highest number of album releases?",
    "Find the total playtime of all tracks in the 'Jazz' genre.",
    "How many artists have names longer than 10 characters?",
    "List all song genres found in the database.",
    "What is the average track length of the tracks?",
    "How many albums were released on average by each artist?",
]

print(questions)

In [None]:
## Agent Input Component
## This is the component that produces agent inputs to the rest of the components
## Can also put initialization logic here.


def agent_input_fn(task: Task, state: Dict[str, Any]) -> Dict[str, Any]:
    """Agent input function.

    Returns:
        A Dictionary of output keys and values. If you are specifying
        src_key when defining links between this component and other
        components, make sure the src_key matches the specified output_key.

    """
    # initialize current_reasoning
    if "current_reasoning" not in state:
        state["current_reasoning"] = []
    reasoning_step = ObservationReasoningStep(observation=task.input)
    state["current_reasoning"].append(reasoning_step)
    return {"input": task.input}


agent_input_component = AgentInputComponent(fn=agent_input_fn)

In [None]:
## define prompt function


def react_prompt_fn(
    task: Task, state: Dict[str, Any], input: str, tools: List[BaseTool]
) -> List[ChatMessage]:
    # Add input to reasoning
    chat_formatter = ReActChatFormatter()
    return chat_formatter.format(
        tools,
        chat_history=task.memory.get() + state["memory"].get_all(),
        current_reasoning=state["current_reasoning"],
    )


react_prompt_component = AgentFnComponent(fn=react_prompt_fn, partial_dict={"tools": [sql_tool]})

In [None]:
## Agent Output Component
## Process reasoning step/tool outputs, and return agent response


def finalize_fn(
    task: Task,
    state: Dict[str, Any],
    reasoning_step: Any,
    is_done: bool = False,
    tool_output: Optional[Any] = None,
) -> Tuple[AgentChatResponse, bool]:
    """Finalize function.

    Here we take the latest reasoning step, and a tool output (if provided),
    and return the agent output (and decide if agent is done).

    This function returns an `AgentChatResponse` and `is_done` tuple. and
    is the last component of the query pipeline. This is the expected
    return type for any query pipeline passed to `QueryPipelineAgentWorker`.

    """
    current_reasoning = state["current_reasoning"]
    current_reasoning.append(reasoning_step)
    # if tool_output is not None, add to current reasoning
    if tool_output is not None:
        observation_step = ObservationReasoningStep(observation=str(tool_output))
        current_reasoning.append(observation_step)
    if isinstance(current_reasoning[-1], ResponseReasoningStep):
        response_step = cast(ResponseReasoningStep, current_reasoning[-1])
        response_str = response_step.response
    else:
        response_str = current_reasoning[-1].get_content()

    # if is_done, add to memory
    # NOTE: memory is a reserved keyword in `state`, but you can add your own too
    if is_done:
        state["memory"].put(ChatMessage(content=task.input, role=MessageRole.USER))
        state["memory"].put(ChatMessage(content=response_str, role=MessageRole.ASSISTANT))

    return AgentChatResponse(response=response_str), is_done


class OutputAgentComponent(CustomAgentComponent):
    """Output agent component."""

    tool_runner_component: ToolRunnerComponent
    output_parser: ReActOutputParser

    def __init__(self, tools, **kwargs):
        tool_runner_component = ToolRunnerComponent(tools)
        super().__init__(
            tool_runner_component=tool_runner_component, output_parser=ReActOutputParser(), **kwargs
        )

    def _run_component(self, **kwargs: Any) -> Any:
        """Run component."""
        chat_response = kwargs["chat_response"]
        task = kwargs["task"]
        state = kwargs["state"]
        reasoning_step = self.output_parser.parse(chat_response.message.content)
        if reasoning_step.is_done:
            return {"output": finalize_fn(task, state, reasoning_step, is_done=True)}
        else:
            tool_output = self.tool_runner_component.run_component(
                tool_name=reasoning_step.action,
                tool_input=reasoning_step.action_input,
            )
            return {
                "output": finalize_fn(
                    task,
                    state,
                    reasoning_step,
                    is_done=False,
                    tool_output=tool_output,
                )
            }

    @property
    def _input_keys(self) -> Set[str]:
        return {"chat_response"}

    @property
    def _optional_input_keys(self) -> Set[str]:
        return {"is_done", "tool_output"}

    @property
    def _output_keys(self) -> Set[str]:
        return {"output"}

    @property
    def sub_query_components(self) -> List[QueryComponent]:
        return [self.tool_runner_component]


react_output_component = OutputAgentComponent([sql_tool])

In [None]:
qp = QueryPipeline(
    modules={
        "agent_input": agent_input_component,
        "react_prompt": react_prompt_component,
        "llm": OpenAI(model="gpt-4o"),
        "react_output": react_output_component,
    },
    verbose=True,
)
qp.add_chain(["agent_input", "react_prompt", "llm", "react_output"])

In [None]:
agent_worker = QueryPipelineAgentWorker(qp)
agent = AgentRunner(agent_worker)
response = agent.chat("Is Aerosmith in this database?")
print(str(response))

In [None]:
all_sql_queries = []
all_ans = []
for question in tqdm(questions):
    agent_worker = QueryPipelineAgentWorker(qp)
    agent = AgentRunner(agent_worker)
    task = agent.create_task(question)
    # Need to manually run the task as to recover the convo_history
    step_output = agent.run_step(task.task_id)
    ans = str(step_output)
    sql_query_match = re.search(r"\'sql_query\': \'([^\']+)\'", ans)

    if not sql_query_match:
        print(ans)
    # Extract the sql_query if the pattern is found
    sql_query = sql_query_match.group(1) if sql_query_match else None

    # Regular expression to extract the response
    response_match = re.search(r"response=\'([^\']+)\'", ans)

    # Extract the response if the pattern is found
    response = response_match.group(1) if response_match else None

    print("SQL Query:", sql_query)
    print("Response:", response)
    all_ans.append(response)
    all_sql_queries.append(sql_query)

In [None]:
df = pd.DataFrame({"question": questions, "query_gen": all_sql_queries, "response": all_ans})
df[:3]

In [None]:
print(SQL_GEN_EVAL_PROMPT_TEMPLATE)

In [None]:
df.rename(columns={"query": "question", "sql_query": "query_gen"}, inplace=True)

In [None]:
rails = list(SQL_GEN_EVAL_PROMPT_RAILS_MAP.values())
model = OpenAIModel(
    model="gpt-4o",
    temperature=0.0,
)
relevance_classifications = llm_classify(
    dataframe=df,
    template=SQL_GEN_EVAL_PROMPT_TEMPLATE,
    model=model,
    rails=rails,
    provide_explanation=True,
)
# relevance_classifications

In [None]:
labels = relevance_classifications["label"].tolist()
explanation = relevance_classifications["explanation"].tolist()

In [None]:
df["label"] = labels
df["explanation"] = explanation

In [None]:
df[:2]

In [None]:
boolean_classifications = [x == "correct" for x in df["label"].tolist()]

In [None]:
df["is_correct"] = boolean_classifications

In [None]:
df["ground_truth"] = [
    True,
    True,
    False,
    False,
    True,
    False,
    True,
    True,
    False,
    True,
    False,
    True,
    True,
    True,
    True,
    False,
    False,
    False,
    True,
    False,
    False,
    False,
    False,
    True,
    True,
    True,
    True,
    True,
    False,
    False,
    False,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    True,
    True,
    True,
    True,
    True,
    True,
    False,
    True,
    False,
    True,
    False,
]
true_labels = df["ground_truth"]

In [None]:
print(classification_report(true_labels, boolean_classifications, labels=[True, False]))
confusion_matrix = ConfusionMatrix(
    actual_vector=true_labels.tolist(), predict_vector=boolean_classifications
)
confusion_matrix.plot(
    cmap=plt.colormaps["Blues"],
    number_label=True,
    normalized=True,
)