diff --git a/config.json b/config.json index 6d80c99c..4f9830b1 100644 --- a/config.json +++ b/config.json @@ -1,6 +1,7 @@ { "language_model": { "provider": "litellm", + "model": "gpt-3.5-turbo-1106", "enable_observability_logging": true }, "github_app": { diff --git a/kaizen/llms/provider.py b/kaizen/llms/provider.py index 2780c2c0..319229fc 100644 --- a/kaizen/llms/provider.py +++ b/kaizen/llms/provider.py @@ -15,12 +15,11 @@ def __init__( max_tokens=DEFAULT_MAX_TOKENS, temperature=DEFAULT_TEMPERATURE, ): + self.config = ConfigData().get_config_data() self.system_prompt = system_prompt self.model = model self.max_tokens = max_tokens self.temperature = temperature - CONFIG_DATA = ConfigData() - self.config = CONFIG_DATA.get_config_data() if self.config.get("language_model", {}).get( "enable_observability_logging", False ): @@ -33,6 +32,9 @@ def chat_completion(self, prompt, user: str = None): {"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}, ] + if "model" in self.config.get("language_model", {}): + self.model = self.config["language_model"]["model"] + response = litellm.completion( model=self.model, messages=messages,