In [5]:
import json
import torch
import openai
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from pprint import pprint

In [14]:
def get_gorilla_response(prompt, functions=[], model="gorilla-openfunctions-v0"):
    openai.api_key = "EMPTY"
    openai.api_base = "http://luigi.millennium.berkeley.edu:8000/v1"
    # openai.api_base = "https://limcheekin-gorilla-openfunctions-v1-gguf.hf.space/v1"
    try:
        completion = openai.ChatCompletion.create(
            model=model,
            temperature=0.0,
            messages=[{"role": "user", "content": prompt}],
            functions=functions,
        )
        return completion.choices[0].message.content
    except Exception as e:
        print(e, model, prompt)

In [12]:
def get_prompt(user_query: str, functions: list = []) -> str:
    """
    Generates a conversation prompt based on the user's query and a list of functions.

    Parameters:
    - user_query (str): The user's query.
    - functions (list): A list of functions to include in the prompt.

    Returns:
    - str: The formatted conversation prompt.
    """
    if len(functions) == 0:
        return f"USER: <<question>> {user_query}\nASSISTANT: "
    functions_string = json.dumps(functions)
    return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "

In [None]:
# Device setup
device: str = "cuda:0"
torch_dtype = torch.float16

In [None]:
_MODEL_PATH = "../models/gorilla-openfunctions-v1/"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(_MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
    _MODEL_PATH,
    device_map=device,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True
)

In [None]:
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=128,
    batch_size=16,
    torch_dtype=torch_dtype,
    # device=device,
)

In [15]:
query = 'call an ola cab from bhajanpura to nodia sector 62 in twenty two minutes. and then call a cab from uber for the same route in 10 minutes.'
functions = [
    {
        "name": "Call Cab Ride",
        "api_name": "cab.ride",
        "description": "Find suitable ride for customers given the location, type of ride, and the amount of time the customer is willing to wait as parameters.",
        "parameters":  [
            {
                "name": "start_loc",
                "default": "current_location",
                "description": "location of the starting place of the uber ride"
            },
            {
                "name": "end_loc",
                "description": "location of the ending place of the uber ride"
            },
            {
                "name": "type",
                "enum": ["normal", "plus", "premium"],
                "default": "normal",
                "description": "types of uber ride user is ordering"
            },
            {
                "name": "time",
                "value": "integer",
                "default": "now",
                "description": "the amount of time in minutes the customer is willing to wait"
            },
            {
                "name": "platform",
                "default": "uber",
                "enum": ["uber", "ola"],
                "description": "the platform the user is ordering the ride from"
            }
        ]
    }
]

In [16]:
get_gorilla_response(query, functions)

'cab.ride(start_loc="bhajanpura", end_loc="nodia-sector-62", type="normal", time=22, platform="ola")'

In [18]:
query = 'List all the devices in the house and their battery percentage.'
functions = [
    {
        "name": "Check Battery",
        "api_name": "devices.battery",
        "description": "Check the battery of a device",
        "parameters":  [
            {
                "name": "devices",
                "description": "list of devices to check the battery of",
                "value": "list",
                "default": "all"
            },
        ],
        "returns": "battery percentage of the device"
    },
    {
        "name": "Get Devices",
        "api_name": "devices.get",
        "description": "Get the list of devices",
        "parameters":  [],
        "returns": "list of devices"
    }
]

get_gorilla_response(query, functions=functions)

'devices.battery()'