|
1 | 1 | import uuid |
2 | 2 | import logging |
3 | | -from typing import Any, Dict, Optional |
| 3 | +from typing import Any, Callable, Dict, Optional |
| 4 | +from functools import wraps |
4 | 5 |
|
5 | 6 | from asgi_correlation_id import correlation_id |
6 | 7 | from langfuse import Langfuse |
7 | 8 | from langfuse.client import StatefulGenerationClient, StatefulTraceClient |
| 9 | +from app.models.llm import CompletionConfig, QueryParams, LLMCallResponse |
8 | 10 |
|
9 | 11 | logger = logging.getLogger(__name__) |
10 | 12 |
|
@@ -107,3 +109,110 @@ def log_error(self, error_message: str, response_id: Optional[str] = None): |
107 | 109 |
|
108 | 110 | def flush(self): |
109 | 111 | self.langfuse.flush() |
| 112 | + |
| 113 | + |
| 114 | +def observe_llm_execution( |
| 115 | + session_id: str | None = None, |
| 116 | + credentials: dict | None = None, |
| 117 | +): |
| 118 | + """Decorator to add Langfuse observability to LLM provider execute methods. |
| 119 | +
|
| 120 | + Args: |
| 121 | + credentials: Langfuse credentials with public_key, secret_key, and host |
| 122 | + session_id: Session ID for grouping traces (conversation_id) |
| 123 | +
|
| 124 | + Usage: |
| 125 | + decorated_execute = observe_llm_execution( |
| 126 | + credentials=langfuse_creds, |
| 127 | + session_id=conversation_id |
| 128 | + )(provider_instance.execute) |
| 129 | + """ |
| 130 | + |
| 131 | + def decorator(func: Callable) -> Callable: |
| 132 | + @wraps(func) |
| 133 | + def wrapper(completion_config: CompletionConfig, query: QueryParams, **kwargs): |
| 134 | + # Skip observability if no credentials provided |
| 135 | + if not credentials: |
| 136 | + logger.info("[Langfuse] No credentials - skipping observability") |
| 137 | + return func(completion_config, query, **kwargs) |
| 138 | + |
| 139 | + try: |
| 140 | + langfuse = Langfuse( |
| 141 | + public_key=credentials.get("public_key"), |
| 142 | + secret_key=credentials.get("secret_key"), |
| 143 | + host=credentials.get("host"), |
| 144 | + ) |
| 145 | + except Exception as e: |
| 146 | + logger.warning(f"[Langfuse] Failed to initialize client: {e}") |
| 147 | + return func(completion_config, query, **kwargs) |
| 148 | + |
| 149 | + trace_metadata = { |
| 150 | + "provider": completion_config.provider, |
| 151 | + } |
| 152 | + |
| 153 | + if query.conversation and query.conversation.id: |
| 154 | + trace_metadata["conversation_id"] = query.conversation.id |
| 155 | + |
| 156 | + trace = langfuse.trace( |
| 157 | + name="unified-llm-call", |
| 158 | + input=query.input, |
| 159 | + metadata=trace_metadata, |
| 160 | + tags=[completion_config.provider], |
| 161 | + ) |
| 162 | + |
| 163 | + generation = trace.generation( |
| 164 | + name=f"{completion_config.provider}-completion", |
| 165 | + input=query.input, |
| 166 | + model=completion_config.params.get("model"), |
| 167 | + ) |
| 168 | + |
| 169 | + try: |
| 170 | + # Execute the actual LLM call |
| 171 | + response: LLMCallResponse | None |
| 172 | + error: str | None |
| 173 | + response, error = func(completion_config, query, **kwargs) |
| 174 | + |
| 175 | + if response: |
| 176 | + generation.end( |
| 177 | + output={ |
| 178 | + "status": "success", |
| 179 | + "output": response.response.output.text, |
| 180 | + }, |
| 181 | + usage_details={ |
| 182 | + "input": response.usage.input_tokens, |
| 183 | + "output": response.usage.output_tokens, |
| 184 | + }, |
| 185 | + model=response.response.model, |
| 186 | + ) |
| 187 | + |
| 188 | + trace.update( |
| 189 | + output={ |
| 190 | + "status": "success", |
| 191 | + "output": response.response.output.text, |
| 192 | + }, |
| 193 | + session_id=session_id or response.response.conversation_id, |
| 194 | + ) |
| 195 | + else: |
| 196 | + error_msg = error or "Unknown error" |
| 197 | + generation.end(output={"error": error_msg}) |
| 198 | + trace.update( |
| 199 | + output={"status": "failure", "error": error_msg}, |
| 200 | + session_id=session_id, |
| 201 | + ) |
| 202 | + |
| 203 | + langfuse.flush() |
| 204 | + return response, error |
| 205 | + |
| 206 | + except Exception as e: |
| 207 | + error_msg = str(e) |
| 208 | + generation.end(output={"error": error_msg}) |
| 209 | + trace.update( |
| 210 | + output={"status": "failure", "error": error_msg}, |
| 211 | + session_id=session_id, |
| 212 | + ) |
| 213 | + langfuse.flush() |
| 214 | + raise |
| 215 | + |
| 216 | + return wrapper |
| 217 | + |
| 218 | + return decorator |
0 commit comments