- In this notebook, we'll build an LLM powered router to route queires to submodules

- Steps:
    - Crafting an initial prompt to select a set of choices
    - Enforcing structured output (for text completion endpoints)
    - Try integrating with a native function calling endpoint.

1. Basic Router Prompt

In [1]:
from llama_index import PromptTemplate

choices = [
    "Useful for questions related to apples",
    "Useful for questions related to oranges",
]


def get_choice_str(choices):
    choices_str = "\n\n".join([f"{idx+1}. {c}" for idx, c in enumerate(choices)])
    return choices_str


choices_str = get_choice_str(choices)

In [2]:
router_prompt0 = PromptTemplate(
    "Some choices are given below. It is provided in a numbered "
    "list (1 to {num_choices}), "
    "where each item in the list corresponds to a summary.\n"
    "---------------------\n"
    "{context_list}"
    "\n---------------------\n"
    "Using only the choices above and not prior knowledge, return the top choices "
    "(no more than {max_outputs}, but only select what is needed) that "
    "are most relevant to the question: '{query_str}'\n"
)

In [3]:
from llama_index.llms import OpenAI

llm = OpenAI(model="gpt-3.5-turbo")


def get_formatted_prompt(query_str):
    fmt_prompt = router_prompt0.format(
        num_choices=len(choices),
        max_outputs=2,
        context_list=choices_str,
        query_str=query_str,
    )
    return fmt_prompt

In [5]:
query_str = "Can you tell me more about the amount of Vitamin C in apples"
fmt_prompt = get_formatted_prompt(query_str)

response = llm.complete(fmt_prompt)
print(str(response))

1. Useful for questions related to apples


In [6]:
query_str = "What are the health benefits of eating orange peels?"
fmt_prompt = get_formatted_prompt(query_str)

response = llm.complete(fmt_prompt)
print(str(response))

2. Useful for questions related to oranges


In [7]:
query_str = "Can you tell me more about the amount of Vitamin C in apples and oranges."
fmt_prompt = get_formatted_prompt(query_str)

response = llm.complete(fmt_prompt)
print(str(response))

1. Useful for questions related to apples
2. Useful for questions related to oranges


2. Enforce Structured Output

In [14]:
from dataclasses import fields
from pydantic import BaseModel
import json


class Answer(BaseModel):
    """Class for storing the answer and its metadata"""
    choice: str
    reason: str

print(json.dumps(Answer.model_json_schema(), indent=2))

{
  "description": "Class for storing the answer and its metadata",
  "properties": {
    "choice": {
      "title": "Choice",
      "type": "string"
    },
    "reason": {
      "title": "Reason",
      "type": "string"
    }
  },
  "required": [
    "choice",
    "reason"
  ],
  "title": "Answer",
  "type": "object"
}


In [15]:
from llama_index.types import BaseOutputParser

FORMAT_STR = """\
The output should be formatted as a JSON instance that conforms to 
the JSON schema below. 

Here is the output schema:
{
  "description": "Class for storing the answer and its metadata",
  "properties": {
    "choice": {
      "title": "Choice",
      "type": "string"
    },
    "reason": {
      "title": "Reason",
      "type": "string"
    }
  },
  "required": [
    "choice",
    "reason"
  ],
  "title": "Answer",
  "type": "object"
}
"""

In [16]:
# Need to excape {} in f-string for PromptTemplate
def _escape_curly_braces(input_string: str) -> str:
    # Replace '{' with '{{' and '}' with '}}'
    escaped = input_string.replace("{", "{{").replace("}", "}}")
    return escaped


# Define parsing function to extract JSON from LLM response by searching for square brackets
def _marshal_output_to_json(output: str) -> str:
    output = output.strip()
    left = output.find("[")
    right = output.find("]")
    output = output[left : right + 1]
    return output

In [17]:
from typing import List

class RouterOutputParser(BaseOutputParser):
    def parse(self, output: str) -> List[Answer]:
        """Parse String"""

        json_output = _marshal_output_to_json(output)
        json_dicts = json.loads(json_output)
        answers = [Answer.from_dict(json_dict) for json_dict in json_dicts]
        return answers
    
    def format(self, prompt_template: str) -> str:
        return prompt_template + "\n\n" + _escape_curly_braces(FORMAT_STR)


In [18]:
output_parser = RouterOutputParser()

from typing import List


def route_query(query_str: str, choices: List[str], output_parser: RouterOutputParser):
    choices_str

    fmt_base_prompt = router_prompt0.format(
        num_choices=len(choices),
        max_outputs=len(choices),
        context_list=choices_str,
        query_str=query_str,
    )
    fmt_json_prompt = output_parser.format(fmt_base_prompt)

    raw_output = llm.complete(fmt_json_prompt)
    parsed = output_parser.parse(str(raw_output))

    return parsed

3.  Perform routing using a Function calling endpoint

In [19]:
from pydantic import Field

class Answer(BaseModel):
    """Represents a single choice with a reason"""
    choice: int
    reason: str

class Answers(BaseModel):
    """Represents a list of answeres"""
    answers: List[Answer]

Answers.model_json_schema()

{'$defs': {'Answer': {'description': 'Represents a single choice with a reason',
   'properties': {'choice': {'title': 'Choice', 'type': 'integer'},
    'reason': {'title': 'Reason', 'type': 'string'}},
   'required': ['choice', 'reason'],
   'title': 'Answer',
   'type': 'object'}},
 'description': 'Represents a list of answeres',
 'properties': {'answers': {'items': {'$ref': '#/$defs/Answer'},
   'title': 'Answers',
   'type': 'array'}},
 'required': ['answers'],
 'title': 'Answers',
 'type': 'object'}

In [23]:
from llama_index.program import OpenAIPydanticProgram

router_prompt1 = router_prompt0.partial_format(
    num_choices=len(choices),
    max_outputs=len(choices)
)

program = OpenAIPydanticProgram.from_defaults(
    output_cls=Answers,
    prompt=router_prompt1,
    verbose=True
)

In [24]:
query_str = "What are the health benefits of eating orange peels?"
output = program(context_list=choices_str, query_str=query_str)

Function call: Answers with args: {
  "answers": [
    {
      "choice": 2,
      "reason": "Useful for questions related to oranges"
    }
  ]
}


In [25]:
print(output)

answers=[Answer(choice=2, reason='Useful for questions related to oranges')]


4. Plug router into RAG Pipeline

- Here, we’ll use the router in a RAG pipeline
- We'll use it to dynamically decide whether to perform question-answering or summarization
  - We can easily get a question-answering query engine using top-k retrieval through the `VectorIndex`
  - and summarization through the `SummaryIndex`
- Each query engine is presented as a choice to the router

In [27]:
# Load darta
!mkdir data
!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"

mkdir: cannot create directory ‘data’: File exists
--2023-09-20 19:36:05--  https://arxiv.org/pdf/2307.09288.pdf
Resolving arxiv.org (arxiv.org)... 128.84.21.199
Connecting to arxiv.org (arxiv.org)|128.84.21.199|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13661300 (13M) [application/pdf]
Saving to: ‘data/llama2.pdf’


2023-09-20 19:36:14 (1.52 MB/s) - ‘data/llama2.pdf’ saved [13661300/13661300]



In [28]:
from pathlib import Path
from llama_hub.file.pymu_pdf.base import PyMuPDFReader

In [29]:
loader = PyMuPDFReader()
documents = loader.load(file_path="./data/llama2.pdf")

In [37]:
# Define indexes
from llama_index import ServiceContext, VectorStoreIndex, SummaryIndex

service_context = ServiceContext.from_defaults(chunk_size=1024)
vector_index = VectorStoreIndex.from_documents(
    documents, service_context=service_context
)

summary_index = SummaryIndex.from_documents(documents, service_context=service_context)

vector_query_engine = vector_index.as_query_engine()
summary_query_engine = summary_index.as_query_engine()

In [38]:
# Define RouterQUeryEngine
from llama_index.query_engine import CustomQueryEngine, BaseQueryEngine
from llama_index.response_synthesizers import TreeSummarize


class RouterQueryEngine(CustomQueryEngine):
    """Use Pydantic Program to perform routing."""

    query_engines: List[BaseQueryEngine]
    choice_descriptions: List[str]
    verbose: bool = False
    router_prompt: PromptTemplate
    llm: OpenAI
    summarizer: TreeSummarize = Field(default_factory=TreeSummarize)

    def custom_query(self, query_str):
        """Define the custom query"""

        # Create program
        program = OpenAIPydanticProgram.from_defaults(
            output_cls=Answers,
            prompt=self.router_prompt,
            verbose=self.verbose,
            llm=self.llm,
        )

        # Define choices
        choices_str = get_choice_str(self.choice_descriptions)
        output = program(context_list=choices_str, query_str=query_str)

        # print choice and reason, and query underlying engine
        if self.verbose:
            print(f"Selected choice(s):")
            for answer in output.answers:
                print(f"Choice: {answer.choice}, Reason: {answer.reason}")

        # submit queries for each choice (QA)
        responses = []
        for answer in output.answers: 
            choice_idx = answer.choice - 1
            query_engine = self.query_engines[choice_idx]
            response = query_engine.query(query_str)
            responses.append(response)

        if len(responses) == 1:
            return responses[0]
        else:
            # if multiple choices are picked, we can pick a summarizer
            response_strs = [str(r) for r in responses]
            result_response = self.summarizer.get_response(query_str, response_strs)
            return result_response

In [40]:
choices = [
    "Useful for answering questions about specific sections of the Llama 2 paper",
    "Useful for questions that ask for a summary of the whole paper",
]

router_query_engine = RouterQueryEngine(
    query_engines=[vector_query_engine, summary_query_engine],
    choice_descriptions=choices,
    verbose=True,
    router_prompt=router_prompt1,
    llm=OpenAI(model="gpt-4"),
)

TypeError: BaseModel.validate() takes 2 positional arguments but 3 were given