Skip to content

Commit

Permalink
make tests work
Browse files Browse the repository at this point in the history
  • Loading branch information
JoranAngevaare committed Jun 16, 2023
1 parent add7e25 commit a70963e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 36 deletions.
31 changes: 29 additions & 2 deletions optim_esm_tools/analyze/cmip_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .globals import _CMIP_HANDLER_VERSION, _FOLDER_FMT
from .xarray_tools import _native_date_fmt
from optim_esm_tools.plotting.map_maker import MapMaker, make_title
from optim_esm_tools.analyze import tipping_criteria
import logging


Expand All @@ -26,7 +27,7 @@ def __init__(self, path=None, dataset=None) -> None:
self.log.warning(
f'Best is to start {self.__class__.__name__} from a synda path'
)
self.dataset = transfor_ds(dataset)
self.dataset = transform_ds(dataset)
else:
self.dataset = read_ds(path)

Expand All @@ -39,6 +40,8 @@ def log(self):

def transform_ds(
ds: xr.Dataset,
calculate_conditions: ty.Tuple[tipping_criteria._Condition] = None,
condition_kwargs: ty.Mapping = None,
variable_of_interest: ty.Tuple[str] = ('tas',),
max_time: ty.Optional[ty.Tuple[int, int, int]] = (2100, 1, 1),
min_time: ty.Optional[ty.Tuple[int, int, int]] = None,
Expand All @@ -59,7 +62,16 @@ def transform_ds(
_detrend_type (str, optional): Type of detrending applied. Defaults to 'linear'.
_ma_window (int, optional): Moving average window (assumed to be years). Defaults to 10.
"""
return _calculate_variables(
if calculate_conditions is None:
calculate_conditions = (
tipping_criteria.StartEndDifference,
tipping_criteria.StdDetrended,
tipping_criteria.MaxJump,
tipping_criteria.MaxDerivitive,
)
if condition_kwargs is None:
condition_kwargs = dict()
ds = _calculate_variables(
oet.synda_files.format_synda.recast(ds),
min_time,
max_time,
Expand All @@ -69,6 +81,18 @@ def transform_ds(
_detrend_type,
_time_var,
)
for cls in calculate_conditions:
condition = cls(**condition_kwargs)
condition_array = condition.calculate(ds)
condition_array = condition_array.assign_attrs(
dict(
short_description=cls.short_description,
long_description=condition.long_description,
name=condition_array.name,
)
)
ds[condition.short_description] = condition_array
return ds


@oet.utils.timed()
Expand Down Expand Up @@ -97,6 +121,9 @@ def read_ds(
Returns:
xr.Dataset: An xarray dataset with the appropriate variables
"""
if kwargs:
oet.config.get_logger().error(f'Not really advised yet to call with {kwargs}')
_cache = False
post_processed_file = _name_cache_file(
base,
variable_of_interest,
Expand Down
2 changes: 1 addition & 1 deletion optim_esm_tools/optim_esm_conf.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
seconds_to_year = 31557600

[versions]
cmip_handler = 0.1.13
cmip_handler = 0.2.0-prealpha

[display]
progress_bar = True
Expand Down
5 changes: 3 additions & 2 deletions optim_esm_tools/plotting/map_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def set_conditions(self, **condition_kwargs):
]

self.conditions = {
label: (condition.long_description, condition.calculate)
label: (condition.short_description, condition.calculate)
for label, condition in zip(self.labels, conditions)
}
print(self.conditions)
Expand Down Expand Up @@ -168,7 +168,8 @@ def plot_i(self, label, ax=None, coastlines=True, **kw):
def __getattr__(self, item):
print(item)
if item in self.conditions:
_, function = self.conditions[item]
key, function = self.conditions[item]
return self.data_set[key]
key = f'_{item}'
if self._cache:
if not isinstance(self._cache, dict):
Expand Down
34 changes: 3 additions & 31 deletions test/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,38 +49,10 @@ def test_map_maker_time_series(self):
plt.clf()

def test_apply_relative_units(self, unit='relative'):
data_set = oet.analyze.cmip_handler.read_ds(os.path.split(self.ayear_file)[0])
mm = oet.analyze.cmip_handler.MapMaker(data_set=data_set)
from immutabledict import immutabledict
from functools import partial

mm.conditions = immutabledict(
{
'i ii iii iv v vi vii viii ix x'.split()[i]: props
for i, props in enumerate(
zip(
(str(i) for i in range(4)),
[
partial(
oet.analyze.tipping_criteria.running_mean_diff,
unit=unit,
),
partial(
oet.analyze.tipping_criteria.running_mean_std, unit=unit
),
partial(
oet.analyze.tipping_criteria.max_change_xyr, unit=unit
),
partial(
oet.analyze.tipping_criteria.max_derivative, unit=unit
),
],
)
)
}
data_set = oet.analyze.cmip_handler.read_ds(
os.path.split(self.ayear_file)[0], condition_kwargs=dict(unit=unit)
)
for i in mm.conditions.keys():
getattr(mm, i)
mm = oet.analyze.cmip_handler.MapMaker(data_set=data_set)

def test_apply_std_unit(self):
self.test_apply_relative_units(unit='std')
Expand Down

0 comments on commit a70963e

Please sign in to comment.