In [2]:
# activate dev venv
!source dev/bin/activate

In [4]:
import os
os.environ['OPENAI_API_KEY'] = ""

In [5]:
from llama_index import PromptTemplate

choices = [
    "Useful for questions related to events",
    "All other questions",
]

def get_choice_str(choices):
    choices_str = "\n\n".join([f"{idx+1}. {c}" for idx, c in enumerate(choices)])
    return choices_str

choices_str = get_choice_str(choices)



In [25]:
print(choices_str)

1. Useful for questions related to events

2. All other questions


In [7]:
router_prompt0 = PromptTemplate(
    "Some choices are given below. It is provided in a numbered "
    "list (1 to {num_choices}), "
    "where each item in the list corresponds to a summary.\n"
    "---------------------\n"
    "{context_list}"
    "\n---------------------\n"
    "Using only the choices above and not prior knowledge, return the top choices "
    "(no more than {max_outputs}, but only select what is needed) that "
    "are most relevant to the question: '{query_str}'\n"
)

In [8]:
from llama_index.llms import OpenAI
llm = OpenAI(model="gpt-3.5-turbo")

In [9]:
def get_formatted_prompt(query_str):
    fmt_prompt = router_prompt0.format(
        num_choices=len(choices),
        max_outputs=2,
        context_list=choices_str,
        query_str=query_str,
    )
    return fmt_prompt

In [10]:
query_str = "Can you tell me more about the Giants vs Dodgers rivalry"
fmt_prompt = get_formatted_prompt(query_str)
response = llm.complete(fmt_prompt)

RateLimitError: You exceeded your current quota, please check your plan and billing details.

In [13]:
from dataclasses import fields
from pydantic import BaseModel
import json

In [14]:
class Answer(BaseModel):
    choice: int
    reason: str

In [16]:
print(json.dumps(Answer.schema(), indent=4))

{
    "properties": {
        "choice": {
            "title": "Choice",
            "type": "integer"
        },
        "reason": {
            "title": "Reason",
            "type": "string"
        }
    },
    "required": [
        "choice",
        "reason"
    ],
    "title": "Answer",
    "type": "object"
}


/var/folders/z6/529jfvqn3ng4d479hkxx755r0000gn/T/ipykernel_45947/3119484813.py:1: PydanticDeprecatedSince20: The `schema` method is deprecated; use `model_json_schema` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.4/migration/
  print(json.dumps(Answer.schema(), indent=4))


In [17]:
from llama_index.types import BaseOutputParser

In [18]:
FORMAT_STR = """The output should be formatted as a JSON instance that conforms to
the JSON schema below.

Here is the output schema:
{
  "type": "array",
  "items": {
    "type": "object",
    "properties": {
      "choice": {
        "type": "integer"
      },
      "reason": {
        "type": "string"
      }
    },
    "required": [
      "choice",
      "reason"
    ],
    "additionalProperties": false
  }
}
"""

In [20]:
def _escape_curly_braces(input_string: str) -> str:
    # Replace '{' with '{{' and '}' with '}}' to escape curly braces
    escaped_string = input_string.replace("{", "{{").replace("}", "}}")
    return escaped_string

In [21]:
def _marshal_output_to_json(output: str) -> str:
    output = output.strip()
    left = output.find("[")
    right = output.find("]")
    output = output[left : right + 1]
    return output

In [28]:
from typing import List

class RouterOutputParser(BaseOutputParser):
    def parse(self, output: str) -> List[Answer]:
        """Parse string."""
        json_output = _marshal_output_to_json(output)
        json_dicts = json.loads(json_output)
        answers = [Answer.from_dict(json_dict) for json_dict in json_dicts]
        return answers

    def format(self, prompt_template: str) -> str:
        return prompt_template + "\n\n" + _escape_curly_braces(FORMAT_STR)

In [29]:
output_parser = RouterOutputParser()

In [41]:
from typing import List

def route_query(query_str: str, choices: List[str], output_parser: RouterOutputParser):
    choices_str

    fmt_base_prompt = router_prompt0.format(
        num_choices=len(choices),
        max_outputs=len(choices),
        context_list=choices_str,
        query_str=query_str,
    )
    fmt_json_prompt = output_parser.format(fmt_base_prompt)

    # raw_output = llm.complete(fmt_json_prompt)
    raw_output = "[1. Useful for questions related to events]"
    # parsed = output_parser.parse(str(raw_output))
    parsed = output_parser.parse(raw_output)

    return parsed

In [42]:
parsed = route_query(query_str="Can you tell me more about the Giants vs Dodgers rivalry",
            choices=choices,
            output_parser=output_parser)

[1. Useful for questions related to events,]


JSONDecodeError: Expecting ',' delimiter: line 1 column 3 (char 2)

In [38]:
_marshal_output_to_json("[1. Useful for questions related to events]")

'[1. Useful for questions related to events]'