diff --git a/autogpt/config/config.py b/autogpt/config/config.py index 579f0c4e6f9..ae2f7bedc54 100644 --- a/autogpt/config/config.py +++ b/autogpt/config/config.py @@ -144,7 +144,18 @@ def validate_plugins(cls, p: AutoGPTPluginTemplate | Any): ), f"Plugins must subclass AutoGPTPluginTemplate; {p} is a template instance" return p - def get_azure_kwargs(self, model: str) -> dict[str, str]: + def get_openai_credentials(self, model: str) -> dict[str, str]: + credentials = { + "api_key": self.openai_api_key, + "api_base": self.openai_api_base, + "organization": self.openai_organization, + } + if self.use_azure: + azure_credentials = self.get_azure_credentials(model) + credentials.update(azure_credentials) + return credentials + + def get_azure_credentials(self, model: str) -> dict[str, str]: """Get the kwargs for the Azure API.""" # Fix --gpt3only and --gpt4only in combination with Azure diff --git a/autogpt/llm/utils/__init__.py b/autogpt/llm/utils/__init__.py index 74e88dc6751..ff485260ded 100644 --- a/autogpt/llm/utils/__init__.py +++ b/autogpt/llm/utils/__init__.py @@ -78,17 +78,14 @@ def create_text_completion( if temperature is None: temperature = config.temperature - if config.use_azure: - kwargs = config.get_azure_kwargs(model) - else: - kwargs = {"model": model} + kwargs = {"model": model} + kwargs.update(config.get_openai_credentials(model)) response = iopenai.create_text_completion( prompt=prompt, **kwargs, temperature=temperature, max_tokens=max_output_tokens, - api_key=config.openai_api_key, ) logger.debug(f"Response: {response}") @@ -150,9 +147,7 @@ def create_chat_completion( if message is not None: return message - chat_completion_kwargs["api_key"] = config.openai_api_key - if config.use_azure: - chat_completion_kwargs.update(config.get_azure_kwargs(model)) + chat_completion_kwargs.update(config.get_openai_credentials(model)) if functions: chat_completion_kwargs["functions"] = [ @@ -196,12 +191,7 @@ def check_model( config: Config, ) -> str: """Check if model is available for use. If not, return gpt-3.5-turbo.""" - openai_credentials = { - "api_key": config.openai_api_key, - } - if config.use_azure: - openai_credentials.update(config.get_azure_kwargs(model_name)) - + openai_credentials = config.get_openai_credentials(model_name) api_manager = ApiManager() models = api_manager.get_models(**openai_credentials) diff --git a/autogpt/memory/vector/utils.py b/autogpt/memory/vector/utils.py index 74438f28c56..eb69125666a 100644 --- a/autogpt/memory/vector/utils.py +++ b/autogpt/memory/vector/utils.py @@ -41,10 +41,8 @@ def get_embedding( input = [text.replace("\n", " ") for text in input] model = config.embedding_model - if config.use_azure: - kwargs = config.get_azure_kwargs(model) - else: - kwargs = {"model": model} + kwargs = {"model": model} + kwargs.update(config.get_openai_credentials(model)) logger.debug( f"Getting embedding{f's for {len(input)} inputs' if multiple else ''}" @@ -57,7 +55,6 @@ def get_embedding( embeddings = iopenai.create_embedding( input, **kwargs, - api_key=config.openai_api_key, ).data if not multiple: diff --git a/autogpt/processing/text.py b/autogpt/processing/text.py index ddb64df1887..faaa50e000d 100644 --- a/autogpt/processing/text.py +++ b/autogpt/processing/text.py @@ -137,7 +137,6 @@ def summarize_text( logger.info(f"Summarized {len(chunks)} chunks") summary, _ = summarize_text("\n\n".join(summaries), config) - return summary.strip(), [ (summaries[i], chunks[i][0]) for i in range(0, len(chunks)) ] diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index b441aa9484d..7abbfcd52fd 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -174,18 +174,32 @@ def test_azure_config(config: Config, workspace: Workspace) -> None: fast_llm = config.fast_llm smart_llm = config.smart_llm - assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" - assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID" + assert ( + config.get_azure_credentials(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" + ) + assert ( + config.get_azure_credentials(config.smart_llm)["deployment_id"] + == "SMART-LLM_ID" + ) # Emulate --gpt4only config.fast_llm = smart_llm - assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "SMART-LLM_ID" - assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "SMART-LLM_ID" + assert ( + config.get_azure_credentials(config.fast_llm)["deployment_id"] == "SMART-LLM_ID" + ) + assert ( + config.get_azure_credentials(config.smart_llm)["deployment_id"] + == "SMART-LLM_ID" + ) # Emulate --gpt3only config.fast_llm = config.smart_llm = fast_llm - assert config.get_azure_kwargs(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" - assert config.get_azure_kwargs(config.smart_llm)["deployment_id"] == "FAST-LLM_ID" + assert ( + config.get_azure_credentials(config.fast_llm)["deployment_id"] == "FAST-LLM_ID" + ) + assert ( + config.get_azure_credentials(config.smart_llm)["deployment_id"] == "FAST-LLM_ID" + ) del os.environ["USE_AZURE"] del os.environ["AZURE_CONFIG_FILE"]