1- __author__ = "lucaspavanelli"
2-
3- """
4- Copyright 2024 The aiXplain SDK authors
1+ """Copyright 2024 The aiXplain SDK authors.
52
63Licensed under the Apache License, Version 2.0 (the "License");
74you may not use this file except in compliance with the License.
2017Description:
2118 Large Language Model Class
2219"""
20+
21+ __author__ = "lucaspavanelli"
2322import time
2423import logging
2524import traceback
@@ -63,7 +62,7 @@ def __init__(
6362 function : Optional [Function ] = None ,
6463 is_subscribed : bool = False ,
6564 cost : Optional [Dict ] = None ,
66- temperature : float = 0.001 ,
65+ temperature : Optional [ float ] = None ,
6766 function_type : Optional [FunctionType ] = FunctionType .AI ,
6867 ** additional_info ,
6968 ) -> None :
@@ -79,14 +78,16 @@ def __init__(
7978 function (Function, optional): Model's AI function. Must be Function.TEXT_GENERATION.
8079 is_subscribed (bool, optional): Whether the user is subscribed. Defaults to False.
8180 cost (Dict, optional): Cost of the model. Defaults to None.
82- temperature (float, optional): Default temperature for text generation. Defaults to 0.001 .
81+ temperature (Optional[ float] , optional): Default temperature for text generation. Defaults to None .
8382 function_type (FunctionType, optional): Type of the function. Defaults to FunctionType.AI.
8483 **additional_info: Any additional model info to be saved.
8584
8685 Raises:
8786 AssertionError: If function is not Function.TEXT_GENERATION.
8887 """
89- assert function == Function .TEXT_GENERATION , "LLM only supports large language models (i.e. text generation function)"
88+ assert function == Function .TEXT_GENERATION , (
89+ "LLM only supports large language models (i.e. text generation function)"
90+ )
9091 super ().__init__ (
9192 id = id ,
9293 name = name ,
@@ -112,12 +113,13 @@ def run(
112113 history : Optional [List [Dict ]] = None ,
113114 temperature : Optional [float ] = None ,
114115 max_tokens : int = 128 ,
115- top_p : float = 1.0 ,
116+ top_p : Optional [ float ] = None ,
116117 name : Text = "model_process" ,
117118 timeout : float = 300 ,
118119 parameters : Optional [Dict ] = None ,
119120 wait_time : float = 0.5 ,
120121 stream : bool = False ,
122+ response_format : Optional [Text ] = None ,
121123 ) -> Union [ModelResponse , ModelResponseStreamer ]:
122124 """Run the LLM model synchronously to generate text.
123125
@@ -138,8 +140,8 @@ def run(
138140 Defaults to None.
139141 max_tokens (int, optional): Maximum number of tokens to generate.
140142 Defaults to 128.
141- top_p (float, optional): Nucleus sampling parameter. Only tokens with cumulative
142- probability < top_p are considered. Defaults to 1.0 .
143+ top_p (Optional[ float] , optional): Nucleus sampling parameter. Only tokens with cumulative
144+ probability < top_p are considered. Defaults to None .
143145 name (Text, optional): Identifier for this model run. Useful for logging.
144146 Defaults to "model_process".
145147 timeout (float, optional): Maximum time in seconds to wait for completion.
@@ -150,6 +152,8 @@ def run(
150152 Defaults to 0.5.
151153 stream (bool, optional): Whether to stream the model's output tokens.
152154 Defaults to False.
155+ response_format (Optional[Union[str, dict, BaseModel]], optional):
156+ Specifies the desired output structure or format of the model’s response.
153157
154158 Returns:
155159 Union[ModelResponse, ModelResponseStreamer]: If stream=False, returns a ModelResponse
@@ -166,9 +170,13 @@ def run(
166170 parameters .setdefault ("context" , context )
167171 parameters .setdefault ("prompt" , prompt )
168172 parameters .setdefault ("history" , history )
169- parameters .setdefault ("temperature" , temperature if temperature is not None else self .temperature )
173+ temp_value = temperature if temperature is not None else self .temperature
174+ if temp_value is not None :
175+ parameters .setdefault ("temperature" , temp_value )
170176 parameters .setdefault ("max_tokens" , max_tokens )
171- parameters .setdefault ("top_p" , top_p )
177+ if top_p is not None :
178+ parameters .setdefault ("top_p" , top_p )
179+ parameters .setdefault ("response_format" , response_format )
172180
173181 if stream :
174182 return self .run_stream (data = data , parameters = parameters )
@@ -210,9 +218,10 @@ def run_async(
210218 history : Optional [List [Dict ]] = None ,
211219 temperature : Optional [float ] = None ,
212220 max_tokens : int = 128 ,
213- top_p : float = 1.0 ,
221+ top_p : Optional [ float ] = None ,
214222 name : Text = "model_process" ,
215223 parameters : Optional [Dict ] = None ,
224+ response_format : Optional [Text ] = None ,
216225 ) -> ModelResponse :
217226 """Run the LLM model asynchronously to generate text.
218227
@@ -233,12 +242,14 @@ def run_async(
233242 Defaults to None.
234243 max_tokens (int, optional): Maximum number of tokens to generate.
235244 Defaults to 128.
236- top_p (float, optional): Nucleus sampling parameter. Only tokens with cumulative
237- probability < top_p are considered. Defaults to 1.0 .
245+ top_p (Optional[ float] , optional): Nucleus sampling parameter. Only tokens with cumulative
246+ probability < top_p are considered. Defaults to None .
238247 name (Text, optional): Identifier for this model run. Useful for logging.
239248 Defaults to "model_process".
240249 parameters (Optional[Dict], optional): Additional model-specific parameters.
241250 Defaults to None.
251+ response_format (Optional[Text], optional): Desired output format specification.
252+ Defaults to None.
242253
243254 Returns:
244255 ModelResponse: A response object containing:
@@ -261,9 +272,13 @@ def run_async(
261272 parameters .setdefault ("context" , context )
262273 parameters .setdefault ("prompt" , prompt )
263274 parameters .setdefault ("history" , history )
264- parameters .setdefault ("temperature" , temperature if temperature is not None else self .temperature )
275+ temp_value = temperature if temperature is not None else self .temperature
276+ if temp_value is not None :
277+ parameters .setdefault ("temperature" , temp_value )
265278 parameters .setdefault ("max_tokens" , max_tokens )
266- parameters .setdefault ("top_p" , top_p )
279+ if top_p is not None :
280+ parameters .setdefault ("top_p" , top_p )
281+ parameters .setdefault ("response_format" , response_format )
267282 payload = build_payload (data = data , parameters = parameters )
268283 response = call_run_endpoint (payload = payload , url = url , api_key = self .api_key )
269284 return ModelResponse (
0 commit comments