diff --git a/mindsql/llms/open_ai.py b/mindsql/llms/open_ai.py index b9bd4f9..923c1a9 100644 --- a/mindsql/llms/open_ai.py +++ b/mindsql/llms/open_ai.py @@ -26,6 +26,8 @@ def __init__(self, config=None, client=None): if 'api_key' not in config: raise ValueError(OPENAI_VALUE_ERROR) api_key = config.pop('api_key') + if 'model' in config: + self.model = config.pop('model') self.client = OpenAI(api_key=api_key, **config) def system_message(self, message: str) -> any: @@ -79,7 +81,7 @@ def invoke(self, prompt, **kwargs) -> str: if prompt is None or len(prompt) == 0: raise Exception(PROMPT_EMPTY_EXCEPTION) - model = self.config.get("model", "gpt-3.5-turbo") + model = self.model if self.model else "gpt-5-2025-08-07" temperature = kwargs.get("temperature", 0.1) max_tokens = kwargs.get("max_tokens", 500) response = self.client.chat.completions.create(model=model, messages=[{"role": "user", "content": prompt}],