diff --git a/openbb_platform/core/openbb_core/api/router/commands.py b/openbb_platform/core/openbb_core/api/router/commands.py index 7b644cca372e..d524156ddbd2 100644 --- a/openbb_platform/core/openbb_core/api/router/commands.py +++ b/openbb_platform/core/openbb_core/api/router/commands.py @@ -51,11 +51,16 @@ def build_new_signature(path: str, func: Callable) -> Signature: parameter_list = sig.parameters.values() return_annotation = sig.return_annotation new_parameter_list = [] - - for parameter in parameter_list: + var_kw_pos = len(parameter_list) + for pos, parameter in enumerate(parameter_list): if parameter.name == "cc" and parameter.annotation == CommandContext: continue + if parameter.kind == Parameter.VAR_KEYWORD: + # We track VAR_KEYWORD parameter to insert the any additional + # parameters we need to add before it and avoid a SyntaxError + var_kw_pos = pos + new_parameter_list.append( Parameter( parameter.name, @@ -66,18 +71,21 @@ def build_new_signature(path: str, func: Callable) -> Signature: ) if CHARTING_INSTALLED and path.replace("/", "_")[1:] in Charting.functions(): - new_parameter_list.append( + new_parameter_list.insert( + var_kw_pos, Parameter( "chart", kind=Parameter.POSITIONAL_OR_KEYWORD, default=False, annotation=bool, - ) + ), ) + var_kw_pos += 1 if custom_headers := SystemService().system_settings.api_settings.custom_headers: for name, default in custom_headers.items(): - new_parameter_list.append( + new_parameter_list.insert( + var_kw_pos, Parameter( name.replace("-", "_"), kind=Parameter.POSITIONAL_OR_KEYWORD, @@ -85,11 +93,13 @@ def build_new_signature(path: str, func: Callable) -> Signature: annotation=Annotated[ Optional[str], Header(include_in_schema=False) ], - ) + ), ) + var_kw_pos += 1 if Env().API_AUTH: - new_parameter_list.append( + new_parameter_list.insert( + var_kw_pos, Parameter( "__authenticated_user_settings", kind=Parameter.POSITIONAL_OR_KEYWORD, @@ -97,8 +107,9 @@ def build_new_signature(path: str, func: Callable) -> Signature: annotation=Annotated[ UserSettings, Depends(AuthService().user_settings_hook) ], - ) + ), ) + var_kw_pos += 1 return Signature( parameters=new_parameter_list, diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py index 33fb4def99fe..9f8b52c8714b 100644 --- a/openbb_platform/core/openbb_core/app/command_runner.py +++ b/openbb_platform/core/openbb_core/app/command_runner.py @@ -20,7 +20,7 @@ from openbb_core.app.model.obbject import OBBject from openbb_core.app.model.system_settings import SystemSettings from openbb_core.app.model.user_settings import UserSettings -from openbb_core.app.provider_interface import ProviderInterface +from openbb_core.app.provider_interface import ExtraParams, ProviderInterface from openbb_core.app.router import CommandMap from openbb_core.app.service.system_service import SystemService from openbb_core.app.service.user_service import UserService @@ -185,13 +185,16 @@ def _warn_kwargs( ) -> None: """Warn if kwargs received and ignored by the validation model.""" # We only check the extra_params annotation because ignored fields - # will always be kwargs + # will always be there annotation = getattr( model.model_fields.get("extra_params", None), "annotation", None ) - if annotation: - # When there is no annotation there is nothing to warn - valid = asdict(annotation()) if is_dataclass(annotation) else {} # type: ignore + if is_dataclass(annotation) and any( + t is ExtraParams for t in getattr(annotation, "__bases__", []) + ): + # We only warn when endpoint defines ExtraParams, so we need + # to check if the annotation is a dataclass and child of ExtraParams + valid = asdict(annotation()) # type: ignore provider = provider_choices.get("provider", None) for p in extra_params: if field := valid.get(p): @@ -216,7 +219,12 @@ def _warn_kwargs( @staticmethod def _as_dict(obj: Any) -> Dict[str, Any]: """Safely convert an object to a dict.""" - return asdict(obj) if is_dataclass(obj) else dict(obj) + try: + if isinstance(obj, dict): + return obj + return asdict(obj) if is_dataclass(obj) else dict(obj) + except Exception: + return {} @staticmethod def validate_kwargs( @@ -227,7 +235,7 @@ def validate_kwargs( sig = signature(func) fields = { n: ( - p.annotation, + Any if p.annotation is Parameter.empty else p.annotation, ... if p.default is Parameter.empty else p.default, ) for n, p in sig.parameters.items() diff --git a/openbb_platform/core/openbb_core/app/static/package_builder.py b/openbb_platform/core/openbb_core/app/static/package_builder.py index 793f82731c6c..0faab0626056 100644 --- a/openbb_platform/core/openbb_core/app/static/package_builder.py +++ b/openbb_platform/core/openbb_core/app/static/package_builder.py @@ -528,10 +528,12 @@ def get_deprecation_message(path: str) -> str: return getattr(PathHandler.build_route_map()[path], "summary", "") @staticmethod - def reorder_params(params: Dict[str, Parameter]) -> "OrderedDict[str, Parameter]": - """Reorder the params.""" + def reorder_params( + params: Dict[str, Parameter], var_kw: Optional[List[str]] = None + ) -> "OrderedDict[str, Parameter]": + """Reorder the params and make sure VAR_KEYWORD come after 'provider.""" formatted_keys = list(params.keys()) - for k in ["provider", "extra_params"]: + for k in ["provider"] + (var_kw or []): if k in formatted_keys: formatted_keys.remove(k) formatted_keys.append(k) @@ -563,14 +565,11 @@ def format_params( ) formatted: Dict[str, Parameter] = {} - + var_kw = [] for name, param in parameter_map.items(): if name == "extra_params": formatted[name] = Parameter(name="kwargs", kind=Parameter.VAR_KEYWORD) - elif name == "kwargs": - formatted["**" + name] = Parameter( - name="kwargs", kind=Parameter.VAR_KEYWORD, annotation=Any - ) + var_kw.append(name) elif name == "provider_choices": fields = param.annotation.__args__[0].__dataclass_fields__ field = fields["provider"] @@ -624,12 +623,14 @@ def format_params( formatted[name] = Parameter( name=name, - kind=Parameter.POSITIONAL_OR_KEYWORD, + kind=param.kind, annotation=updated_type, default=param.default, ) + if param.kind == Parameter.VAR_KEYWORD: + var_kw.append(name) - return MethodDefinition.reorder_params(params=formatted) + return MethodDefinition.reorder_params(params=formatted, var_kw=var_kw) @staticmethod def add_field_custom_annotations( diff --git a/openbb_platform/core/tests/app/static/test_package_builder.py b/openbb_platform/core/tests/app/static/test_package_builder.py index aa7c00294757..029c200fa45a 100644 --- a/openbb_platform/core/tests/app/static/test_package_builder.py +++ b/openbb_platform/core/tests/app/static/test_package_builder.py @@ -206,17 +206,57 @@ class TestAnnotatedDataClass: assert result -def test_reorder_params(method_definition): - """Test reorder params.""" - params = { - "provider": Parameter.empty, - "extra_params": Parameter.empty, - "param1": Parameter.empty, - "param2": Parameter.empty, - } - result = method_definition.reorder_params(params=params) +@pytest.mark.parametrize( + "params, var_kw, expected", + [ + ( + { + "provider": Parameter.empty, + "extra_params": Parameter.empty, + "param1": Parameter.empty, + "param2": Parameter.empty, + }, + None, + ["extra_params", "param1", "param2", "provider"], + ), + ( + { + "param1": Parameter.empty, + "provider": Parameter.empty, + "extra_params": Parameter.empty, + "param2": Parameter.empty, + }, + ["extra_params"], + ["param1", "param2", "provider", "extra_params"], + ), + ( + { + "param2": Parameter.empty, + "any_kwargs": Parameter.empty, + "provider": Parameter.empty, + "param1": Parameter.empty, + }, + ["any_kwargs"], + ["param2", "param1", "provider", "any_kwargs"], + ), + ( + { + "any_kwargs": Parameter.empty, + "extra_params": Parameter.empty, + "provider": Parameter.empty, + "param1": Parameter.empty, + "param2": Parameter.empty, + }, + ["any_kwargs", "extra_params"], + ["param1", "param2", "provider", "any_kwargs", "extra_params"], + ), + ], +) +def test_reorder_params(method_definition, params, var_kw, expected): + """Test reorder params, ensure var_kw are last after 'provider'.""" + result = method_definition.reorder_params(params, var_kw) assert result - assert list(result.keys()) == ["param1", "param2", "provider", "extra_params"] + assert list(result.keys()) == expected def test_build_func_params(method_definition): diff --git a/openbb_platform/core/tests/app/test_command_runner.py b/openbb_platform/core/tests/app/test_command_runner.py index 3447d66407bf..cd751e9c592d 100644 --- a/openbb_platform/core/tests/app/test_command_runner.py +++ b/openbb_platform/core/tests/app/test_command_runner.py @@ -16,6 +16,7 @@ from openbb_core.app.model.command_context import CommandContext from openbb_core.app.model.system_settings import SystemSettings from openbb_core.app.model.user_settings import UserSettings +from openbb_core.app.provider_interface import ExtraParams from openbb_core.app.router import CommandMap from pydantic import BaseModel, ConfigDict @@ -228,45 +229,57 @@ def test_parameters_builder_validate_kwargs(mock_func): @pytest.mark.parametrize( - "provider_choices, extra_params, expect", + "provider_choices, extra_params, base, expect", [ ( {"provider": "provider1"}, {"exists_in_2": ...}, + ExtraParams, OpenBBWarning, ), ( {"provider": "inexistent_provider"}, {"exists_in_both": ...}, + ExtraParams, OpenBBWarning, ), ( {}, {"inexistent_field": ...}, + ExtraParams, OpenBBWarning, ), + ( + {}, + {"inexistent_field": ...}, + object, + None, + ), ( {"provider": "provider2"}, {"exists_in_2": ...}, + ExtraParams, None, ), ( {"provider": "provider2"}, {"exists_in_both": ...}, + ExtraParams, None, ), ( {}, {"exists_in_both": ...}, + ExtraParams, None, ), ], ) -def test_parameters_builder__warn_kwargs(provider_choices, extra_params, expect): +def test_parameters_builder__warn_kwargs(provider_choices, extra_params, base, expect): """Test _warn_kwargs.""" @dataclass - class SomeModel: + class SomeModel(base): """SomeModel""" exists_in_2: QueryParam = Query(..., title="provider2")