# Notebook 27: Structured Output & Logit Biasing

## Inference Engineering Course

---

## Overview

LLMs generate free-form text by default, but production systems often need **structured output** -- JSON, function calls, or specific formats. Unstructured output leads to parsing failures, downstream errors, and fragile pipelines.

### The Problem

```
Prompt: "Extract the name and age from: 'John is 25 years old'"

Unstructured output (unreliable):     Structured output (reliable):
"The name is John and he is 25"       {"name": "John", "age": 25}
"Name: John, Age: 25"                 {"name": "John", "age": 25}
"John (25)"                            {"name": "John", "age": 25}
```

### What You'll Learn

| Topic | Description |
|-------|-------------|
| JSON Mode | Forcing models to output valid JSON |
| Logit Biasing | Manipulating token probabilities |
| Function Calling | Building tool-use patterns |
| Grammar Constraints | Constraining generation with formal grammars |
| Schema Validation | Validating LLM outputs against schemas |
| Reliability Testing | Comparing structured vs unstructured reliability |

### Prerequisites
- Basic understanding of LLM token generation
- Python and JSON knowledge
- No GPU required (CPU is sufficient)

In [None]:
# ============================================================
# Install dependencies
# ============================================================
!pip install jsonschema matplotlib numpy pandas -q

import json
import re
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

print("Dependencies loaded!")

---

## Section 1: Understanding Structured Output Methods

There are several approaches to getting structured output from LLMs, ranging from simple to sophisticated:

```
Approach         Reliability  Flexibility  Complexity
─────────────────────────────────────────────────────
Prompt engineering    Low        High        Low
Output parsing        Medium     Medium      Medium
JSON mode (API)       High       Medium      Low
Logit biasing         High       Low         Medium
Grammar constraints   Very High  Medium      High
Function calling      Very High  High        Medium
```

Let's explore each approach.

In [None]:
# ============================================================
# Section 1a: JSON Mode with OpenAI-compatible APIs
# ============================================================

# This shows how you would use JSON mode with an API.
# We simulate the API calls for Colab compatibility.

def simulate_json_mode_request():
    """
    Demonstrates the API call structure for JSON mode.
    In production, replace with actual API call.
    """
    
    # ---- How you'd call OpenAI with JSON mode ----
    api_request = {
        "model": "gpt-4",
        "response_format": {"type": "json_object"},  # <-- JSON mode!
        "messages": [
            {
                "role": "system",
                "content": "You are a data extraction assistant. Always respond with valid JSON."
            },
            {
                "role": "user", 
                "content": "Extract entities from: 'Apple Inc. reported $394B revenue in 2023, up 2% from 2022.'"
            }
        ]
    }
    
    # ---- Simulated response ----
    simulated_response = {
        "company": "Apple Inc.",
        "revenue": {
            "amount": 394000000000,
            "currency": "USD",
            "year": 2023
        },
        "growth": {
            "percentage": 2,
            "compared_to": 2022
        }
    }
    
    return api_request, simulated_response


request, response = simulate_json_mode_request()

print("API Request Structure:")
print("=" * 50)
print(json.dumps(request, indent=2))
print("\nSimulated JSON Response:")
print("=" * 50)
print(json.dumps(response, indent=2))

print("\n\nKey Points:")
print('1. Set response_format: {"type": "json_object"}')
print("2. Include 'JSON' in system/user prompt")
print("3. API guarantees valid JSON output")
print("4. Does NOT guarantee schema compliance -- you still need validation!")

---

## Section 2: Logit Biasing - Forcing Specific Tokens

**Logit biasing** lets you increase or decrease the probability of specific tokens during generation. This is powerful for:

- Forcing the model to use certain keywords
- Preventing the model from generating certain content
- Guiding the model toward a specific format

### How It Works

At each generation step, the model produces logits (scores) for every token in its vocabulary. Logit bias adds a fixed value to specific token logits before the softmax:

$$P(token_i) = \frac{e^{z_i + bias_i}}{\sum_j e^{z_j + bias_j}}$$

- **Positive bias (+5 to +100)**: Makes the token much more likely
- **Negative bias (-5 to -100)**: Makes the token much less likely
- **-100**: Effectively bans the token

In [None]:
# ============================================================
# Demonstrate logit biasing mechanics
# ============================================================

import numpy as np
import matplotlib.pyplot as plt

def softmax(logits):
    """Compute softmax probabilities."""
    exp_logits = np.exp(logits - np.max(logits))
    return exp_logits / exp_logits.sum()

# Simulated model output logits for next token
tokens = ['"name"', '"age"', '"city"', 'The', 'Hello', '{', '[', 'null', 'true', '42']
original_logits = np.array([2.1, 1.8, 1.5, 3.2, 2.8, 1.0, 0.5, -0.5, -1.0, 0.3])

# Apply different biasing strategies
# Strategy 1: Boost JSON-related tokens
json_bias = np.array([5, 5, 5, -10, -10, 10, 5, 2, 2, 2])
biased_logits = original_logits + json_bias

# Strategy 2: Ban specific tokens
ban_bias = np.array([0, 0, 0, -100, -100, 0, 0, 0, 0, 0])
banned_logits = original_logits + ban_bias

# Compute probabilities
original_probs = softmax(original_logits)
biased_probs = softmax(biased_logits)
banned_probs = softmax(banned_logits)

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original
colors_orig = ['#4CAF50' if t.startswith('"') or t in ['{', '['] else '#90CAF9' for t in tokens]
axes[0].barh(tokens, original_probs, color=colors_orig, alpha=0.8)
axes[0].set_xlabel('Probability', fontsize=12)
axes[0].set_title('Original Probabilities', fontsize=13, fontweight='bold')
axes[0].invert_yaxis()
for i, v in enumerate(original_probs):
    axes[0].text(v + 0.01, i, f'{v:.3f}', va='center', fontsize=10)

# With JSON bias
colors_bias = ['#4CAF50' if p > 0.1 else '#FFCDD2' for p in biased_probs]
axes[1].barh(tokens, biased_probs, color=colors_bias, alpha=0.8)
axes[1].set_xlabel('Probability', fontsize=12)
axes[1].set_title('With JSON Logit Bias\n(Boost JSON tokens)', fontsize=13, fontweight='bold')
axes[1].invert_yaxis()
for i, v in enumerate(biased_probs):
    axes[1].text(v + 0.01, i, f'{v:.3f}', va='center', fontsize=10)

# With banned tokens
colors_ban = ['#F44336' if p < 0.001 else '#4CAF50' for p in banned_probs]
axes[2].barh(tokens, banned_probs, color=colors_ban, alpha=0.8)
axes[2].set_xlabel('Probability', fontsize=12)
axes[2].set_title('With Token Banning\n(Ban: The, Hello)', fontsize=13, fontweight='bold')
axes[2].invert_yaxis()
for i, v in enumerate(banned_probs):
    axes[2].text(v + 0.01, i, f'{v:.3f}', va='center', fontsize=10)

plt.tight_layout()
plt.savefig('logit_biasing.png', dpi=150, bbox_inches='tight')
plt.show()

print("Observation: Logit biasing dramatically shifts the probability distribution.")
print(f"Most likely token changed from '{tokens[np.argmax(original_probs)]}' to '{tokens[np.argmax(biased_probs)]}'")

In [None]:
# ============================================================
# Implementing a logit bias processor
# ============================================================

class LogitBiasProcessor:
    """
    A processor that applies logit biases during generation.
    Compatible with HuggingFace's LogitsProcessor interface.
    """
    
    def __init__(self, bias_dict: dict):
        """
        Args:
            bias_dict: {token_id: bias_value}
                       Positive values increase probability,
                       negative values decrease probability.
        """
        self.bias_dict = bias_dict
    
    def __call__(self, input_ids, scores):
        """Apply biases to logits."""
        for token_id, bias in self.bias_dict.items():
            if 0 <= token_id < scores.shape[-1]:
                scores[:, token_id] += bias
        return scores
    
    @staticmethod
    def from_tokens(tokenizer, token_biases: dict):
        """
        Create from token strings instead of IDs.
        
        Args:
            tokenizer: HuggingFace tokenizer
            token_biases: {"token_string": bias_value}
        """
        bias_dict = {}
        for token, bias in token_biases.items():
            token_ids = tokenizer.encode(token, add_special_tokens=False)
            for tid in token_ids:
                bias_dict[tid] = bias
        return LogitBiasProcessor(bias_dict)

# Demo: Using the API's logit_bias parameter
print("Example API request with logit_bias:")
print(json.dumps({
    "model": "gpt-4",
    "messages": [{"role": "user", "content": "Pick a color:"}],
    "logit_bias": {
        "2654": 10,    # Token ID for "blue" -> boost
        "12567": -100,  # Token ID for "red" -> ban
    },
    "max_tokens": 10
}, indent=2))

print("\nCommon use cases for logit biasing:")
print("  1. Force JSON: Boost '{', '[', '\"' tokens at the start")
print("  2. Prevent repetition: Decrease recently-used tokens")
print("  3. Language control: Ban tokens from other languages")
print("  4. Safety: Ban specific sensitive words")

---

## Section 3: Function Calling Pattern

Function calling allows LLMs to invoke external tools by generating structured function call specifications. This is the foundation of **agentic AI** systems.

### Architecture

```
User Query → LLM → Function Call (JSON) → Execute Function → LLM → Response
               ↑                                               ↑
          Tool definitions                              Function results
```

In [None]:
# ============================================================
# Implement a function calling system
# ============================================================

# Step 1: Define available tools
tool_definitions = [
    {
        "name": "get_weather",
        "description": "Get the current weather for a location",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "City name, e.g., 'San Francisco'"
                },
                "unit": {
                    "type": "string",
                    "enum": ["celsius", "fahrenheit"],
                    "description": "Temperature unit"
                }
            },
            "required": ["location"]
        }
    },
    {
        "name": "search_products",
        "description": "Search for products in the catalog",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": "Search query"
                },
                "max_price": {
                    "type": "number",
                    "description": "Maximum price filter"
                },
                "category": {
                    "type": "string",
                    "enum": ["electronics", "clothing", "books", "home"]
                }
            },
            "required": ["query"]
        }
    },
    {
        "name": "calculate",
        "description": "Perform mathematical calculations",
        "parameters": {
            "type": "object",
            "properties": {
                "expression": {
                    "type": "string",
                    "description": "Mathematical expression to evaluate"
                }
            },
            "required": ["expression"]
        }
    }
]

# Step 2: Implement the actual functions
def get_weather(location, unit="celsius"):
    """Simulated weather API."""
    weather_data = {
        "San Francisco": {"temp_c": 18, "condition": "Foggy", "humidity": 75},
        "New York": {"temp_c": 25, "condition": "Sunny", "humidity": 50},
        "London": {"temp_c": 14, "condition": "Rainy", "humidity": 85},
    }
    data = weather_data.get(location, {"temp_c": 20, "condition": "Clear", "humidity": 60})
    if unit == "fahrenheit":
        data["temp_f"] = data["temp_c"] * 9/5 + 32
    return data

def search_products(query, max_price=None, category=None):
    """Simulated product search."""
    products = [
        {"name": "Wireless Headphones", "price": 79.99, "category": "electronics"},
        {"name": "USB-C Hub", "price": 34.99, "category": "electronics"},
        {"name": "Python Cookbook", "price": 45.00, "category": "books"},
        {"name": "Standing Desk Mat", "price": 59.99, "category": "home"},
    ]
    results = [p for p in products if query.lower() in p['name'].lower()]
    if max_price:
        results = [p for p in results if p['price'] <= max_price]
    if category:
        results = [p for p in results if p['category'] == category]
    return results

def calculate(expression):
    """Safe math evaluation."""
    try:
        # Only allow safe math operations
        allowed = set('0123456789+-*/().% ')
        if set(expression) <= allowed:
            return {"result": eval(expression)}
        return {"error": "Invalid expression"}
    except Exception as e:
        return {"error": str(e)}

# Step 3: Function dispatcher
FUNCTIONS = {
    "get_weather": get_weather,
    "search_products": search_products,
    "calculate": calculate,
}

def execute_function_call(function_call: dict) -> dict:
    """Execute a function call from the LLM."""
    name = function_call["name"]
    args = function_call.get("arguments", {})
    
    if name not in FUNCTIONS:
        return {"error": f"Unknown function: {name}"}
    
    try:
        result = FUNCTIONS[name](**args)
        return {"status": "success", "result": result}
    except Exception as e:
        return {"status": "error", "message": str(e)}


# Demo: Simulate LLM generating function calls
simulated_function_calls = [
    {
        "user_query": "What's the weather in San Francisco?",
        "function_call": {
            "name": "get_weather",
            "arguments": {"location": "San Francisco", "unit": "fahrenheit"}
        }
    },
    {
        "user_query": "Find me headphones under $100",
        "function_call": {
            "name": "search_products",
            "arguments": {"query": "Headphones", "max_price": 100}
        }
    },
    {
        "user_query": "What is 15% of 340?",
        "function_call": {
            "name": "calculate",
            "arguments": {"expression": "340 * 0.15"}
        }
    }
]

print("Function Calling Demo")
print("=" * 60)

for item in simulated_function_calls:
    print(f"\nUser: {item['user_query']}")
    print(f"LLM generates: {json.dumps(item['function_call'])}")
    result = execute_function_call(item['function_call'])
    print(f"Function result: {json.dumps(result, indent=2)}")
    print("-" * 60)

---

## Section 4: Grammar-Constrained Generation

Grammar-constrained generation uses formal grammar rules (like BNF or regex) to **guarantee** the output conforms to a specific structure. At each token generation step, only tokens that are valid according to the grammar are allowed.

### How It Works

```
Grammar: JSON object with specific fields
         → { "name": <string>, "age": <number> }

Step 1: Only '{' is allowed → generates '{'
Step 2: Only '"name"' is allowed → generates '"name"'
Step 3: Only ':' is allowed → generates ':'
Step 4: Only '"' is allowed → generates '"'
Step 5: Any string chars allowed → generates 'John'
...
```

In [None]:
# ============================================================
# Implement a simple grammar-constrained generator
# ============================================================

class SimpleGrammarConstraint:
    """
    A simplified grammar constraint engine for JSON generation.
    Demonstrates the concept of grammar-constrained decoding.
    
    In production, use libraries like:
    - lm-format-enforcer
    - guidance
    - outlines
    """
    
    def __init__(self, schema: dict):
        self.schema = schema
        self.state = 'start'
        self.field_index = 0
        self.fields = list(schema.get('properties', {}).keys())
    
    def get_allowed_tokens(self) -> list:
        """Return list of allowed token patterns in current state."""
        if self.state == 'start':
            return ['{']  # Must start with opening brace
        elif self.state == 'field_name':
            if self.field_index < len(self.fields):
                return [f'"{self.fields[self.field_index]}"']
            return ['}']  # No more fields
        elif self.state == 'colon':
            return [':']
        elif self.state == 'value':
            field = self.fields[self.field_index]
            field_type = self.schema['properties'][field].get('type', 'string')
            if field_type == 'string':
                return ['"<string>"']  # Any string
            elif field_type == 'number' or field_type == 'integer':
                return ['<number>']  # Any number
            elif field_type == 'boolean':
                return ['true', 'false']
            return ['<any>']
        elif self.state == 'separator':
            if self.field_index < len(self.fields) - 1:
                return [',']  # More fields to come
            return ['}']  # Last field
        return []
    
    def advance(self, token):
        """Advance the state machine."""
        if self.state == 'start':
            self.state = 'field_name'
        elif self.state == 'field_name':
            self.state = 'colon'
        elif self.state == 'colon':
            self.state = 'value'
        elif self.state == 'value':
            self.state = 'separator'
        elif self.state == 'separator':
            if token == ',':
                self.field_index += 1
                self.state = 'field_name'
            elif token == '}':
                self.state = 'done'


# Demo: Walk through constrained generation
schema = {
    "type": "object",
    "properties": {
        "name": {"type": "string"},
        "age": {"type": "integer"},
        "is_student": {"type": "boolean"}
    },
    "required": ["name", "age", "is_student"]
}

constraint = SimpleGrammarConstraint(schema)

# Simulate the step-by-step generation
simulated_tokens = ['{', '"name"', ':', '"Alice"', ',', 
                    '"age"', ':', '25', ',', 
                    '"is_student"', ':', 'true', '}']

print("Grammar-Constrained Generation Step-by-Step:")
print("=" * 60)
print(f"Schema: {json.dumps(schema, indent=2)}")
print("\n" + "-" * 60)

generated = ""
for token in simulated_tokens:
    allowed = constraint.get_allowed_tokens()
    print(f"  State: {constraint.state:15s} | Allowed: {str(allowed):30s} | Generated: {token}")
    constraint.advance(token)
    generated += token + (" " if token in [',', ':'] else "")

print(f"\nFinal output: {generated}")
print("\nWith grammar constraints, invalid JSON is IMPOSSIBLE.")

---

## Section 5: JSON Schema Validation for LLM Outputs

Even with JSON mode, the output might not conform to your expected schema. Validation is essential.

In [None]:
# ============================================================
# Implement comprehensive JSON schema validator
# ============================================================

import jsonschema
from jsonschema import validate, ValidationError

class LLMOutputValidator:
    """
    Validates LLM outputs against JSON schemas.
    Provides detailed error messages and auto-repair hints.
    """
    
    def __init__(self, schema: dict):
        self.schema = schema
        self.validation_history = []
    
    def validate(self, output: str) -> dict:
        """Validate an LLM output string against the schema."""
        result = {
            'is_valid_json': False,
            'is_schema_valid': False,
            'parsed': None,
            'errors': [],
            'warnings': [],
        }
        
        # Step 1: Check if it's valid JSON
        try:
            parsed = json.loads(output)
            result['is_valid_json'] = True
            result['parsed'] = parsed
        except json.JSONDecodeError as e:
            result['errors'].append(f"Invalid JSON: {str(e)}")
            
            # Try to extract JSON from surrounding text
            json_match = re.search(r'\{[^{}]*\}', output)
            if json_match:
                try:
                    parsed = json.loads(json_match.group())
                    result['parsed'] = parsed
                    result['is_valid_json'] = True
                    result['warnings'].append("JSON extracted from surrounding text")
                except json.JSONDecodeError:
                    pass
            
            if not result['is_valid_json']:
                self.validation_history.append(result)
                return result
        
        # Step 2: Validate against schema
        try:
            validate(instance=result['parsed'], schema=self.schema)
            result['is_schema_valid'] = True
        except ValidationError as e:
            result['errors'].append(f"Schema validation error: {e.message}")
            result['errors'].append(f"  Path: {'.'.join(str(p) for p in e.path)}")
            result['errors'].append(f"  Schema rule: {e.schema}")
        
        # Step 3: Additional quality checks
        if result['parsed']:
            # Check for empty strings
            for key, value in result['parsed'].items():
                if isinstance(value, str) and value.strip() == '':
                    result['warnings'].append(f"Empty string for field '{key}'")
                if value is None:
                    result['warnings'].append(f"Null value for field '{key}'")
        
        self.validation_history.append(result)
        return result
    
    def get_success_rate(self) -> float:
        """Get the validation success rate."""
        if not self.validation_history:
            return 0.0
        valid = sum(1 for r in self.validation_history if r['is_schema_valid'])
        return valid / len(self.validation_history)


# ---- Define a real-world schema ----
product_review_schema = {
    "type": "object",
    "properties": {
        "product_name": {"type": "string", "minLength": 1},
        "rating": {"type": "integer", "minimum": 1, "maximum": 5},
        "sentiment": {"type": "string", "enum": ["positive", "negative", "neutral"]},
        "key_points": {
            "type": "array",
            "items": {"type": "string"},
            "minItems": 1,
            "maxItems": 5
        },
        "recommend": {"type": "boolean"}
    },
    "required": ["product_name", "rating", "sentiment", "key_points", "recommend"]
}

validator = LLMOutputValidator(product_review_schema)

# Test with various outputs
test_outputs = [
    # Good output
    '{"product_name": "Wireless Mouse", "rating": 4, "sentiment": "positive", "key_points": ["Comfortable grip", "Long battery life"], "recommend": true}',
    
    # Invalid JSON (missing quote)
    '{product_name: "Keyboard", "rating": 3}',
    
    # Valid JSON but schema violation (rating out of range)
    '{"product_name": "Monitor", "rating": 11, "sentiment": "positive", "key_points": ["Great display"], "recommend": true}',
    
    # Missing required fields
    '{"product_name": "Laptop", "rating": 5}',
    
    # Wrong enum value
    '{"product_name": "Tablet", "rating": 3, "sentiment": "amazing", "key_points": ["Nice screen"], "recommend": true}',
    
    # JSON embedded in text
    'Sure! Here is the review analysis: {"product_name": "Speaker", "rating": 4, "sentiment": "positive", "key_points": ["Great sound"], "recommend": true} I hope this helps!',
]

print("JSON Schema Validation Results")
print("=" * 70)

for i, output in enumerate(test_outputs):
    result = validator.validate(output)
    status = "PASS" if result['is_schema_valid'] else "FAIL"
    color = '\033[92m' if result['is_schema_valid'] else '\033[91m'
    
    print(f"\nTest {i+1}: [{status}]")
    print(f"  Input: {output[:80]}..." if len(output) > 80 else f"  Input: {output}")
    print(f"  Valid JSON: {result['is_valid_json']}")
    print(f"  Schema valid: {result['is_schema_valid']}")
    if result['errors']:
        for err in result['errors']:
            print(f"  ERROR: {err}")
    if result['warnings']:
        for warn in result['warnings']:
            print(f"  WARNING: {warn}")

print(f"\nOverall success rate: {validator.get_success_rate():.0%}")

---

## Section 6: Comparing Structured vs Unstructured Output Reliability

Let's quantify how much structured output methods improve reliability compared to relying on free-form text parsing.

In [None]:
# ============================================================
# Simulate reliability comparison across methods
# ============================================================

np.random.seed(42)

n_trials = 1000

# Simulate success rates for different methods
# Based on real-world observations from production systems

methods = {
    'Free-form\n(Regex Parse)': {
        'valid_json': np.random.binomial(1, 0.60, n_trials),
        'schema_valid': np.random.binomial(1, 0.45, n_trials),
        'fully_correct': np.random.binomial(1, 0.40, n_trials),
    },
    'Prompt\nEngineering': {
        'valid_json': np.random.binomial(1, 0.82, n_trials),
        'schema_valid': np.random.binomial(1, 0.70, n_trials),
        'fully_correct': np.random.binomial(1, 0.65, n_trials),
    },
    'JSON Mode\n(API)': {
        'valid_json': np.random.binomial(1, 0.99, n_trials),
        'schema_valid': np.random.binomial(1, 0.88, n_trials),
        'fully_correct': np.random.binomial(1, 0.85, n_trials),
    },
    'JSON Mode +\nSchema Prompt': {
        'valid_json': np.random.binomial(1, 0.99, n_trials),
        'schema_valid': np.random.binomial(1, 0.94, n_trials),
        'fully_correct': np.random.binomial(1, 0.91, n_trials),
    },
    'Grammar\nConstrained': {
        'valid_json': np.ones(n_trials),  # Always valid JSON
        'schema_valid': np.ones(n_trials),  # Always schema-valid
        'fully_correct': np.random.binomial(1, 0.97, n_trials),
    },
    'Function\nCalling': {
        'valid_json': np.random.binomial(1, 0.99, n_trials),
        'schema_valid': np.random.binomial(1, 0.96, n_trials),
        'fully_correct': np.random.binomial(1, 0.93, n_trials),
    },
}

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(18, 7))

# Plot 1: Grouped bar chart
method_names = list(methods.keys())
x = np.arange(len(method_names))
width = 0.25

metrics = ['valid_json', 'schema_valid', 'fully_correct']
metric_labels = ['Valid JSON', 'Schema Compliant', 'Fully Correct']
colors = ['#2196F3', '#FF9800', '#4CAF50']

for i, (metric, label, color) in enumerate(zip(metrics, metric_labels, colors)):
    rates = [methods[m][metric].mean() * 100 for m in method_names]
    bars = axes[0].bar(x + i * width, rates, width, label=label, color=color, alpha=0.85)

axes[0].set_xlabel('Method', fontsize=12)
axes[0].set_ylabel('Success Rate (%)', fontsize=12)
axes[0].set_title('Structured Output Method Reliability\n(1000 trials each)', 
                  fontsize=14, fontweight='bold')
axes[0].set_xticks(x + width)
axes[0].set_xticklabels(method_names, fontsize=9)
axes[0].legend(fontsize=10)
axes[0].set_ylim(0, 110)
axes[0].grid(True, alpha=0.3, axis='y')

# Plot 2: Cost of failures (retries needed)
avg_retries = []
for method_name in method_names:
    success_rate = methods[method_name]['fully_correct'].mean()
    if success_rate > 0:
        # Expected retries = 1/success_rate - 1
        avg_retries.append(max(0, 1/success_rate - 1))
    else:
        avg_retries.append(float('inf'))

retry_colors = ['#F44336' if r > 0.3 else '#FF9800' if r > 0.1 else '#4CAF50' for r in avg_retries]
bars = axes[1].bar(method_names, avg_retries, color=retry_colors, alpha=0.8)

for bar, val in zip(bars, avg_retries):
    axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,
                f'{val:.2f}', ha='center', fontweight='bold', fontsize=10)

axes[1].set_xlabel('Method', fontsize=12)
axes[1].set_ylabel('Average Retries per Request', fontsize=12)
axes[1].set_title('Cost of Failures\n(Expected retries to get correct output)', 
                  fontsize=14, fontweight='bold')
axes[1].tick_params(axis='x', rotation=0)
axes[1].set_xticklabels(method_names, fontsize=9)
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('reliability_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nKey Insight: Grammar-constrained generation provides the highest reliability.")
print("JSON mode + schema prompting is a practical sweet spot for most use cases.")

In [None]:
# ============================================================
# Cost analysis: Retries vs structured output overhead
# ============================================================

# Calculate total cost per correct output
base_cost_per_request = 0.003  # $0.003 per request (example)
structured_overhead = {
    'Free-form\n(Regex Parse)': 0,
    'Prompt\nEngineering': 0.0005,  # Longer prompts
    'JSON Mode\n(API)': 0.0002,     # Minimal overhead
    'JSON Mode +\nSchema Prompt': 0.001,  # Schema in prompt
    'Grammar\nConstrained': 0.0008,  # Processing overhead
    'Function\nCalling': 0.001,      # Tool definitions
}

total_costs = []
for method_name in method_names:
    success_rate = methods[method_name]['fully_correct'].mean()
    overhead = structured_overhead[method_name]
    cost_per_request = base_cost_per_request + overhead
    expected_attempts = 1 / success_rate if success_rate > 0 else 10
    total_cost = cost_per_request * expected_attempts
    total_costs.append(total_cost * 1000)  # Convert to millicents for readability

fig, ax = plt.subplots(figsize=(12, 6))
colors = plt.cm.RdYlGn_r(np.linspace(0.2, 0.8, len(method_names)))
bars = ax.bar(method_names, total_costs, color=colors, alpha=0.85, edgecolor='white', linewidth=2)

for bar, cost in zip(bars, total_costs):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.05,
            f'${cost/1000:.4f}', ha='center', fontweight='bold', fontsize=10)

ax.set_xlabel('Method', fontsize=12)
ax.set_ylabel('Cost per Correct Output ($)', fontsize=12)
ax.set_title('True Cost per Correct Output (Including Retries)',
            fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
ax.tick_params(axis='x', labelsize=9)

plt.tight_layout()
plt.savefig('cost_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Section 7: Building a Robust Output Pipeline

In production, you need a pipeline that combines multiple strategies for maximum reliability.

In [None]:
# ============================================================
# Production-grade structured output pipeline
# ============================================================

class StructuredOutputPipeline:
    """
    A production pipeline for reliable structured output.
    
    Strategy:
    1. Use JSON mode or function calling
    2. Validate against schema
    3. Attempt auto-repair if validation fails
    4. Retry with more explicit instructions
    5. Fall back to default values as last resort
    """
    
    def __init__(self, schema: dict, max_retries: int = 3):
        self.schema = schema
        self.max_retries = max_retries
        self.validator = LLMOutputValidator(schema)
        self.stats = {'attempts': 0, 'successes': 0, 'retries': 0, 'repairs': 0}
    
    def process(self, llm_output: str) -> dict:
        """Process LLM output with validation and repair."""
        self.stats['attempts'] += 1
        
        # Step 1: Validate
        result = self.validator.validate(llm_output)
        
        if result['is_schema_valid']:
            self.stats['successes'] += 1
            return {'status': 'success', 'data': result['parsed'], 'method': 'direct'}
        
        # Step 2: Try auto-repair
        if result['is_valid_json'] and result['parsed']:
            repaired = self._auto_repair(result['parsed'])
            repair_result = self.validator.validate(json.dumps(repaired))
            if repair_result['is_schema_valid']:
                self.stats['repairs'] += 1
                self.stats['successes'] += 1
                return {'status': 'repaired', 'data': repaired, 'method': 'auto_repair'}
        
        # Step 3: Return with error info
        return {
            'status': 'failed',
            'errors': result['errors'],
            'method': 'none',
            'retry_prompt': self._generate_retry_prompt(result)
        }
    
    def _auto_repair(self, parsed: dict) -> dict:
        """Attempt to auto-repair common issues."""
        repaired = dict(parsed)
        properties = self.schema.get('properties', {})
        required = self.schema.get('required', [])
        
        for field in required:
            if field not in repaired:
                # Add default values for missing fields
                field_type = properties.get(field, {}).get('type', 'string')
                defaults = {
                    'string': 'unknown',
                    'integer': 0,
                    'number': 0.0,
                    'boolean': False,
                    'array': [],
                }
                repaired[field] = defaults.get(field_type, None)
        
        # Fix type mismatches
        for field, field_schema in properties.items():
            if field in repaired:
                expected_type = field_schema.get('type')
                value = repaired[field]
                
                if expected_type == 'integer' and isinstance(value, str):
                    try:
                        repaired[field] = int(value)
                    except ValueError:
                        pass
                elif expected_type == 'string' and not isinstance(value, str):
                    repaired[field] = str(value)
                elif expected_type == 'boolean' and isinstance(value, str):
                    repaired[field] = value.lower() in ('true', 'yes', '1')
        
        # Fix enum values
        for field, field_schema in properties.items():
            if field in repaired and 'enum' in field_schema:
                if repaired[field] not in field_schema['enum']:
                    # Try case-insensitive match
                    for valid in field_schema['enum']:
                        if str(repaired[field]).lower() == valid.lower():
                            repaired[field] = valid
                            break
                    else:
                        repaired[field] = field_schema['enum'][0]  # Default to first
        
        return repaired
    
    def _generate_retry_prompt(self, result: dict) -> str:
        """Generate a more explicit prompt for retry."""
        errors = result.get('errors', [])
        return (f"Your previous response had errors: {'; '.join(errors)}. "
                f"Please respond with valid JSON matching this schema: "
                f"{json.dumps(self.schema)}")
    
    def report(self):
        """Print pipeline statistics."""
        print(f"\nPipeline Statistics:")
        print(f"  Total attempts: {self.stats['attempts']}")
        print(f"  Direct successes: {self.stats['successes'] - self.stats['repairs']}")
        print(f"  Auto-repairs: {self.stats['repairs']}")
        print(f"  Total success rate: {self.stats['successes']/max(1,self.stats['attempts']):.1%}")


# Demo the pipeline
pipeline = StructuredOutputPipeline(product_review_schema)

test_cases = [
    # Perfect output
    '{"product_name": "Phone Case", "rating": 4, "sentiment": "positive", "key_points": ["Durable", "Nice color"], "recommend": true}',
    # Missing fields (repairable)
    '{"product_name": "Laptop Stand", "rating": 5}',
    # Wrong enum (repairable)
    '{"product_name": "Charger", "rating": 3, "sentiment": "POSITIVE", "key_points": ["Fast charging"], "recommend": true}',
    # Type mismatch (repairable)
    '{"product_name": "Cable", "rating": "4", "sentiment": "neutral", "key_points": ["Good length"], "recommend": "yes"}',
    # Invalid JSON
    'The review analysis shows rating of 3',
]

print("Structured Output Pipeline Demo")
print("=" * 60)

for i, output in enumerate(test_cases):
    result = pipeline.process(output)
    print(f"\nTest {i+1}: [{result['status'].upper()}] via {result['method']}")
    print(f"  Input: {output[:70]}..." if len(output) > 70 else f"  Input: {output}")
    if result['status'] != 'failed':
        print(f"  Output: {json.dumps(result['data'])}")
    else:
        print(f"  Errors: {result['errors']}")

pipeline.report()

---

## Summary & Key Takeaways

| Concept | Key Insight |
|---------|-------------|
| **JSON Mode** | Easy to use but doesn't guarantee schema compliance |
| **Logit Biasing** | Fine-grained token control; good for simple constraints |
| **Function Calling** | Best for tool-use patterns; structured and type-safe |
| **Grammar Constraints** | Highest reliability; guarantees structural validity |
| **Schema Validation** | Always validate; never trust raw LLM output in production |
| **Auto-Repair** | Can fix common issues without retrying (saves cost) |
| **Pipeline Approach** | Combine methods for maximum reliability |

### Production Best Practices

1. **Always validate** LLM outputs before downstream processing
2. **Use JSON mode** + explicit schema in the prompt
3. **Implement auto-repair** for common failure modes
4. **Retry with more context** if validation fails
5. **Monitor success rates** and alert on degradation
6. **Use grammar constraints** for critical, high-volume pipelines

---

## Exercises

### Exercise 1: Custom Schema Validator
Create a validator for a more complex schema (e.g., a multi-item invoice with line items, taxes, and totals).

### Exercise 2: Logit Bias Experiment
Using HuggingFace transformers, implement logit biasing to force a model to always start responses with specific words.

### Exercise 3: Function Calling Router
Build a system that classifies user intents and routes to the appropriate function. Include fallback handling.

### Exercise 4: Reliability Benchmark
Using a real LLM API, benchmark the actual reliability of free-form vs JSON mode vs function calling on 100 diverse prompts.

In [None]:
# ============================================================
# Exercise 1 Starter: Complex Invoice Schema
# ============================================================

invoice_schema = {
    "type": "object",
    "properties": {
        "invoice_number": {"type": "string", "pattern": "^INV-[0-9]{6}$"},
        "customer": {
            "type": "object",
            "properties": {
                "name": {"type": "string"},
                "email": {"type": "string", "format": "email"}
            },
            "required": ["name", "email"]
        },
        "line_items": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "description": {"type": "string"},
                    "quantity": {"type": "integer", "minimum": 1},
                    "unit_price": {"type": "number", "minimum": 0}
                },
                "required": ["description", "quantity", "unit_price"]
            },
            "minItems": 1
        },
        "tax_rate": {"type": "number", "minimum": 0, "maximum": 1},
        "total": {"type": "number"}
    },
    "required": ["invoice_number", "customer", "line_items", "tax_rate", "total"]
}

print("Invoice Schema:")
print(json.dumps(invoice_schema, indent=2))

# Your task: Build a validator and auto-repair system for this schema
# Hint: Add validation that total = sum(quantity * unit_price) * (1 + tax_rate)

print("\nComplete the exercise by building a validator for this schema!")