# Goals (目標)
 1. Implement pydantic schema validation for LLM JSON outputs
 2. Build automatic retry with repair mechanisms (90%+ success rate)
 3. Handle common JSON parsing errors and malformed structures
 4. Create robust tool-calling with graceful degradation
 5. Demonstrate JSON repair strategies for Chinese LLM outputs

# Prerequisites (前置需求)
 - nb20_function_calling_format.ipynb completed
 - Basic understanding of pydantic BaseModel
 - JSON parsing and error handling concepts

In [None]:
# Schema Validation & Retry Mechanisms
# nb27_schema_validation_retry.ipynb
# 結構化輸出驗證與自動修復機制

# Cell1:  Shared Cache Bootstrap
import os, pathlib, torch
import sys
from datetime import datetime

# Shared cache configuration (複製到每本 notebook)
AI_CACHE_ROOT = os.getenv("AI_CACHE_ROOT", "../ai_warehouse/cache")

for k, v in {
    "HF_HOME": f"{AI_CACHE_ROOT}/hf",
    "TRANSFORMERS_CACHE": f"{AI_CACHE_ROOT}/hf/transformers",
    "HF_DATASETS_CACHE": f"{AI_CACHE_ROOT}/hf/datasets",
    "HUGGINGFACE_HUB_CACHE": f"{AI_CACHE_ROOT}/hf/hub",
    "TORCH_HOME": f"{AI_CACHE_ROOT}/torch",
}.items():
    os.environ[k] = v
    pathlib.Path(v).mkdir(parents=True, exist_ok=True)
print("[Cache]", AI_CACHE_ROOT, "| GPU:", torch.cuda.is_available())

In [None]:
# ============================================================================
# Cell 2: Dependencies and Imports
# ============================================================================

import json
import re
import time
from typing import Dict, List, Any, Optional, Union
from dataclasses import dataclass
from pydantic import BaseModel, Field, ValidationError, validator
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# ============================================================================
# Cell 3: Tool Schema Definitions (Pydantic Models)
# ============================================================================


class CalculatorArgs(BaseModel):
    """Safe calculator tool arguments"""

    expression: str = Field(
        ...,
        description="Mathematical expression to evaluate",
        min_length=1,
        max_length=200,
    )

    @validator("expression")
    def validate_expression(cls, v):
        # Allow only safe characters for math expressions
        allowed_chars = set("0123456789+-*/().,e ")
        if not all(c in allowed_chars for c in v):
            raise ValueError("Expression contains invalid characters")
        return v.strip()


class WebSearchArgs(BaseModel):
    """Web search tool arguments"""

    query: str = Field(..., description="Search query", min_length=1, max_length=500)
    max_results: int = Field(
        default=5, ge=1, le=20, description="Number of results to return"
    )


class FileReadArgs(BaseModel):
    """File reading tool arguments"""

    filepath: str = Field(..., description="Path to file to read", min_length=1)
    max_lines: int = Field(
        default=100, ge=1, le=1000, description="Maximum lines to read"
    )


class ToolCall(BaseModel):
    """Complete tool call structure"""

    tool: str = Field(..., description="Tool name to call")
    args: Dict[str, Any] = Field(..., description="Tool arguments")
    reasoning: Optional[str] = Field(None, description="Why this tool is needed")


# Tool registry mapping
TOOL_SCHEMAS = {
    "calculator": CalculatorArgs,
    "web_search": WebSearchArgs,
    "file_read": FileReadArgs,
}

In [None]:
# ============================================================================
# Cell 4: JSON Repair Utilities
# ============================================================================


class JSONRepair:
    """Utilities for repairing common JSON formatting issues"""

    @staticmethod
    def extract_json_from_text(text: str) -> str:
        """Extract JSON from text that may contain extra content"""
        # Look for JSON-like patterns
        patterns = [
            r"\{[^{}]*\}",  # Simple object
            r"\{.*?\}",  # Greedy object match
            r"```json\s*(\{.*?\})\s*```",  # Markdown code block
            r"```\s*(\{.*?\})\s*```",  # Code block without json tag
        ]

        for pattern in patterns:
            matches = re.findall(pattern, text, re.DOTALL)
            if matches:
                return matches[0] if isinstance(matches[0], str) else matches[0]

        return text.strip()

    @staticmethod
    def fix_common_issues(json_str: str) -> str:
        """Fix common JSON formatting issues"""
        # Remove trailing commas
        json_str = re.sub(r",(\s*[}\]])", r"\1", json_str)

        # Fix unquoted keys (common in LLM outputs)
        json_str = re.sub(r"(\w+):", r'"\1":', json_str)

        # Fix single quotes to double quotes
        json_str = json_str.replace("'", '"')

        # Remove extra whitespace
        json_str = re.sub(r"\s+", " ", json_str).strip()

        # Ensure proper quotes around string values
        json_str = re.sub(r':\s*([^",{\[\]}\s]+)(?=\s*[,}])', r': "\1"', json_str)

        return json_str

    @staticmethod
    def attempt_repair(broken_json: str) -> Optional[Dict]:
        """Attempt multiple repair strategies"""
        repair_strategies = [
            lambda x: x,  # Try as-is first
            JSONRepair.extract_json_from_text,
            JSONRepair.fix_common_issues,
            lambda x: JSONRepair.fix_common_issues(
                JSONRepair.extract_json_from_text(x)
            ),
        ]

        for strategy in repair_strategies:
            try:
                repaired = strategy(broken_json)
                return json.loads(repaired)
            except (json.JSONDecodeError, TypeError, AttributeError):
                continue

        return None

In [None]:
# ============================================================================
# Cell 5: Validation and Retry Logic
# ============================================================================


@dataclass
class ValidationResult:
    """Result of validation attempt"""

    success: bool
    data: Optional[Dict] = None
    errors: List[str] = None
    repaired: bool = False
    attempts: int = 1


class SchemaValidator:
    """Handles schema validation with retry and repair"""

    def __init__(self, max_retries: int = 3, repair_enabled: bool = True):
        self.max_retries = max_retries
        self.repair_enabled = repair_enabled
        self.repair_stats = {"attempts": 0, "successes": 0}

    def validate_tool_call(
        self, raw_output: str, allow_repair: bool = True
    ) -> ValidationResult:
        """Validate and potentially repair tool call output"""
        errors = []

        for attempt in range(1, self.max_retries + 1):
            try:
                # Try direct JSON parsing first
                if attempt == 1:
                    data = json.loads(raw_output)
                else:
                    # Use repair on subsequent attempts
                    if self.repair_enabled and allow_repair:
                        self.repair_stats["attempts"] += 1
                        data = JSONRepair.attempt_repair(raw_output)
                        if data is None:
                            errors.append(f"Attempt {attempt}: JSON repair failed")
                            continue
                    else:
                        errors.append(f"Attempt {attempt}: Repair disabled, skipping")
                        continue

                # Validate against ToolCall schema
                tool_call = ToolCall(**data)

                # Validate specific tool args
                if tool_call.tool in TOOL_SCHEMAS:
                    tool_schema = TOOL_SCHEMAS[tool_call.tool]
                    validated_args = tool_schema(**tool_call.args)
                    tool_call.args = validated_args.dict()

                # Success!
                if attempt > 1:
                    self.repair_stats["successes"] += 1

                return ValidationResult(
                    success=True,
                    data=tool_call.dict(),
                    repaired=(attempt > 1),
                    attempts=attempt,
                )

            except json.JSONDecodeError as e:
                errors.append(f"Attempt {attempt}: JSON decode error: {str(e)}")
            except ValidationError as e:
                errors.append(f"Attempt {attempt}: Schema validation error: {str(e)}")
            except Exception as e:
                errors.append(f"Attempt {attempt}: Unexpected error: {str(e)}")

        return ValidationResult(success=False, errors=errors, attempts=self.max_retries)

    def get_repair_success_rate(self) -> float:
        """Get current repair success rate"""
        if self.repair_stats["attempts"] == 0:
            return 0.0
        return self.repair_stats["successes"] / self.repair_stats["attempts"]

In [None]:
# ============================================================================
# Cell 6: LLM Integration with Retry
# ============================================================================


class LLMWithRetry:
    """LLM adapter with built-in validation and retry"""

    def __init__(self, model_id: str = "Qwen/Qwen2.5-7B-Instruct"):
        print(f"Loading model: {model_id}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype=torch.float16,
            load_in_4bit=True,  # Low VRAM option
        )
        self.validator = SchemaValidator(max_retries=3)

        # Add padding token if missing
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def generate_tool_call(
        self, user_query: str, available_tools: List[str]
    ) -> ValidationResult:
        """Generate and validate tool call with automatic retry"""

        # Create prompt for tool calling
        tools_desc = "\n".join(
            [f"- {tool}: {TOOL_SCHEMAS[tool].__doc__}" for tool in available_tools]
        )

        prompt = f"""You are a helpful assistant that can call tools. Given a user query, output a JSON object with tool name and arguments.

Available tools:
{tools_desc}

Output format (JSON only):
{{"tool": "tool_name", "args": {{"param": "value"}}, "reasoning": "why this tool"}}

User query: {user_query}

JSON response:"""

        max_attempts = 3
        for attempt in range(1, max_attempts + 1):
            try:
                # Generate response
                inputs = self.tokenizer(
                    prompt, return_tensors="pt", truncation=True, max_length=2048
                )
                inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

                with torch.no_grad():
                    outputs = self.model.generate(
                        **inputs,
                        max_new_tokens=200,
                        temperature=0.3,
                        do_sample=True,
                        pad_token_id=self.tokenizer.pad_token_id,
                    )

                # Decode response
                response = self.tokenizer.decode(
                    outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
                )
                response = response.strip()

                logger.info(f"Attempt {attempt} raw output: {response}")

                # Validate with repair
                result = self.validator.validate_tool_call(response, allow_repair=True)

                if result.success:
                    logger.info(f"✅ Validation successful on attempt {attempt}")
                    return result
                else:
                    logger.warning(f"❌ Attempt {attempt} failed: {result.errors}")
                    if attempt < max_attempts:
                        # Modify prompt for retry
                        prompt += f"\n\nPrevious attempt failed. Please output valid JSON only:"

            except Exception as e:
                logger.error(f"Generation error on attempt {attempt}: {str(e)}")

        # All attempts failed
        return ValidationResult(
            success=False,
            errors=[f"All {max_attempts} generation attempts failed"],
            attempts=max_attempts,
        )

In [None]:
# ============================================================================
# Cell 7: Demo Functions and Tool Execution
# ============================================================================


def execute_tool_call(tool_call_data: Dict) -> Dict:
    """Execute validated tool call (mock implementation)"""
    tool_name = tool_call_data["tool"]
    args = tool_call_data["args"]

    if tool_name == "calculator":
        try:
            # Safe eval for demo (real implementation should use ast.literal_eval)
            result = eval(args["expression"])
            return {"status": "success", "result": result}
        except Exception as e:
            return {"status": "error", "error": str(e)}

    elif tool_name == "web_search":
        # Mock search results
        return {
            "status": "success",
            "results": [
                {"title": f"搜尋結果：{args['query']}", "url": "https://example.com"},
                {"title": f"相關文章：{args['query']}", "url": "https://example2.com"},
            ][: args.get("max_results", 5)],
        }

    elif tool_name == "file_read":
        return {
            "status": "success",
            "content": f"Mock file content from {args['filepath']} (max {args.get('max_lines', 100)} lines)",
        }

    else:
        return {"status": "error", "error": f"Unknown tool: {tool_name}"}


def demo_query_with_retry(llm: LLMWithRetry, query: str, tools: List[str]):
    """Demo complete flow: query -> generate -> validate -> execute"""
    print(f"\n🔍 Query: {query}")
    print(f"📋 Available tools: {tools}")

    # Generate and validate tool call
    result = llm.generate_tool_call(query, tools)

    if result.success:
        print(
            f"✅ Validation successful (attempts: {result.attempts}, repaired: {result.repaired})"
        )
        print(f"📝 Tool call: {json.dumps(result.data, ensure_ascii=False, indent=2)}")

        # Execute tool
        execution_result = execute_tool_call(result.data)
        print(
            f"🔧 Execution result: {json.dumps(execution_result, ensure_ascii=False, indent=2)}"
        )

        return result.data, execution_result
    else:
        print(f"❌ Validation failed after {result.attempts} attempts")
        for error in result.errors:
            print(f"   - {error}")
        return None, None

In [None]:
# ============================================================================
# Cell 8: MVP Test Cases
# ============================================================================

# Initialize LLM (low VRAM settings)
print("🚀 Initializing LLM with retry capabilities...")
llm = LLMWithRetry("Qwen/Qwen2.5-7B-Instruct")

# Test queries
test_queries = [
    ("計算 25 * 4 + 10 的結果", ["calculator"]),
    ("搜尋關於 RAG 檢索增強生成的資料", ["web_search"]),
    ("讀取 config.yaml 檔案的前 50 行", ["file_read"]),
    ("我想要計算 (100 + 200) / 3", ["calculator", "web_search"]),
]

results = []
for query, tools in test_queries:
    tool_call, execution = demo_query_with_retry(llm, query, tools)
    results.append(
        {
            "query": query,
            "success": tool_call is not None,
            "tool_call": tool_call,
            "execution": execution,
        }
    )
    time.sleep(1)  # Brief pause between requests

In [None]:
# ============================================================================
# Cell 9: JSON Repair Testing
# ============================================================================

print("\n🔧 Testing JSON repair capabilities...")

# Test cases with deliberately broken JSON
broken_json_examples = [
    '{"tool": "calculator", "args": {"expression": "2+3"}}',  # Valid (control)
    '{tool: "calculator", args: {expression: "2+3"}}',  # Unquoted keys
    "{'tool': 'calculator', 'args': {'expression': '2+3'}}",  # Single quotes
    '{"tool": "calculator", "args": {"expression": "2+3",},}',  # Trailing commas
    '```json\n{"tool": "web_search", "args": {"query": "test"}}\n```',  # Markdown
    '{"tool": "calculator", "args": {"expression": 2+3}}',  # Unquoted value
]

validator = SchemaValidator()
repair_test_results = []

for i, broken_json in enumerate(broken_json_examples):
    print(f"\n📝 Test case {i+1}: {broken_json}")
    result = validator.validate_tool_call(broken_json)
    repair_test_results.append(
        {
            "input": broken_json,
            "success": result.success,
            "repaired": result.repaired,
            "attempts": result.attempts,
        }
    )

    if result.success:
        print(f"   ✅ Success (repaired: {result.repaired})")
        print(f"   📋 Parsed: {json.dumps(result.data, ensure_ascii=False)}")
    else:
        print(
            f"   ❌ Failed: {result.errors[-1] if result.errors else 'Unknown error'}"
        )

# ============================================================================
# Cell 10: Smoke Test & Success Rate Analysis
# ============================================================================

# Calculate overall success rates
total_tests = len(results) + len(repair_test_results)
successful_tests = sum(1 for r in results if r["success"]) + sum(
    1 for r in repair_test_results if r["success"]
)
success_rate = successful_tests / total_tests if total_tests > 0 else 0

repair_success_rate = llm.validator.get_repair_success_rate()

print(f"\n📊 SMOKE TEST RESULTS")
print(f"═══════════════════════════════════════")
print(f"Overall success rate: {success_rate:.1%} ({successful_tests}/{total_tests})")
print(f"JSON repair success rate: {repair_success_rate:.1%}")
print(
    f"Tool execution tests: {len([r for r in results if r['success']])}/{len(results)} passed"
)
print(
    f"JSON repair tests: {len([r for r in repair_test_results if r['success']])}/{len(repair_test_results)} passed"
)

# Detailed repair statistics
repair_stats = llm.validator.repair_stats
print(f"\n🔧 Repair Statistics:")
print(f"   Repair attempts: {repair_stats['attempts']}")
print(f"   Repair successes: {repair_stats['successes']}")

# Assert minimum success rate for smoke test
assert success_rate >= 0.7, f"Success rate {success_rate:.1%} below minimum 70%"
assert (
    len([r for r in results if r["success"]]) >= 2
), "At least 2 tool executions should succeed"

print(f"\n✅ SMOKE TEST PASSED!")
print(f"   - Schema validation working")
print(f"   - JSON repair functional")
print(f"   - Tool execution pipeline complete")
print(f"   - Success rate: {success_rate:.1%} ≥ 70% ✓")

# ============================================================================
# When to use this notebook:
# ============================================================================
"""
📚 使用時機 (When to use this):

1. **LLM 結構化輸出驗證**: 當需要確保 LLM 輸出符合特定 JSON schema 時
2. **工具調用容錯**: 在 ReAct/Function-calling 中需要穩定的工具調用成功率
3. **JSON 格式修復**: 處理 LLM 常見的 JSON 格式錯誤（缺引號、逗號等）
4. **生產環境穩定性**: 需要 90%+ 成功率的結構化輸出場景
5. **多語言 LLM 適配**: 中文 LLM 在 JSON 輸出上的特殊處理需求

Key benefits:
- 自動 JSON 修復與重試機制
- Pydantic schema 驗證與型別安全
- 詳細的錯誤追蹤與統計
- 低 VRAM 友善的實作
- 90%+ 工具調用成功率目標
"""

In [None]:
class ToolCall(BaseModel):
    tool: str = Field(..., description="Tool name to call")
    args: Dict[str, Any] = Field(..., description="Tool arguments")
    reasoning: Optional[str] = Field(None, description="Why this tool is needed")