-
Notifications
You must be signed in to change notification settings - Fork 322
/
OrchestratorBase.py
93 lines (83 loc) · 3.46 KB
/
OrchestratorBase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import logging
from uuid import uuid4
from typing import List, Optional
from abc import ABC, abstractmethod
from ..loggers.ConversationLogger import ConversationLogger
from ..helpers.config.ConfigHelper import ConfigHelper
from ..parser.OutputParserTool import OutputParserTool
from ..tools.ContentSafetyChecker import ContentSafetyChecker
logger = logging.getLogger(__name__)
class OrchestratorBase(ABC):
def __init__(self) -> None:
super().__init__()
self.config = ConfigHelper.get_active_config_or_default()
self.message_id = str(uuid4())
self.tokens = {"prompt": 0, "completion": 0, "total": 0}
logger.debug(f"New message id: {self.message_id} with tokens {self.tokens}")
self.conversation_logger: ConversationLogger = ConversationLogger()
self.content_safety_checker = ContentSafetyChecker()
self.output_parser = OutputParserTool()
def log_tokens(self, prompt_tokens, completion_tokens):
self.tokens["prompt"] += prompt_tokens
self.tokens["completion"] += completion_tokens
self.tokens["total"] += prompt_tokens + completion_tokens
@abstractmethod
async def orchestrate(
self, user_message: str, chat_history: List[dict], **kwargs: dict
) -> list[dict]:
pass
def call_content_safety_input(self, user_message: str):
logger.debug("Calling content safety with question")
filtered_user_message = (
self.content_safety_checker.validate_input_and_replace_if_harmful(
user_message
)
)
if user_message != filtered_user_message:
logger.warning("Content safety detected harmful content in question")
messages = self.output_parser.parse(
question=user_message, answer=filtered_user_message
)
return messages
return None
def call_content_safety_output(self, user_message: str, answer: str):
logger.debug("Calling content safety with answer")
filtered_answer = (
self.content_safety_checker.validate_output_and_replace_if_harmful(answer)
)
if answer != filtered_answer:
logger.warning("Content safety detected harmful content in answer")
messages = self.output_parser.parse(
question=user_message, answer=filtered_answer
)
return messages
return None
async def handle_message(
self,
user_message: str,
chat_history: List[dict],
conversation_id: Optional[str],
**kwargs: Optional[dict],
) -> dict:
result = await self.orchestrate(user_message, chat_history, **kwargs)
if self.config.logging.log_tokens:
custom_dimensions = {
"conversation_id": conversation_id,
"message_id": self.message_id,
"prompt_tokens": self.tokens["prompt"],
"completion_tokens": self.tokens["completion"],
"total_tokens": self.tokens["total"],
}
logger.info("Token Consumption", extra=custom_dimensions)
if self.config.logging.log_user_interactions:
self.conversation_logger.log(
messages=[
{
"role": "user",
"content": user_message,
"conversation_id": conversation_id,
}
]
+ result
)
return result