Skip to content

Commit

Permalink
Adding support for the "Custom_Model" by providing a .shc file.
Browse files Browse the repository at this point in the history
  • Loading branch information
smithara committed Dec 9, 2018
1 parent c7cc41e commit 658fd69
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 6 deletions.
40 changes: 34 additions & 6 deletions viresclient/_client_swarm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import json
from collections import OrderedDict
import os

from ._wps.environment import JINJA2_ENVIRONMENT
from ._wps import time_util
Expand Down Expand Up @@ -106,7 +107,8 @@ def __init__(self,
variables=None,
filters=None,
sampling_step=None,
response_type=None):
response_type=None,
custom_shc=None):
# Set up default values
# Obligatory - these must be replaced before the request is made
self.collection_ids = None if collection_ids is None else collection_ids
Expand All @@ -119,6 +121,7 @@ def __init__(self,
self.variables = [] if variables is None else variables
self.filters = None if filters is None else filters
self.sampling_step = None if sampling_step is None else sampling_step
self.custom_shc = None if custom_shc is None else custom_shc

self.names = ('collection_ids',
'model_ids',
Expand All @@ -127,7 +130,8 @@ def __init__(self,
'variables',
'filters',
'sampling_step',
'response_type'
'response_type',
'custom_shc'
)

@property
Expand Down Expand Up @@ -226,6 +230,17 @@ def response_type(self, response_type):
else:
raise TypeError

@property
def custom_shc(self):
return self._custom_shc

@custom_shc.setter
def custom_shc(self, custom_shc):
if isinstance(custom_shc, str) or custom_shc is None:
self._custom_shc = custom_shc
else:
raise TypeError


class SwarmRequest(ClientRequest):
"""Handles the requests to and downloads from the server.
Expand Down Expand Up @@ -300,7 +315,8 @@ def _set_available_data():
MMA_SHA_2C-Primary, MMA_SHA_2C-Secondary,
MMA_SHA_2F-Primary, MMA_SHA_2F-Secondary,
MIO_SHA_2C-Primary, MIO_SHA_2C-Secondary,
MIO_SHA_2D-Primary, MIO_SHA_2D-Secondary
MIO_SHA_2D-Primary, MIO_SHA_2D-Secondary,
Custom_Model
""".replace("\n", "").replace(" ", "").split(",")

auxiliaries = """
Expand Down Expand Up @@ -450,8 +466,8 @@ def set_collection(self, collection):
self._collection = collection
self._request_inputs.set_collection(collection)

def set_products(self, measurements=None, models=None, auxiliaries=None,
residuals=False, sampling_step=None
def set_products(self, measurements=None, models=None, custom_model=None,
auxiliaries=None, residuals=False, sampling_step=None
):
"""Set the combination of products to retrieve.
Expand All @@ -461,6 +477,7 @@ def set_products(self, measurements=None, models=None, auxiliaries=None,
Args:
measurements (list(str)): from .available_measurements(collection_key)
models (list(str)): from .available_models()
custom_model (str): path to a custom model in .shc format
auxiliaries (list(str)): from .available_auxiliaries()
residuals (bool): True if only returning measurement-model residual
sampling_step (str): ISO_8601 duration, e.g. 10 seconds: PT10S, 1 minute: PT1M
Expand Down Expand Up @@ -497,6 +514,16 @@ def set_products(self, measurements=None, models=None, auxiliaries=None,
"'{}' not available. Check available with "
"SwarmRequest.available_auxiliaries()".format(variable)
)
# Load the custom .shc file
if custom_model:
if os.path.exists(custom_model):
with open(custom_model) as custom_shc_file:
custom_shc = custom_shc_file.read()
models.append("Custom_Model")
else:
raise OSError("Custom model .shc file not found")
else:
custom_shc = None
# Set up the variables that actually get passed to the WPS request
variables = []

Expand All @@ -513,7 +540,7 @@ def set_products(self, measurements=None, models=None, auxiliaries=None,
"%s_%s" % (variable, model_name)
for model_name in models
)
else: # not a model variable
else: # not a model variable
variables.append(variable)

variables.extend(auxiliaries)
Expand All @@ -522,6 +549,7 @@ def set_products(self, measurements=None, models=None, auxiliaries=None,
self._request_inputs.model_ids = models
self._request_inputs.variables = variables
self._request_inputs.sampling_step = sampling_step
self._request_inputs.custom_shc = custom_shc

def set_range_filter(self, parameter=None, minimum=None, maximum=None):
"""Set a filter to apply.
Expand Down
8 changes: 8 additions & 0 deletions viresclient/_wps/templates/vires_fetch_filtered_data.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
</wps:Data>
</wps:Input>
{% endif -%}
{% if custom_shc -%}
<wps:Input>
<ows:Identifier>shc</ows:Identifier>
<wps:Data>
<wps:ComplexData>{{ custom_shc }}</wps:ComplexData>
</wps:Data>
</wps:Input>
{% endif -%}
<wps:Input>
<ows:Identifier>begin_time</ows:Identifier>
<wps:Data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
</wps:Data>
</wps:Input>
{% endif -%}
{% if custom_shc -%}
<wps:Input>
<ows:Identifier>shc</ows:Identifier>
<wps:Data>
<wps:ComplexData>{{ custom_shc }}</wps:ComplexData>
</wps:Data>
</wps:Input>
{% endif -%}
<wps:Input>
<ows:Identifier>begin_time</ows:Identifier>
<wps:Data>
Expand Down

0 comments on commit 658fd69

Please sign in to comment.