In [1]:
from prompt_data import template, action_list, object_list, answer_example, json_example, xml_example, question_example
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

import torch
torch.cuda.empty_cache()
from evaluation import evaluate_model

In [2]:
torch.cuda.empty_cache()
torch_dtype = torch.float16
attn_implementation = "eager"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

base_model = "meta-llama/Llama-3.2-1B-Instruct"

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

# Load tokenizer 
tokenizer= AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

1

In [3]:
data_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{}<|eot_id|>
<|start_header_id|>user<|end_header_id|>
{}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
{}
"""

def formatting_prompt(examples):
    instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]
    texts = []
    for instruction,input_, output in zip(instructions, inputs, outputs):
        text = data_prompt.format(instruction,input_, output)
        texts.append(text)
    return { "text" : texts, }

In [4]:
json_system = template.format(
    format_type="JSON",
    example=question_example + "\n" + answer_example + "\n" + json_example,
    available_actions=action_list,
    object_list=object_list,
)

xml_system = template.format(
    format_type="XML",
    example=question_example + "\n" + answer_example + "\n" + xml_example,
    available_actions=action_list,
    object_list=object_list,
)

In [5]:
json_1b = evaluate_model(model=model,
                        tokenizer=tokenizer,
                        formatting_prompt=formatting_prompt,
                        validation_type="json",
                        query_file="./query_dataset.json",
                        instruction=json_system,
                        action_list=action_list)

100%|██████████| 50/50 [07:14<00:00,  8.68s/it]


In [6]:
xml_1b = evaluate_model(model=model,
                        tokenizer=tokenizer,
                        formatting_prompt=formatting_prompt,
                        validation_type="xml",
                        query_file="./query_dataset.json",
                        instruction=xml_system,
                        action_list=action_list)

100%|██████████| 50/50 [09:03<00:00, 10.86s/it]


In [7]:
torch.cuda.empty_cache()
torch_dtype = torch.float16
attn_implementation = "eager"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

base_model = "meta-llama/Llama-3.2-3B-Instruct"

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

# Load tokenizer 
tokenizer= AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

1

In [8]:
json_3b = evaluate_model(model=model,
                        tokenizer=tokenizer,
                        formatting_prompt=formatting_prompt,
                        validation_type="json",
                        query_file="./query_dataset.json",
                        instruction=json_system,
                        action_list=action_list)

100%|██████████| 50/50 [06:28<00:00,  7.76s/it]


In [9]:
xml_3b = evaluate_model(model=model,
                        tokenizer=tokenizer,
                        formatting_prompt=formatting_prompt,
                        validation_type="xml",
                        query_file="./query_dataset.json",
                        instruction=xml_system,
                        action_list=action_list)

100%|██████████| 50/50 [05:09<00:00,  6.20s/it]


In [11]:
print(f"xml 3b : {xml_3b['score']}\nxml 1b : {xml_1b['score']}\njson 3b : {json_3b['score']}\njson 1b : {json_1b['score']}")

xml 3b : 0.88
xml 1b : 0.38
json 3b : 0.82
json 1b : 0.36


In [13]:
json_3b

{'score': 0.82,
 'plans': {'no plan': [['Locate the flashlight and report its position.',
    '\n\nThe flashlight is located at position (8.7, 3.6, 0.4).<|eot_id|>'],
   ['Locate the smartphone and tell me its X, Y, and Z coordinates.',
    "\n\nThe smartphone is located at position {'x': 7.9, 'y': 6.8, 'z': 0.3}.<|eot_id|>"],
   ["Request assistance with the message: 'Help needed to identify the camping stove.'",
    "\n\nHello! I'm GoatBrain, here to help. I've identified the camping stove as an object with a position at (18.1, 7.9, 1.4). Would you like more information about it?<|eot_id|>"],
   ['Check the current battery level.',
    '\n\nYour current battery level is 87%.<|eot_id|>'],
   ['Locate the compass and report its X, Y, and Z coordinates.',
    "\n\nYou've asked me to locate the compass. I will do that now.\n\nThe compass is located at position {'x': 6.8, 'y': 2.5, 'z': 0.3}.<|eot_id|>"],
   ['Check the current battery level of the robot.',
    '\n\nThe current battery le