Skip to content

Commit

Permalink
fix: add defaults to cds queryables
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunato committed Feb 19, 2024
1 parent 9c61377 commit 56c8be8
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 56 deletions.
96 changes: 56 additions & 40 deletions docs/notebooks/api_user_guide/4_search.ipynb

Large diffs are not rendered by default.

22 changes: 20 additions & 2 deletions eodag/plugins/apis/base.py
Expand Up @@ -20,20 +20,21 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from pydantic.fields import FieldInfo
from pydantic.fields import Field, FieldInfo

if TYPE_CHECKING:
from eodag.api.product import EOProduct
from eodag.api.search_result import SearchResult
from eodag.config import PluginConfig
from eodag.utils import DownloadedCallback, ProgressCallback, Annotated
from eodag.utils import DownloadedCallback, ProgressCallback

from eodag.plugins.base import PluginTopic
from eodag.utils import (
DEFAULT_DOWNLOAD_TIMEOUT,
DEFAULT_DOWNLOAD_WAIT,
DEFAULT_ITEMS_PER_PAGE,
DEFAULT_PAGE,
Annotated,
)

logger = logging.getLogger("eodag.apis.base")
Expand Down Expand Up @@ -108,6 +109,23 @@ def discover_queryables(
"""
return None

def get_defaults_as_queryables(
self, product_type: str
) -> Dict[str, Annotated[Any, FieldInfo]]:
"""
Return given product type defaut settings as queryables
:param product_type: given product type
:type product_type: str
:returns: queryable parameters dict
:rtype: Dict[str, Annotated[Any, FieldInfo]]
"""
defaults = self.config.products.get(product_type, {})
queryables = {}
for parameter, value in defaults.items():
queryables[parameter] = Annotated[type(value), Field(default=value)]
return queryables

def download(
self,
product: EOProduct,
Expand Down
20 changes: 11 additions & 9 deletions eodag/plugins/apis/cds.py
Expand Up @@ -76,8 +76,6 @@
logger = logging.getLogger("eodag.apis.cds")

CDS_KNOWN_FORMATS = {"grib": "grib", "netcdf": "nc"}
# always available queryables (needed as not available in constraints)
CDS_ALLOWED_QUERYABLES = ["format"]


class CdsApi(HTTPDownload, Api, BuildPostSearchResult):
Expand Down Expand Up @@ -481,6 +479,12 @@ def discover_queryables(
constraints = fetch_constraints(constraints_file_url, self)
if not constraints:
return {}

# defaults
default_queryables = self.get_defaults_as_queryables(product_type)
# remove dataset from queryables
default_queryables.pop("dataset", None)

constraint_params: Dict[str, Dict[str, Set[Any]]] = {}
if len(kwargs) == 0:
# get values from constraints without additional filters
Expand All @@ -493,11 +497,7 @@ def discover_queryables(
constraint_params[key]["enum"] = set(constraint[key])
else:
# get values from constraints with additional filters
constraints_input_params = {
k: v
for k, v in non_empty_kwargs.items()
if k not in CDS_ALLOWED_QUERYABLES
}
constraints_input_params = {k: v for k, v in non_empty_kwargs.items()}
constraint_params = get_constraint_queryables_with_additional_params(
constraints, constraints_input_params, self, product_type
)
Expand All @@ -506,7 +506,9 @@ def discover_queryables(
not_queryables = set()
for constraint_param in constraint_params["not_available"]["enum"]:
param = CommonQueryables.get_queryable_from_alias(constraint_param)
if param in CommonQueryables.model_fields:
if param in dict(
CommonQueryables.model_fields, **default_queryables
):
non_empty_kwargs.pop(constraint_param)
else:
not_queryables.add(constraint_param)
Expand Down Expand Up @@ -535,4 +537,4 @@ def discover_queryables(
field_definitions[param] = get_args(annotated_def)

python_queryables = create_model("m", **field_definitions).model_fields
return model_fields_to_annotated(python_queryables)
return dict(default_queryables, **model_fields_to_annotated(python_queryables))
21 changes: 19 additions & 2 deletions eodag/plugins/search/base.py
Expand Up @@ -20,7 +20,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from pydantic.fields import FieldInfo
from pydantic.fields import Field, FieldInfo

from eodag.api.product.metadata_mapping import (
DEFAULT_METADATA_MAPPING,
Expand All @@ -31,13 +31,13 @@
DEFAULT_ITEMS_PER_PAGE,
DEFAULT_PAGE,
GENERIC_PRODUCT_TYPE,
Annotated,
format_dict_items,
)

if TYPE_CHECKING:
from eodag.api.product import EOProduct
from eodag.config import PluginConfig
from eodag.utils import Annotated

logger = logging.getLogger("eodag.search.base")

Expand Down Expand Up @@ -103,6 +103,23 @@ def discover_queryables(
"""
return None

def get_defaults_as_queryables(
self, product_type: str
) -> Dict[str, Annotated[Any, FieldInfo]]:
"""
Return given product type defaut settings as queryables
:param product_type: given product type
:type product_type: str
:returns: queryable parameters dict
:rtype: Dict[str, Annotated[Any, FieldInfo]]
"""
defaults = self.config.products.get(product_type, {})
queryables = {}
for parameter, value in defaults.items():
queryables[parameter] = Annotated[type(value), Field(default=value)]
return queryables

def map_product_type(
self, product_type: Optional[str], **kwargs: Any
) -> Optional[str]:
Expand Down
4 changes: 2 additions & 2 deletions tests/units/test_apis_plugins.py
Expand Up @@ -925,14 +925,14 @@ def test_plugins_apis_cds_discover_queryables(self, mock_requests_constraints):
queryables = self.api_plugin.discover_queryables(
productType="CAMS_EU_AIR_QUALITY_RE"
)
self.assertEqual(8, len(queryables))
self.assertEqual(12, len(queryables))
self.assertIn("variable", queryables)
# with additional param
queryables = self.api_plugin.discover_queryables(
productType="CAMS_EU_AIR_QUALITY_RE",
variable="a",
)
self.assertEqual(8, len(queryables))
self.assertEqual(12, len(queryables))
queryable = queryables.get("variable")
self.assertEqual("a", queryable.__metadata__[0].get_default())
queryable = queryables.get("month")
Expand Down
2 changes: 1 addition & 1 deletion tests/units/test_http_server.py
Expand Up @@ -1205,7 +1205,7 @@ def test_product_type_queryables_from_constraints(self, mock_requests_constraint
headers=USER_AGENT,
timeout=5,
)
self.assertEqual(9, len(res["properties"]))
self.assertEqual(10, len(res["properties"]))
self.assertIn("year", res["properties"])
self.assertIn("ids", res["properties"])
self.assertIn("geometry", res["properties"])
Expand Down

0 comments on commit 56c8be8

Please sign in to comment.