From ee2e130bd0255a450f516978de7b62af65b5115d Mon Sep 17 00:00:00 2001 From: Saurav Panda Date: Tue, 28 May 2024 00:51:51 -0400 Subject: [PATCH 1/2] fix: added model update functionality --- config.json | 1 + kaizen/llms/provider.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/config.json b/config.json index 6d80c99c..4f9830b1 100644 --- a/config.json +++ b/config.json @@ -1,6 +1,7 @@ { "language_model": { "provider": "litellm", + "model": "gpt-3.5-turbo-1106", "enable_observability_logging": true }, "github_app": { diff --git a/kaizen/llms/provider.py b/kaizen/llms/provider.py index 2780c2c0..17c22823 100644 --- a/kaizen/llms/provider.py +++ b/kaizen/llms/provider.py @@ -15,12 +15,12 @@ def __init__( max_tokens=DEFAULT_MAX_TOKENS, temperature=DEFAULT_TEMPERATURE, ): + CONFIG_DATA = ConfigData() + self.config = CONFIG_DATA.get_config_data() self.system_prompt = system_prompt self.model = model self.max_tokens = max_tokens self.temperature = temperature - CONFIG_DATA = ConfigData() - self.config = CONFIG_DATA.get_config_data() if self.config.get("language_model", {}).get( "enable_observability_logging", False ): @@ -33,6 +33,9 @@ def chat_completion(self, prompt, user: str = None): {"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}, ] + if "model" in self.config.get("language_model", {}): + self.model = self.config["language_model"]["model"] + response = litellm.completion( model=self.model, messages=messages, From b822cdb80f7764b907de57d5291a83ff60a19b29 Mon Sep 17 00:00:00 2001 From: Saurav Panda Date: Tue, 28 May 2024 00:54:02 -0400 Subject: [PATCH 2/2] fix: directly assigning config data --- kaizen/llms/provider.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/kaizen/llms/provider.py b/kaizen/llms/provider.py index 17c22823..319229fc 100644 --- a/kaizen/llms/provider.py +++ b/kaizen/llms/provider.py @@ -15,8 +15,7 @@ def __init__( max_tokens=DEFAULT_MAX_TOKENS, temperature=DEFAULT_TEMPERATURE, ): - CONFIG_DATA = ConfigData() - self.config = CONFIG_DATA.get_config_data() + self.config = ConfigData().get_config_data() self.system_prompt = system_prompt self.model = model self.max_tokens = max_tokens