# load model

In [1]:
from schemallm import SchemaLLM
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GPTQ", device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-7B-Chat-GPTQ")
schema_llm = SchemaLLM(model=model, tokenizer=tokenizer)

  from .autonotebook import tqdm as notebook_tqdm
CUDA extension not installed.
CUDA extension not installed.


# create functions for choosing function and building arguments

In [2]:
import inspect
from typing import List, Literal, Callable
from pydantic import BaseModel, create_model


class Function(BaseModel):
    func: Callable
    func_description: str
    arg_description: str


def choose_function(situation: str, functions: List[Function]) -> Function:
    """
    Args:
        situation: str, prompt explaining situation
        functions: List[Function]
    """
    function_names = [f.func.__name__ for f in functions]
    # build input text
    prompt = situation + "\n\nChoose one of following functions:\n"
    prompt += "\n".join([f"{name}\n  description: {f.func_description}" for name, f in zip(function_names, functions)])
    # build output schema
    func_choose_output_schema = create_model("TmpModel", chosen_function=(Literal[tuple(function_names)], ...)).schema()
    # choose function
    func_choose_output = schema_llm.generate(prompt=prompt, schema=func_choose_output_schema)
    chosen_function_name = func_choose_output["chosen_function"]
    chosen_function = functions[function_names.index(chosen_function_name)]
    return chosen_function


def build_function_argument(situation: str, function: Function) -> dict:
    # build arguments schema
    argspec = inspect.getfullargspec(function.func)
    arg_names = argspec.args
    arg_annotations = argspec.annotations
    schema_annotation = {}
    for arg_name in arg_names:
        if arg_name not in arg_annotations:
            raise Exception(f"Argument {arg_name} of function {function.func.__name__} is not annotated.")
        arg_annotation = arg_annotations[arg_name]
        schema_annotation[arg_name] = (arg_annotation, ...)
    kwargs_schema = create_model("TmpModel", **schema_annotation).schema()
    # build arguments for function
    prompt = (
        situation
        + f"\n\npython function {function.func.__name__} is chosen. Continue to build arguments for the function, arguments description:\n"
        + function.arg_description
    )
    kwargs = schema_llm.generate(prompt=prompt, schema=kwargs_schema)
    return kwargs

# define situation and available tools

In [3]:
situation = """I need today's stock price of APPLE"""


def google_search(text: str):
    print(f"search {text} in google")


def hackernews_search(text: str):
    print(f"search {text} in hackernews")


functions = [
    Function(
        func=google_search,
        func_description="search via google search engine",
        arg_description="- text: string, text used for search in google",
    ),
    Function(
        func=hackernews_search,
        func_description="search in hackernews forum",
        arg_description="- text: string, text used for search in hackernews",
    ),
]

# run the task and call function

In [4]:
func = choose_function(
    situation=situation,
    functions=functions
)
kwargs = build_function_argument(
    situation=situation,
    function=func
)
func.func(**kwargs)



search Apple Inc. stock price in google
