# â˜• BaristaBench: Local GPU, Zero API Keys

## Architecture: Qwen2.5-14B â†’ Validator â†’ Deterministic Price Engine

**Our edges over the starter notebooks:**
1. **14B model** vs their 1.5B/8B â€” 10x better reasoning on messy corrections
2. **Correct JSON schema** â€” the starters use the WRONG output format
3. **10 curated few-shot examples** vs their 3 random ones â€” covers every edge case
4. **Full menu in prompt** â€” the starters truncate to 500 chars (!)
5. **Deterministic price engine** â€” verified 100% on all 500 training rows. The LLM never does math.
6. **Validation + fuzzy matching** â€” catches and fixes model formatting mistakes
7. **Retry on failure** â€” malformed JSON gets a second chance

**Accelerator:** GPU T4 x2 (32GB VRAM total) â€” runs Qwen2.5-14B in float16 with no quantization loss.


## Libraries & Data Loading

In [None]:
# =============================================================
# IMPORTS & DATA LOADING
# =============================================================
!pip install -q -U transformers accelerate bitsandbytes

import pandas as pd
import numpy as np
import json
import re
import os
import time
import gc
import warnings
warnings.filterwarnings('ignore')

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
from difflib import get_close_matches

# --- GPU Check ---
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
for i in range(torch.cuda.device_count()):
    name = torch.cuda.get_device_name(i)
    mem = torch.cuda.get_device_properties(i).total_mem / 1e9
    print(f"  GPU {i}: {name} ({mem:.1f} GB)")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TOTAL_VRAM = sum(torch.cuda.get_device_properties(i).total_mem for i in range(torch.cuda.device_count())) / 1e9
print(f"\nTotal VRAM: {TOTAL_VRAM:.1f} GB")

# --- Model Configuration ---
# Qwen2.5-14B-Instruct: Best balance of accuracy + speed for T4x2
# Fits in float16 across 2x T4 (28GB model < 32GB VRAM)
MODEL_ID = "Qwen/Qwen2.5-14B-Instruct"

# If you pre-loaded the model as a Kaggle dataset, use that path instead:
# MODEL_ID = "/kaggle/input/qwen2.5-14b-instruct/transformers/default/1"

# For faster runtime (slightly less accurate), switch to 7B:
# MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"

print(f"\nLoading model: {MODEL_ID}")
print("This may take 5-10 minutes on first run...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model - float16 across both GPUs
# If VRAM is tight, we fall back to 4-bit quantization automatically
if TOTAL_VRAM >= 30:
    print("Loading in float16 (full precision, best quality)...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
    )
else:
    print("Loading in 4-bit quantization (fits smaller GPUs)...")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )

model.eval()
print(f"âœ… Model loaded successfully")
print(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'single device'}")

# --- Load Data ---
DATA_DIR = "/kaggle/input/barista-bench"

train_df = pd.read_csv(f"{DATA_DIR}/train.csv")
test_df  = pd.read_csv(f"{DATA_DIR}/test.csv")
sample_sub = pd.read_csv(f"{DATA_DIR}/sample_submission.csv")

with open(f"{DATA_DIR}/menu.md", "r") as f:
    MENU_TEXT = f.read()

print(f"\nTrain: {len(train_df)} rows | Test: {len(test_df)} rows")


## Exploratory Data Analysis

In [None]:
# =============================================================
# EXPLORATORY DATA ANALYSIS
# =============================================================

train_df['parsed'] = train_df['expected_json'].apply(json.loads)
train_df['num_items'] = train_df['parsed'].apply(lambda x: len(x['items']))
train_df['total_qty'] = train_df['parsed'].apply(lambda x: sum(i['quantity'] for i in x['items']))
train_df['has_mods'] = train_df['parsed'].apply(lambda x: any(len(i['modifiers']) > 0 for i in x['items']))
train_df['order_len'] = train_df['order'].str.len()

print("=" * 60)
print("DATASET OVERVIEW")
print("=" * 60)
print(f"Avg order length: {train_df['order_len'].mean():.0f} chars | Max: {train_df['order_len'].max()}")
print(f"Orders with modifiers: {train_df['has_mods'].mean()*100:.1f}%")

print(f"\n--- Items per Order ---")
for n, c in train_df['num_items'].value_counts().sort_index().items():
    print(f"  {n} items: {c} orders ({c/len(train_df)*100:.0f}%)")

# Correction patterns â€” this is where most models fail
print(f"\n--- Correction Patterns (Key Challenge) ---")
corrections = {
    'scratch that':     r'scratch that',
    'actually nevermind': r'actually nevermind',
    'cancel that':      r'cancel that',
    'remove that':      r'remove that',
    'bump/make/change qty': r'bump that|make (?:it|that)|change that to',
}
total_corrections = 0
for label, pat in corrections.items():
    count = train_df['order'].str.lower().str.contains(pat, regex=True).sum()
    total_corrections += count
    print(f"  {label:25s}: {count:3d} ({count/len(train_df)*100:.1f}%)")

orders_with_corrections = train_df['order'].apply(
    lambda x: any(p in x.lower() for p in ['scratch', 'cancel', 'nevermind', 'remove that', 'bump that', 'change that']))
print(f"\n  Total orders with corrections: {orders_with_corrections.sum()} ({orders_with_corrections.mean()*100:.0f}%)")

# Empty orders (everything cancelled)
empty = train_df[train_df['num_items'] == 0]
print(f"\n--- Full Cancellations: {len(empty)} orders ---")
for _, r in empty.head(3).iterrows():
    print(f"  \"{r['order'][:80]}...\"")

# Unique vocabulary
all_items = [item for p in train_df['parsed'] for item in p['items']]
print(f"\n--- Vocabulary ---")
print(f"  Unique item names:  {len(set(i['name'] for i in all_items))}")
print(f"  Unique modifiers:   {len(set(m for i in all_items for m in i['modifiers']))}")
print(f"  Unique sizes:       {len(set(i['size'] for i in all_items if i['size']))}")
print(f"  Orders w/ duplicate item names: {sum(1 for p in train_df['parsed'] if len([i['name'] for i in p['items']]) != len(set(i['name'] for i in p['items'])))}")


## Deterministic Pricing Engine (100% Verified)

This is our biggest competitive advantage. Every other team lets their LLM calculate prices.
We calculate them ourselves â€” verified perfect on all 500 training rows.


In [None]:
# =============================================================
# DETERMINISTIC PRICING ENGINE
# =============================================================

MENU_PRICES = {
    "Espresso": 3.00, "Americano": 3.50, "Drip Coffee": 2.50,
    "Latte": 4.50, "Cappuccino": 4.50, "Flat White": 4.75,
    "Mocha": 5.00, "Caramel Macchiato": 5.25,
    "Cold Brew": 4.25, "Iced Coffee": 3.00,
    "Frappe (Coffee)": 5.50, "Frappe (Mocha)": 5.75,
    "Strawberry Smoothie": 6.00,
    "Chai Latte": 4.75, "Matcha Latte": 5.25,
    "Earl Grey Tea": 3.00, "Green Tea": 3.00,
    "Hot Chocolate": 4.00,
    "Butter Croissant": 3.50, "Blueberry Muffin": 3.75,
    "Bagel": 2.50, "Avocado Toast": 7.00, "Bacon Gouda Sandwich": 5.50,
}

SIZE_ADJUSTMENTS = {
    "Short": -0.50, "Tall": 0.00, "Grande": 0.50,
    "Venti": 1.00, "Trenta": 1.50,
}

MODIFIER_COSTS = {
    "Oat Milk": 0.80, "Almond Milk": 0.60, "Soy Milk": 0.60,
    "Coconut Milk": 0.70, "Breve": 0.80, "Skim Milk": 0.00,
    "Vanilla Syrup": 0.50, "Caramel Syrup": 0.50,
    "Hazelnut Syrup": 0.50, "Peppermint Syrup": 0.50,
    "Sugar Free Vanilla": 0.50, "Classic Syrup": 0.00,
    "Extra Shot": 1.00, "Whip Cream": 0.50, "No Whip": 0.00,
    "Cold Foam": 1.25, "Caramel Drizzle": 0.50,
    "Extra Hot": 0.00, "Light Ice": 0.00, "No Ice": 0.00,
}

FOOD_ITEMS = {"Butter Croissant", "Blueberry Muffin", "Bagel", "Avocado Toast", "Bacon Gouda Sandwich"}
VALID_NAMES = set(MENU_PRICES.keys())
VALID_SIZES = set(SIZE_ADJUSTMENTS.keys()) | {None}
VALID_MODIFIERS = set(MODIFIER_COSTS.keys())


def calculate_price(items: list) -> float:
    """Calculate total price. Returns rounded float."""
    total = 0.0
    for item in items:
        base = MENU_PRICES.get(item["name"], 0)
        size_adj = SIZE_ADJUSTMENTS.get(item.get("size"), 0) if item["name"] not in FOOD_ITEMS else 0
        mod_cost = sum(MODIFIER_COSTS.get(m, 0) for m in item.get("modifiers", []))
        total += (base + size_adj + mod_cost) * item.get("quantity", 1)
    return round(total, 2)


# --- Validate on entire training set ---
mismatches = 0
for _, row in train_df.iterrows():
    parsed = json.loads(row['expected_json'])
    calc = calculate_price(parsed['items'])
    if abs(calc - parsed['total_price']) > 0.001:
        mismatches += 1
        print(f"MISMATCH ID {row['id']}: calc={calc}, expected={parsed['total_price']}")

print(f"Pricing Engine: {len(train_df) - mismatches}/{len(train_df)} correct")
if mismatches == 0:
    print("âœ… PERFECT â€” 100% accuracy on all training rows")


## Output Validation & Post-Processing

Critical safety net: fuzzy-matches names/sizes/modifiers, strips hallucinations, recalculates price.


In [None]:
# =============================================================
# OUTPUT VALIDATION & POST-PROCESSING
# =============================================================

NAME_LOOKUP = {n.lower(): n for n in VALID_NAMES}
SIZE_LOOKUP = {s.lower(): s for s in SIZE_ADJUSTMENTS.keys()}
MOD_LOOKUP  = {m.lower(): m for m in VALID_MODIFIERS}

NAME_ALIASES = {
    "frappe coffee": "Frappe (Coffee)", "frappe mocha": "Frappe (Mocha)",
    "coffee frappe": "Frappe (Coffee)", "mocha frappe": "Frappe (Mocha)",
    "frappe(coffee)": "Frappe (Coffee)", "frappe(mocha)": "Frappe (Mocha)",
    "frappe (coffee)": "Frappe (Coffee)", "frappe (mocha)": "Frappe (Mocha)",
    "caramel macchiatto": "Caramel Macchiato",
    "bacon gouda": "Bacon Gouda Sandwich",
    "earl grey": "Earl Grey Tea",
    "hot choco": "Hot Chocolate",
}

MOD_ALIASES = {
    "oat": "Oat Milk", "almond": "Almond Milk", "soy": "Soy Milk",
    "coconut": "Coconut Milk", "half & half": "Breve", "half and half": "Breve",
    "sugar free vanilla syrup": "Sugar Free Vanilla", "sf vanilla": "Sugar Free Vanilla",
    "whipped cream": "Whip Cream", "whip": "Whip Cream",
    "no whipped cream": "No Whip",
}


def normalize_name(name: str) -> str:
    if name in VALID_NAMES:
        return name
    low = name.lower().strip()
    if low in NAME_LOOKUP:
        return NAME_LOOKUP[low]
    if low in NAME_ALIASES:
        return NAME_ALIASES[low]
    matches = get_close_matches(low, NAME_LOOKUP.keys(), n=1, cutoff=0.7)
    if matches:
        return NAME_LOOKUP[matches[0]]
    return name


def normalize_size(size) -> str | None:
    if size is None or str(size).lower() in ("null", "none", ""):
        return None
    s = str(size).strip()
    if s in SIZE_ADJUSTMENTS:
        return s
    low = s.lower()
    if low in SIZE_LOOKUP:
        return SIZE_LOOKUP[low]
    return None


def normalize_modifier(mod: str) -> str:
    if mod in VALID_MODIFIERS:
        return mod
    low = mod.lower().strip()
    if low in MOD_LOOKUP:
        return MOD_LOOKUP[low]
    if low in MOD_ALIASES:
        return MOD_ALIASES[low]
    matches = get_close_matches(low, MOD_LOOKUP.keys(), n=1, cutoff=0.7)
    if matches:
        return MOD_LOOKUP[matches[0]]
    return mod


def validate_and_fix(parsed: dict) -> dict:
    """Validate, normalize, and recalculate price."""
    items = parsed.get("items", [])
    fixed_items = []

    for item in items:
        name = normalize_name(item.get("name", ""))
        size = normalize_size(item.get("size"))
        qty = item.get("quantity", 1)
        mods = [normalize_modifier(m) for m in item.get("modifiers", [])]

        if name in FOOD_ITEMS:
            size = None
        if not isinstance(qty, int) or qty < 1:
            try:
                qty = int(qty)
                if qty < 1: qty = 1
            except:
                qty = 1

        # Keep only valid modifiers
        valid_mods = [m for m in mods if m in VALID_MODIFIERS]

        if name in VALID_NAMES:
            fixed_items.append({
                "name": name, "size": size,
                "quantity": qty, "modifiers": valid_mods,
            })

    return {"items": fixed_items, "total_price": calculate_price(fixed_items)}


# Quick test
test = {"items": [{"name": "latte", "size": "VENTI", "quantity": 2, "modifiers": ["oat milk"]}], "total_price": 999}
result = validate_and_fix(test)
print(f"Validation test: {test} â†’ {result}")
print(f"âœ… Validator working correctly")


## System Prompt & Few-Shot Examples

10 hand-curated examples covering: simple orders, multi-item, filler words, modifier removal,
item cancellation, quantity changes, full cancellation, modifier replacement chains, duplicate items.

The prompt gives the LLM **exact vocabulary** and lets it handle natural language corrections freely.


In [None]:
# =============================================================
# SYSTEM PROMPT & FEW-SHOT EXAMPLES
# =============================================================

SYSTEM_PROMPT = """You are a coffee shop POS system. Parse the customer's spoken order into JSON.

VALID ITEM NAMES (use exact strings):
"Espresso", "Americano", "Drip Coffee", "Latte", "Cappuccino", "Flat White", "Mocha", "Caramel Macchiato", "Cold Brew", "Iced Coffee", "Frappe (Coffee)", "Frappe (Mocha)", "Strawberry Smoothie", "Chai Latte", "Matcha Latte", "Earl Grey Tea", "Green Tea", "Hot Chocolate", "Butter Croissant", "Blueberry Muffin", "Bagel", "Avocado Toast", "Bacon Gouda Sandwich"

VALID SIZES: "Short", "Tall", "Grande", "Venti", "Trenta", or null for food.
If no size is stated for a drink, use "Tall".

VALID MODIFIERS (exact strings):
"Oat Milk", "Almond Milk", "Soy Milk", "Coconut Milk", "Breve", "Skim Milk", "Vanilla Syrup", "Caramel Syrup", "Hazelnut Syrup", "Peppermint Syrup", "Sugar Free Vanilla", "Classic Syrup", "Extra Shot", "Whip Cream", "No Whip", "Cold Foam", "Caramel Drizzle", "Extra Hot", "Light Ice", "No Ice"

RULES:
- Only list modifiers the customer explicitly asks for. Never add defaults.
- "hold the ice"/"no ice" = "No Ice". "hold the whip"/"no whip" = "No Whip".
- Resolve all corrections: "scratch that", "cancel that", "remove that", "nevermind" = undo the relevant item or modifier based on context.
- Quantity words: single/a/one=1, double/couple/pair/two=2, triple/three/a few=3, four=4, five=5.
- "make it N"/"bump that to N"/"change that to N" = update quantity.
- Same drink with different modifiers = separate items. Same item ordered again later = separate item.
- If everything is cancelled: {"items": [], "total_price": 0.0}
- Ignore filler: "like", "um", "uh", "literally", "you know".

OUTPUT: Return ONLY valid JSON. Set total_price to 0.0 (I calculate it).
{"items": [{"name": "...", "size": "...", "quantity": N, "modifiers": [...]}, ...], "total_price": 0.0}"""


# 10 curated few-shot examples
FEW_SHOT_EXAMPLES = [
    ("Grab me four avocado toasts.",
     '''{"items": [{"name": "Avocado Toast", "size": null, "quantity": 4, "modifiers": []}], "total_price": 0.0}'''),

    ("Could I have single trenta mocha plus peppermint syrup.",
     '''{"items": [{"name": "Mocha", "size": "Trenta", "quantity": 1, "modifiers": ["Peppermint Syrup"]}], "total_price": 0.0}'''),

    ("I'm craving couple of VENTI frappe (mocha)s include Caramel Drizzle.",
     '''{"items": [{"name": "Frappe (Mocha)", "size": "Venti", "quantity": 2, "modifiers": ["Caramel Drizzle"]}], "total_price": 0.0}'''),

    ("Could you get me three SHORT FRAPPE (COFFEE) and hold um the ice, and also one literally tall mocha and two tall Mochas include breve.",
     '''{"items": [{"name": "Frappe (Coffee)", "size": "Short", "quantity": 3, "modifiers": ["No Ice"]}, {"name": "Mocha", "size": "Tall", "quantity": 1, "modifiers": []}, {"name": "Mocha", "size": "Tall", "quantity": 2, "modifiers": ["Breve"]}], "total_price": 0.0}'''),

    ("Lemme get one tall Strawberry Smoothie include caramel drizzle - remove uh that. Next, I need a like venti drip coffee and extra hot. Oh, and add three trenta chai latte include Sugar Free Vanilla... scratch that one Sugar Free Vanilla. caramel drizzle and make sure no whip. Oh, and add double short mochas.",
     '''{"items": [{"name": "Drip Coffee", "size": "Venti", "quantity": 1, "modifiers": ["Extra Hot"]}, {"name": "Chai Latte", "size": "Trenta", "quantity": 3, "modifiers": ["Caramel Drizzle", "No Whip"]}, {"name": "Mocha", "size": "Short", "quantity": 2, "modifiers": []}], "total_price": 0.0}'''),

    ("Could you get me a Bagel. Oh, and add one TALL Latte plus um vanilla syrup. remove that vanilla syrup. Extra Shot. Also single bagel.",
     '''{"items": [{"name": "Bagel", "size": null, "quantity": 2, "modifiers": []}, {"name": "Latte", "size": "Tall", "quantity": 1, "modifiers": ["Extra Shot"]}], "total_price": 0.0}'''),

    ("Gimme single BUTTER actually CROISSANT. no, make it four. Also single trenta green tea and single blueberry muffin - actually nevermind.",
     '''{"items": [{"name": "Butter Croissant", "size": null, "quantity": 4, "modifiers": []}, {"name": "Green Tea", "size": "Trenta", "quantity": 1, "modifiers": []}], "total_price": 0.0}'''),

    ("Hook me up with three venti frappe (mocha) plus Soy Milk plus classic syrup, and also double grande Green Teas include peppermint syrup... wait, change that to three.",
     '''{"items": [{"name": "Frappe (Mocha)", "size": "Venti", "quantity": 3, "modifiers": ["Soy Milk", "Classic Syrup"]}, {"name": "Green Tea", "size": "Grande", "quantity": 3, "modifiers": ["Peppermint Syrup"]}], "total_price": 0.0}'''),

    ("Start me off with couple of Short FRAPPE (MOCHA)s include LIGHT ICE add CARAMEL SYRUP. Then give me a venti cold brew add breve, wait cancel that breve. coconut milk plus Hazelnut Syrup... wait cancel that Hazelnut Syrup... soy milk add almond milk. Also a Tall iced coffee.",
     '''{"items": [{"name": "Frappe (Mocha)", "size": "Short", "quantity": 2, "modifiers": ["Light Ice", "Caramel Syrup"]}, {"name": "Cold Brew", "size": "Venti", "quantity": 1, "modifiers": ["Coconut Milk", "Soy Milk", "Almond Milk"]}, {"name": "Iced Coffee", "size": "Tall", "quantity": 1, "modifiers": []}], "total_price": 0.0}'''),

    ("Start me off with a few grande caramel macchiatos. wait cancel that.",
     '''{"items": [], "total_price": 0.0}'''),
]

print(f"System prompt: {len(SYSTEM_PROMPT)} chars")
print(f"Few-shot examples: {len(FEW_SHOT_EXAMPLES)} pairs")
print("\nExample categories covered:")
print("  1. Simple food order")
print("  2. Single drink + modifier")
print("  3. Quantity word (couple) + modifier")
print("  4. Multi-item + duplicate names + filler words")
print("  5. Item cancellation (remove that) + modifier cancellation (scratch that one)")
print("  6. Modifier removal + same food merged")
print("  7. Quantity change (make it N) + full item cancel (actually nevermind)")
print("  8. Quantity change (change that to N)")
print("  9. Modifier cancel + replacement chain")
print("  10. Full order cancellation")


## Inference Engine

Processes orders through the LLM with:
- Chat template formatting for Qwen
- JSON extraction with regex fallback
- Validation + price recalculation
- Retry on failure with simplified prompt


In [None]:
# =============================================================
# INFERENCE ENGINE
# =============================================================

def build_messages(order_text: str) -> list:
    """Build chat messages with system prompt + few-shot + new order."""
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]

    for order, response in FEW_SHOT_EXAMPLES:
        messages.append({"role": "user", "content": f'Customer order: "{order}"'})
        messages.append({"role": "assistant", "content": response})

    messages.append({"role": "user", "content": f'Customer order: "{order_text}"'})
    return messages


def extract_json(text: str) -> dict | None:
    """Extract JSON from model output, handling markdown fences and extra text."""
    text = text.strip()

    # Remove markdown code fences
    text = re.sub(r'^```(?:json)?\s*', '', text)
    text = re.sub(r'\s*```$', '', text)
    text = text.strip()

    # Try direct parse first
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass

    # Try to find JSON object in the text
    # Look for the outermost { ... }
    brace_depth = 0
    start = None
    for i, ch in enumerate(text):
        if ch == '{':
            if brace_depth == 0:
                start = i
            brace_depth += 1
        elif ch == '}':
            brace_depth -= 1
            if brace_depth == 0 and start is not None:
                try:
                    return json.loads(text[start:i+1])
                except json.JSONDecodeError:
                    start = None

    return None


def parse_order(order_text: str, retries: int = 2) -> dict:
    """Parse a single order through the LLM pipeline."""

    for attempt in range(retries):
        try:
            messages = build_messages(order_text)

            text = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )

            inputs = tokenizer(
                text, return_tensors="pt", truncation=True, max_length=4096
            ).to(model.device)

            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=512,
                    do_sample=False,        # Greedy = deterministic
                    temperature=None,
                    top_p=None,
                    pad_token_id=tokenizer.eos_token_id,
                )

            # Decode only the new tokens
            new_tokens = outputs[0][inputs['input_ids'].shape[1]:]
            raw_output = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

            # Extract JSON
            parsed = extract_json(raw_output)

            if parsed is not None:
                result = validate_and_fix(parsed)
                # Clean up GPU memory
                del inputs, outputs, new_tokens
                return result
            else:
                print(f"  [Attempt {attempt+1}] Could not extract JSON from: {raw_output[:150]}")

        except Exception as e:
            print(f"  [Attempt {attempt+1}] Error: {e}")

        # Clean up before retry
        try:
            del inputs, outputs
        except:
            pass
        torch.cuda.empty_cache()

    # Final fallback: empty order
    return {"items": [], "total_price": 0.0}


# --- Quick sanity check ---
print("Running sanity check...")
test_order = "Gimme two venti lattes with oat milk and a bagel."
result = parse_order(test_order)
print(f'\nTest: "{test_order}"')
print(f"Result: {json.dumps(result, indent=2)}")


## Training Set Evaluation

Run the full pipeline on 500 training rows to measure accuracy before touching test.
**Skip this cell for final submission to save ~1-2 hours of runtime.**


In [None]:
# =============================================================
# TRAINING SET EVALUATION (skip for final submission to save time)
# =============================================================

def exact_match(predicted: dict, expected: dict) -> bool:
    """Exact match with order-insensitive modifiers."""
    if round(predicted["total_price"], 2) != round(expected["total_price"], 2):
        return False
    if len(predicted["items"]) != len(expected["items"]):
        return False
    for p, e in zip(predicted["items"], expected["items"]):
        if p["name"] != e["name"]:
            return False
        if p["size"] != e["size"]:
            return False
        if p["quantity"] != e["quantity"]:
            return False
        if sorted(p.get("modifiers", [])) != sorted(e.get("modifiers", [])):
            return False
    return True


print("Evaluating on training set (500 rows)...")
print("=" * 60)

correct = 0
errors = []

for idx, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Train eval"):
    predicted = parse_order(row["order"])
    expected = json.loads(row["expected_json"])

    if exact_match(predicted, expected):
        correct += 1
    else:
        errors.append({"id": row["id"], "order": row["order"],
                       "predicted": predicted, "expected": expected})

accuracy = correct / len(train_df) * 100
print(f"\n{'=' * 60}")
print(f"TRAINING ACCURACY: {correct}/{len(train_df)} = {accuracy:.1f}%")
print(f"{'=' * 60}")

# Error breakdown
if errors:
    item_count_err = sum(1 for e in errors if len(e['predicted']['items']) != len(e['expected']['items']))
    print(f"\nError breakdown ({len(errors)} total):")
    print(f"  Wrong number of items: {item_count_err}")
    print(f"  Other (name/size/qty/mod): {len(errors) - item_count_err}")
    print(f"\nFirst 10 errors:")
    for err in errors[:10]:
        print(f"\n  ID {err['id']}: \"{err['order'][:100]}...\"")
        print(f"    Expected items: {[i['name'] for i in err['expected']['items']]}")
        print(f"    Got items:      {[i['name'] for i in err['predicted']['items']]}")
        if err['predicted']['total_price'] != err['expected']['total_price']:
            print(f"    Price: expected={err['expected']['total_price']}, got={err['predicted']['total_price']}")


## Cell 8 â€” Test Set Inference & Submission

In [None]:
# =============================================================
# TEST SET INFERENCE & SUBMISSION
# =============================================================

print(f"Running inference on {len(test_df)} test rows...")
print("Estimated time: 5-9 hours on T4x2 with 14B model")
print("=" * 60)

start_time = time.time()
test_predictions = []

for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Test inference"):
    predicted = parse_order(row["order"])
    test_predictions.append({
        "id": row["id"],
        "predicted_json": json.dumps(predicted),
    })

    # Periodic memory cleanup
    if (idx + 1) % 100 == 0:
        torch.cuda.empty_cache()
        gc.collect()

    # Progress checkpoint every 500 rows
    if (idx + 1) % 500 == 0:
        elapsed = time.time() - start_time
        rate = (idx + 1) / elapsed
        remaining = (len(test_df) - idx - 1) / rate
        print(f"  [{idx+1}/{len(test_df)}] {elapsed/60:.0f}min elapsed, ~{remaining/60:.0f}min remaining")

elapsed = time.time() - start_time
print(f"\nâœ… Inference complete in {elapsed/3600:.1f} hours ({elapsed/60:.0f} min)")

# Build submission
submission = pd.DataFrame(test_predictions)
assert list(submission.columns) == ["id", "predicted_json"]
assert len(submission) == len(test_df)

submission.to_csv("submission.csv", index=False)
print(f"âœ… Saved: submission.csv ({len(submission)} rows)")

# Quick stats
prices = submission['predicted_json'].apply(lambda x: json.loads(x)['total_price'])
n_items = submission['predicted_json'].apply(lambda x: len(json.loads(x)['items']))
print(f"\nSubmission stats:")
print(f"  Avg price: ${prices.mean():.2f} | Median: ${prices.median():.2f}")
print(f"  Price range: ${prices.min():.2f} - ${prices.max():.2f}")
print(f"  Zero-price (cancelled): {(prices == 0).sum()}")
print(f"  Avg items per order: {n_items.mean():.1f}")


## Cell 9 â€” Error Analysis (Development)

In [None]:
# =============================================================
# ERROR ANALYSIS â€” Run after training eval to understand failure modes
# =============================================================

try:
    if errors:
        item_count_errors = 0
        mod_errors = 0
        qty_errors = 0
        size_errors = 0

        for err in errors:
            pred, exp = err["predicted"]["items"], err["expected"]["items"]
            if len(pred) != len(exp):
                item_count_errors += 1
                continue
            for pi, ei in zip(pred, exp):
                if pi["quantity"] != ei["quantity"]: qty_errors += 1
                if pi["size"] != ei["size"]: size_errors += 1
                if sorted(pi.get("modifiers",[])) != sorted(ei.get("modifiers",[])): mod_errors += 1

        print(f"Detailed Error Breakdown:")
        print(f"  Item count wrong:  {item_count_errors}")
        print(f"  Quantity wrong:    {qty_errors}")
        print(f"  Size wrong:        {size_errors}")
        print(f"  Modifiers wrong:   {mod_errors}")
    else:
        print("ðŸŽ‰ No errors to analyze!")
except NameError:
    print("Run training eval (Cell 7) first to generate error data.")
