Skip to content

Commit

Permalink
Rebased and fixed compatibility with PR geopython#658:
Browse files Browse the repository at this point in the history
* Fixed EDR provider signature (added locale)
* Fixed EDR API routes and query function (and improved parameter-name handling)
* Fixed EDR tests
  • Loading branch information
GeoSander committed Mar 16, 2021
1 parent 089162e commit 2dad88e
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 117 deletions.
108 changes: 48 additions & 60 deletions pygeoapi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,8 @@ def __init__(self, request, supported_locales):
self._path_info = request.headers.environ['PATH_INFO'].strip('/')

# Extract locale from params or headers
# _l_param stores a boolean -> True if language was found in query str
self._raw_locale, self._locale, self._l_param = \
self._get_locale(request.headers, supported_locales)
self._raw_locale, self._locale = self._get_locale(request.headers,
supported_locales)

# Determine format
self._format = self._get_format(request.headers)
Expand All @@ -174,18 +173,18 @@ def _get_params(request):
return {}

def _get_locale(self, headers, supported_locales):
""" Detects locale from "l=<language>" or Accept-Language header.
Returns a tuple of (raw, locale, True) if found in the query params.
Returns a tuple of (raw, locale, False) if found in headers.
Returns a tuple of (raw, default locale, False) if not found.
""" Detects locale from "l=<language>" param or Accept-Language header.
Returns a tuple of (raw, locale) if found in params or headers.
Returns a tuple of (raw default, default locale) if not found.
:param headers: A dict with Request headers
:param supported_locales: List or set of supported Locale instances
:returns: A tuple of (Locale, bool)
:returns: A tuple of (str, Locale)
"""
raw = None
try:
default_locale = l10n.str2locale(supported_locales[0])
default_str = l10n.locale2str(default_locale)
except (TypeError, IndexError, l10n.LocaleError) as err:
# This should normally not happen, since the API class already
# loads the supported languages from the config, which raises
Expand All @@ -203,11 +202,11 @@ def _get_locale(self, headers, supported_locales):
raw = loc_str
# Check of locale string is a good match for the UI
loc = l10n.best_match(loc_str, supported_locales)
precedence = func is l10n.locale_from_params
if loc != default_locale or precedence:
return raw, loc, precedence
is_override = func is l10n.locale_from_params
if loc != default_locale or is_override:
return raw, loc

return raw or supported_locales[0], default_locale, False
return raw or default_str, default_locale

def _get_format(self, headers) -> Union[str, None]:
"""
Expand Down Expand Up @@ -782,7 +781,7 @@ def describe_collections(self, request: Union[APIRequest, Any], dataset=None):
try:
p = load_plugin('provider', get_provider_by_type(
self.config['resources'][dataset]['providers'],
'edr'))
'edr'), request.raw_locale)
parameters = p.get_fields()
if parameters:
collection['parameters'] = {}
Expand Down Expand Up @@ -2329,116 +2328,102 @@ def delete_process_job(self, process_id, job_id):
LOGGER.info(response)
return {}, http_status, response

def get_collection_edr_query(self, headers, args, dataset, instance,
query_type):
@pre_process
def get_collection_edr_query(self, request: Union[APIRequest, Any],
dataset, instance, query_type):
"""
Queries collection EDR
:param headers: dict of HTTP headers
:param args: dict of HTTP request parameters
:param request: APIRequest instance with query params
:param dataset: dataset name
:param dataset: instance name
:param instance: instance name
:param query_type: EDR query type
:returns: tuple of headers, status code, content
"""

headers_ = HEADERS.copy()

query_args = {}
formats = FORMATS
formats.extend(f.lower() for f in PLUGINS['formatter'].keys())
if not request.is_valid(PLUGINS['formatter'].keys()):
return self.get_format_exception(request)
headers = request.get_response_headers()

collections = filter_dict_by_key_value(self.config['resources'],
'type', 'collection')

format_ = check_format(args, headers)

if dataset not in collections.keys():
msg = 'Invalid collection'
return self.get_exception(
400, headers_, format_, 'InvalidParameterValue', msg)

if format_ is not None and format_ not in formats:
msg = 'Invalid format'
return self.get_exception(
400, headers_, format_, 'InvalidParameterValue', msg)
400, headers, request.format, 'InvalidParameterValue', msg)

LOGGER.debug('Processing query parameters')

LOGGER.debug('Processing datetime parameter')
datetime_ = args.get('datetime')
datetime_ = request.params.get('datetime')
try:
datetime_ = validate_datetime(collections[dataset]['extents'],
datetime_)
except ValueError as err:
msg = str(err)
return self.get_exception(
400, headers_, format_, 'InvalidParameterValue', msg)
400, headers, request.format, 'InvalidParameterValue', msg)

LOGGER.debug('Processing parameter-name parameter')
parameternames = args.get('parameter-name', [])
if parameternames:
parameternames = request.params.get('parameter-name') or []
if isinstance(parameternames, str):
parameternames = parameternames.split(',')

LOGGER.debug('Processing coords parameter')
wkt = args.get('coords', None)
wkt = request.params.get('coords', None)

if wkt is None:
if not wkt:
msg = 'missing coords parameter'
return self.get_exception(
400, headers_, format_, 'InvalidParameterValue', msg)
400, headers, request.format, 'InvalidParameterValue', msg)

try:
wkt = shapely_loads(wkt)
except WKTReadingError:
msg = 'invalid coords parameter'
return self.get_exception(
400, headers_, format_, 'InvalidParameterValue', msg)
400, headers, request.format, 'InvalidParameterValue', msg)

LOGGER.debug('Processing z parameter')
z = args.get('z')
z = request.params.get('z')

LOGGER.debug('Loading provider')
try:
p = load_plugin('provider', get_provider_by_type(
collections[dataset]['providers'], 'edr'))
collections[dataset]['providers'], 'edr'), request.raw_locale)
except ProviderTypeError:
msg = 'invalid provider type'
return self.get_exception(
500, headers_, format_, 'NoApplicableCode', msg)
500, headers, request.format, 'NoApplicableCode', msg)
except ProviderConnectionError:
msg = 'connection error (check logs)'
return self.get_exception(
500, headers_, format_, 'NoApplicableCode', msg)
500, headers, request.format, 'NoApplicableCode', msg)
except ProviderQueryError:
msg = 'query error (check logs)'
return self.get_exception(
500, headers_, format_, 'NoApplicableCode', msg)
500, headers, request.format, 'NoApplicableCode', msg)

if instance is not None and not p.get_instance(instance):
msg = 'Invalid instance identifier'
return self.get_exception(
400, headers_, format_, 'InvalidParameterValue', msg)
400, headers, request.format, 'InvalidParameterValue', msg)

if query_type not in p.get_query_types():
msg = 'Unsupported query type'
return self.get_exception(
400, headers_, format_, 'InvalidParameterValue', msg)

parametername_matches = list(
filter(
lambda p: p['id'] in parameternames, p.get_fields()['field']
)
)
400, headers, request.format, 'InvalidParameterValue', msg)

if len(parametername_matches) < len(parameternames):
if parameternames and not any((fld['id'] in parameternames)
for fld in p.get_fields()['field']):
msg = 'Invalid parameter-name'
return self.get_exception(
400, headers_, format_, 'InvalidParameterValue', msg)
400, headers, request.format, 'InvalidParameterValue', msg)

query_args = dict(
query_type=query_type,
instance=instance,
format_=format_,
format_=request.format,
datetime_=datetime_,
select_properties=parameternames,
wkt=wkt,
Expand All @@ -2450,20 +2435,23 @@ def get_collection_edr_query(self, headers, args, dataset, instance,
except ProviderNoDataError:
msg = 'No data found'
return self.get_exception(
204, headers_, format_, 'NoMatch', msg)
204, headers, request.format, 'NoMatch', msg)
except ProviderQueryError:
msg = 'query error (check logs)'
return self.get_exception(
500, headers_, format_, 'NoApplicableCode', msg)
500, headers, request.format, 'NoApplicableCode', msg)

if format_ == 'html': # render
headers_['Content-Type'] = 'text/html'
if p.locale:
# If provider supports locales, override/set response locale
headers['Content-Language'] = p.locale

if request.format == 'html': # render
content = render_j2_template(
self.config, 'collections/edr/query.html', data)
else:
content = to_json(data, self.pretty_print)

return headers_, 200, content
return headers, 200, content

@pre_process
@jsonldify
Expand Down
13 changes: 2 additions & 11 deletions pygeoapi/flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,18 +359,9 @@ def get_collection_edr_query(collection_id, instance_id=None):
:returns: HTTP response
"""

query_type = request.path.split('/')[-1]

headers, status_code, content = api_.get_collection_edr_query(
request.headers, request.args, collection_id, instance_id, query_type)

response = make_response(content, status_code)

if headers:
response.headers = headers

return response
return get_response(api_.get_collection_edr_query(request, collection_id,
instance_id, query_type))


@BLUEPRINT.route('/stac')
Expand Down
4 changes: 2 additions & 2 deletions pygeoapi/provider/xarray_edr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
class XarrayEDRProvider(XarrayProvider):
"""EDR Provider"""

def __init__(self, provider_def):
def __init__(self, provider_def, requested_locale=None):
"""
Initialize object
Expand All @@ -47,7 +47,7 @@ def __init__(self, provider_def):
:returns: pygeoapi.provider.rasterio_.RasterioProvider
"""

XarrayProvider.__init__(self, provider_def)
XarrayProvider.__init__(self, provider_def, requested_locale)
self.instances = []

def get_fields(self):
Expand Down
15 changes: 3 additions & 12 deletions pygeoapi/starlette_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,18 +398,9 @@ async def get_collection_edr_query(request: Request, collection_id=None, instanc
if 'instance_id' in request.path_params:
instance_id = request.path_params['instance_id']

query_type = request.path.split('/')[-1]

headers, status_code, content = api_.get_collection_edr_query(
request.headers, request.query_params, collection_id, instance_id,
query_type)

response = Response(content=content, status_code=status_code)

if headers:
response.headers.update(headers)

return response
query_type = request.path.split('/')[-1] # noqa
return get_response(api_.get_collection_edr_query(request, collection_id,
instance_id, query_type))


@app.route('/stac')
Expand Down
Loading

0 comments on commit 2dad88e

Please sign in to comment.