# load model

In [1]:
from schemallm import SchemaLLM
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GPTQ", device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-7B-Chat-GPTQ")
schema_llm = SchemaLLM(model=model, tokenizer=tokenizer)

  from .autonotebook import tqdm as notebook_tqdm
CUDA extension not installed.
CUDA extension not installed.


# create functions for choosing function and building arguments

In [2]:
import inspect
from typing import List, Literal, Callable
from pydantic import BaseModel, create_model


class Function(BaseModel):
    func: Callable
    func_description: str
    arg_description: str


def choose_function(situation: str, functions: List[Function]) -> Function:
    """
    Args:
        situation: str, prompt explaining situation
        functions: List[Function]
    """
    function_names = [f.func.__name__ for f in functions]
    # build input text
    prompt = situation + "\n\nChoose one of following functions:\n"
    prompt += "\n".join([f"{name}\n  description: {f.func_description}" for name, f in zip(function_names, functions)])
    # build output schema
    func_choose_output_schema = create_model("TmpModel", chosen_function=(Literal[tuple(function_names)], ...)).schema()
    # choose function
    func_choose_output = schema_llm.generate(prompt=prompt, schema=func_choose_output_schema)
    chosen_function_name = func_choose_output["chosen_function"]
    chosen_function = functions[function_names.index(chosen_function_name)]
    return chosen_function


def build_function_argument(situation: str, function: Function) -> dict:
    # build arguments schema
    argspec = inspect.getfullargspec(function.func)
    arg_names = argspec.args
    arg_annotations = argspec.annotations
    schema_annotation = {}
    for arg_name in arg_names:
        if arg_name not in arg_annotations:
            raise Exception(f"Argument {arg_name} of function {function.func.__name__} is not annotated.")
        arg_annotation = arg_annotations[arg_name]
        schema_annotation[arg_name] = (arg_annotation, ...)
    kwargs_schema = create_model("TmpModel", **schema_annotation).schema()
    # build arguments for function
    prompt = (
        situation
        + f"\n\npython function {function.func.__name__} is chosen. Continue to build arguments for the function, arguments description:\n"
        + function.arg_description
    )
    kwargs = schema_llm.generate(prompt=prompt, schema=kwargs_schema)
    return kwargs

# define situation and available tools

In [3]:
def google_search(query: str, num_results: int = 10, language: str = Literal["en", "de"]):
    print("google_search\n")
    for k, v in locals().items():
        print(k, ":", v)


def hackernews_search(query: str, num_results: int = 10, sort_by: str = Literal["popularity", "publish date"]):
    print("hackernews_search\n")
    for k, v in locals().items():
        print(k, ":", v)


def write_email(sender: str, recipient: str, subject: str, body: str):
    print("write_email\n")
    for k, v in locals().items():
        print(k, ":", v)


def database_query(table_name: str, columns: List[str], limit: int):
    print("database_query\n")
    for k, v in locals().items():
        print(k, ":", v)


functions = [
    Function(
        func=database_query,
        func_description="Perform a query on a database table.",
        arg_description="""
- table_name (str): The name of the database table to query.
- columns (list, optional): A list of column names to retrieve. If None, all columns are selected.
- limit (int, optional): Limit the number of rows returned by the query.
""",
    ),
    Function(
        func=google_search,
        func_description="Perform a Google search and retrieve relevant results.",
        arg_description="""
- query (str): The search query to be executed on Google.
- num_results (int, optional): The number of search results to retrieve (default is 10).
- language (str, optional): The language in which the search results should be displayed (default is "en").
""",
    ),
    Function(
        func=hackernews_search,
        func_description="Search for relevant articles on Hacker News based on a query.",
        arg_description="""
- query (str): The search query to find relevant articles on Hacker News.
- num_results (int, optional): The number of search results to retrieve (default is 10).
- sort_by (str, optional): The criteria for sorting the search results ('popularity' or 'date').
""",
    ),
    Function(
        func=write_email,
        func_description="Compose and send an email using Python.",
        arg_description="""
- sender (str): The email address of the sender.
- recipient (str): The email address of the recipient.
- subject (str): The subject of the email.
- body (str): The main content of the email.
""",
    ),
]

# run the task and call function

In [4]:
situation = """Give me list of my top customer."""

func = choose_function(
    situation=situation,
    functions=functions
)
kwargs = build_function_argument(
    situation=situation,
    function=func
)
func.func(**kwargs)



database_query

table_name : my_table
columns : ['column1', 'column2']
limit : 100000


In [5]:
situation = """I need today's stock price of APPLE"""

func = choose_function(
    situation=situation,
    functions=functions
)
kwargs = build_function_argument(
    situation=situation,
    function=func
)
func.func(**kwargs)

database_query

table_name : stock_prices
columns : ['date', 'open', 'high', 'low', 'close']
limit : 100000


In [6]:
situation = """What is the most popular hackernews article today?"""

func = choose_function(
    situation=situation,
    functions=functions
)
kwargs = build_function_argument(
    situation=situation,
    function=func
)
func.func(**kwargs)

database_query

table_name : users
columns : ['id', 'name', 'email']
limit : 100000
