In [1]:
!pip install -qU sentencepiece accelerate bitsandbytes

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m270.9/270.9 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 MB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig

In [4]:
def get_prompt(user_query: str, functions: list = []) -> str:
    """
    Generates a conversation prompt based on the user's query and a list of functions.

    Parameters:
    - user_query (str): The user's query.
    - functions (list): A list of functions to include in the prompt.

    Returns:
    - str: The formatted conversation prompt.
    """
    if len(functions) == 0:
        return f"USER: <<question>> {user_query}\nASSISTANT: "
    functions_string = json.dumps(functions)
    return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "

In [5]:
# Device setup
device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

print(f"Device: {device}")

Device: cuda:0


In [6]:
# Quantization Setup
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [7]:
# Model and tokenizer setup
model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    quantization_config=bnb_config,
    low_cpu_mem_usage=True,
    device_map="auto"
)

tokenizer_config.json:   0%|          | 0.00/935 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/657 [00:00<?, ?B/s]

pytorch_model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

pytorch_model-00002-of-00002.bin:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/183 [00:00<?, ?B/s]

In [8]:
# Pipeline setup
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=128,
    batch_size=16,
    torch_dtype=torch_dtype,
)

In [27]:
# Example usage
query: str = "Call me an Uber ride type \"Plus\" in Berkeley at zipcode 94704 in 10 minutes"
functions = [
    {
        "name": "Uber Carpool",
        "api_name": "uber.ride",
        "description": "Find suitable ride for customers given the location, type of ride, and the amount of time the customer is willing to wait as parameters",
        "parameters":  [
            {"name": "loc", "description": "Location of the starting place of the Uber ride"},
            {"name": "type", "enum": ["plus", "comfort", "black"], "description": "Types of Uber ride user is ordering"},
            {"name": "time", "description": "The amount of time in minutes the customer is willing to wait"}
        ]
    }
]

In [28]:
# Generate prompt and obtain model output
prompt = get_prompt(query, functions=functions)
output = pipe(prompt)
print(output[0]['generated_text'].splitlines()[-1])

ASSISTANT: uber.ride(loc="94704", type="plus", time=10)


In [29]:
query = "I want to order five burgers and six chicken wings from McDonald's."
function_documentation = {
    "name" : "Order Food on Uber",
    "api_call": "uber.eat.order",
    "description": "Order food on uber eat, specifying items and their quantities",
    "parameters": [
        {
            "name": "restaurants",
            "description": "The chosen restaurant"
        },
        {
            "name": "items",
            "description": "List of selected items"
        },
        {
            "name": "quantities",
            "description": "Quantities corresponding to the chosen items"
        }
    ]
}

In [30]:
prompt= get_prompt(query, functions=[function_documentation])
output = pipe(prompt)
print(output[0]['generated_text'].splitlines()[-1])

ASSISTANT: uber.eat.order(restaurants="McDonald's", items=["burger", "chicken wing"], quantities=5, 6)


In [31]:
query = "I want to list the exports for my bot with the bot id \"my-bot-id\" and the bot version \"v2\"."
functions = [
    {
        "domain": "Cloud Infrastructure",
        "framework": "aws",
        "functionality": "Lists the exports for a bot, bot locale, or custom vocabulary. Exports are kept in the list for 7 days.",
        "api_name": "aws.lexv2-models.list-exports",
        "api_arguments": [
            {
                "name": "bot-id",
                "description": "\nThe unique identifier that Amazon Lex assigned to the bot."
            },
            {
                "name": "bot-version",
                "description": "\nThe version of the bot to list exports for."
            }
        ],
        "python_environment_requirements": [
            "aws"
        ],
        "example_code": [],
        "output": {
            "botId -> (string)": "\nThe unique identifier assigned to the bot by Amazon Lex.",
            "botVersion -> (string)": "\nThe version of the bot that was exported.",
            "exportSummaries -> (list)": "\nSummary information for the exports that meet the filter criteria specified in the request. The length of the list is specified in the maxResults parameter. If there are more exports available, the nextToken field contains a token to get the next page of results.\n(structure)\n\nProvides summary information about an export in an export list.\nexportId -> (string)\n\nThe unique identifier that Amazon Lex assigned to the export.\nresourceSpecification -> (structure)\n\nInformation about the bot or bot locale that was exported.\nbotExportSpecification -> (structure)\n\nParameters for exporting a bot.\nbotId -> (string)\n\nThe identifier of the bot assigned by Amazon Lex.\nbotVersion -> (string)\n\nThe version of the bot that was exported. This will be either DRAFT or the version number.\n\nbotLocaleExportSpecification -> (structure)\n\nParameters for exporting a bot locale.\nbotId -> (string)\n\nThe identifier of the bot to create the locale for.\nbotVersion -> (string)\n\nThe version of the bot to export.\nlocaleId -> (string)\n\nThe identifier of the language and locale to export. The string must match one of the locales in the bot.\n\ncustomVocabularyExportSpecification -> (structure)\n\nThe parameters required to export a custom vocabulary.\nbotId -> (string)\n\nThe identifier of the bot that contains the custom vocabulary to export.\nbotVersion -> (string)\n\nThe version of the bot that contains the custom vocabulary to export.\nlocaleId -> (string)\n\nThe locale of the bot that contains the custom vocabulary to export.\n\ntestSetExportSpecification -> (structure)\n\nSpecifications for the test set that is exported as a resource.\ntestSetId -> (string)\n\nThe unique identifier of the test set.\n\n\nfileFormat -> (string)\n\nThe file format used in the export files.\nexportStatus -> (string)\n\nThe status of the export. When the status is Completed the export is ready to download.\ncreationDateTime -> (timestamp)\n\nThe date and time that the export was created.\nlastUpdatedDateTime -> (timestamp)\n\nThe date and time that the export was last updated.\n\n",
            "nextToken -> (string)": "\nA token that indicates whether there are more results to return in a response to the ListExports operation. If the nextToken field is present, you send the contents as the nextToken parameter of a ListExports operation request to get the next page of results.",
            "localeId -> (string)": "\nThe locale specified in the request."
        },
        "api_arguments_all": {
            "--bot-id ": "\nThe unique identifier that Amazon Lex assigned to the bot.",
            "--bot-version ": "\nThe version of the bot to list exports for.",
            "--sort-by ": "\nDetermines the field that the list of exports is sorted by. You can sort by the LastUpdatedDateTime field in ascending or descending order.\nattribute -> (string)\n\nThe export field to use for sorting.\norder -> (string)\n\nThe order to sort the list.\n",
            "--filters ": "\nProvides the specification of a filter used to limit the exports in the response to only those that match the filter specification. You can only specify one filter and one string to filter on.\n(structure)\n\nFilters the response form the ListExports operation\nname -> (string)\n\nThe name of the field to use for filtering.\nvalues -> (list)\n\nThe values to use to filter the response. The values must be Bot , BotLocale , or CustomVocabulary .\n(string)\n\noperator -> (string)\n\nThe operator to use for the filter. Specify EQ when the ListExports operation should return only resource types that equal the specified value. Specify CO when the ListExports operation should return resource types that contain the specified value.\n\n",
            "--max-results ": "\nThe maximum number of exports to return in each page of results. If there are fewer results than the max page size, only the actual number of results are returned.",
            "--next-token ": "\nIf the response from the ListExports operation contains more results that specified in the maxResults parameter, a token is returned in the response.\nUse the returned token in the nextToken parameter of a ListExports request to return the next page of results. For a complete set of results, call the ListExports operation until the nextToken returned in the response is null.\n",
            "--locale-id ": "\nSpecifies the resources that should be exported. If you don\u00e2\u0080\u0099t specify a resource type in the filters parameter, both bot locales and custom vocabularies are exported."
        }
    }
]

In [32]:
prompt= get_prompt(query, functions=functions)
output = pipe(prompt)
print(output[0]['generated_text'].splitlines()[-1])

ASSISTANT:  aws.lexv2-models.list-exports(bot-id="my-bot-id", bot-version="v2")
