Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agent/core): Add Anthropic Claude 3 support #7085

Merged
merged 35 commits into from May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
76e9fd1
refactor(agent/core): Tweak `model_providers.schema`
Pwuts Apr 16, 2024
7ecc459
feat(agent/core): Add `AnthropicProvider`
Pwuts Apr 16, 2024
95bbda0
Merge branch 'master' into reinier/open-591-add-claude-3-support
Pwuts Apr 18, 2024
001fe75
feat(agent/core): Allow zero-argument instantiation of `AnthropicProv…
Pwuts Apr 18, 2024
8beaedd
feat(agent/core): Add `max_output_tokens` parameter to `create_chat_c…
Pwuts Apr 18, 2024
d5eb79f
refactor(agent): Add `ChatModelProvider.get_available_models()` and r…
Pwuts Apr 20, 2024
0594653
refactor(agent/core): Allow `ModelProviderBudget` zero-argument insta…
Pwuts Apr 20, 2024
651c99d
refactor(agent/core): Add shared attributes and constructor to `Model…
Pwuts Apr 20, 2024
2da125d
fix(agent): Change `max_tokens` to `max_output_tokens` in `create_cha…
Pwuts Apr 20, 2024
9d38dbd
feat(agent): Allow use of any available LLM provider through `MultiPr…
Pwuts Apr 20, 2024
70c97ca
feat(agent): Enable use of tool calling API(s) by default
Pwuts Apr 20, 2024
dcd1685
fix(agent/core): Make retry mechanism of `AnthropicProvider` specific…
Pwuts Apr 20, 2024
2aa4ca5
fix(agent/core): Set `retries_per_request` to 7 by default
Pwuts Apr 20, 2024
a60854e
Merge branch 'master' into reinier/open-591-add-claude-3-support
Pwuts Apr 22, 2024
1ae07a5
Merge branch 'master' into reinier/open-591-add-claude-3-support
Pwuts Apr 23, 2024
92ff5a4
Merge branch 'master' into reinier/open-591-add-claude-3-support
Pwuts Apr 23, 2024
933ec93
Merge branch 'master' into reinier/open-591-add-claude-3-support
Pwuts Apr 24, 2024
72d0248
Adhere to Anthropic message schema in `AnthropicProvider` parse-fix m…
Pwuts Apr 24, 2024
e492258
feat(agent/core): Validate function call arguments in `create_chat_co…
Pwuts Apr 25, 2024
d5d8bfc
Merge branch 'master' into reinier/open-591-add-claude-3-support
Pwuts Apr 27, 2024
a62ec0f
Add logic to `AnthropicProvider` to merge prefill message into response
Pwuts Apr 30, 2024
6503b9e
Merge branch 'master' into reinier/open-591-add-claude-3-support
Pwuts May 1, 2024
e65b57f
feat(agent): Formalize pre-filling as a global feature
Pwuts May 1, 2024
b32778b
Rename `get_openai_command_specs` to `function_specs_from_commands`
Pwuts May 1, 2024
fd067a0
Fix clock message format
Pwuts May 1, 2024
c10fafa
Remove problematic "final instruction message" from `OneShot` prompt
Pwuts May 1, 2024
338986d
Remove unused imports
Pwuts May 1, 2024
32fe727
Remove unused import in benchmarks.py
Pwuts May 2, 2024
d2a9be0
FIx pre-fill interaction with parse-fix mechanism
Pwuts May 2, 2024
2e197eb
Merge branch 'master' into reinier/open-591-add-claude-3-support
Pwuts May 2, 2024
207957b
Revert "Remove problematic "final instruction message" from `OneShot`…
Pwuts May 2, 2024
94985ef
Implement mechanism to merge subsequent user messages in `AnthropicPr…
Pwuts May 2, 2024
1ca8d20
Revert "feat(agent): Enable use of tool calling API(s) by default"
Pwuts May 2, 2024
bb4f8a2
Add `ANTHROPIC_API_KEY` to .env.template and docs
Pwuts May 3, 2024
0f778e5
Merge branch 'master' into reinier/open-591-add-claude-3-support
Pwuts May 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions autogpts/autogpt/.env.template
Expand Up @@ -2,8 +2,11 @@
### AutoGPT - GENERAL SETTINGS
################################################################################

## OPENAI_API_KEY - OpenAI API Key (Example: my-openai-api-key)
OPENAI_API_KEY=your-openai-api-key
## OPENAI_API_KEY - OpenAI API Key (Example: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
# OPENAI_API_KEY=

## ANTHROPIC_API_KEY - Anthropic API Key (Example: sk-ant-api03-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
# ANTHROPIC_API_KEY=

## TELEMETRY_OPT_IN - Share telemetry on errors and other issues with the AutoGPT team, e.g. through Sentry.
## This helps us to spot and solve problems earlier & faster. (Default: DISABLED)
Expand Down
10 changes: 2 additions & 8 deletions autogpts/autogpt/agbenchmark_config/benchmarks.py
Expand Up @@ -5,8 +5,7 @@

from autogpt.agent_manager.agent_manager import AgentManager
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
from autogpt.agents.prompt_strategies.one_shot import OneShotAgentPromptStrategy
from autogpt.app.main import _configure_openai_provider, run_interaction_loop
from autogpt.app.main import _configure_llm_provider, run_interaction_loop
from autogpt.config import AIProfile, ConfigBuilder
from autogpt.file_storage import FileStorageBackendName, get_storage
from autogpt.logs.config import configure_logging
Expand Down Expand Up @@ -38,10 +37,6 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
ai_goals=[task],
)

agent_prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(
deep=True
)
agent_prompt_config.use_functions_api = config.openai_functions
agent_settings = AgentSettings(
name=Agent.default_settings.name,
agent_id=AgentManager.generate_id("AutoGPT-benchmark"),
Expand All @@ -53,7 +48,6 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
allow_fs_access=not config.restrict_to_workspace,
use_functions_api=config.openai_functions,
),
prompt_config=agent_prompt_config,
history=Agent.default_settings.history.copy(deep=True),
)

Expand All @@ -66,7 +60,7 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:

agent = Agent(
settings=agent_settings,
llm_provider=_configure_openai_provider(config),
llm_provider=_configure_llm_provider(config),
file_storage=file_storage,
legacy_config=config,
)
Expand Down
43 changes: 10 additions & 33 deletions autogpts/autogpt/autogpt/agents/agent.py
Expand Up @@ -19,15 +19,14 @@
from autogpt.core.configuration import Configurable
from autogpt.core.prompting import ChatPrompt
from autogpt.core.resource.model_providers import (
AssistantChatMessage,
AssistantFunctionCall,
ChatMessage,
ChatModelProvider,
ChatModelResponse,
)
from autogpt.core.runner.client_lib.logging.helpers import dump_prompt
from autogpt.file_storage.base import FileStorage
from autogpt.llm.providers.openai import get_openai_command_specs
from autogpt.llm.providers.openai import function_specs_from_commands
from autogpt.logs.log_cycle import (
CURRENT_CONTEXT_FILE_NAME,
NEXT_ACTION_FILE_NAME,
Expand All @@ -46,7 +45,6 @@
AgentException,
AgentTerminated,
CommandExecutionError,
InvalidArgumentError,
UnknownCommandError,
)

Expand Down Expand Up @@ -104,7 +102,11 @@ def __init__(
self.ai_profile = settings.ai_profile
self.directives = settings.directives
prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(deep=True)
prompt_config.use_functions_api = settings.config.use_functions_api
prompt_config.use_functions_api = (
settings.config.use_functions_api
# Anthropic currently doesn't support tools + prefilling :(
and self.llm.provider_name != "anthropic"
Comment on lines +105 to +108

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: The conditional logic for setting use_functions_api in the OneShotAgentPromptStrategy configuration is based on the provider name being "anthropic". This is a string comparison that should be case-insensitive to avoid potential bugs due to case variations. [best practice]

Suggested change
prompt_config.use_functions_api = (
settings.config.use_functions_api
# Anthropic currently doesn't support tools + prefilling :(
and self.llm.provider_name != "anthropic"
prompt_config.use_functions_api = (
settings.config.use_functions_api
and self.llm.provider_name.lower() != "anthropic"
)

)
self.prompt_strategy = OneShotAgentPromptStrategy(prompt_config, logger)
self.commands: list[Command] = []

Expand Down Expand Up @@ -172,7 +174,7 @@ async def propose_action(self) -> OneShotAgentActionProposal:
task=self.state.task,
ai_profile=self.state.ai_profile,
ai_directives=directives,
commands=get_openai_command_specs(self.commands),
commands=function_specs_from_commands(self.commands),
include_os_info=self.legacy_config.execute_local_commands,
)

Expand Down Expand Up @@ -202,12 +204,9 @@ async def complete_and_parse(
] = await self.llm_provider.create_chat_completion(
prompt.messages,
model_name=self.llm.name,
completion_parser=self.parse_and_validate_response,
functions=(
get_openai_command_specs(self.commands)
if self.config.use_functions_api
else []
),
completion_parser=self.prompt_strategy.parse_response_content,
functions=prompt.functions,
prefill_response=prompt.prefill_response,
)
result = response.parsed_result

Expand All @@ -223,28 +222,6 @@ async def complete_and_parse(

return result

def parse_and_validate_response(
self, llm_response: AssistantChatMessage
) -> OneShotAgentActionProposal:
parsed_response = self.prompt_strategy.parse_response_content(llm_response)

# Validate command arguments
command_name = parsed_response.use_tool.name
command = self._get_command(command_name)
if arg_errors := command.validate_args(parsed_response.use_tool.arguments)[1]:
fmt_errors = [
f"{'.'.join(str(p) for p in f.path)}: {f.message}"
if f.path
else f.message
for f in arg_errors
]
raise InvalidArgumentError(
f"The set of arguments supplied for {command_name} is invalid:\n"
+ "\n".join(fmt_errors)
)

return parsed_response

async def execute(
self,
proposal: OneShotAgentActionProposal,
Expand Down
15 changes: 8 additions & 7 deletions autogpts/autogpt/autogpt/agents/base.py
Expand Up @@ -39,11 +39,12 @@
SystemSettings,
UserConfigurable,
)
from autogpt.core.resource.model_providers import AssistantFunctionCall
from autogpt.core.resource.model_providers.openai import (
OPEN_AI_CHAT_MODELS,
OpenAIModelName,
from autogpt.core.resource.model_providers import (
CHAT_MODELS,
AssistantFunctionCall,
ModelName,
)
from autogpt.core.resource.model_providers.openai import OpenAIModelName
from autogpt.models.utils import ModelWithSummary
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT

Expand All @@ -56,8 +57,8 @@
class BaseAgentConfiguration(SystemConfiguration):
allow_fs_access: bool = UserConfigurable(default=False)

fast_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
smart_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT4)
fast_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
smart_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT4)
use_functions_api: bool = UserConfigurable(default=False)

default_cycle_instruction: str = DEFAULT_TRIGGERING_PROMPT
Expand Down Expand Up @@ -174,7 +175,7 @@ def llm(self) -> ChatModelInfo:
llm_name = (
self.config.smart_llm if self.config.big_brain else self.config.fast_llm
)
return OPEN_AI_CHAT_MODELS[llm_name]
return CHAT_MODELS[llm_name]

@property
def send_token_limit(self) -> int:
Expand Down
42 changes: 28 additions & 14 deletions autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py
Expand Up @@ -122,7 +122,7 @@
1. System prompt
3. `cycle_instruction`
"""
system_prompt = self.build_system_prompt(
system_prompt, response_prefill = self.build_system_prompt(

Check warning on line 125 in autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py

View check run for this annotation

Codecov / codecov/patch

autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py#L125

Added line #L125 was not covered by tests
ai_profile=ai_profile,
ai_directives=ai_directives,
commands=commands,
Expand All @@ -131,24 +131,34 @@

final_instruction_msg = ChatMessage.user(self.config.choose_action_instruction)

prompt = ChatPrompt(
return ChatPrompt(

Check warning on line 134 in autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py

View check run for this annotation

Codecov / codecov/patch

autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py#L134

Added line #L134 was not covered by tests
messages=[
ChatMessage.system(system_prompt),
ChatMessage.user(f'"""{task}"""'),
*messages,
final_instruction_msg,
],
prefill_response=response_prefill,
functions=commands if self.config.use_functions_api else [],
)

return prompt

def build_system_prompt(
self,
ai_profile: AIProfile,
ai_directives: AIDirectives,
commands: list[CompletionModelFunction],
include_os_info: bool,
) -> str:
) -> tuple[str, str]:
"""
Builds the system prompt.

Returns:
str: The system prompt body
str: The desired start for the LLM's response; used to steer the output
"""
response_fmt_instruction, response_prefill = self.response_format_instruction(

Check warning on line 159 in autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py

View check run for this annotation

Codecov / codecov/patch

autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py#L159

Added line #L159 was not covered by tests
self.config.use_functions_api
)
system_prompt_parts = (
self._generate_intro_prompt(ai_profile)
+ (self._generate_os_info() if include_os_info else [])
Expand All @@ -169,16 +179,16 @@
" in the next message. Your job is to complete the task while following"
" your directives as given above, and terminate when your task is done."
]
+ [
"## RESPONSE FORMAT\n"
+ self.response_format_instruction(self.config.use_functions_api)
]
+ ["## RESPONSE FORMAT\n" + response_fmt_instruction]
)

# Join non-empty parts together into paragraph format
return "\n\n".join(filter(None, system_prompt_parts)).strip("\n")
return (

Check warning on line 186 in autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py

View check run for this annotation

Codecov / codecov/patch

autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py#L186

Added line #L186 was not covered by tests
"\n\n".join(filter(None, system_prompt_parts)).strip("\n"),
response_prefill,
)
Comment on lines +187 to +189

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: The method validate_call in ToolResultMessage class raises a ValueError if the function call name does not match the expected name. This exception handling could be enhanced by providing more detailed error information, such as the expected function name and the provided arguments, to aid in debugging. [enhancement]

Suggested change
"\n\n".join(filter(None, system_prompt_parts)).strip("\n"),
response_prefill,
)
raise ValueError(
f"Can't validate function call '{function_call.name}' with arguments {function_call.arguments} using expected function '{self.name}' spec"
)


def response_format_instruction(self, use_functions_api: bool) -> str:
def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]:
response_schema = self.response_schema.copy(deep=True)
if (
use_functions_api
Expand All @@ -193,11 +203,15 @@
"\n",
response_schema.to_typescript_object_interface(_RESPONSE_INTERFACE_NAME),
)
response_prefill = f'{{\n "{list(response_schema.properties.keys())[0]}":'

Check warning on line 206 in autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py

View check run for this annotation

Codecov / codecov/patch

autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py#L206

Added line #L206 was not covered by tests

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: The method response_format_instruction returns a tuple which includes a string and a prefill response. The prefill response is constructed using the first key of the response_schema.properties. This could potentially lead to a KeyError if response_schema.properties is empty. To prevent this, add a check to ensure that response_schema.properties is not empty before accessing its keys. [possible issue]

Suggested change
response_prefill = f'{{\n "{list(response_schema.properties.keys())[0]}":'
if not response_schema.properties:
raise ValueError("No properties found in response schema.")
response_prefill = f'{{\n "{list(response_schema.properties.keys())[0]}":'


return (
f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n"
f"{response_format}"
+ ("\n\nYOU MUST ALSO INVOKE A TOOL!" if use_functions_api else "")
(
f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n"
f"{response_format}"
+ ("\n\nYOU MUST ALSO INVOKE A TOOL!" if use_functions_api else "")
),
response_prefill,
)

def _generate_intro_prompt(self, ai_profile: AIProfile) -> list[str]:
Expand Down
25 changes: 11 additions & 14 deletions autogpts/autogpt/autogpt/app/agent_protocol_server.py
Expand Up @@ -34,7 +34,6 @@
from autogpt.app.utils import is_port_free
from autogpt.config import Config
from autogpt.core.resource.model_providers import ChatModelProvider, ModelProviderBudget
from autogpt.core.resource.model_providers.openai import OpenAIProvider
from autogpt.file_storage import FileStorage
from autogpt.models.action_history import ActionErrorResult, ActionSuccessResult
from autogpt.utils.exceptions import AgentFinished
Expand Down Expand Up @@ -464,20 +463,18 @@
if task.additional_input and (user_id := task.additional_input.get("user_id")):
_extra_request_headers["AutoGPT-UserID"] = user_id

task_llm_provider = None
if isinstance(self.llm_provider, OpenAIProvider):
settings = self.llm_provider._settings.copy()
settings.budget = task_llm_budget
settings.configuration = task_llm_provider_config # type: ignore
task_llm_provider = OpenAIProvider(
settings=settings,
logger=logger.getChild(f"Task-{task.task_id}_OpenAIProvider"),
)

if task_llm_provider and task_llm_provider._budget:
self._task_budgets[task.task_id] = task_llm_provider._budget
settings = self.llm_provider._settings.copy()
settings.budget = task_llm_budget
settings.configuration = task_llm_provider_config
task_llm_provider = self.llm_provider.__class__(

Check warning on line 469 in autogpts/autogpt/autogpt/app/agent_protocol_server.py

View check run for this annotation

Codecov / codecov/patch

autogpts/autogpt/autogpt/app/agent_protocol_server.py#L466-L469

Added lines #L466 - L469 were not covered by tests
settings=settings,
logger=logger.getChild(
f"Task-{task.task_id}_{self.llm_provider.__class__.__name__}"
),
)
self._task_budgets[task.task_id] = task_llm_provider._budget # type: ignore

Check warning on line 475 in autogpts/autogpt/autogpt/app/agent_protocol_server.py

View check run for this annotation

Codecov / codecov/patch

autogpts/autogpt/autogpt/app/agent_protocol_server.py#L475

Added line #L475 was not covered by tests

return task_llm_provider or self.llm_provider
return task_llm_provider

Check warning on line 477 in autogpts/autogpt/autogpt/app/agent_protocol_server.py

View check run for this annotation

Codecov / codecov/patch

autogpts/autogpt/autogpt/app/agent_protocol_server.py#L477

Added line #L477 was not covered by tests
Comment on lines +466 to +477

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Refactor the method _get_task_llm_provider to separate concerns, improving readability and maintainability by extracting the settings configuration into a separate method. [maintainability]

Suggested change
settings = self.llm_provider._settings.copy()
settings.budget = task_llm_budget
settings.configuration = task_llm_provider_config
task_llm_provider = self.llm_provider.__class__(
settings=settings,
logger=logger.getChild(
f"Task-{task.task_id}_{self.llm_provider.__class__.__name__}"
),
)
self._task_budgets[task.task_id] = task_llm_provider._budget # type: ignore
return task_llm_provider or self.llm_provider
return task_llm_provider
def _configure_task_settings(self, task: Task, logger: logging.Logger) -> ModelProviderSettings:
settings = self.llm_provider._settings.copy()
settings.budget = task_llm_budget
settings.configuration = task_llm_provider_config
return settings
def _get_task_llm_provider(self, task: Task, logger: logging.Logger) -> ModelProvider:
if task.additional_input and (user_id := task.additional_input.get("user_id")):
_extra_request_headers["AutoGPT-UserID"] = user_id
settings = self._configure_task_settings(task, logger)
task_llm_provider = self.llm_provider.__class__(
settings=settings,
logger=logger.getChild(
f"Task-{task.task_id}_{self.llm_provider.__class__.__name__}"
),
)
self._task_budgets[task.task_id] = task_llm_provider._budget # type: ignore
return task_llm_provider



def task_agent_id(task_id: str | int) -> str:
Expand Down
10 changes: 5 additions & 5 deletions autogpts/autogpt/autogpt/app/configurator.py
Expand Up @@ -10,7 +10,7 @@

from autogpt.config import Config
from autogpt.config.config import GPT_3_MODEL, GPT_4_MODEL
from autogpt.core.resource.model_providers.openai import OpenAIModelName, OpenAIProvider
from autogpt.core.resource.model_providers import ModelName, MultiProvider
from autogpt.logs.helpers import request_user_double_check
from autogpt.memory.vector import get_supported_memory_backends
from autogpt.utils import utils
Expand Down Expand Up @@ -150,11 +150,11 @@ async def apply_overrides_to_config(


async def check_model(
model_name: OpenAIModelName, model_type: Literal["smart_llm", "fast_llm"]
) -> OpenAIModelName:
model_name: ModelName, model_type: Literal["smart_llm", "fast_llm"]
) -> ModelName:
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
openai = OpenAIProvider()
models = await openai.get_available_models()
multi_provider = MultiProvider()
models = await multi_provider.get_available_models()

if any(model_name == m.name for m in models):
return model_name
Expand Down
30 changes: 9 additions & 21 deletions autogpts/autogpt/autogpt/app/main.py
Expand Up @@ -35,7 +35,7 @@
ConfigBuilder,
assert_config_has_openai_api_key,
)
from autogpt.core.resource.model_providers.openai import OpenAIProvider
from autogpt.core.resource.model_providers import MultiProvider
from autogpt.core.runner.client_lib.utils import coroutine
from autogpt.file_storage import FileStorageBackendName, get_storage
from autogpt.logs.config import configure_logging
Expand Down Expand Up @@ -123,7 +123,7 @@
skip_news=skip_news,
)

llm_provider = _configure_openai_provider(config)
llm_provider = _configure_llm_provider(config)

Check warning on line 126 in autogpts/autogpt/autogpt/app/main.py

View check run for this annotation

Codecov / codecov/patch

autogpts/autogpt/autogpt/app/main.py#L126

Added line #L126 was not covered by tests

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -399,7 +399,7 @@
allow_downloads=allow_downloads,
)

llm_provider = _configure_openai_provider(config)
llm_provider = _configure_llm_provider(config)

Check warning on line 402 in autogpts/autogpt/autogpt/app/main.py

View check run for this annotation

Codecov / codecov/patch

autogpts/autogpt/autogpt/app/main.py#L402

Added line #L402 was not covered by tests

# Set up & start server
database = AgentDB(
Expand All @@ -421,24 +421,12 @@
)


def _configure_openai_provider(config: Config) -> OpenAIProvider:
"""Create a configured OpenAIProvider object.

Args:
config: The program's configuration.

Returns:
A configured OpenAIProvider object.
"""
if config.openai_credentials is None:
raise RuntimeError("OpenAI key is not configured")

openai_settings = OpenAIProvider.default_settings.copy(deep=True)
openai_settings.credentials = config.openai_credentials
return OpenAIProvider(
settings=openai_settings,
logger=logging.getLogger("OpenAIProvider"),
)
def _configure_llm_provider(config: Config) -> MultiProvider:
multi_provider = MultiProvider()
for model in [config.smart_llm, config.fast_llm]:
# Ensure model providers for configured LLMs are available
multi_provider.get_model_provider(model)
return multi_provider
Comment on lines +424 to +429

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: The function _configure_llm_provider should handle the case where the configuration does not specify any models, to avoid runtime errors when accessing config.smart_llm or config.fast_llm. [enhancement]

Suggested change
def _configure_llm_provider(config: Config) -> MultiProvider:
multi_provider = MultiProvider()
for model in [config.smart_llm, config.fast_llm]:
# Ensure model providers for configured LLMs are available
multi_provider.get_model_provider(model)
return multi_provider
def _configure_llm_provider(config: Config) -> MultiProvider:
multi_provider = MultiProvider()
models = [model for model in [config.smart_llm, config.fast_llm] if model]
for model in models:
# Ensure model providers for configured LLMs are available
multi_provider.get_model_provider(model)
return multi_provider

Comment on lines +424 to +429

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: The method _configure_llm_provider should include error handling or logging to provide feedback when a model provider is not available, which would improve maintainability and debugging. [maintainability]

Suggested change
def _configure_llm_provider(config: Config) -> MultiProvider:
multi_provider = MultiProvider()
for model in [config.smart_llm, config.fast_llm]:
# Ensure model providers for configured LLMs are available
multi_provider.get_model_provider(model)
return multi_provider
def _configure_llm_provider(config: Config) -> MultiProvider:
multi_provider = MultiProvider()
for model in [config.smart_llm, config.fast_llm]:
try:
# Ensure model providers for configured LLMs are available
multi_provider.get_model_provider(model)
except Exception as e:
logging.getLogger(__name__).error(f"Failed to configure model provider for {model}: {str(e)}")
continue
return multi_provider



def _get_cycle_budget(continuous_mode: bool, continuous_limit: int) -> int | float:
Expand Down