# Tool Call Decision Trace

Debug tool call outputs by extracting decision graphs from logits.

This notebook captures logits at every generation step to:
- Identify decision points (tool name, param names, param values)
- Calculate confidence at each decision
- Find the "weakest link" in tool call generation
- Detect potential hallucinations

In [None]:
"""Setup: Import dependencies and load model."""

import json
import re
from dataclasses import dataclass, field

import mlx.core as mx
from mlx_lm import load
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.text import Text

console = Console()

# Load model
print("Loading model...")
model, tokenizer = load("mlx-community/Qwen3-4B-4bit")
print("Model loaded!")

In [11]:
"""Data structures for tracing."""

@dataclass
class DecisionPoint:
    """A single token decision during generation."""
    position: int           # Token position in output
    token_id: int           # Selected token ID
    token_str: str          # Decoded token string
    decision_type: str      # "tool_name" | "param_name" | "param_value" | "structure" | "other"
    prob: float             # Probability of selected token
    alternatives: list      # Top-5 alternatives [(token_str, prob), ...]
    entropy: float          # Distribution entropy


@dataclass
class ToolCallTrace:
    """Complete trace of a tool call generation."""
    raw_output: str                           # Full generated text
    parsed_call: dict                         # {"name": "...", "arguments": {...}}
    decisions: list = field(default_factory=list)  # List of DecisionPoint
    weakest_link: DecisionPoint = None        # Lowest confidence decision
    total_confidence: float = 1.0             # Product of all decision probs


print("Dataclasses defined!")

Dataclasses defined!


In [3]:
"""ToolCallTracer: Generate tool calls while capturing logits."""

class ToolCallTracer:
    """Generate tool calls and capture full logit history."""
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.logit_history = []  # Captured during generation
        self.token_history = []  # Generated tokens
        
    def generate_with_trace(self, prompt: str, max_tokens: int = 100, 
                           temperature: float = 0.0) -> ToolCallTrace:
        """Generate while capturing logits at every step."""
        self.logit_history = []
        self.token_history = []
        
        # Encode prompt
        tokens = mx.array(self.tokenizer.encode(prompt))
        prompt_len = len(tokens)
        
        # Generate token by token
        for _ in range(max_tokens):
            logits = self.model(tokens[None])[0, -1, :]
            
            # Store logits before any modification
            self.logit_history.append(logits)
            
            # Sample next token
            if temperature == 0:
                next_token = mx.argmax(logits).item()
            else:
                scaled_logits = logits / temperature
                next_token = mx.random.categorical(scaled_logits).item()
            
            self.token_history.append(next_token)
            
            # Check for EOS
            if next_token == self.tokenizer.eos_token_id:
                break
                
            tokens = mx.concatenate([tokens, mx.array([next_token])])
        
        # Decode output
        raw_output = self.tokenizer.decode(self.token_history)
        
        # Parse tool call
        parsed_call = self._parse_tool_call(raw_output)
        
        # Extract decision points
        decisions = self._extract_decisions(raw_output, parsed_call)
        
        # Find weakest link and total confidence
        weakest_link = None
        total_confidence = 1.0
        
        for d in decisions:
            total_confidence *= d.prob
            if weakest_link is None or d.prob < weakest_link.prob:
                weakest_link = d
        
        return ToolCallTrace(
            raw_output=raw_output,
            parsed_call=parsed_call,
            decisions=decisions,
            weakest_link=weakest_link,
            total_confidence=total_confidence,
        )
    
    def _parse_tool_call(self, text: str) -> dict:
        """Extract tool call from <tool_call>...</tool_call> tags."""
        # Try to find tool_call tags
        match = re.search(r'<tool_call>\s*(.+?)\s*</tool_call>', text, re.DOTALL)
        if match:
            try:
                return json.loads(match.group(1))
            except json.JSONDecodeError:
                pass
        
        # Try to find raw JSON
        match = re.search(r'\{"name":\s*"([^"]+)".*?"arguments":\s*(\{[^}]+\})\}', text, re.DOTALL)
        if match:
            try:
                return {"name": match.group(1), "arguments": json.loads(match.group(2))}
            except json.JSONDecodeError:
                pass
        
        return {"name": None, "arguments": {}}
    
    def _classify_token(self, pos: int, token_str: str, raw_output: str, parsed_call: dict) -> str:
        """Classify token as tool_name, param_name, param_value, structure, or other."""
        # Get context up to this position
        context = self.tokenizer.decode(self.token_history[:pos+1])
        
        # Check if in tool name
        if parsed_call.get("name"):
            name = parsed_call["name"]
            if name in context and name not in self.tokenizer.decode(self.token_history[:pos]):
                if token_str.strip() in name:
                    return "tool_name"
        
        # Check if in param name
        for param_name in parsed_call.get("arguments", {}).keys():
            if param_name in context:
                prev_context = self.tokenizer.decode(self.token_history[:pos]) if pos > 0 else ""
                if param_name not in prev_context:
                    if token_str.strip() in param_name or param_name.startswith(token_str.strip()):
                        return "param_name"
        
        # Check if in param value
        for param_name, param_value in parsed_call.get("arguments", {}).items():
            param_str = str(param_value)
            if param_str in context:
                prev_context = self.tokenizer.decode(self.token_history[:pos]) if pos > 0 else ""
                if param_str not in prev_context:
                    if token_str.strip() in param_str:
                        return "param_value"
        
        # Check for JSON structure tokens
        structure_tokens = ['{', '}', '[', ']', ':', ',', '"', '<', '>', '/']
        if any(s in token_str for s in structure_tokens):
            return "structure"
        
        return "other"
    
    def _extract_decisions(self, raw_output: str, parsed_call: dict) -> list:
        """Extract decision points from logit history."""
        decisions = []
        
        for pos, (token_id, logits) in enumerate(zip(self.token_history, self.logit_history)):
            # Decode token
            token_str = self.tokenizer.decode([token_id])
            
            # Compute probabilities
            probs = mx.softmax(logits)
            prob = probs[token_id].item()
            
            # Compute entropy
            log_probs = mx.log(probs + 1e-10)
            entropy = -mx.sum(probs * log_probs).item()
            
            # Get top-5 alternatives
            top_indices = mx.argsort(probs)[-5:][::-1].tolist()
            alternatives = []
            for idx in top_indices:
                alt_str = self.tokenizer.decode([idx])
                alt_prob = probs[idx].item()
                alternatives.append((alt_str, alt_prob))
            
            # Classify token
            decision_type = self._classify_token(pos, token_str, raw_output, parsed_call)
            
            decisions.append(DecisionPoint(
                position=pos,
                token_id=token_id,
                token_str=token_str,
                decision_type=decision_type,
                prob=prob,
                alternatives=alternatives,
                entropy=entropy,
            ))
        
        return decisions


print("ToolCallTracer defined!")

ToolCallTracer defined!


In [4]:
"""ToolCallVerifier: Verify tool calls using decision trace."""

class ToolCallVerifier:
    """Verify tool calls and detect potential issues."""
    
    def __init__(self, allowed_tools: list, user_input: str = ""):
        self.allowed_tools = allowed_tools
        self.user_input = user_input.lower()
    
    def verify(self, trace: ToolCallTrace) -> dict:
        """Verify a tool call trace and return issues."""
        issues = []
        
        # Check if tool is allowed
        tool_name = trace.parsed_call.get("name")
        valid_tool = tool_name in self.allowed_tools if tool_name else False
        if not valid_tool:
            issues.append(f"Unknown tool: {tool_name}")
        
        # Check confidence threshold
        confidence_ok = trace.weakest_link is None or trace.weakest_link.prob > 0.5
        if not confidence_ok:
            issues.append(f"Low confidence at: {trace.weakest_link.token_str!r} ({trace.weakest_link.prob:.2%})")
        
        # Check for hallucination risk
        hallucination_risk, hallucinated_values = self._check_hallucination(trace)
        if hallucination_risk:
            for val in hallucinated_values:
                issues.append(f"Potential hallucination: {val!r} not in user input")
        
        # Check for low-confidence param values
        low_conf_params = []
        for d in trace.decisions:
            if d.decision_type == "param_value" and d.prob < 0.7:
                low_conf_params.append((d.token_str, d.prob))
        if low_conf_params:
            issues.append(f"Low-confidence param values: {low_conf_params}")
        
        return {
            "valid_tool": valid_tool,
            "confidence_ok": confidence_ok,
            "hallucination_risk": hallucination_risk,
            "total_confidence": trace.total_confidence,
            "weakest_prob": trace.weakest_link.prob if trace.weakest_link else 1.0,
            "issues": issues,
        }
    
    def _check_hallucination(self, trace: ToolCallTrace) -> tuple:
        """Check if param values appear in user input."""
        if not self.user_input:
            return False, []
        
        hallucinated = []
        for param_name, param_value in trace.parsed_call.get("arguments", {}).items():
            param_str = str(param_value).lower()
            # Skip short values that might be defaults
            if len(param_str) < 3:
                continue
            if param_str not in self.user_input:
                hallucinated.append(f"{param_name}={param_value}")
        
        return len(hallucinated) > 0, hallucinated


print("ToolCallVerifier defined!")

ToolCallVerifier defined!


In [5]:
"""Visualization: Color-coded trace display."""

def get_confidence_color(prob: float) -> str:
    """Get color based on confidence level."""
    if prob >= 0.8:
        return "green"
    elif prob >= 0.5:
        return "yellow"
    else:
        return "red"


def print_trace(trace: ToolCallTrace):
    """Print trace with color-coded confidence."""
    console = Console()
    
    # Header
    console.print(Panel("[bold]Tool Call Decision Trace[/bold]", expand=False))
    
    # Parsed call
    console.print(f"\n[bold cyan]Parsed Tool Call:[/bold cyan]")
    console.print(f"  Name: {trace.parsed_call.get('name', 'N/A')}")
    console.print(f"  Arguments: {trace.parsed_call.get('arguments', {})}")
    
    # Summary stats
    console.print(f"\n[bold cyan]Confidence Summary:[/bold cyan]")
    console.print(f"  Total confidence: {trace.total_confidence:.4f}")
    if trace.weakest_link:
        color = get_confidence_color(trace.weakest_link.prob)
        console.print(f"  Weakest link: [{color}]{trace.weakest_link.token_str!r}[/{color}] "
                     f"(prob: {trace.weakest_link.prob:.2%}, type: {trace.weakest_link.decision_type})")
    
    # Color-coded output
    console.print(f"\n[bold cyan]Token-by-Token Output:[/bold cyan]")
    
    output_text = Text()
    for d in trace.decisions:
        color = get_confidence_color(d.prob)
        output_text.append(d.token_str, style=color)
    
    console.print(output_text)
    
    # Decision table
    console.print(f"\n[bold cyan]Decision Points:[/bold cyan]")
    
    table = Table()
    table.add_column("Pos", justify="right", style="dim")
    table.add_column("Token", style="cyan")
    table.add_column("Type", style="magenta")
    table.add_column("Prob", justify="right")
    table.add_column("Entropy", justify="right", style="dim")
    table.add_column("Top Alternatives")
    
    for d in trace.decisions:
        color = get_confidence_color(d.prob)
        prob_str = f"[{color}]{d.prob:.2%}[/{color}]"
        
        # Format alternatives (skip the selected one)
        alts = [f"{t!r}:{p:.1%}" for t, p in d.alternatives if t != d.token_str][:3]
        alts_str = ", ".join(alts) if alts else "-"
        
        table.add_row(
            str(d.position),
            repr(d.token_str),
            d.decision_type,
            prob_str,
            f"{d.entropy:.2f}",
            alts_str,
        )
    
    console.print(table)


def print_verification(result: dict):
    """Print verification results."""
    console = Console()
    
    console.print(f"\n[bold cyan]Verification Results:[/bold cyan]")
    
    status_color = "green" if not result["issues"] else "red"
    console.print(f"  Valid tool: {'Yes' if result['valid_tool'] else 'No'}")
    console.print(f"  Confidence OK: {'Yes' if result['confidence_ok'] else 'No'}")
    console.print(f"  Hallucination risk: {'Yes' if result['hallucination_risk'] else 'No'}")
    console.print(f"  Total confidence: {result['total_confidence']:.4f}")
    
    if result["issues"]:
        console.print(f"\n[bold red]Issues:[/bold red]")
        for issue in result["issues"]:
            console.print(f"  - {issue}")
    else:
        console.print(f"\n[bold green]No issues found![/bold green]")


print("Visualization functions defined!")

Visualization functions defined!


In [17]:
"""Sample Prompt: Eney Assistant format with weather_forecast tool."""

SYSTEM_PROMPT = """# Role
You are **Eney** â€” a macOS Assistant built by MacPaw.

# Tools
## weather_forecast
Get weather forecast for a location.
Parameters:
- location (string, required): City name
- days (integer, optional): Number of days (1-7)
## climate_forecast
Get climate forecast for a region.
Parameters:
- location (string, required): City name
- days (integer, optional): Number of days (1-7)


# Response Format
When you need to call a tool, output it in this format:
<tool_call>{"name": "tool_name", "arguments": {"param": "value"}}</tool_call>"""

USER_MESSAGE = "Is it raining in Paris?"

# Format with chat template (Qwen3 format)
PROMPT = f"""<|im_start|>system
{SYSTEM_PROMPT}<|im_end|>
<|im_start|>user
{USER_MESSAGE}<|im_end|>
<|im_start|>assistant
<think>
</think>
"""

ALLOWED_TOOLS = ["weather_forecast"]

print("Prompt configured!")
print(f"User message: {USER_MESSAGE!r}")

Prompt configured!
User message: 'Is it raining in Paris?'


In [None]:
"""Demo: Run tracer and visualize."""

# Create tracer
tracer = ToolCallTracer(model, tokenizer)

# Generate with trace
print("Generating tool call...")
trace = tracer.generate_with_trace(PROMPT, max_tokens=100, temperature=0.0)

# Print trace
print_trace(trace)

Generating tool call...


In [14]:
"""Demo: Verify the tool call."""

# Create verifier
verifier = ToolCallVerifier(ALLOWED_TOOLS, user_input=USER_MESSAGE)

# Verify trace
result = verifier.verify(trace)

# Print results
print_verification(result)

In [15]:
"""Analysis: Compare different temperatures."""

temperatures = [0.0, 0.3, 0.7]

for temp in temperatures:
    print(f"\n{'='*60}")
    print(f"Temperature: {temp}")
    print('='*60)
    
    trace = tracer.generate_with_trace(PROMPT, max_tokens=100, temperature=temp)
    
    print(f"\nOutput: {trace.raw_output[:200]}...")
    print(f"Parsed: {trace.parsed_call}")
    print(f"Total confidence: {trace.total_confidence:.4f}")
    if trace.weakest_link:
        print(f"Weakest: {trace.weakest_link.token_str!r} ({trace.weakest_link.prob:.2%})")
    
    result = verifier.verify(trace)
    print(f"Issues: {result['issues'] if result['issues'] else 'None'}")


Temperature: 0.0

Output: <tool_call>{"name": "weather_forecast", "arguments": {"location": "London", "days": 1}}</tool_call><|im_end|>...
Parsed: {'name': 'weather_forecast', 'arguments': {'location': 'London', 'days': 1}}
Total confidence: 0.9537
Weakest: '<tool_call>' (96.88%)
Issues: None

Temperature: 0.3

Output: <tool_call>{"name": "weather_forecast", "arguments": {"location": "London", "days": 1}}</tool_call><|im_end|>...
Parsed: {'name': 'weather_forecast', 'arguments': {'location': 'London', 'days': 1}}
Total confidence: 0.9537
Weakest: '<tool_call>' (96.88%)
Issues: None

Temperature: 0.7

Output: <tool_call>{"name": "weather_forecast", "arguments": {"location": "London", "days": 1}}</tool_call><|im_end|>...
Parsed: {'name': 'weather_forecast', 'arguments': {'location': 'London', 'days': 1}}
Total confidence: 0.9537
Weakest: '<tool_call>' (96.88%)
Issues: None


In [16]:
"""Analysis: Compare different user prompts."""

test_prompts = [
    "Check the weather in London",
    "What's the weather like in Tokyo?",
    "Weather forecast for New York, 5 days",
    "Is it raining in Paris?",
]

for user_msg in test_prompts:
    prompt = f"""<|im_start|>system
{SYSTEM_PROMPT}<|im_end|>
<|im_start|>user
{user_msg}<|im_end|>
<|im_start|>assistant
<think>
</think>
"""
    
    print(f"\n{'='*60}")
    print(f"User: {user_msg}")
    print('='*60)
    
    trace = tracer.generate_with_trace(prompt, max_tokens=100, temperature=0.0)
    verifier = ToolCallVerifier(ALLOWED_TOOLS, user_input=user_msg)
    result = verifier.verify(trace)
    
    print(f"Parsed: {trace.parsed_call}")
    print(f"Total confidence: {trace.total_confidence:.4f}")
    print(f"Issues: {result['issues'] if result['issues'] else 'None'}")


User: Check the weather in London
Parsed: {'name': 'weather_forecast', 'arguments': {'location': 'London', 'days': 1}}
Total confidence: 0.9537
Issues: None

User: What's the weather like in Tokyo?
Parsed: {'name': 'weather_forecast', 'arguments': {'location': 'Tokyo', 'days': 1}}
Total confidence: 0.5634
Issues: None

User: Weather forecast for New York, 5 days
Parsed: {'name': 'weather_forecast', 'arguments': {'location': 'New York', 'days': 5}}
Total confidence: 1.0000
Issues: None

User: Is it raining in Paris?
Parsed: {'name': 'weather_forecast', 'arguments': {'location': 'Paris', 'days': 1}}
Total confidence: 0.2163
Issues: None


## Key Insights

1. **Confidence Colors**:
   - Green: High confidence (>80%)
   - Yellow: Medium confidence (50-80%)
   - Red: Low confidence (<50%)

2. **Decision Types**:
   - `tool_name`: Part of the tool name
   - `param_name`: Part of a parameter name
   - `param_value`: Part of a parameter value
   - `structure`: JSON/XML structure tokens
   - `other`: Other tokens

3. **Weakest Link**:
   - The token with lowest probability indicates where the model was most uncertain
   - Low confidence in `param_value` may indicate hallucination risk

4. **Verification**:
   - Checks tool validity, confidence threshold, and hallucination risk
   - Param values not in user input flagged as potential hallucinations