Skip to content

Commit

Permalink
fix: cast loaded env vars type using config type-hints (#987)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunato committed Jan 15, 2024
1 parent fe2a514 commit 270f673
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 50 deletions.
7 changes: 4 additions & 3 deletions eodag/api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from eodag.utils.stac_reader import fetch_stac_items

if TYPE_CHECKING:
from pydantic.fields import FieldInfo
from shapely.geometry.base import BaseGeometry
from whoosh.index import Index

Expand Down Expand Up @@ -2086,7 +2087,7 @@ def list_queryables(
self,
provider: Optional[str] = None,
product_type: Optional[str] = None,
) -> Dict[str, Tuple[Annotated, Any]]:
) -> Dict[str, Tuple[Annotated[Any, FieldInfo], Any]]:
"""Fetch the queryable properties for a given product type and/or provider.
:param provider: (optional) The provider.
Expand All @@ -2096,7 +2097,7 @@ def list_queryables(
:returns: A dict containing the EODAG queryable properties, associating
parameters to a tuple containing their annotaded type and default
value
:rtype: Dict[str, Tuple[Annotated, Any]]
:rtype: Dict[str, Tuple[Annotated[Any, FieldInfo], Any]]
"""
# unknown product type
if product_type is not None and product_type not in self.list_product_types(
Expand All @@ -2106,7 +2107,7 @@ def list_queryables(

# dictionary of the queryable properties of the providers supporting the given product type
providers_available_queryables: Dict[
str, Dict[str, Tuple[Annotated, Any]]
str, Dict[str, Tuple[Annotated[Any, FieldInfo], Any]]
] = dict()

if provider is None and product_type is None:
Expand Down
46 changes: 35 additions & 11 deletions eodag/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import logging
import os
import tempfile
from inspect import isclass
from typing import (
TYPE_CHECKING,
Any,
Dict,
ItemsView,
Expand All @@ -32,20 +32,24 @@
TypedDict,
Union,
ValuesView,
get_type_hints,
)

import orjson
import requests
import yaml
import yaml.constructor
import yaml.parser
from jsonpath_ng import JSONPath
from pkg_resources import resource_filename
from requests.auth import AuthBase

from eodag.utils import (
HTTP_REQ_TIMEOUT,
USER_AGENT,
cached_yaml_load,
cached_yaml_load_all,
cast_scalar_value,
deepcopy,
dict_items_recursive_apply,
merge_mappings,
Expand All @@ -56,10 +60,6 @@
)
from eodag.utils.exceptions import ValidationError

if TYPE_CHECKING:
from jsonpath_ng import JSONPath
from requests.auth import AuthBase

logger = logging.getLogger("eodag.config")

EXT_PRODUCT_TYPES_CONF_URI = (
Expand Down Expand Up @@ -250,7 +250,7 @@ class OrderStatusOnSuccess(TypedDict):
need_auth: bool
result_type: str
results_entry: str
pagination: Pagination
pagination: PluginConfig.Pagination
query_params_key: str
discover_metadata: Dict[str, str]
discover_product_types: Dict[str, Any]
Expand All @@ -266,7 +266,7 @@ class OrderStatusOnSuccess(TypedDict):
merge_responses: bool # PostJsonSearch for aws_eos
collection: bool # PostJsonSearch for aws_eos
max_connections: int # StaticStacSearch
timeout: int # StaticStacSearch
timeout: float # StaticStacSearch

# download -------------------------------------------------------------------------
base_uri: str
Expand All @@ -275,7 +275,7 @@ class OrderStatusOnSuccess(TypedDict):
order_enabled: bool # HTTPDownload
order_method: str # HTTPDownload
order_headers: Dict[str, str] # HTTPDownload
order_status_on_success: OrderStatusOnSuccess
order_status_on_success: PluginConfig.OrderStatusOnSuccess
bucket_path_level: int # S3RestDownload

# auth -----------------------------------------------------------------------------
Expand Down Expand Up @@ -471,13 +471,38 @@ def build_mapping_from_env(
:type mapping: dict
"""
parts = env_var.split("__")
if len(parts) == 1:
iter_parts = iter(parts)
env_type = get_type_hints(PluginConfig).get(next(iter_parts, ""), str)
child_env_type = (
get_type_hints(env_type).get(next(iter_parts, ""), None)
if isclass(env_type)
else None
)
if len(parts) == 2 and child_env_type:
# for nested config (pagination, ...)
# try converting env_value type from type hints
try:
env_value = cast_scalar_value(env_value, child_env_type)
except TypeError:
logger.warning(
f"Could not convert {parts} value {env_value} to {child_env_type}"
)
mapping.setdefault(parts[0], {})
mapping[parts[0]][parts[1]] = env_value
elif len(parts) == 1:
# try converting env_value type from type hints
try:
env_value = cast_scalar_value(env_value, env_type)
except TypeError:
logger.warning(
f"Could not convert {parts[0]} value {env_value} to {env_type}"
)
mapping[parts[0]] = env_value
else:
new_map = mapping.setdefault(parts[0], {})
build_mapping_from_env("__".join(parts[1:]), env_value, new_map)

mapping_from_env = {}
mapping_from_env: Dict[str, Any] = {}
for env_var in os.environ:
if env_var.startswith("EODAG__"):
build_mapping_from_env(
Expand All @@ -500,7 +525,6 @@ def override_config_from_mapping(
:type mapping: dict
"""
for provider, new_conf in mapping.items():
new_conf: Dict[str, Any]
old_conf: Optional[Dict[str, Any]] = config.get(provider)
if old_conf is not None:
old_conf.update(new_conf)
Expand Down
4 changes: 3 additions & 1 deletion eodag/plugins/apis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

if TYPE_CHECKING:
from pydantic.fields import FieldInfo

from eodag.api.product import EOProduct
from eodag.api.search_result import SearchResult
from eodag.config import PluginConfig
Expand Down Expand Up @@ -95,7 +97,7 @@ def discover_product_types(self) -> Optional[Dict[str, Any]]:

def discover_queryables(
self, product_type: Optional[str] = None
) -> Optional[Dict[str, Tuple[Annotated, Any]]]:
) -> Optional[Dict[str, Tuple[Annotated[Any, FieldInfo], Any]]]:
"""Fetch queryables list from provider using `discover_queryables` conf"""
return None

Expand Down
6 changes: 4 additions & 2 deletions eodag/plugins/apis/usgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ def download_request(
stream.raise_for_status()
except RequestException as e:
if e.response and hasattr(e.response, "content"):
error_message = f"{e.response.content} - {e}"
error_message = (
f"{e.response.content.decode('utf-8')} - {e}"
)
else:
error_message = str(e)
raise NotAvailableError(error_message)
Expand All @@ -341,7 +343,7 @@ def download_request(
progress_callback(len(chunk))
except requests.exceptions.Timeout as e:
if e.response and hasattr(e.response, "content"):
error_message = f"{e.response.content} - {e}"
error_message = f"{e.response.content.decode('utf-8')} - {e}"
else:
error_message = str(e)
raise NotAvailableError(error_message)
Expand Down
13 changes: 9 additions & 4 deletions eodag/plugins/download/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def __init__(self, provider: str, config: PluginConfig) -> None:
super(HTTPDownload, self).__init__(provider, config)
if not hasattr(self.config, "base_uri"):
raise MisconfiguredError(
"{} plugin require a base_uri configuration key".format(self.__name__)
"{} plugin require a base_uri configuration key".format(
type(self).__name__
)
)

def orderDownload(
Expand Down Expand Up @@ -166,20 +168,23 @@ def orderDownload(
logger.debug(ordered_message)
logger.info("%s was ordered", product.properties["title"])
except RequestException as e:
if e.response and hasattr(e.response, "content"):
error_message = f"{e.response.content.decode('utf-8')} - {e}"
else:
error_message = str(e)
logger.warning(
"%s could not be ordered, request returned %s",
product.properties["title"],
f"{e.response.content} - {e}",
error_message,
)

order_metadata_mapping = getattr(self.config, "order_on_response", {}).get(
"metadata_mapping", {}
)
if order_metadata_mapping:
logger.debug("Parsing order response to update product metada-mapping")
order_metadata_mapping_jsonpath = {}
order_metadata_mapping_jsonpath = mtd_cfg_as_conversion_and_querypath(
order_metadata_mapping, order_metadata_mapping_jsonpath
order_metadata_mapping,
)
properties_update = properties_from_json(
response.json(),
Expand Down
4 changes: 3 additions & 1 deletion eodag/plugins/search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
)

if TYPE_CHECKING:
from pydantic.fields import FieldInfo

from eodag.api.product import EOProduct
from eodag.config import PluginConfig
from eodag.utils import Annotated
Expand Down Expand Up @@ -90,7 +92,7 @@ def discover_product_types(self) -> Optional[Dict[str, Any]]:

def discover_queryables(
self, product_type: Optional[str] = None
) -> Optional[Dict[str, Tuple[Annotated, Any]]]:
) -> Optional[Dict[str, Tuple[Annotated[Any, FieldInfo], Any]]]:
"""Fetch queryables list from provider using `discover_queryables` conf"""
return None

Expand Down
2 changes: 1 addition & 1 deletion eodag/plugins/search/qssearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,7 @@ def normalize_results(

def discover_queryables(
self, product_type: Optional[str] = None
) -> Optional[Dict[str, Tuple[Annotated, Any]]]:
) -> Optional[Dict[str, Tuple[Annotated[Any, FieldInfo], Any]]]:
"""Fetch queryables list from provider using `discover_queryables` conf
:param product_type: (optional) product type
Expand Down
7 changes: 5 additions & 2 deletions eodag/rest/types/stac_queryables.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, Field

from eodag.types import python_field_definition_to_json
from eodag.utils import Annotated

if TYPE_CHECKING:
from pydantic.fields import FieldInfo


class StacQueryableProperty(BaseModel):
"""A class representing a queryable property.
Expand All @@ -49,7 +52,7 @@ def update_properties(self, new_properties: dict):

@classmethod
def from_python_field_definition(
cls, id: str, python_field_definition: Tuple[Annotated, Any]
cls, id: str, python_field_definition: Tuple[Annotated[Any, FieldInfo], Any]
) -> StacQueryableProperty:
"""Build Model from python_field_definition"""
def_dict = python_field_definition_to_json(python_field_definition)
Expand Down
7 changes: 4 additions & 3 deletions eodag/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def json_field_definition_to_python(


def python_field_definition_to_json(
python_field_definition: Tuple[Annotated, Any]
python_field_definition: Tuple[Annotated[Any, FieldInfo], Any]
) -> Dict[str, Any]:
"""Get json field definition from python `typing.Annotated`
Expand All @@ -144,7 +144,8 @@ def python_field_definition_to_json(
or len(python_field_definition) != 2
):
raise ValidationError(
"%s must be an instance of Tuple[Annotated, Any]" % python_field_definition
"%s must be an instance of Tuple[Annotated[Any, FieldInfo], Any]"
% python_field_definition
)

python_field_annotated = python_field_definition[0]
Expand Down Expand Up @@ -196,7 +197,7 @@ def python_field_definition_to_json(

def model_fields_to_annotated_tuple(
model_fields: Dict[str, FieldInfo]
) -> Dict[str, Tuple[Annotated, Any]]:
) -> Dict[str, Tuple[Annotated[Any, FieldInfo], Any]]:
"""Convert BaseModel.model_fields from FieldInfo to Annotated tuple usable as create_model argument
>>> from pydantic import create_model
Expand Down
55 changes: 33 additions & 22 deletions eodag/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,29 +373,9 @@ def merge_mappings(mapping1: Dict[Any, Any], mapping2: Dict[Any, Any]) -> None:
and current_value_type == list
):
mapping1[m1_keys_lowercase.get(key, key)] = value
elif isinstance(value, str):
# Bool is a type with special meaning in Python, thus the special
# case
if current_value_type is bool:
if value.capitalize() not in ("True", "False"):
raise ValueError(
"Only true or false strings (case insensitive) are "
"allowed for booleans"
)
# Get the real Python value of the boolean. e.g: value='tRuE'
# => eval(value.capitalize())=True.
# str.capitalize() transforms the first character of the string
# to a capital letter
mapping1[m1_keys_lowercase.get(key, key)] = eval(
value.capitalize()
)
else:
mapping1[
m1_keys_lowercase.get(key, key)
] = current_value_type(value)
else:
mapping1[m1_keys_lowercase.get(key, key)] = current_value_type(
value
mapping1[m1_keys_lowercase.get(key, key)] = cast_scalar_value(
value, current_value_type
)
except (TypeError, ValueError):
# Ignore any override value that does not have the same type
Expand Down Expand Up @@ -1398,3 +1378,34 @@ def parse_header(header: str) -> Message:
m = Message()
m["content-type"] = header
return m


def cast_scalar_value(value: Any, new_type: Any) -> Any:
"""Convert a scalar (not nested) value type to the given one
>>> cast_scalar_value('1', int)
1
>>> cast_scalar_value(1, str)
'1'
>>> cast_scalar_value('false', bool)
False
:param value: the scalar value to convert
:param new_type: the wanted type
:returns: scalar value converted to new_type
"""
if isinstance(value, str) and new_type is bool:
# Bool is a type with special meaning in Python, thus the special
# case
if value.capitalize() not in ("True", "False"):
raise ValueError(
"Only true or false strings (case insensitive) are "
"allowed for booleans"
)
# Get the real Python value of the boolean. e.g: value='tRuE'
# => eval(value.capitalize())=True.
# str.capitalize() transforms the first character of the string
# to a capital letter
return eval(value.capitalize())

return new_type(value)

0 comments on commit 270f673

Please sign in to comment.