From a42f14de5a93164c2e75dbc9acaf720cb2eb0782 Mon Sep 17 00:00:00 2001 From: qmac Date: Sun, 1 Apr 2018 22:21:38 -0500 Subject: [PATCH] consolidated some batch caching code --- pliers/transformers/base.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/pliers/transformers/base.py b/pliers/transformers/base.py index aafe6e8d..aaddca9e 100644 --- a/pliers/transformers/base.py +++ b/pliers/transformers/base.py @@ -236,18 +236,13 @@ def _iterate(self, stims, validation='strict', *args, **kwargs): results = [] for batch in progress_bar_wrapper(batches): use_cache = config.get_option('cache_transformers') - target_inds = [] + target_inds = {} non_cached = [] - transformed_keys = set() for stim in batch: key = hash((hash(self), hash(stim))) - if use_cache and (key in _cache or key in transformed_keys): - target_inds.append(-1) # signals to query cache - else: - target_inds.append(len(non_cached)) + if not (use_cache and (key in _cache or key in target_inds)): + target_inds[key] = len(non_cached) non_cached.append(stim) - # Can't use _cache in case _transform fails - transformed_keys.add(key) if len(non_cached) > 0: batch_results = self._transform(non_cached, *args, **kwargs) @@ -256,10 +251,8 @@ def _iterate(self, stims, validation='strict', *args, **kwargs): for i, stim in enumerate(batch): key = hash((hash(self), hash(stim))) - if target_inds[i] == -1: - results.append(_cache[key]) - else: - result = batch_results[target_inds[i]] + if key in target_inds: + result = batch_results[target_inds[key]] result = _log_transformation(stim, result, self) self._propagate_context(stim, result) if use_cache: @@ -267,6 +260,8 @@ def _iterate(self, stims, validation='strict', *args, **kwargs): result = list(result) _cache[key] = result results.append(result) + else: + results.append(_cache[key]) return results def _transform(self, stim, *args, **kwargs):