Skip to content

Commit

Permalink
Skip checking creds (#5737)
Browse files Browse the repository at this point in the history
* skip checking creds

* fix bug

* ok mypy

* add comment

* rename required_credentials

* fix

* fix top_retail bug

* don't require creds in some nasdaq fetchers

* black

* this was not here before

* fix tests

* update fetcher test

* update contributing guidelines
  • Loading branch information
montezdesousa authored and piiq committed Nov 17, 2023
1 parent ab24988 commit 6f98a8a
Show file tree
Hide file tree
Showing 34 changed files with 136 additions and 74 deletions.
24 changes: 21 additions & 3 deletions openbb_platform/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ class <ProviderName>EquityHistoricalData(EquityHistoricalData):

The `Fetcher` class is responsible for making the request to the API endpoint and providing the output.

It will receive the Query Parameters, and it will return the output while leveraging the pydantic model schemas.
It will receive the query parameters, and it will return the output while leveraging the pydantic model schemas.

For the `EquityHistorical` example, this would look like the following:

Expand Down Expand Up @@ -382,7 +382,9 @@ class <ProviderName>EquityHistoricalFetcher(

@staticmethod
def transform_data(
query: <ProviderName>EquityHistoricalQueryParams,
data: dict,
**kwargs: Any,
) -> List[<ProviderName>EquityHistoricalData]:
"""Transform the data to the standard format."""

Expand All @@ -391,6 +393,22 @@ class <ProviderName>EquityHistoricalFetcher(

> Make sure that you're following the TET pattern when building a `Fetcher` - **Transform, Extract, Transform**. See more on this [here](#the-tet-pattern).
By default the credentials declared on each `Provider` are required. This means that before a query is executed, we check that all the credentials are present and if not an exception is raised. If you want to make credentials optional on a given fetcher, even though they are declared on the `Provider`, you can add `require_credentials=False` to the `Fetcher` class. See the following example:

```python
class <ProviderName>EquityHistoricalFetcher(
Fetcher[
<ProviderName>EquityHistoricalQueryParams,
List[<ProviderName>EquityHistoricalData],
]
):
"""Transform the query, extract and transform the data."""

require_credentials = False

...
```

#### Make the provider visible

In order to make the new provider visible to the OpenBB Platform, you'll need to add it to the `__init__.py` file of the `providers/<provider_name>/openbb_<provider_name>/` folder.
Expand All @@ -405,14 +423,14 @@ from openbb_<provider_name>.models.equity_historical import <ProviderName>Equity
name="<provider_name>",
website="<URL to the provider website>",
description="Provider description goes here",
required_credentials=["api_key"],
credentials=["api_key"],
fetcher_dict={
"EquityHistorical": <ProviderName>EquityHistoricalFetcher,
},
)
```

If the provider does not require any credentials, you can remove that parameter. On the other hand, if it requires more than 2 items to authenticate, you can add a list of all the required items to the `required_credentials` list.
If the provider does not require any credentials, you can remove that parameter. On the other hand, if it requires more than 2 items to authenticate, you can add a list of all the required items to the `credentials` list.

After running `pip install .` on `openbb_platform/providers/<provider_name>` your provider should be ready for usage, both from the Python interface and the API.

Expand Down
2 changes: 1 addition & 1 deletion openbb_platform/platform/core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ Steps to create an `OBBject` extension:

```python
from openbb_core.app.model.extension import Extension
ext = Extension(name="example", required_credentials=["some_api_key"])
ext = Extension(name="example", credentials=["some_api_key"])
```

3. Optionally declare an `OBBject` accessor, it will use the extension name:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ class CredentialsLoader:

@staticmethod
def prepare(
required_credentials: Dict[str, Set[str]],
credentials: Dict[str, Set[str]],
) -> Dict[str, Tuple[object, None]]:
"""Prepare credentials map to be used in the Credentials model"""
formatted: Dict[str, Tuple[object, None]] = {}
for origin, creds in required_credentials.items():
for origin, creds in credentials.items():
for c in creds:
# Not sure we should do this, if you require the same credential it breaks
# if c in formatted:
Expand All @@ -61,7 +61,7 @@ def from_obbject(self) -> None:
try:
entry = entry_point.load()
if isinstance(entry, Extension):
for c in entry.required_credentials:
for c in entry.credentials:
self.credentials["obbject"].add(c)
except Exception as e:
traceback.print_exception(type(e), e, e.__traceback__)
Expand All @@ -70,7 +70,7 @@ def from_obbject(self) -> None:
def from_providers(self) -> None:
"""Load credentials from providers"""
self.credentials["providers"] = set()
for c in ProviderInterface().required_credentials:
for c in ProviderInterface().credentials:
self.credentials["providers"].add(c)

def load(self) -> BaseModel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ class Extension:
def __init__(
self,
name: str,
required_credentials: Optional[List[str]] = None,
credentials: Optional[List[str]] = None,
) -> None:
"""Initialize the extension.
Parameters
----------
name : str
Name of the extension.
required_credentials : Optional[List[str]], optional
credentials : Optional[List[str]], optional
List of required credentials, by default None
"""
self.name = name
self.required_credentials = required_credentials or []
self.credentials = credentials or []

@property
def obbject_accessor(self) -> Callable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class ProviderInterface(metaclass=SingletonMeta):
----------
map : MapType
Dictionary of provider information.
required_credentials: List[str]
List of required_credentials.
credentials: List[str]
List of credentials.
model_providers : Dict[str, ProviderChoices]
Dictionary of provider choices by model.
params : Dict[str, Dict[str, Union[StandardParams, ExtraParams]]]
Expand Down Expand Up @@ -110,9 +110,9 @@ def map(self) -> MapType:
return self._map

@property
def required_credentials(self) -> List[str]:
def credentials(self) -> List[str]:
"""Dictionary of required credentials by provider."""
return self._registry_map.required_credentials
return self._registry_map.credentials

@property
def model_providers(self) -> Dict[str, ProviderChoices]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def test_credentials():
with patch(
target="openbb_core.app.model.credentials.ProviderInterface"
) as mock_provider_interface:
mock_provider_interface.required_credentials = {
mock_provider_interface.credentials = {
"benzinga_api_key": (typing.Optional[str], None),
"polygon_api_key": (typing.Optional[str], None),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def test_map(provider_interface):
assert "EquityHistorical" in provider_interface_map


def test_required_credentials(provider_interface):
def test_credentials(provider_interface):
"""Test required credentials."""
required_credentials = provider_interface.required_credentials
assert isinstance(required_credentials, list)
assert len(required_credentials) > 0
credentials = provider_interface.credentials
assert isinstance(credentials, list)
assert len(credentials) > 0


def test_model_providers(provider_interface):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __get__(self, obj, owner):
class Fetcher(Generic[Q, R]):
"""Abstract class for the fetcher."""

# Tell query executor if credentials are required. Can be overridden by subclasses.
require_credentials = True

@staticmethod
def transform_query(params: Dict[str, Any]) -> Q:
"""Transform the params to the provider-specific query."""
Expand Down Expand Up @@ -108,6 +111,9 @@ def test(
data = cls.extract_data(query=query, credentials=credentials, **kwargs)
transformed_data = cls.transform_data(query=query, data=data, **kwargs)

# Class Assertions
assert isinstance(cls.require_credentials, bool)

# Query Assertions
assert query
assert issubclass(type(query), cls.query_params_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(
name: str,
description: str,
website: Optional[str] = None,
required_credentials: Optional[List[str]] = None,
credentials: Optional[List[str]] = None,
fetcher_dict: Optional[Dict[str, Type[Fetcher]]] = None,
) -> None:
"""Initialize the provider.
Expand All @@ -26,7 +26,7 @@ def __init__(
Description of the provider.
website : Optional[str]
Website of the provider, by default None.
required_credentials : Optional[List[str]], optional
credentials : Optional[List[str]], optional
List of required credentials, by default None
fetcher_dict : Optional[Dict[str, Type[Fetcher]]]
Dictionary of fetchers, by default None.
Expand All @@ -35,9 +35,9 @@ def __init__(
self.description = description
self.website = website
self.fetcher_dict = fetcher_dict or {}
if required_credentials is None:
self.required_credentials: List = []
if credentials is None:
self.credentials: List = []
else:
self.required_credentials = []
for rq in required_credentials:
self.required_credentials.append(f"{self.name.lower()}_{rq}")
self.credentials = []
for c in credentials:
self.credentials.append(f"{self.name.lower()}_{c}")
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,24 @@ def get_fetcher(self, provider: Provider, model_name: str) -> Type[Fetcher]:

@staticmethod
def filter_credentials(
provider: Provider, credentials: Optional[Dict[str, SecretStr]]
credentials: Optional[Dict[str, SecretStr]],
provider: Provider,
require_credentials: bool,
) -> Dict[str, str]:
"""Filter credentials and check if they match provider requirements."""
if provider.required_credentials is not None:
filtered_credentials = {}

if provider.credentials:
if credentials is None:
credentials = {}

filtered_credentials = {}
for c in provider.required_credentials:
for c in provider.credentials:
credential_value = credentials.get(c)
if c not in credentials or credential_value is None:
raise ProviderError(f"Missing credential '{c}'.")
filtered_credentials[c] = credential_value.get_secret_value()
if require_credentials:
raise ProviderError(f"Missing credential '{c}'.")
else:
filtered_credentials[c] = credential_value.get_secret_value()

return filtered_credentials

Expand Down Expand Up @@ -80,8 +85,10 @@ def execute(
Query result.
"""
provider = self.get_provider(provider_name)
filtered_credentials = self.filter_credentials(provider, credentials)
fetcher = self.get_fetcher(provider, model_name)
filtered_credentials = self.filter_credentials(
credentials, provider, fetcher.require_credentials
)

try:
return fetcher.fetch_data(params, filtered_credentials, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class RegistryMap:
def __init__(self, registry: Optional[Registry] = None) -> None:
"""Initialize Registry Map."""
self._registry = registry or RegistryLoader.from_extensions()
self._required_credentials = self._get_required_credentials(self._registry)
self._credentials = self._get_credentials(self._registry)
self._available_providers = self._get_available_providers(self._registry)
self._map, self._return_map = self._get_map(self._registry)
self._models = self._get_models(self._map)
Expand All @@ -36,9 +36,9 @@ def available_providers(self) -> List[str]:
return self._available_providers

@property
def required_credentials(self) -> List[str]:
def credentials(self) -> List[str]:
"""Get list of required credentials."""
return self._required_credentials
return self._credentials

@property
def map(self) -> MapType:
Expand All @@ -55,11 +55,11 @@ def models(self) -> List[str]:
"""Get available models."""
return self._models

def _get_required_credentials(self, registry: Registry) -> List[str]:
def _get_credentials(self, registry: Registry) -> List[str]:
"""Get list of required credentials."""
cred_list = []
for provider in registry.providers.values():
for c in provider.required_credentials:
for c in provider.credentials:
cred_list.append(c)
return cred_list

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_provider_initialization():
assert provider.name == "TestProvider"
assert provider.description == "A simple test provider."
assert provider.website is None
assert provider.required_credentials == []
assert provider.credentials == []
assert provider.fetcher_dict == {}


Expand All @@ -20,25 +20,25 @@ def test_provider_with_optional_parameters():
name="TestProvider",
description="A simple test provider.",
website="https://testprovider.example.com",
required_credentials=["api_key"],
credentials=["api_key"],
fetcher_dict={"fetcher1": None},
)

assert provider.name == "TestProvider"
assert provider.description == "A simple test provider."
assert provider.website == "https://testprovider.example.com"
assert provider.required_credentials == ["testprovider_api_key"]
assert provider.credentials == ["testprovider_api_key"]
assert provider.fetcher_dict == {"fetcher1": None}


def test_provider_required_credentials_formatting():
def test_provider_credentials_formatting():
"""Test the formatting of required credentials."""
required_credentials = ["key1", "key2"]
credentials = ["key1", "key2"]
provider = Provider(
name="TestProvider",
description="A simple test provider.",
required_credentials=required_credentials,
credentials=credentials,
)

expected_credentials = ["testprovider_key1", "testprovider_key2"]
assert provider.required_credentials == expected_credentials
assert provider.credentials == expected_credentials
33 changes: 26 additions & 7 deletions openbb_platform/platform/provider/tests/test_query_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,40 @@ def test_get_fetcher_failure(mock_query_executor):
def test_filter_credentials_success(mock_query_executor):
"""Test if credentials are properly filtered."""
provider = mock_query_executor.get_provider("test_provider")
provider.required_credentials = ["api_key"]
credentials = {"api_key": SecretStr("12345")}
provider.credentials = ["test_provider_api_key"]
credentials = {
"test_provider_api_key": SecretStr("12345"),
"other_api_key": SecretStr("12345"),
}

filtered_credentials = mock_query_executor.filter_credentials(provider, credentials)
filtered_credentials = mock_query_executor.filter_credentials(
credentials, provider, True
)

assert filtered_credentials == {"api_key": "12345"}
assert filtered_credentials == {"test_provider_api_key": "12345"}


def test_filter_credentials_missing(mock_query_executor):
def test_filter_credentials_missing_require(mock_query_executor):
"""Test if the proper error is raised when a credential is missing."""
provider = mock_query_executor.get_provider("test_provider")
provider.required_credentials = ["api_key"]
provider.credentials = ["test_provider_api_key"]
credentials = {"other_api_key": SecretStr("12345")}

with pytest.raises(ProviderError, match="Missing credential"):
mock_query_executor.filter_credentials(provider, {})
mock_query_executor.filter_credentials(credentials, provider, True)


def test_filter_credentials_missing_dont_require(mock_query_executor):
"""Test if the proper error is raised when a credential is missing."""
provider = mock_query_executor.get_provider("test_provider")
provider.credentials = ["test_provider_api_key"]
credentials = {"other_api_key": SecretStr("12345")}

filtered_credentials = mock_query_executor.filter_credentials(
credentials, provider, False
)

assert filtered_credentials == {}


def test_execute_success(mock_query_executor):
Expand Down
Loading

0 comments on commit 6f98a8a

Please sign in to comment.