In [1]:
!pip install "transformers==4.34.0" "datasets==2.13.0" "peft==0.6.2" "accelerate==0.23.0" "bitsandbytes==0.41.1" "trl==0.4.7" "safetensors>=0.3.1" ipywidgets wandb --upgrade
!python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
!pip install ninja packaging
!MAX_JOBS=1 pip install flash-attn --no-build-isolation

Defaulting to user installation because normal site-packages is not writeable
Collecting peft==0.6.2
  Using cached peft-0.6.2-py3-none-any.whl (174 kB)
Installing collected packages: peft
  Attempting uninstall: peft
    Found existing installation: peft 0.4.0
    Uninstalling peft-0.4.0:
      Successfully uninstalled peft-0.4.0
Successfully installed peft-0.6.2
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [4]:
# add parent directory to path
import sys
import os

project_dir = os.getcwd()
parent_dir = os.path.dirname(project_dir)
sys.path.insert(0, parent_dir)

In [5]:
# base LLM model and tokenizer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "SebastianS/llama-7-chat-instruction-int4-fc-op_glaive-sft"

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    use_cache=False,
    device_map="auto",
)
model.config.pretraining_tp = 1


tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
from gaitor_function_calling.evaluation.evaluation_utils import FunctionCallingMetric, compute_perplexity, get_logits_and_labels
from gaitor_function_calling.data.prompting_utils import INSTRUCTION, json_arguments_from_prompt, generate_prediction, build_prompt

def generate_prompt(functions, prompt):
    json_input =  {
        "input": [
            {
                "chatgptMessage": {
                    "role": "user",
                    "content": prompt
                },
                "functions": functions
            }
        ],
        "target": {
            "chatgptMessage": {
                "role": "assistant",
                "function_call": {
                    "name": "...",
                    "arguments": "{\n    \"...\": \"...\"\n}"
                }
            },
            "functions": [
                {
                    "name": "get_news_headlines",
                    "description": "Get the latest news headlines",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "country": {
                                "type": "string",
                                "description": "The country for which to fetch news"
                            }
                        }
                    }
                }
            ]
        }
    }
    return build_prompt(json_input, INSTRUCTION)
    
prompt = generate_prompt([
    {
        "name": "find_fruit",
        "description": "extract the fruit from the text",
        "parameters": {
            "type": "object",
            "properties": {
                "fruit": {
                    "type": "string",
                    "description": "fruit found"
                }
            }
        }
    }
], "I ate a banana ice cream sundae!")

generated_str = generate_prediction(prompt, model, tokenizer, INSTRUCTION)

generated_arguments, _, __ = json_arguments_from_prompt(
    prompt,
    generated_str,
    INSTRUCTION
)
generated_arguments

{'fruit': 'banana'}