15
15
from api .domain .config_pd import DOCUMENTS_PD , PROMPT_PD , SYSTEM_PROMPT_PD
16
16
from api .settings import Settings
17
17
18
- weave .init ("opentronsai/OpentronsAI-Phase-march-25" )
18
+ MessageType = Literal ["create" , "update" ]
19
+
20
+ weave .init ("opentronsai/OpentronsAI-Phase-May-23-25" )
19
21
settings : Settings = Settings ()
20
22
logger = structlog .stdlib .get_logger (settings .logger_name )
21
23
ROOT_PATH : Path = Path (Path (__file__ )).parent .parent .parent
25
27
class AnthropicPredict :
26
28
def __init__ (self , settings : Settings ) -> None :
27
29
self .settings : Settings = settings
30
+ self .max_tokens : int = 20000
28
31
self .client : Anthropic = Anthropic (api_key = settings .anthropic_api_key .get_secret_value ())
29
32
self .model_name : str = settings .anthropic_model_name
30
33
self .model_helper : str = settings .model_helper
@@ -176,7 +179,7 @@ def get_relevant_api_docs(self, query: str, user_id: str) -> str:
176
179
]
177
180
178
181
response = self .client .messages .create ( # type: ignore[call-overload]
179
- max_tokens = 2048 ,
182
+ max_tokens = 4096 ,
180
183
temperature = 0.0 ,
181
184
messages = msg ,
182
185
model = self .model_helper ,
@@ -188,16 +191,14 @@ def get_relevant_api_docs(self, query: str, user_id: str) -> str:
188
191
return response .content [0 ].text # type: ignore[no-any-return]
189
192
190
193
@tracer .wrap ()
191
- def _process_message (
192
- self , user_id : str , messages : List [MessageParam ], message_type : Literal ["create" , "update" ], max_tokens : int = 4096
193
- ) -> Message :
194
+ def _process_message (self , user_id : str , messages : List [MessageParam ], message_type : MessageType ) -> Message :
194
195
"""
195
196
Internal method to handle message processing with different system prompts.
196
197
For now, system prompt is the same.
197
198
"""
198
199
199
200
response : Message = self .client .messages .create ( # type: ignore[call-overload]
200
- max_tokens = max_tokens ,
201
+ max_tokens = self . max_tokens ,
201
202
messages = messages ,
202
203
model = self .model_name ,
203
204
system = self .system_prompt ,
@@ -219,7 +220,7 @@ def _process_message(
219
220
220
221
@tracer .wrap ()
221
222
def process_message (
222
- self , user_id : str , prompt : str , history : List [MessageParam ] | None = None , message_type : Literal [ "create" , "update" ] = "create"
223
+ self , user_id : str , prompt : str , history : List [MessageParam ] | None = None , message_type : MessageType = "create"
223
224
) -> str | None :
224
225
"""Unified method for creating and updating messages"""
225
226
try :
@@ -269,7 +270,7 @@ def process_message(
269
270
270
271
@tracer .wrap ()
271
272
def process_message_pd (
272
- self , user_id : str , prompt : str , history : List [MessageParam ] | None = None , message_type : Literal [ "create" , "update" ] = "create"
273
+ self , user_id : str , prompt : str , history : List [MessageParam ] | None = None , message_type : MessageType = "create"
273
274
) -> str | None :
274
275
"""return a partial json protocol"""
275
276
try :
@@ -281,7 +282,7 @@ def process_message_pd(
281
282
messages .append ({"role" : "user" , "content" : self .PROMPT_PD .format (USER_PROMPT = prompt )})
282
283
283
284
response : Message = self .client .messages .create (
284
- max_tokens = 20000 ,
285
+ max_tokens = self . max_tokens ,
285
286
messages = messages ,
286
287
model = self .model_name ,
287
288
system = cast (Iterable [TextBlockParam ], self .system_prompt_pd ),
0 commit comments