Skip to content

Commit

Permalink
Merge branch 'master' into add_coveralls
Browse files Browse the repository at this point in the history
  • Loading branch information
JoranAngevaare committed Nov 24, 2020
2 parents 771f2e3 + a214b21 commit 0dcd6b6
Showing 1 changed file with 42 additions and 2 deletions.
44 changes: 42 additions & 2 deletions straxen/corrections_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
except (RuntimeError, FileNotFoundError):
# We might be on a travis job
pass
import straxen
import os
export, __all__ = strax.exporter()


Expand Down Expand Up @@ -138,17 +140,34 @@ def get_elife(self, run_id, model_type, global_version):
else:
raise ValueError(f'model type {model_type} not implemented for electron lifetime')

def get_pmt_gains(self, run_id, model_type, global_version, gain_dtype = np.float32):
def get_pmt_gains(self, run_id, model_type, global_version,
cacheable_versions=('ONLINE',),
gain_dtype=np.float32):
"""
Smart logic to return pmt gains to PE values.
:param run_id: run id from runDB
:param model_type: to_pe_model (gain model)
:param global_version: global version
:param cacheable_versions: versions that are allowed to be
cached in ./resource_cache
:param gain_dtype: dtype of the gains to be returned as array
:return: array of pmt gains to PE values
"""
to_pe = None
cache_name = None

if model_type == 'to_pe_model':
to_pe = self._get_correction(run_id, 'pmt', global_version)
if global_version in cacheable_versions:
# Try to load from cache, if it does not exist it will be created below
cache_name = cacheable_naming(run_id, model_type, global_version)
try:
to_pe = straxen.get_resource(cache_name, fmt='npy')
except (ValueError, FileNotFoundError):
pass

if to_pe is None:
to_pe = self._get_correction(run_id, 'pmt', global_version)

# be cautious with very early runs, check that not all are None
if np.isnan(to_pe).all():
raise ValueError(
Expand All @@ -171,6 +190,14 @@ def get_pmt_gains(self, run_id, model_type, global_version, gain_dtype = np.floa
raise GainsNotFoundError(
f'Gains returned by CMT are None for PMT_i = {pmts_affected}. '
f'Cannot proceed with processing. Report to CMT-maintainers.')

if (cache_name is not None
and global_version in cacheable_versions
and not os.path.exists(cache_name)):
# This is an array we can save since it's in the cacheable
# versions but it has not been saved yet. Next time we need
# it, we can get it from our cache.
np.save(cache_name, to_pe, allow_pickle=False)
return to_pe

def get_lce(self, run_id, s, position, global_version='v1'):
Expand Down Expand Up @@ -213,6 +240,19 @@ def get_start_time(self, run_id):
return time.replace(tzinfo=pytz.utc)


def cacheable_naming(*args, format='.npy', base='./resource_cache/'):
"""Convert args to consistent naming convention for array to be cached"""
if not os.path.exists(base):
try:
os.mkdir(base)
except (FileExistsError, PermissionError):
pass
for arg in args:
if not type(arg) == str:
raise TypeError(f'One or more args of {args} are not strings')
return base + '_'.join(args) + format


class GainsNotFoundError(Exception):
"""Fatal error if a None value is returned by the corrections"""
pass

0 comments on commit 0dcd6b6

Please sign in to comment.