In [1]:
!pip install -q -U transformers datasets accelerate peft trl bitsandbytes scipy

In [2]:
import torch
import os
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel, get_peft_model
from trl import SFTTrainer

from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
model_id = "google/gemma-3-4b-it"

new_model_name = "gemma-3-4b-it-small-json-16bit"

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
eos_token = tokenizer.eos_token

print(f"EOS Token: {eos_token}")

EOS Token: <eos>


In [7]:
print(f"Sample Data Point (check EOS):\n{train_dataset[0]['formatted_text']}")

Sample Data Point (check EOS):
Convert the following food order into JSON format using this structure:
[
  { "customizations": ["Customization 1", "Customization 2"], "name": "Item Name" },
  { "customizations": [], "name": "Another Item" }
]:
Get me a lemonade, a lemon meringue pie, and add extra toppings of chocolate chips, berries, and nuts.
[{"customizations":[],"name":"Lemonade"},{"customizations":[],"name":"Lemon Meringue Pie"},{"customizations":["Chocolate Chips","Berries","Nuts"],"name":"Extra Topping"}]<eos>


In [8]:
import pandas as pd

df = train_dataset.to_pandas()
num_duplicates = df.duplicated().sum()
num_duplicates

np.int64(0)

In [6]:
from datasets import load_dataset
import json

dataset_id = "iTzMiNOS/voice-orders-small-clean-12k"
split_name = "train"

print(f"Loading dataset '{dataset_id}' (split: '{split_name}')...")
dataset = load_dataset(dataset_id, split=split_name)

columns_to_keep = ["transcribed_text", "items", "speaker"]
dataset = dataset.remove_columns([col for col in dataset.column_names if col not in columns_to_keep])

dataset = dataset.select(range(min(1200, len(dataset))))

print(f"Dataset loaded: {dataset}")

prefix = """Convert the following food order into JSON format using this structure:
[
  { "customizations": ["Customization 1", "Customization 2"], "name": "Item Name" },
  { "customizations": [], "name": "Another Item" }
]:
"""

print("Splitting dataset into train and validation sets...")
train_val_split = dataset.train_test_split(test_size=0.1, shuffle=True, seed=42)
train_data = train_val_split['train']
validation_data = train_val_split['test']

print(f"Train set size: {len(train_data)}")
print(f"Validation set size: {len(validation_data)}")

def format_data_for_sft(example):
    text = example["transcribed_text"]
    items_data = example["items"]
    json_string = json.dumps(items_data, separators=(',', ':'))
    if 'tokenizer' not in globals():
        raise NameError("Tokenizer not found. Please run the tokenizer loading cell (Cell 5) first.")
    eos = tokenizer.eos_token
    formatted_string = f"{prefix}{text}\n{json_string}{eos}"
    return {"formatted_text": formatted_string}

if 'tokenizer' in globals():
    print("Applying formatting function to the datasets...")
    train_dataset = train_data.map(format_data_for_sft, remove_columns=train_data.column_names)
    validation_dataset = validation_data.map(format_data_for_sft, remove_columns=validation_data.column_names) # Process validation data too!
    print("Dataset formatting complete.")
    print(f"Train dataset features: {train_dataset.features}")
    print(f"Validation dataset features: {validation_dataset.features}")
else:
    print("WARNING: Tokenizer not loaded yet. Re-run Cell 5 and this cell's mapping part.")
    train_dataset = None
    validation_dataset = None

Loading dataset 'iTzMiNOS/voice-orders-small-clean-12k' (split: 'train')...
Dataset loaded: Dataset({
    features: ['transcribed_text', 'speaker', 'items'],
    num_rows: 1200
})
Splitting dataset into train and validation sets...
Train set size: 1080
Validation set size: 120
Applying formatting function to the datasets...
Dataset formatting complete.
Train dataset features: {'formatted_text': Value(dtype='string', id=None)}
Validation dataset features: {'formatted_text': Value(dtype='string', id=None)}


In [9]:
validation_data

Dataset({
    features: ['transcribed_text', 'speaker', 'items'],
    num_rows: 120
})

In [10]:
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
  print("GPU supports bfloat16, using torch.bfloat16 for model loading and training.")
  model_dtype = torch.bfloat16
else:
  print("GPU does not support bfloat16, using torch.float16 for model loading and training.")
  model_dtype = torch.float16

GPU supports bfloat16, using torch.bfloat16 for model loading and training.


In [11]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=model_dtype,
    device_map="auto",
)

print(f"Base model loaded in {model_dtype}.")
print(model)

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

Base model loaded in torch.bfloat16.
Gemma3ForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(4096, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (self_attn): SiglipAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
     

In [12]:
peft_config = LoraConfig(
    lora_alpha=128,
    lora_dropout=0.05,
    r=32,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

In [13]:
from trl import SFTConfig

sft_config = SFTConfig(
     output_dir=new_model_name,
     num_train_epochs=3,
     per_device_train_batch_size=2,
     gradient_accumulation_steps=4,
     optim="paged_adamw_8bit",
     eval_strategy="steps",
     eval_steps=100,
     logging_steps=10,
     learning_rate=5e-5,
     weight_decay=0.001,
     fp16=(model_dtype==torch.float16),
     bf16=(model_dtype==torch.bfloat16),
     max_grad_norm=0.3,
     max_steps=-1,
     warmup_ratio=0.03,
     group_by_length=True,
     lr_scheduler_type="cosine",
     report_to="tensorboard",
     save_strategy="steps",
     save_steps=100,
     save_total_limit=2,
     load_best_model_at_end=True,
     metric_for_best_model="eval_loss",
     greater_is_better=False,
     max_seq_length=1024,
     packing=False,
     dataset_text_field="formatted_text",
     push_to_hub=True,
     hub_model_id=f"iTzMiNOS/{new_model_name}",
)

print("SFTConfig configured.")

SFTConfig configured.


In [14]:
if 'train_dataset' not in globals() or 'validation_dataset' not in globals() or train_dataset is None or validation_dataset is None:
     raise ValueError("Training or Validation dataset not found or not processed. Please ensure Cell 4 ran correctly after loading the tokenizer.")
if 'model' not in globals():
      raise ValueError("Model not loaded.")
if 'tokenizer' not in globals():
      raise ValueError("Tokenizer not loaded.")
if 'sft_config' not in globals():
      raise ValueError("SFTConfig not defined.")
if 'peft_config' not in globals():
      raise ValueError("peft_config not defined.")

tokenizer.chat_template = None
print("Attempting SFTTrainer initialization using SFTConfig...")
try:
    trainer = SFTTrainer(
        model=model,
        args=sft_config,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        peft_config=peft_config,
        processing_class=tokenizer,
    )
    print("SFTTrainer initialized successfully using SFTConfig!")

except TypeError as e:
    print(f"SFTTrainer STILL FAILED with TypeError: {e}")
    print("Check if 'tokenizer' is now the unexpected argument.")
except Exception as e:
    print(f"SFTTrainer FAILED with another error: {e}")

Attempting SFTTrainer initialization using SFTConfig...


Converting train dataset to ChatML:   0%|          | 0/1080 [00:00<?, ? examples/s]

Adding EOS to train dataset:   0%|          | 0/1080 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/1080 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/1080 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/120 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/120 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/120 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/120 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


SFTTrainer initialized successfully using SFTConfig!


In [15]:
print("Starting fine-tuning...")
trainer.train()
print("Fine-tuning finished.")

Starting fine-tuning...


It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.


Step,Training Loss,Validation Loss
100,0.8726,0.274509
200,0.7238,0.226427
300,0.606,0.215358
400,0.5876,0.213853


Fine-tuning finished.


In [16]:
print(f"Saving fine-tuned adapters to ./{new_model_name}")
trainer.model.save_pretrained(new_model_name)
tokenizer.save_pretrained(new_model_name)
print("Adapters and tokenizer saved.")

# Optional: Clean up memory
# del model
# del trainer
# torch.cuda.empty_cache()

Saving fine-tuned adapters to ./gemma-3-4b-it-small-json-16bit
Adapters and tokenizer saved.


In [11]:
# Cell: Inference and Evaluation (Batched & GPU Ensured & Fence Cleaning)

import json
import pandas as pd
from datasets import Dataset
from transformers import (
    pipeline,
    AutoModelForCausalLM,
    AutoTokenizer,
    logging
)
from peft import PeftModel
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from difflib import SequenceMatcher
import torch
import gc
import os
import math
import numpy as np
import re # Import regex for cleaning fences

# --- Configuration ---
base_model_id = "google/gemma-3-4b-it"
adapter_model_id = f"iTzMiNOS/{new_model_name}"
prefix = """Convert the following food order into JSON format using this structure:
[
  { "customizations": ["Customization 1", "Customization 2"], "name": "Item Name" },
  { "customizations": [], "name": "Another Item" }
]:
"""
inference_batch_size = 16

if not torch.cuda.is_available():
     raise SystemError("CUDA is not available. This script requires a GPU.")
else:
     device_name = torch.cuda.get_device_name(0)
     print(f"CUDA is available. Using GPU: {device_name}")
     if torch.cuda.get_device_capability(0)[0] >= 8:
         print("GPU supports bfloat16, using torch.bfloat16 for inference.")
         model_dtype_inference = torch.bfloat16
     else:
         print("GPU does not support bfloat16, using torch.float16 for inference.")
         model_dtype_inference = torch.float16
     device = 0

# --- Memory Cleanup ---
print("Cleaning up memory before loading...")
gc.collect()
torch.cuda.empty_cache()
print("CUDA cache cleared.")

# --- Load Tokenizer ---
print(f"Loading tokenizer from {base_model_id}...")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
eos_token = tokenizer.eos_token
assert tokenizer.pad_token_id is not None, "Tokenizer pad_token_id is not set!"
print("Tokenizer loaded.")

# --- Load Base Model ---
print(f"Loading base model ({base_model_id}) in {model_dtype_inference}")
model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    torch_dtype=model_dtype_inference,
    device_map="auto",
    trust_remote_code=True,
)
print(f"Base model loaded. Device map: {model.hf_device_map}")

# --- Load LoRA Adapter ---
print(f"Loading LoRA adapter ({adapter_model_id}) onto the base model...")
try:
    model = PeftModel.from_pretrained(model, adapter_model_id)
    print("LoRA adapter loaded successfully.")
    print("Attempting to merge LoRA adapter...")
    try:
        model = model.merge_and_unload()
        print("LoRA adapter merged successfully.")
    except Exception as e:
        print(f"⚠️ Could not merge LoRA adapter: {e}. Proceeding with PEFT model.")
except Exception as e:
     print(f"❌ Failed to load LoRA adapter: {e}")
     raise e

# --- Build the Inference Pipeline ---
logging.set_verbosity(logging.CRITICAL)
print("Building text-generation pipeline...")
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)
if hasattr(pipe, 'device'): print(f"Pipeline device: {pipe.device}")
else: print("Pipeline device managed by model's device_map.")

# --- Load Validation Data ---
if 'validation_data' not in globals():
     print("validation_data not found, attempting reload...")
     from datasets import load_dataset
     dataset_id = "iTzMiNOS/voice-orders-small-clean-12k"
     split_name = "train"
     dataset = load_dataset(dataset_id, split=split_name)
     columns_to_keep = ["transcribed_text", "items", "speaker"]
     dataset = dataset.remove_columns([col for col in dataset.column_names if col not in columns_to_keep])
     dataset = dataset.select(range(min(1200, len(dataset))))
     train_val_split = dataset.train_test_split(test_size=0.1, shuffle=True, seed=42)
     validation_data = train_val_split['test']
     print("Reloaded and split dataset.")

df = validation_data.to_pandas()
print(f"Loaded validation data with {len(df)} rows.")

# --- Convert Numpy arrays in 'items' column ---
def deep_convert(obj):
    if isinstance(obj, dict):
        return {k: deep_convert(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [deep_convert(v) for v in obj]
    elif isinstance(obj, np.ndarray):
        return deep_convert(obj.tolist())
    else:
        return obj

df['items'] = df['items'].apply(deep_convert)
print("Conversion complete.")


# --- ***** UPDATED: Function for Robust JSON Parsing (Handles Fences) ***** ---
def parse_json_robustly(generated_text):
    """Attempts to extract and parse JSON, handling optional markdown fences."""
    json_str = None
    try:
        # Find the start of the first list or object
        first_bracket = generated_text.find('[')
        first_brace = generated_text.find('{')

        start_index = -1
        if first_bracket != -1 and (first_brace == -1 or first_bracket < first_brace):
            start_index = first_bracket
            start_char = '['
            end_char = ']'
        elif first_brace != -1:
            start_index = first_brace
            start_char = '{'
            end_char = '}'
        else:
             # Neither bracket nor brace found - maybe it's ONLY fences?
             # Try finding fences directly if no brace/bracket
             fence_match = re.search(r"`{3}(json)?\s*([\[\{])", generated_text)
             if fence_match:
                 start_index = fence_match.end() -1 # Start at the brace/bracket
                 start_char = fence_match.group(2)
                 end_char = ']' if start_char == '[' else '}'
             else:
                # Give up if no structure found
                # print(f"Debug: No JSON start ('[' or '{{') found. Output: {generated_text}")
                return None

        # Find the corresponding closing character using balancing
        open_count = 0
        end_index = -1
        # Check if start_index is valid before proceeding
        if start_index >= 0 and start_index < len(generated_text):
            for i in range(start_index, len(generated_text)):
                if generated_text[i] == start_char:
                    open_count += 1
                elif generated_text[i] == end_char:
                    open_count -= 1
                if open_count == 0:
                    end_index = i
                    break
        else:
             # Handle invalid start_index if fence logic above failed unusually
             print(f"Debug: Invalid start_index {start_index}. Output: {generated_text}")
             return None


        if end_index == -1:
            # print(f"Debug: No matching closing bracket/brace. Output: {generated_text}")
            return None

        # Extract the potential JSON substring
        json_str = generated_text[start_index : end_index + 1]

        # Clean leading/trailing whitespace that might remain
        json_str = json_str.strip()

        # --- Attempt to parse the extracted & cleaned string ---
        json_data = json.loads(json_str)
        return json_data

    except json.JSONDecodeError as e:
        # Add logging for parse failures, include the string attempted
        print(f"Warning: Could not parse JSON: {e}.")
        print(f"Attempted to parse (after extraction): '{json_str}'")
        # print(f"Original Generated Text: {generated_text}") # Uncomment for deeper debugging
        return None
    except Exception as e:
        # Catch any other unexpected errors during parsing
        print(f"Warning: Unexpected error during JSON parsing: {e}.")
        print(f"Attempted to parse (after extraction): '{json_str}'")
        # print(f"Original Generated Text: {generated_text}") # Uncomment for deeper debugging
        return None
# --- End Updated Function ---


# --- Apply Inference with Batching (Unchanged) ---
print(f"Running batched inference (batch size: {inference_batch_size})...")
all_prompts = [f"{prefix}{text}" for text in df['transcribed_text']]
all_results = []
num_batches = math.ceil(len(all_prompts) / inference_batch_size)
for i in tqdm(range(0, len(all_prompts), inference_batch_size), desc="Inference Batches", total=num_batches):
    batch_prompts = all_prompts[i:i+inference_batch_size]
    try:
        batch_outputs = pipe(batch_prompts, max_new_tokens=500, return_full_text=False, pad_token_id=tokenizer.eos_token_id, batch_size=len(batch_prompts))
        for output_list in batch_outputs:
            if output_list and isinstance(output_list, list):
                 generated_text = output_list[0]["generated_text"].strip()
                 parsed_json = parse_json_robustly(generated_text) # Use updated parser
                 all_results.append(parsed_json)
            else: print(f"Warning: Unexpected output format: {output_list}"); all_results.append(None)
    except Exception as e:
        print(f"\n--- ERROR during batch {i // inference_batch_size + 1} --- Error: {e}")
        all_results.extend([None] * len(batch_prompts))
if len(all_results) != len(all_prompts):
     print(f"Warning: Result count mismatch! Padding with None.")
     all_results.extend([None] * (len(all_prompts) - len(all_results)))
df['predicted_items'] = all_results
print("Inference complete.")

def to_lower(obj):
    if isinstance(obj, str):
        return obj.lower()  # Convert strings to lowercase
    elif isinstance(obj, dict):
        return {k: to_lower(v) for k, v in obj.items()}  # Apply recursively for dictionaries
    elif isinstance(obj, list):
        return [to_lower(v) for v in obj]  # Apply recursively for lists
    else:
        return obj

# --- Comparison Metric (Unchanged) ---
def similarity_score(pred, target):
    if pred is None or target is None:
        return 0.0
    try:
        # Convert both the prediction and target to lowercase
        pred = to_lower(pred)
        target = to_lower(target)

        # Convert the structures into strings
        pred_str = json.dumps(pred, sort_keys=True, separators=(',', ':'))
        target_str = json.dumps(target, sort_keys=True, separators=(',', ':'))

        return SequenceMatcher(None, pred_str, target_str).ratio()
    except Exception as e:
        print(f"Error calculating similarity: Pred={pred}, Target={target}, Error={e}")
        return 0.0
# --- Calculate Metrics (Unchanged) ---
print("Calculating metrics...")
df['similarity'] = df.apply(lambda row: similarity_score(row['predicted_items'], row['items']), axis=1)
df['exact_match'] = df.apply(lambda row:
                             row['predicted_items'] is not None and
                             row['items'] is not None and
                             to_lower(row['predicted_items']) == to_lower(row['items']),
                             axis=1)

average_similarity = df['similarity'].mean()
exact_match_accuracy = df['exact_match'].mean()

print("\n--- Evaluation Results ---")
print(f"🔍 Average Similarity Score: {average_similarity:.4f}")
print(f"✅ Exact Match Accuracy: {exact_match_accuracy:.2%}")

# --- Display Mismatches (Unchanged) ---
print("\n--- Low Similarity Examples (< 0.8) ---")
low_sim_df = df[df['similarity'] < 0.8][['transcribed_text', 'items', 'predicted_items', 'similarity']]
print(low_sim_df.to_string())

CUDA is available. Using GPU: NVIDIA L4
GPU supports bfloat16, using torch.bfloat16 for inference.
Cleaning up memory before loading...
CUDA cache cleared.
Loading tokenizer from google/gemma-3-4b-it...
Tokenizer loaded.
Loading base model (google/gemma-3-4b-it) in torch.bfloat16


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Base model loaded. Device map: {'': 0}
Loading LoRA adapter (iTzMiNOS/gemma-3-4b-it-small-json-16bit) onto the base model...
LoRA adapter loaded successfully.
Attempting to merge LoRA adapter...
LoRA adapter merged successfully.
Building text-generation pipeline...
Pipeline device: cuda:0
Loaded validation data with 120 rows.
Conversion complete.
Running batched inference (batch size: 16)...


Inference Batches: 100%|██████████| 8/8 [00:53<00:00,  6.73s/it]

Inference complete.
Calculating metrics...

--- Evaluation Results ---
🔍 Average Similarity Score: 0.9832
✅ Exact Match Accuracy: 82.50%

--- Low Similarity Examples (< 0.8) ---
                                                                                                                                            transcribed_text                                                                                                                                                                                                                                                                               items                                                                                                                                                                                                                                             predicted_items  similarity
21                  Ordering a Caesar Salad with grilled shrimp, a Greek Salad with grilled chicken, an apple pie, and some e




In [12]:
df.head(5)

Unnamed: 0,transcribed_text,speaker,items,predicted_items,similarity,exact_match
0,I'll have some chicken and vegetarian spring r...,af_bella,"[{'customizations': ['Chicken', 'Vegetarian'],...","[{'customizations': ['Chicken', 'Vegetarian'],...",1.0,True
1,"Please order a chocolate cake, an apple pie, a...",af_bella,"[{'customizations': [], 'name': 'Chocolate Cak...","[{'customizations': [], 'name': 'Chocolate Cak...",1.0,True
2,"Please add extra ranch dipping sauce, extra ch...",af_bella,"[{'customizations': ['Ranch'], 'name': 'Extra ...","[{'customizations': ['Ranch'], 'name': 'Extra ...",1.0,True
3,"Can I get a lemonade, a chocolate, and extra w...",af_bella,"[{'customizations': [], 'name': 'Lemonade'}, {...","[{'customizations': [], 'name': 'Lemonade'}, {...",1.0,True
4,"I'll have a bottle of red wine, a slice of che...",af_bella,"[{'customizations': ['Red'], 'name': 'Wine'}, ...","[{'customizations': ['Red', 'White'], 'name': ...",0.972222,False


In [13]:
df['predicted_items'].isnull().sum()

np.int64(0)

In [14]:
input_text = "Fish and Chips with crispy coating, a slice of Chocolate Cake, sweetened Iced Tea, Onion Rings with BBQ sauce, extra Gravy for my fries, an extra scoop of ice cream, Soft Drinks including Fanta and Sprite, Grilled Chicken Breast with BBQ and garlic butter, Lemon Meringue Pie for dessert, and flavored syrups of vanilla and caramel."

In [16]:
res_pred = pipe(prefix + input_text, max_new_tokens=500, return_full_text=False, pad_token_id=tokenizer.eos_token_id)

In [17]:
similarity_score(res_pred, [
        {
            "name": "Fish and Chips",
            "customizations": [
                "Crispy"
            ]
        },
        {
            "name": "Chocolate Cake",
            "customizations": []
        },
        {
            "name": "Iced Tea",
            "customizations": [
                "Sweetened"
            ]
        },
        {
            "name": "Onion Rings",
            "customizations": [
                "BBQ"
            ]
        },
        {
            "name": "Extra Sauce",
            "customizations": [
                "Gravy"
            ]
        },
        {
            "name": "Extra Ice Cream Scoop",
            "customizations": []
        },
        {
            "name": "Soft Drinks",
            "customizations": [
                "Fanta",
                "Sprite"
            ]
        },
        {
            "name": "Grilled Chicken Breast",
            "customizations": [
                "BBQ",
                "Garlic Butter"
            ]
        },
        {
            "name": "Lemon Meringue Pie",
            "customizations": []
        },
        {
            "name": "Flavored Syrups",
            "customizations": [
                "Vanilla",
                "Caramel"
            ]
        }
    ])

0.3699421965317919