Skip to content

基于历史工具调用轨迹来优化 Agent #139

@Pines-Cheng

Description

@Pines-Cheng

1. 收集和结构化历史轨迹数据

from dataclasses import dataclass
from typing import List, Dict, Any, Optional
from datetime import datetime
import json

@dataclass
class ToolCall:
    """单次工具调用记录"""
    tool_name: str
    parameters: Dict[str, Any]
    result: Any
    success: bool
    error_message: Optional[str] = None
    timestamp: datetime = None
    duration_ms: float = 0

@dataclass
class AgentTrace:
    """完整的 Agent 执行轨迹"""
    session_id: str
    user_query: str
    tool_calls: List[ToolCall]
    final_response: str
    total_duration_ms: float
    success: bool
    metadata: Dict[str, Any] = None

2. 轨迹分析和模式识别

class TraceAnalyzer:
    def __init__(self, traces: List[AgentTrace]):
        self.traces = traces
    
    def analyze_patterns(self):
        """分析常见的调用模式"""
        patterns = {
            'common_sequences': self.find_common_sequences(),
            'error_patterns': self.find_error_patterns(),
            'performance_bottlenecks': self.find_bottlenecks(),
            'redundant_calls': self.find_redundant_calls()
        }
        return patterns
    
    def find_common_sequences(self):
        """识别常见的工具调用序列"""
        sequences = {}
        for trace in self.traces:
            seq = tuple(call.tool_name for call in trace.tool_calls)
            sequences[seq] = sequences.get(seq, 0) + 1
        return sorted(sequences.items(), key=lambda x: x[1], reverse=True)
    
    def find_error_patterns(self):
        """识别导致错误的模式"""
        error_contexts = []
        for trace in self.traces:
            for i, call in enumerate(trace.tool_calls):
                if not call.success:
                    context = {
                        'tool': call.tool_name,
                        'error': call.error_message,
                        'previous_tools': [c.tool_name for c in trace.tool_calls[:i]],
                        'parameters': call.parameters
                    }
                    error_contexts.append(context)
        return error_contexts

3. 构建优化提示词

class AgentOptimizer:
    def __init__(self, analyzer: TraceAnalyzer):
        self.analyzer = analyzer
    
    def generate_optimization_prompt(self, current_query: str) -> str:
        """基于历史轨迹生成优化的提示词"""
        patterns = self.analyzer.analyze_patterns()
        
        prompt = f"""
你是一个智能助手,需要根据用户查询选择合适的工具来完成任务。

## 历史经验总结:

### 1. 常见成功模式:
{self._format_success_patterns(patterns['common_sequences'][:5])}

### 2. 需要避免的错误:
{self._format_error_patterns(patterns['error_patterns'][:5])}

### 3. 性能优化建议:
{self._format_performance_tips(patterns['performance_bottlenecks'])}

### 4. 相似查询的最佳实践:
{self._find_similar_queries(current_query)}

## 当前任务:
用户查询:{current_query}

请基于以上历史经验,选择最优的工具调用策略。
"""
        return prompt
    
    def _find_similar_queries(self, query: str) -> str:
        """找到相似的历史查询及其成功策略"""
        # 这里可以使用向量相似度或其他方法
        similar_traces = self._get_similar_traces(query, top_k=3)
        
        tips = []
        for trace in similar_traces:
            if trace.success:
                tools = " -> ".join([call.tool_name for call in trace.tool_calls])
                tips.append(f"- 查询:'{trace.user_query}' 成功使用:{tools}")
        
        return "\n".join(tips)

4. 实时优化策略

class AdaptiveAgent:
    def __init__(self, base_agent, trace_history: List[AgentTrace]):
        self.base_agent = base_agent
        self.optimizer = AgentOptimizer(TraceAnalyzer(trace_history))
        self.current_trace = None
    
    def execute(self, query: str) -> str:
        """执行优化后的 Agent"""
        # 1. 基于历史生成优化提示
        optimized_prompt = self.optimizer.generate_optimization_prompt(query)
        
        # 2. 创建新的执行轨迹
        self.current_trace = AgentTrace(
            session_id=self._generate_session_id(),
            user_query=query,
            tool_calls=[],
            final_response="",
            total_duration_ms=0,
            success=True
        )
        
        # 3. 执行 Agent 并记录轨迹
        try:
            response = self.base_agent.run(
                query=query,
                system_prompt=optimized_prompt,
                callbacks=[self._trace_callback]
            )
            self.current_trace.final_response = response
            return response
        except Exception as e:
            self.current_trace.success = False
            raise e
        finally:
            # 4. 保存轨迹供未来优化使用
            self._save_trace(self.current_trace)
    
    def _trace_callback(self, tool_name, params, result):
        """记录工具调用"""
        tool_call = ToolCall(
            tool_name=tool_name,
            parameters=params,
            result=result,
            success=True,
            timestamp=datetime.now()
        )
        self.current_trace.tool_calls.append(tool_call)

5. 高级优化技术

class AdvancedOptimizer:
    def __init__(self, traces: List[AgentTrace]):
        self.traces = traces
    
    def learn_tool_dependencies(self):
        """学习工具之间的依赖关系"""
        dependencies = {}
        for trace in self.traces:
            for i in range(len(trace.tool_calls) - 1):
                current = trace.tool_calls[i].tool_name
                next_tool = trace.tool_calls[i + 1].tool_name
                
                if current not in dependencies:
                    dependencies[current] = {}
                dependencies[current][next_tool] = \
                    dependencies[current].get(next_tool, 0) + 1
        
        return dependencies
    
    def generate_meta_prompt(self):
        """生成元提示词,教会 Agent 如何更好地使用工具"""
        deps = self.learn_tool_dependencies()
        
        rules = []
        for tool, next_tools in deps.items():
            most_common = max(next_tools.items(), key=lambda x: x[1])
            if most_common[1] > 5:  # 阈值
                rules.append(f"- 使用 {tool} 后,通常需要使用 {most_common[0]}")
        
        return "\n".join(rules)

6. 实际使用示例

# 初始化
trace_history = load_traces_from_database()  # 加载历史轨迹
agent = AdaptiveAgent(base_agent, trace_history)

# 使用优化后的 Agent
response = agent.execute("帮我分析最近一周的销售数据")

# 持续学习
if response_quality_is_good:
    # 好的轨迹会被保存并用于未来的优化
    update_optimization_model(agent.current_trace)

这种方法的优势:

  1. 自适应学习:从成功和失败的经验中学习
  2. 避免重复错误:识别并避免历史上的错误模式
  3. 性能优化:减少不必要的工具调用
  4. 上下文感知:根据相似查询的经验做出更好的决策

你可以根据具体需求调整这个框架,比如添加更复杂的相似度计算、使用机器学习模型来预测最佳工具序列等。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions