From 199a87ccaa17719a3923bb67edebc5ee01cd50b2 Mon Sep 17 00:00:00 2001 From: Joran Angevaare Date: Tue, 20 Jun 2023 09:34:42 +0200 Subject: [PATCH 1/8] add region finder --- optim_esm_tools/_test_utils.py | 53 ++- optim_esm_tools/analyze/__init__.py | 1 + optim_esm_tools/analyze/cmip_handler.py | 7 +- optim_esm_tools/analyze/region_finding.py | 424 ++++++++++++++++++++ optim_esm_tools/analyze/tipping_criteria.py | 13 +- optim_esm_tools/cmip_files/find_matches.py | 2 +- optim_esm_tools/optim_esm_conf.ini | 4 +- optim_esm_tools/plotting/map_maker.py | 23 +- test/test_pangeo_download.py | 6 +- test/test_workflow_pangeo.py | 88 ++++ 10 files changed, 577 insertions(+), 44 deletions(-) create mode 100644 optim_esm_tools/analyze/region_finding.py create mode 100644 test/test_workflow_pangeo.py diff --git a/optim_esm_tools/_test_utils.py b/optim_esm_tools/_test_utils.py index 5fac5dc3..318f52d7 100644 --- a/optim_esm_tools/_test_utils.py +++ b/optim_esm_tools/_test_utils.py @@ -4,35 +4,60 @@ EXMPLE_DATA_SET = 'CMIP6/ScenarioMIP/CCCma/CanESM5/ssp585/r3i1p2f1/Amon/tas/gn/v20190429/tas_Amon_CanESM5_ssp585_r3i1p2f1_gn_201501-210012.nc' -def get_file_from_pangeo(): +def get_file_from_pangeo(experiment_id='ssp585', refresh=True): + dest_folder = os.path.split( + get_example_data_loc().replace('ssp585', experiment_id) + )[0] + write_to = os.path.join(dest_folder, 'test.nc') + if os.path.exists(write_to) and not refresh: + print(f'already file at {write_to}') + return write_to + from xmip.utils import google_cmip_col - # import cftime col = google_cmip_col() - search = col.search( + query = dict( source_id='CanESM5', variable_id='tas', table_id='Amon', - experiment_id='ssp585', - member_id=['r3i1p2f1'], + experiment_id=experiment_id, ) + if experiment_id in ['historical', 'ssp585']: + query.update(dict(member_id=['r3i1p2f1'])) + else: + query.update(dict(member_id=['r1i1p1f1'])) + search = col.search(**query) ddict = search.to_dataset_dict( xarray_open_kwargs={'use_cftime': True}, ) data = list(ddict.values())[0] - # data = data.groupby('time.year').mean('time') - # data = data.rename(year='time') - data = data.mean(set(data.dims) - {'x', 'y', 'time'}) - # data['time'] = [cftime.DatetimeNoLeap(y,1,1) for y in data['time']] - write_to = get_example_data_loc() - dest_folder = os.path.split(write_to)[0] + data = data.mean(set(data.dims) - {'x', 'y', 'lat', 'lon', 'time'}) + os.makedirs(dest_folder, exist_ok=True) - if os.path.exists(write_to): - print(f'already file at {write_to}') - write_to = os.path.join(dest_folder, 'test.nc') data.to_netcdf(write_to) + return write_to + + +def year_means(path, refresh=True): + new_dir = os.path.split(path.replace('Amon', 'AYear'))[0] + new_dest = os.path.join(new_dir, 'test_merged.nc') + if os.path.exists(new_dest) and not refresh: + print(f'File at {new_dest} already exists') + return new_dest + import cftime + import optim_esm_tools as oet + + data = oet.cmip_files.io.load_glob(path) + + data = data.groupby('time.year').mean('time') + data = data.rename(year='time') + data['time'] = [cftime.DatetimeNoLeap(y, 1, 1) for y in data['time']] + + os.makedirs(new_dir, exist_ok=True) + data.to_netcdf(new_dest) + return new_dest def get_synda_loc(): diff --git a/optim_esm_tools/analyze/__init__.py b/optim_esm_tools/analyze/__init__.py index 441b5c5f..aad53741 100644 --- a/optim_esm_tools/analyze/__init__.py +++ b/optim_esm_tools/analyze/__init__.py @@ -3,3 +3,4 @@ from . import xarray_tools from . import clustering +from . import region_finding diff --git a/optim_esm_tools/analyze/cmip_handler.py b/optim_esm_tools/analyze/cmip_handler.py index b340f136..cd0f23db 100644 --- a/optim_esm_tools/analyze/cmip_handler.py +++ b/optim_esm_tools/analyze/cmip_handler.py @@ -103,6 +103,7 @@ def read_ds( min_time: ty.Optional[ty.Tuple[int, int, int]] = None, _ma_window: int = 10, _cache: bool = True, + _file_name: str = 'merged.nc', **kwargs, ) -> xr.Dataset: """Read a dataset from a folder called "base". @@ -136,7 +137,7 @@ def read_ds( if os.path.exists(post_processed_file) and _cache: return oet.synda_files.format_synda.load_glob(post_processed_file) - data_path = os.path.join(base, 'merged.nc') + data_path = os.path.join(base, _file_name) if not os.path.exists(data_path): warn(f'No dataset at {data_path}') return None @@ -178,8 +179,8 @@ def _name_cache_file( path = os.path.join( base, f'{variable_of_interest}' - f'_{min_time if min_time else ""}' - f'_{max_time if max_time else ""}' + f'_s{min_time if min_time else ""}' + f'_e{max_time if max_time else ""}' f'_ma{_ma_window}' f'_optimesm_v{version}.nc', ) diff --git a/optim_esm_tools/analyze/region_finding.py b/optim_esm_tools/analyze/region_finding.py new file mode 100644 index 00000000..3d58176e --- /dev/null +++ b/optim_esm_tools/analyze/region_finding.py @@ -0,0 +1,424 @@ +import os +import optim_esm_tools as oet +from optim_esm_tools.plotting.map_maker import MapMaker +import numpy as np +import matplotlib.pyplot as plt +import typing as ty +from optim_esm_tools.analyze import tipping_criteria +import logging +from optim_esm_tools.analyze.cmip_handler import transform_ds, read_ds +import typing as ty +import matplotlib.pyplot as plt +from functools import wraps + +import inspect +from optim_esm_tools.analyze.clustering import build_cluster_mask +from optim_esm_tools.plotting.plot import setup_map +from immutabledict import immutabledict + +def _show(show): + if show: + plt.show() + else: + plt.clf() + plt.close() + + +def mask_xr_ds(ds_masked, da_mask): + for k, v in ds_masked.data_vars.items(): + if all(xy in list(v.dims) for xy in 'xy'): + ds_masked[k] = ds_masked[k].where(da_mask, drop=False) + return ds_masked + + +def plt_show(*a): + def somedec_outer(fn): + @wraps(fn) + def plt_func(*args, **kwargs): + res = fn(*args, **kwargs) + self = args[0] + if getattr(self, 'show', False): + plt.show() + else: + plt.clf() + plt.close() + return res + + return plt_func + + if a and isinstance(a[0], ty.Callable): + # Decorator that isn't closed + return somedec_outer(a[0]) + return somedec_outer + + +def apply_options(*a): + def somedec_outer(fn): + @wraps(fn) + def timed_func(*args, **kwargs): + self = args[0] + takes = inspect.signature(fn).parameters + kwargs.update({k: v for k, v in self.extra_opt.items() + if k in takes + }) + res = fn(*args, **kwargs) + return res + + return timed_func + + if a and isinstance(a[0], ty.Callable): + # Decorator that isn't closed + return somedec_outer(a[0]) + return somedec_outer + + +class ResultDataSet: + _logger: logging.Logger = None + labels: tuple = tuple('ii iii'.split()) + show: bool = True + + show_basic = True + criteria = (tipping_criteria.StdDetrended, tipping_criteria.MaxJump) + extra_opt = None + + def __init__(self, + variable='tas', + path=None, + dataset=None, + transform=True, + save_kw=None, + extra_opt=None, + read_ds_kw=None) -> None: + if path is None: + if transform: + self.log.warning( + f'Best is to start {self.__class__.__name__} from a synda path' + ) + self.dataset = transform_ds(dataset) + else: + self.dataset = dataset + else: + read_ds_kw = dict() if read_ds_kw is None else read_ds_kw + self.dataset = read_ds(path, **read_ds_kw) + if save_kw is None: + save_kw = dict(save_in='./', + file_types=('png', 'pdf',), + skip=False, + sub_dir=None) + if extra_opt is None: + extra_opt = dict() + self.extra_opt = extra_opt + self.save_kw = save_kw + self.variable = variable + + @property + def log(self): + if self._logger is None: + self._logger = oet.config.get_logger() + return self._logger + + @apply_options + def workflow(self, show_basic=True): + if show_basic: + self.plot_basic_map() + masks = self.get_masks() + self.plot_masks(masks) + self.plot_mask_time_series(masks) + + @plt_show + def plot_basic_map(self): + self._plot_basic_map() + self.save(f'{self.title_label}_global_map') + + def _plot_basic_map(self): + mm = MapMaker(self.dataset) + return mm.plot_all(2) + + def save(self, name): + oet.utils.save_fig(name, **self.save_kw) + + @property + def title(self): + return MapMaker(self.dataset).title + + @property + def title_label(self): + return MapMaker(self.dataset).title.replace(' ', '_') + + +class MaxRegion(ResultDataSet): + def get_masks(self) -> dict: + """Get mask for max of ii and iii and a box arround that""" + labels = [crit.short_description for crit in self.criteria] + masks = {label: self.dataset[label].values == self.dataset[label].values.max() + for label in labels + } + return masks + + @plt_show + def plot_masks(self, masks, ax=None, legend=True): + res = self._plot_masks(masks=masks, ax=ax, legend=legend,) + self.save(f'{self.title_label}_map_maxes_{"-".join(self.labels)}') + + @apply_options + def _plot_masks(self, masks, ax=None, legend=True): + points = {} + for key, mask_2d in masks.items(): + points[key] = self._mask_to_coord(mask_2d) + if ax is None: + oet.plotting.plot.setup_map() + ax = plt.gca() + for i, (label, xy) in enumerate(zip( + self.labels, + points.values()) + ): + ax.scatter(*xy, marker='oxv^'[i], label=f'Maximum {label}') + if legend: + ax.legend(**oet.utils.legend_kw()) + plt.suptitle(self.title, y=0.95) + plt.ylim(-90, 90) + plt.xlim(-180, 180) + + def _mask_to_coord(self, mask_2d): + arg_mask = np.argwhere(mask_2d)[0] + x = self.dataset.x[arg_mask[1]] + y = self.dataset.y[arg_mask[0]] + return x, y + + def _plot_basic_map(self): + mm = MapMaker(self.dataset) + axes = mm.plot_all(2) + masks = self.get_masks() + for ax in axes: + self._plot_masks(masks, ax=ax, legend=False) + plt.suptitle(self.title, y=0.95) + + @plt_show + def plot_mask_time_series(self, masks, time_series_joined=True): + res = self._plot_mask_time_series( + masks, time_series_joined=time_series_joined) + if time_series_joined: + self.save( + f'{self.title_label}_time_series_maxes_{"-".join(self.labels)}') + return res + + @apply_options + def _plot_mask_time_series(self, masks, time_series_joined=True, only_rm=False, axes=None): + legend_kw = oet.utils.legend_kw( + loc='upper left', bbox_to_anchor=None, mode=None, ncol=4) + for label, mask_2d in zip(self.labels, masks.values()): + x, y = self._mask_to_coord(mask_2d) + plot_labels = {f'{self.variable}': f'{label} at {x:.1f}:{y:.1f}', + f'{self.variable}_detrend': f'{label} at {x:.1f}:{y:.1f}', + f'{self.variable}_detrend_run_mean_10': f'$RM_{{10}}$ {label} at {x:.1f}:{y:.1f}', + f'{self.variable}_run_mean_10': f'$RM_{{10}}$ {label} at {x:.1f}:{y:.1f}'} + argwhere = np.argwhere(mask_2d)[0] + ds_sel = self.dataset.isel(x=argwhere[1], y=argwhere[0]) + mm_sel = MapMaker(ds_sel) + axes = mm_sel.time_series( + other_dim=(), + interval=False, + labels=plot_labels, + axes=axes, + only_rm=only_rm) + if time_series_joined is False: + axes = None + plt.suptitle(f'Max. {label} {self.title}', y=0.95) + self.save(f'{self.title_label}_time_series_max_{label}') + _show(self.show) + if not time_series_joined: + return + + for ax in axes: + ax.legend(**legend_kw) + plt.suptitle(f'Max. {"-".join(self.labels)} {self.title}', y=0.95) + + +class Percentiles(ResultDataSet): + @apply_options + def get_masks(self, percentiles=99) -> dict: + """Get mask for max of ii and iii and a box arround that""" + labels = [crit.short_description for crit in self.criteria] + masks = [] + vmin_vmax = [] + + for lab in labels: + arr = self.dataset[lab].values.T + vmin_vmax.append([np.min(arr), np.max(arr)]) + thr = np.percentile(arr, percentiles) + masks.append(arr >= thr) + + all_mask = np.ones_like(masks[0]) + for m in masks: + all_mask &= m + + masks, clusters = build_cluster_mask( + all_mask, self.dataset['x'].values, self.dataset['y'].values) + return masks, clusters + + @plt_show + def plot_masks(self, masks_and_clusters, ax=None, legend=True): + if not len(masks_and_clusters[0]): + return + res = self._plot_masks( + masks_and_clusters=masks_and_clusters, ax=ax, legend=legend,) + self.save(f'{self.title_label}_map_clusters_{"-".join(self.labels)}') + + @apply_options + def _plot_masks(self, masks_and_clusters, + scatter_medians=True, + ax=None, legend=True, mask_cbar_kw=None,cluster_kw=None): + masks, clusters = masks_and_clusters + all_masks = np.zeros(masks[0].shape, np.int16) + + for m, c in zip(masks, clusters): + all_masks[m] = len(c) + if ax is None: + setup_map() + ax = plt.gca() + if mask_cbar_kw is None: + mask_cbar_kw = dict(extend='neither', label='Number of gridcells') + mask_cbar_kw.setdefault('orientation', 'horizontal') + ds_dummy = self.dataset.copy() + + all_masks = all_masks.astype(np.float16) + all_masks[all_masks == 0] = np.nan + ds_dummy['n_grid_cells'] = (('y', 'x'), all_masks) + + ds_dummy['n_grid_cells'].plot(cbar_kwargs=mask_cbar_kw, + + vmin=0, extend='neither') + plt.title('') + if scatter_medians: + if cluster_kw is None: + cluster_kw = dict() + for m_i, cluster in enumerate(clusters): + ax.scatter(*np.median(cluster, axis=0), label=f'cluster {m_i}', **cluster_kw) + if legend: + plt.legend(**oet.utils.legend_kw()) + plt.suptitle(f'Clusters {self.title}', y=0.95) + return ax + + def _plot_basic_map(self): + mm = MapMaker(self.dataset) + axes = mm.plot_all(2) + plt.suptitle(self.title, y=0.95) + return axes + + # Could add some masked selection on top +# masks, _ = self.get_masks() + +# all_masks = masks[0] +# for m in masks[1:]: +# all_masks &= m +# ds_masked = mask_xr_ds(self.dataset.copy(), all_masks) +# mm_sel = MapMaker(ds_masked) +# for label, ax in zip(mm.labels, axes): +# plt.sca(ax) +# mm_sel.plot_i(label, ax=ax, coastlines=False) + + @plt_show + def plot_mask_time_series(self, masks_and_clusters, time_series_joined=True): + res = self._plot_mask_time_series( + masks_and_clusters, time_series_joined=time_series_joined) + if time_series_joined: + self.save(f'{self.title_label}_all_clusters') + return res + + @apply_options + def _plot_mask_time_series(self, masks_and_clusters, time_series_joined=True, only_rm=None, axes=None): + if only_rm is None: + only_rm = True if (len(masks_and_clusters[0]) > 1 and time_series_joined) else False + masks, clusters = masks_and_clusters + legend_kw = oet.utils.legend_kw( + loc='upper left', bbox_to_anchor=None, mode=None, ncol=4) + for m_i, (mask, cluster) in enumerate(zip(masks, clusters)): + x, y = np.median(cluster, axis=0) + plot_labels = {f'{self.variable}': f'Cluster {m_i} near ~{x:.1f}:{y:.1f}', + f'{self.variable}_detrend': f'Cluster {m_i} near ~{x:.1f}:{y:.1f}', + f'{self.variable}_detrend_run_mean_10': f'Cluster {m_i} $RM_{{10}}$ near ~{x:.1f}:{y:.1f}', + f'{self.variable}_run_mean_10': f'Cluster {m_i} $RM_{{10}}$ near ~{x:.1f}:{y:.1f}'} + ds_sel = mask_xr_ds(self.dataset.copy(), mask) + mm_sel = MapMaker(ds_sel) + axes = mm_sel.time_series( + other_dim=('x', 'y'), + interval=True, + labels=plot_labels, + axes=axes, + only_rm=only_rm) + if time_series_joined == False: + axes = None + plt.suptitle(f'Cluster. {m_i} {self.title}', y=0.95) + self.save(f'{self.title_label}_cluster_{m_i}') + _show(self.show) + if not time_series_joined: + return + + if axes is not None: + for ax in axes: + ax.legend(**legend_kw) + plt.suptitle(f'Clusters {self.title}', y=0.95) + + +class PercentilesHistory(Percentiles): + @apply_options + def get_masks(self, percentiles=99, read_ds_kw=None) -> dict: + if read_ds_kw is None: + read_ds_kw = dict() + for k, v in dict(min_time=None, max_time=None).items(): + read_ds_kw.setdefault(k, v) + historical_path = self.find_historical()[0] + historical_ds = read_ds(historical_path, **read_ds_kw) + + labels = [crit.short_description for crit in self.criteria] + masks = [] + + for lab in labels: + arr = self.dataset[lab].values.T + arr_historical = historical_ds[lab].values.T + thr = np.percentile(arr_historical, percentiles) + masks.append(arr >= thr) + + all_mask = np.ones_like(masks[0]) + for m in masks: + all_mask &= m + + masks, clusters = build_cluster_mask( + all_mask, self.dataset['x'].values, self.dataset['y'].values) + return masks, clusters + + @apply_options + def find_historical(self, match_to='piControl'): + from optim_esm_tools.config import config + base = os.path.join(os.sep, + *self.dataset.attrs['path'].split(os.sep)[:-len(config['CMIP_files']['folder_fmt'].split())-1]) + + search = oet.cmip_files.find_matches.folder_to_dict( + self.dataset.attrs['path']) + search['activity_id'] = 'CMIP' + if search['experiment_id'] == match_to: + raise NotImplementedError() + search['experiment_id'] = match_to + + first_try = oet.cmip_files.find_matches.find_matches(base, **search) + if first_try: + return first_try + self.log.warning( + 'No results at first try, retying with any variant_label') + search.update(dict(variant_label='*', )) + + second_try = oet.cmip_files.find_matches.find_matches(base, **search) + if second_try: + return second_try + self.log.warning('No results at second try, retying with any version') + search.update(dict(version='*', )) + third_try = oet.cmip_files.find_matches.find_matches(base, **search) + if third_try: + return third_try + raise RuntimeError + + @property + def log(self): + if self._logger is None: + self._logger = oet.config.get_logger() + return self._logger diff --git a/optim_esm_tools/analyze/tipping_criteria.py b/optim_esm_tools/analyze/tipping_criteria.py index 5b2a9f13..7978445c 100644 --- a/optim_esm_tools/analyze/tipping_criteria.py +++ b/optim_esm_tools/analyze/tipping_criteria.py @@ -303,7 +303,7 @@ def max_derivative( def _remove_any_none_times(da, time_dim): data_var = da.copy() - time_null = da.isnull().all(dim=set(da.dims) - {time_dim}) + time_null = data_var.isnull().all(dim=set(data_var.dims) - {time_dim}) if np.all(time_null): # If we take a running mean of 10 (the default), and the array is shorter than # 10 years we will run into issues here because a the window is longer than the @@ -312,5 +312,14 @@ def _remove_any_none_times(da, time_dim): f'This array only has NaN values, perhaps array too short ({len(time_null)} < 10)?' ) if np.any(time_null): - data_var = data_var.load().where(~time_null, drop=True) + try: + # For some reason only alt_calc seems to work even if it should be equivalent to the data_var + # I think there is some fishy indexing going on in pandas <-> dask + # Maybe worth raising an issue? + alt_calc = xr.where(~time_null, da, np.nan).dropna('time') + data_var = data_var.load().where(~time_null, drop=True) + assert np.all((alt_calc == data_var).values) + except IndexError as e: + print(e) + return alt_calc return data_var diff --git a/optim_esm_tools/cmip_files/find_matches.py b/optim_esm_tools/cmip_files/find_matches.py index 1c849090..b8572a26 100644 --- a/optim_esm_tools/cmip_files/find_matches.py +++ b/optim_esm_tools/cmip_files/find_matches.py @@ -129,7 +129,7 @@ def _variant_label_id_and_version(full_path): raise ValueError( f'could not find run and version from {full_path} {run_variant_number} {grid_version}' ) - return run_variant_number, -grid_version + return -grid_version, run_variant_number def folder_to_dict(path): diff --git a/optim_esm_tools/optim_esm_conf.ini b/optim_esm_tools/optim_esm_conf.ini index 4392cf3b..f93ce234 100644 --- a/optim_esm_tools/optim_esm_conf.ini +++ b/optim_esm_tools/optim_esm_conf.ini @@ -5,7 +5,7 @@ seconds_to_year = 31557600 [versions] -cmip_handler = 0.2.0 +cmip_handler = 0.2.1 [display] progress_bar = True @@ -26,4 +26,4 @@ logging_level=WARNING # For the wrapper that monitors real time of functions (@timed) [time_tool] min_seconds=5 -reporter='print' +reporter=print diff --git a/optim_esm_tools/plotting/map_maker.py b/optim_esm_tools/plotting/map_maker.py index d8c5ec39..3ebf15c4 100644 --- a/optim_esm_tools/plotting/map_maker.py +++ b/optim_esm_tools/plotting/map_maker.py @@ -60,7 +60,6 @@ def __init__( self, data_set: xr.Dataset, normalizations: ty.Union[None, ty.Mapping, ty.Iterable] = None, - cache: bool = False, **conditions, ): self.data_set = data_set @@ -91,7 +90,6 @@ def _incorrect_format(): f'{self.conditions.keys()} to vmin, vmax, ' f'got {self.normalizations} (from {normalizations})' ) - self._cache = cache def plot(self, *a, **kw): print('Depricated use plot_all') @@ -167,21 +165,8 @@ def plot_i(self, label, ax=None, coastlines=True, **kw): def __getattr__(self, item): if item in self.conditions: - key, function = self.conditions[item] + key, _ = self.conditions[item] return self.data_set[key] - key = f'_{item}' - if self._cache: - if not isinstance(self._cache, dict): - self._cache = dict() - if key in self._cache: - data = self._cache.get(key) - return data - - data = function(self.data_set) - if self._cache or isinstance(self._cache, dict): - self._cache[key] = data.load() - return data - return self.__getattribute__(item) @staticmethod @@ -293,14 +278,12 @@ def _ddt_ts( # mean, std = self._mean_and_std(dy_dt, variable=None, other_dim=other_dim) # plot_kw['label'] = variable # self._ts_single(ds[time].values, mean, std, plot_kw, fill_kw) - label = f'd/dt {labels.get(variable, variable)}' + label = labels.get(variable, variable) dy_dt.plot(label=label, **plot_kw) dy_dt_rm = da_rm.dropna(time).differentiate(time) dy_dt_rm *= _SECONDS_TO_YEAR - label = ( - f"d/dt {labels.get(variable_rm, f'{variable} running mean {running_mean}')}" - ) + label = f"{labels.get(variable_rm, f'{variable} running mean {running_mean}')}" dy_dt_rm.plot(label=label, **plot_kw) # mean, std = self._mean_and_std(dy_dt_rm, variable=None, other_dim=other_dim) # plot_kw['label'] = variable diff --git a/test/test_pangeo_download.py b/test/test_pangeo_download.py index 7ea24d20..5c0a42e2 100644 --- a/test/test_pangeo_download.py +++ b/test/test_pangeo_download.py @@ -1,5 +1,7 @@ from optim_esm_tools._test_utils import get_file_from_pangeo +import pytest -def test_download(): - get_file_from_pangeo() +@pytest.mark.parametrize('scenario', ['ssp585', 'piControl', 'historical']) +def test_download(scenario): + get_file_from_pangeo(scenario) diff --git a/test/test_workflow_pangeo.py b/test/test_workflow_pangeo.py new file mode 100644 index 00000000..f1a84ef2 --- /dev/null +++ b/test/test_workflow_pangeo.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +import unittest +import optim_esm_tools._test_utils +from optim_esm_tools.analyze import region_finding +import tempfile +import pytest +import os + +class Work(unittest.TestCase): + # example_data_set = oet._test_utils.EXAMPLE_DATA_SET + def test(self): + for data_name in ['ssp585', 'piControl']: + self.get_path(data_name) + + + @staticmethod + def get_path(data_name): + path = optim_esm_tools._test_utils.get_file_from_pangeo(data_name) + year_path = optim_esm_tools._test_utils.year_means(path) + assert year_path + assert os.path.exists(year_path) + return year_path + + @pytest.mark.paramerize('make', ['region_finding', 'Percentiles', 'PercentilesHistory']) + def test_build_plots(self, make='MaxRegion',): + cls = getattr(region_finding, make) + with tempfile.TemporaryDirectory() as temp_dir: + print(make) + save_kw = dict( + save_in = temp_dir, + sub_dir = None, + file_types=('png', 'pdf'), + skip= False, +) + head, tail = os.path.split(self.get_path('ssp585')) + r=cls(path=head, read_ds_kw=dict(_file_name=tail), transform=True, save_kw=save_kw, extra_opt=dict()) + r.show=False + r.workflow() + # def from_amon_to_ayear(self): + # if os.path.exists(self.ayear_file): + # return + + # os.makedirs(os.path.split(self.ayear_file)[0], exist_ok=1) + # # Doesn't work? + # # cdo.Cdo().yearmonmean(self.amon_file, self.ayear_file) + # cmd = f'cdo yearmonmean {self.amon_file} {self.ayear_file}' + # print(cmd) + # subprocess.call(cmd, shell=True) + # assert os.path.exists(self.ayear_file), self.ayear_file + + # @classmethod + # def setUpClass(cls): + # cls.base = os.path.join(os.environ['ST_HOME'], 'data') + # cls.amon_file = get_example_data_loc() + # cls.ayear_file = os.path.join( + # os.path.split(cls.amon_file.replace('Amon', 'AYear'))[0], 'merged.nc' + # ) + + # def setUp(self): + # self.from_amon_to_ayear() + # super().setUp() + + # def test_read_data(self): + # data_set = oet.synda_files.format_synda.load_glob(self.ayear_file) + + # def test_make_map(self): + # data_set = oet.analyze.cmip_handler.read_ds(os.path.split(self.ayear_file)[0]) + # oet.analyze.cmip_handler.MapMaker(data_set=data_set).plot_all(2) + # plt.clf() + + # def test_map_maker_time_series(self): + # data_set = oet.analyze.cmip_handler.read_ds(os.path.split(self.ayear_file)[0]) + # oet.analyze.cmip_handler.MapMaker(data_set=data_set).time_series() + # plt.clf() + + # def test_apply_relative_units(self, unit='relative'): + # data_set = oet.analyze.cmip_handler.read_ds( + # os.path.split(self.ayear_file)[0], condition_kwargs=dict(unit=unit) + # ) + # mm = oet.analyze.cmip_handler.MapMaker(data_set=data_set) + + # def test_apply_std_unit(self): + # self.test_apply_relative_units(unit='std') + + # @classmethod + # def tearDownClass(cls) -> None: + # os.remove(cls.ayear_file) + # return super().tearDownClass() From e39095a01e5e5b8d745019a640c08675147fd46f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Jun 2023 07:35:01 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- optim_esm_tools/analyze/region_finding.py | 188 ++++++++++++++-------- test/test_workflow_pangeo.py | 36 +++-- 2 files changed, 142 insertions(+), 82 deletions(-) diff --git a/optim_esm_tools/analyze/region_finding.py b/optim_esm_tools/analyze/region_finding.py index 3d58176e..04636f4d 100644 --- a/optim_esm_tools/analyze/region_finding.py +++ b/optim_esm_tools/analyze/region_finding.py @@ -16,6 +16,7 @@ from optim_esm_tools.plotting.plot import setup_map from immutabledict import immutabledict + def _show(show): if show: plt.show() @@ -58,9 +59,7 @@ def somedec_outer(fn): def timed_func(*args, **kwargs): self = args[0] takes = inspect.signature(fn).parameters - kwargs.update({k: v for k, v in self.extra_opt.items() - if k in takes - }) + kwargs.update({k: v for k, v in self.extra_opt.items() if k in takes}) res = fn(*args, **kwargs) return res @@ -81,14 +80,16 @@ class ResultDataSet: criteria = (tipping_criteria.StdDetrended, tipping_criteria.MaxJump) extra_opt = None - def __init__(self, - variable='tas', - path=None, - dataset=None, - transform=True, - save_kw=None, - extra_opt=None, - read_ds_kw=None) -> None: + def __init__( + self, + variable='tas', + path=None, + dataset=None, + transform=True, + save_kw=None, + extra_opt=None, + read_ds_kw=None, + ) -> None: if path is None: if transform: self.log.warning( @@ -101,10 +102,15 @@ def __init__(self, read_ds_kw = dict() if read_ds_kw is None else read_ds_kw self.dataset = read_ds(path, **read_ds_kw) if save_kw is None: - save_kw = dict(save_in='./', - file_types=('png', 'pdf',), - skip=False, - sub_dir=None) + save_kw = dict( + save_in='./', + file_types=( + 'png', + 'pdf', + ), + skip=False, + sub_dir=None, + ) if extra_opt is None: extra_opt = dict() self.extra_opt = extra_opt @@ -150,14 +156,19 @@ class MaxRegion(ResultDataSet): def get_masks(self) -> dict: """Get mask for max of ii and iii and a box arround that""" labels = [crit.short_description for crit in self.criteria] - masks = {label: self.dataset[label].values == self.dataset[label].values.max() - for label in labels - } + masks = { + label: self.dataset[label].values == self.dataset[label].values.max() + for label in labels + } return masks @plt_show def plot_masks(self, masks, ax=None, legend=True): - res = self._plot_masks(masks=masks, ax=ax, legend=legend,) + res = self._plot_masks( + masks=masks, + ax=ax, + legend=legend, + ) self.save(f'{self.title_label}_map_maxes_{"-".join(self.labels)}') @apply_options @@ -168,10 +179,7 @@ def _plot_masks(self, masks, ax=None, legend=True): if ax is None: oet.plotting.plot.setup_map() ax = plt.gca() - for i, (label, xy) in enumerate(zip( - self.labels, - points.values()) - ): + for i, (label, xy) in enumerate(zip(self.labels, points.values())): ax.scatter(*xy, marker='oxv^'[i], label=f'Maximum {label}') if legend: ax.legend(**oet.utils.legend_kw()) @@ -195,23 +203,26 @@ def _plot_basic_map(self): @plt_show def plot_mask_time_series(self, masks, time_series_joined=True): - res = self._plot_mask_time_series( - masks, time_series_joined=time_series_joined) + res = self._plot_mask_time_series(masks, time_series_joined=time_series_joined) if time_series_joined: - self.save( - f'{self.title_label}_time_series_maxes_{"-".join(self.labels)}') + self.save(f'{self.title_label}_time_series_maxes_{"-".join(self.labels)}') return res @apply_options - def _plot_mask_time_series(self, masks, time_series_joined=True, only_rm=False, axes=None): + def _plot_mask_time_series( + self, masks, time_series_joined=True, only_rm=False, axes=None + ): legend_kw = oet.utils.legend_kw( - loc='upper left', bbox_to_anchor=None, mode=None, ncol=4) + loc='upper left', bbox_to_anchor=None, mode=None, ncol=4 + ) for label, mask_2d in zip(self.labels, masks.values()): x, y = self._mask_to_coord(mask_2d) - plot_labels = {f'{self.variable}': f'{label} at {x:.1f}:{y:.1f}', - f'{self.variable}_detrend': f'{label} at {x:.1f}:{y:.1f}', - f'{self.variable}_detrend_run_mean_10': f'$RM_{{10}}$ {label} at {x:.1f}:{y:.1f}', - f'{self.variable}_run_mean_10': f'$RM_{{10}}$ {label} at {x:.1f}:{y:.1f}'} + plot_labels = { + f'{self.variable}': f'{label} at {x:.1f}:{y:.1f}', + f'{self.variable}_detrend': f'{label} at {x:.1f}:{y:.1f}', + f'{self.variable}_detrend_run_mean_10': f'$RM_{{10}}$ {label} at {x:.1f}:{y:.1f}', + f'{self.variable}_run_mean_10': f'$RM_{{10}}$ {label} at {x:.1f}:{y:.1f}', + } argwhere = np.argwhere(mask_2d)[0] ds_sel = self.dataset.isel(x=argwhere[1], y=argwhere[0]) mm_sel = MapMaker(ds_sel) @@ -220,7 +231,8 @@ def _plot_mask_time_series(self, masks, time_series_joined=True, only_rm=False, interval=False, labels=plot_labels, axes=axes, - only_rm=only_rm) + only_rm=only_rm, + ) if time_series_joined is False: axes = None plt.suptitle(f'Max. {label} {self.title}', y=0.95) @@ -253,7 +265,8 @@ def get_masks(self, percentiles=99) -> dict: all_mask &= m masks, clusters = build_cluster_mask( - all_mask, self.dataset['x'].values, self.dataset['y'].values) + all_mask, self.dataset['x'].values, self.dataset['y'].values + ) return masks, clusters @plt_show @@ -261,13 +274,22 @@ def plot_masks(self, masks_and_clusters, ax=None, legend=True): if not len(masks_and_clusters[0]): return res = self._plot_masks( - masks_and_clusters=masks_and_clusters, ax=ax, legend=legend,) + masks_and_clusters=masks_and_clusters, + ax=ax, + legend=legend, + ) self.save(f'{self.title_label}_map_clusters_{"-".join(self.labels)}') @apply_options - def _plot_masks(self, masks_and_clusters, - scatter_medians=True, - ax=None, legend=True, mask_cbar_kw=None,cluster_kw=None): + def _plot_masks( + self, + masks_and_clusters, + scatter_medians=True, + ax=None, + legend=True, + mask_cbar_kw=None, + cluster_kw=None, + ): masks, clusters = masks_and_clusters all_masks = np.zeros(masks[0].shape, np.int16) @@ -285,15 +307,17 @@ def _plot_masks(self, masks_and_clusters, all_masks[all_masks == 0] = np.nan ds_dummy['n_grid_cells'] = (('y', 'x'), all_masks) - ds_dummy['n_grid_cells'].plot(cbar_kwargs=mask_cbar_kw, - - vmin=0, extend='neither') + ds_dummy['n_grid_cells'].plot( + cbar_kwargs=mask_cbar_kw, vmin=0, extend='neither' + ) plt.title('') if scatter_medians: if cluster_kw is None: cluster_kw = dict() for m_i, cluster in enumerate(clusters): - ax.scatter(*np.median(cluster, axis=0), label=f'cluster {m_i}', **cluster_kw) + ax.scatter( + *np.median(cluster, axis=0), label=f'cluster {m_i}', **cluster_kw + ) if legend: plt.legend(**oet.utils.legend_kw()) plt.suptitle(f'Clusters {self.title}', y=0.95) @@ -306,38 +330,49 @@ def _plot_basic_map(self): return axes # Could add some masked selection on top -# masks, _ = self.get_masks() -# all_masks = masks[0] -# for m in masks[1:]: -# all_masks &= m -# ds_masked = mask_xr_ds(self.dataset.copy(), all_masks) -# mm_sel = MapMaker(ds_masked) -# for label, ax in zip(mm.labels, axes): -# plt.sca(ax) -# mm_sel.plot_i(label, ax=ax, coastlines=False) + # masks, _ = self.get_masks() + + # all_masks = masks[0] + # for m in masks[1:]: + # all_masks &= m + # ds_masked = mask_xr_ds(self.dataset.copy(), all_masks) + # mm_sel = MapMaker(ds_masked) + # for label, ax in zip(mm.labels, axes): + # plt.sca(ax) + # mm_sel.plot_i(label, ax=ax, coastlines=False) @plt_show def plot_mask_time_series(self, masks_and_clusters, time_series_joined=True): res = self._plot_mask_time_series( - masks_and_clusters, time_series_joined=time_series_joined) + masks_and_clusters, time_series_joined=time_series_joined + ) if time_series_joined: self.save(f'{self.title_label}_all_clusters') return res @apply_options - def _plot_mask_time_series(self, masks_and_clusters, time_series_joined=True, only_rm=None, axes=None): + def _plot_mask_time_series( + self, masks_and_clusters, time_series_joined=True, only_rm=None, axes=None + ): if only_rm is None: - only_rm = True if (len(masks_and_clusters[0]) > 1 and time_series_joined) else False + only_rm = ( + True + if (len(masks_and_clusters[0]) > 1 and time_series_joined) + else False + ) masks, clusters = masks_and_clusters legend_kw = oet.utils.legend_kw( - loc='upper left', bbox_to_anchor=None, mode=None, ncol=4) + loc='upper left', bbox_to_anchor=None, mode=None, ncol=4 + ) for m_i, (mask, cluster) in enumerate(zip(masks, clusters)): x, y = np.median(cluster, axis=0) - plot_labels = {f'{self.variable}': f'Cluster {m_i} near ~{x:.1f}:{y:.1f}', - f'{self.variable}_detrend': f'Cluster {m_i} near ~{x:.1f}:{y:.1f}', - f'{self.variable}_detrend_run_mean_10': f'Cluster {m_i} $RM_{{10}}$ near ~{x:.1f}:{y:.1f}', - f'{self.variable}_run_mean_10': f'Cluster {m_i} $RM_{{10}}$ near ~{x:.1f}:{y:.1f}'} + plot_labels = { + f'{self.variable}': f'Cluster {m_i} near ~{x:.1f}:{y:.1f}', + f'{self.variable}_detrend': f'Cluster {m_i} near ~{x:.1f}:{y:.1f}', + f'{self.variable}_detrend_run_mean_10': f'Cluster {m_i} $RM_{{10}}$ near ~{x:.1f}:{y:.1f}', + f'{self.variable}_run_mean_10': f'Cluster {m_i} $RM_{{10}}$ near ~{x:.1f}:{y:.1f}', + } ds_sel = mask_xr_ds(self.dataset.copy(), mask) mm_sel = MapMaker(ds_sel) axes = mm_sel.time_series( @@ -345,7 +380,8 @@ def _plot_mask_time_series(self, masks_and_clusters, time_series_joined=True, on interval=True, labels=plot_labels, axes=axes, - only_rm=only_rm) + only_rm=only_rm, + ) if time_series_joined == False: axes = None plt.suptitle(f'Cluster. {m_i} {self.title}', y=0.95) @@ -384,17 +420,22 @@ def get_masks(self, percentiles=99, read_ds_kw=None) -> dict: all_mask &= m masks, clusters = build_cluster_mask( - all_mask, self.dataset['x'].values, self.dataset['y'].values) + all_mask, self.dataset['x'].values, self.dataset['y'].values + ) return masks, clusters @apply_options def find_historical(self, match_to='piControl'): from optim_esm_tools.config import config - base = os.path.join(os.sep, - *self.dataset.attrs['path'].split(os.sep)[:-len(config['CMIP_files']['folder_fmt'].split())-1]) - search = oet.cmip_files.find_matches.folder_to_dict( - self.dataset.attrs['path']) + base = os.path.join( + os.sep, + *self.dataset.attrs['path'].split(os.sep)[ + : -len(config['CMIP_files']['folder_fmt'].split()) - 1 + ], + ) + + search = oet.cmip_files.find_matches.folder_to_dict(self.dataset.attrs['path']) search['activity_id'] = 'CMIP' if search['experiment_id'] == match_to: raise NotImplementedError() @@ -403,15 +444,22 @@ def find_historical(self, match_to='piControl'): first_try = oet.cmip_files.find_matches.find_matches(base, **search) if first_try: return first_try - self.log.warning( - 'No results at first try, retying with any variant_label') - search.update(dict(variant_label='*', )) + self.log.warning('No results at first try, retying with any variant_label') + search.update( + dict( + variant_label='*', + ) + ) second_try = oet.cmip_files.find_matches.find_matches(base, **search) if second_try: return second_try self.log.warning('No results at second try, retying with any version') - search.update(dict(version='*', )) + search.update( + dict( + version='*', + ) + ) third_try = oet.cmip_files.find_matches.find_matches(base, **search) if third_try: return third_try diff --git a/test/test_workflow_pangeo.py b/test/test_workflow_pangeo.py index f1a84ef2..6f9b81ab 100644 --- a/test/test_workflow_pangeo.py +++ b/test/test_workflow_pangeo.py @@ -6,13 +6,13 @@ import pytest import os + class Work(unittest.TestCase): # example_data_set = oet._test_utils.EXAMPLE_DATA_SET def test(self): for data_name in ['ssp585', 'piControl']: self.get_path(data_name) - @staticmethod def get_path(data_name): path = optim_esm_tools._test_utils.get_file_from_pangeo(data_name) @@ -20,22 +20,34 @@ def get_path(data_name): assert year_path assert os.path.exists(year_path) return year_path - - @pytest.mark.paramerize('make', ['region_finding', 'Percentiles', 'PercentilesHistory']) - def test_build_plots(self, make='MaxRegion',): + + @pytest.mark.paramerize( + 'make', ['region_finding', 'Percentiles', 'PercentilesHistory'] + ) + def test_build_plots( + self, + make='MaxRegion', + ): cls = getattr(region_finding, make) with tempfile.TemporaryDirectory() as temp_dir: print(make) save_kw = dict( - save_in = temp_dir, - sub_dir = None, - file_types=('png', 'pdf'), - skip= False, -) - head, tail = os.path.split(self.get_path('ssp585')) - r=cls(path=head, read_ds_kw=dict(_file_name=tail), transform=True, save_kw=save_kw, extra_opt=dict()) - r.show=False + save_in=temp_dir, + sub_dir=None, + file_types=('png', 'pdf'), + skip=False, + ) + head, tail = os.path.split(self.get_path('ssp585')) + r = cls( + path=head, + read_ds_kw=dict(_file_name=tail), + transform=True, + save_kw=save_kw, + extra_opt=dict(), + ) + r.show = False r.workflow() + # def from_amon_to_ayear(self): # if os.path.exists(self.ayear_file): # return From 9923ab3cd26c1bf63739fda52ae2d60ed47a4e57 Mon Sep 17 00:00:00 2001 From: Joran Angevaare Date: Tue, 20 Jun 2023 09:47:58 +0200 Subject: [PATCH 3/8] retry --- test/test_plotting.py | 6 +-- test/test_region_finding.py | 50 ++++++++++++++++++ test/test_workflow_pangeo.py | 100 ----------------------------------- 3 files changed, 53 insertions(+), 103 deletions(-) create mode 100644 test/test_region_finding.py delete mode 100644 test/test_workflow_pangeo.py diff --git a/test/test_plotting.py b/test/test_plotting.py index 3d4e2bac..90c51a88 100644 --- a/test/test_plotting.py +++ b/test/test_plotting.py @@ -1,4 +1,4 @@ -def test_map_basic(): - import optim_esm_tools +# def test_map_basic(): +# import optim_esm_tools - optim_esm_tools.plotting.plot.setup_map() +# optim_esm_tools.plotting.plot.setup_map() diff --git a/test/test_region_finding.py b/test/test_region_finding.py new file mode 100644 index 00000000..224b99c3 --- /dev/null +++ b/test/test_region_finding.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +import unittest +import optim_esm_tools._test_utils +from optim_esm_tools.analyze import region_finding +import tempfile +import pytest +import os + + +class Work(unittest.TestCase): + def test(self): + for data_name in ['ssp585', 'piControl']: + self.get_path(data_name) + + @staticmethod + def get_path(data_name, refresh=True): + path = optim_esm_tools._test_utils.get_file_from_pangeo(data_name, refresh=refresh) + year_path = optim_esm_tools._test_utils.year_means(path, refresh=refresh) + assert year_path + assert os.path.exists(year_path) + return year_path + + @pytest.mark.paramerize( + 'make', ['region_finding', 'Percentiles', 'PercentilesHistory'] + ) + def test_build_plots( + self, + make='MaxRegion', + ): + cls = getattr(region_finding, make) + extra_opt=dict(time_series_joined=True, scatter_medians=True) + with tempfile.TemporaryDirectory() as temp_dir: + print(make) + save_kw = dict( + save_in=temp_dir, + sub_dir=None, + file_types=('png',), + dpi=25, + skip=False, + ) + head, tail = os.path.split(self.get_path('ssp585', refresh=False)) + r = cls( + path=head, + read_ds_kw=dict(_file_name=tail), + transform=True, + save_kw=save_kw, + extra_opt=extra_opt, + ) + r.show = False + r.workflow() diff --git a/test/test_workflow_pangeo.py b/test/test_workflow_pangeo.py deleted file mode 100644 index 6f9b81ab..00000000 --- a/test/test_workflow_pangeo.py +++ /dev/null @@ -1,100 +0,0 @@ -# -*- coding: utf-8 -*- -import unittest -import optim_esm_tools._test_utils -from optim_esm_tools.analyze import region_finding -import tempfile -import pytest -import os - - -class Work(unittest.TestCase): - # example_data_set = oet._test_utils.EXAMPLE_DATA_SET - def test(self): - for data_name in ['ssp585', 'piControl']: - self.get_path(data_name) - - @staticmethod - def get_path(data_name): - path = optim_esm_tools._test_utils.get_file_from_pangeo(data_name) - year_path = optim_esm_tools._test_utils.year_means(path) - assert year_path - assert os.path.exists(year_path) - return year_path - - @pytest.mark.paramerize( - 'make', ['region_finding', 'Percentiles', 'PercentilesHistory'] - ) - def test_build_plots( - self, - make='MaxRegion', - ): - cls = getattr(region_finding, make) - with tempfile.TemporaryDirectory() as temp_dir: - print(make) - save_kw = dict( - save_in=temp_dir, - sub_dir=None, - file_types=('png', 'pdf'), - skip=False, - ) - head, tail = os.path.split(self.get_path('ssp585')) - r = cls( - path=head, - read_ds_kw=dict(_file_name=tail), - transform=True, - save_kw=save_kw, - extra_opt=dict(), - ) - r.show = False - r.workflow() - - # def from_amon_to_ayear(self): - # if os.path.exists(self.ayear_file): - # return - - # os.makedirs(os.path.split(self.ayear_file)[0], exist_ok=1) - # # Doesn't work? - # # cdo.Cdo().yearmonmean(self.amon_file, self.ayear_file) - # cmd = f'cdo yearmonmean {self.amon_file} {self.ayear_file}' - # print(cmd) - # subprocess.call(cmd, shell=True) - # assert os.path.exists(self.ayear_file), self.ayear_file - - # @classmethod - # def setUpClass(cls): - # cls.base = os.path.join(os.environ['ST_HOME'], 'data') - # cls.amon_file = get_example_data_loc() - # cls.ayear_file = os.path.join( - # os.path.split(cls.amon_file.replace('Amon', 'AYear'))[0], 'merged.nc' - # ) - - # def setUp(self): - # self.from_amon_to_ayear() - # super().setUp() - - # def test_read_data(self): - # data_set = oet.synda_files.format_synda.load_glob(self.ayear_file) - - # def test_make_map(self): - # data_set = oet.analyze.cmip_handler.read_ds(os.path.split(self.ayear_file)[0]) - # oet.analyze.cmip_handler.MapMaker(data_set=data_set).plot_all(2) - # plt.clf() - - # def test_map_maker_time_series(self): - # data_set = oet.analyze.cmip_handler.read_ds(os.path.split(self.ayear_file)[0]) - # oet.analyze.cmip_handler.MapMaker(data_set=data_set).time_series() - # plt.clf() - - # def test_apply_relative_units(self, unit='relative'): - # data_set = oet.analyze.cmip_handler.read_ds( - # os.path.split(self.ayear_file)[0], condition_kwargs=dict(unit=unit) - # ) - # mm = oet.analyze.cmip_handler.MapMaker(data_set=data_set) - - # def test_apply_std_unit(self): - # self.test_apply_relative_units(unit='std') - - # @classmethod - # def tearDownClass(cls) -> None: - # os.remove(cls.ayear_file) - # return super().tearDownClass() From b9058ed65386891db42b6d615af0375243167532 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Jun 2023 07:48:21 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/test_region_finding.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_region_finding.py b/test/test_region_finding.py index 224b99c3..7c7b0792 100644 --- a/test/test_region_finding.py +++ b/test/test_region_finding.py @@ -14,7 +14,9 @@ def test(self): @staticmethod def get_path(data_name, refresh=True): - path = optim_esm_tools._test_utils.get_file_from_pangeo(data_name, refresh=refresh) + path = optim_esm_tools._test_utils.get_file_from_pangeo( + data_name, refresh=refresh + ) year_path = optim_esm_tools._test_utils.year_means(path, refresh=refresh) assert year_path assert os.path.exists(year_path) @@ -28,7 +30,7 @@ def test_build_plots( make='MaxRegion', ): cls = getattr(region_finding, make) - extra_opt=dict(time_series_joined=True, scatter_medians=True) + extra_opt = dict(time_series_joined=True, scatter_medians=True) with tempfile.TemporaryDirectory() as temp_dir: print(make) save_kw = dict( From b353904777aad1bae6520eae27d16d3aefd8a7e9 Mon Sep 17 00:00:00 2001 From: Joran Angevaare Date: Tue, 20 Jun 2023 10:58:24 +0200 Subject: [PATCH 5/8] fix finding --- optim_esm_tools/_test_utils.py | 2 ++ optim_esm_tools/analyze/region_finding.py | 7 ++++--- test/test_region_finding.py | 14 ++++++++------ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/optim_esm_tools/_test_utils.py b/optim_esm_tools/_test_utils.py index 318f52d7..7cfc29a1 100644 --- a/optim_esm_tools/_test_utils.py +++ b/optim_esm_tools/_test_utils.py @@ -8,6 +8,8 @@ def get_file_from_pangeo(experiment_id='ssp585', refresh=True): dest_folder = os.path.split( get_example_data_loc().replace('ssp585', experiment_id) )[0] + if experiment_id in ['piControl', 'historical']: + dest_folder = dest_folder.replace('ScenarioMIP', 'CMIP') write_to = os.path.join(dest_folder, 'test.nc') if os.path.exists(write_to) and not refresh: print(f'already file at {write_to}') diff --git a/optim_esm_tools/analyze/region_finding.py b/optim_esm_tools/analyze/region_finding.py index 04636f4d..f5d88f3f 100644 --- a/optim_esm_tools/analyze/region_finding.py +++ b/optim_esm_tools/analyze/region_finding.py @@ -113,6 +113,7 @@ def __init__( ) if extra_opt is None: extra_opt = dict() + extra_opt.update(dict(read_ds_kw=read_ds_kw)) self.extra_opt = extra_opt self.save_kw = save_kw self.variable = variable @@ -425,13 +426,13 @@ def get_masks(self, percentiles=99, read_ds_kw=None) -> dict: return masks, clusters @apply_options - def find_historical(self, match_to='piControl'): + def find_historical(self, match_to='piControl', look_back_extra=1): from optim_esm_tools.config import config base = os.path.join( os.sep, *self.dataset.attrs['path'].split(os.sep)[ - : -len(config['CMIP_files']['folder_fmt'].split()) - 1 + : -len(config['CMIP_files']['folder_fmt'].split()) -look_back_extra ], ) @@ -463,7 +464,7 @@ def find_historical(self, match_to='piControl'): third_try = oet.cmip_files.find_matches.find_matches(base, **search) if third_try: return third_try - raise RuntimeError + raise RuntimeError(f'Looked for {search}, in {base} found nothing') @property def log(self): diff --git a/test/test_region_finding.py b/test/test_region_finding.py index 7c7b0792..b10dc2a5 100644 --- a/test/test_region_finding.py +++ b/test/test_region_finding.py @@ -3,14 +3,14 @@ import optim_esm_tools._test_utils from optim_esm_tools.analyze import region_finding import tempfile -import pytest import os class Work(unittest.TestCase): - def test(self): + @classmethod + def setUpClass(cls): for data_name in ['ssp585', 'piControl']: - self.get_path(data_name) + cls.get_path(data_name) @staticmethod def get_path(data_name, refresh=True): @@ -22,14 +22,12 @@ def get_path(data_name, refresh=True): assert os.path.exists(year_path) return year_path - @pytest.mark.paramerize( - 'make', ['region_finding', 'Percentiles', 'PercentilesHistory'] - ) def test_build_plots( self, make='MaxRegion', ): cls = getattr(region_finding, make) + print(cls) extra_opt = dict(time_series_joined=True, scatter_medians=True) with tempfile.TemporaryDirectory() as temp_dir: print(make) @@ -50,3 +48,7 @@ def test_build_plots( ) r.show = False r.workflow() + def test_percentiles(self): + self.test_build_plots('Percentiles') + def test_percentiles_history(self): + self.test_build_plots('PercentilesHistory') \ No newline at end of file From 668f312fdda64ab39a80474d14e9c40d7ecd003d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Jun 2023 08:58:35 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- optim_esm_tools/analyze/region_finding.py | 2 +- test/test_region_finding.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/optim_esm_tools/analyze/region_finding.py b/optim_esm_tools/analyze/region_finding.py index f5d88f3f..4d46353b 100644 --- a/optim_esm_tools/analyze/region_finding.py +++ b/optim_esm_tools/analyze/region_finding.py @@ -432,7 +432,7 @@ def find_historical(self, match_to='piControl', look_back_extra=1): base = os.path.join( os.sep, *self.dataset.attrs['path'].split(os.sep)[ - : -len(config['CMIP_files']['folder_fmt'].split()) -look_back_extra + : -len(config['CMIP_files']['folder_fmt'].split()) - look_back_extra ], ) diff --git a/test/test_region_finding.py b/test/test_region_finding.py index b10dc2a5..ecf32b69 100644 --- a/test/test_region_finding.py +++ b/test/test_region_finding.py @@ -48,7 +48,9 @@ def test_build_plots( ) r.show = False r.workflow() + def test_percentiles(self): self.test_build_plots('Percentiles') + def test_percentiles_history(self): - self.test_build_plots('PercentilesHistory') \ No newline at end of file + self.test_build_plots('PercentilesHistory') From bc974e1e3a001579c180070432b66df87db2bc77 Mon Sep 17 00:00:00 2001 From: Joran Angevaare Date: Tue, 20 Jun 2023 11:04:44 +0200 Subject: [PATCH 7/8] reactivate test --- test/test_plotting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_plotting.py b/test/test_plotting.py index 90c51a88..3d4e2bac 100644 --- a/test/test_plotting.py +++ b/test/test_plotting.py @@ -1,4 +1,4 @@ -# def test_map_basic(): -# import optim_esm_tools +def test_map_basic(): + import optim_esm_tools -# optim_esm_tools.plotting.plot.setup_map() + optim_esm_tools.plotting.plot.setup_map() From 0a043e1ad68cc3443f1c2fdd9469937fc943e69f Mon Sep 17 00:00:00 2001 From: Joran Angevaare Date: Tue, 20 Jun 2023 11:17:54 +0200 Subject: [PATCH 8/8] low percentiles for signal --- optim_esm_tools/analyze/region_finding.py | 21 ++++----------------- test/test_region_finding.py | 2 +- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/optim_esm_tools/analyze/region_finding.py b/optim_esm_tools/analyze/region_finding.py index 4d46353b..77ee4d6e 100644 --- a/optim_esm_tools/analyze/region_finding.py +++ b/optim_esm_tools/analyze/region_finding.py @@ -273,6 +273,7 @@ def get_masks(self, percentiles=99) -> dict: @plt_show def plot_masks(self, masks_and_clusters, ax=None, legend=True): if not len(masks_and_clusters[0]): + self.log.warning('No clusters found!') return res = self._plot_masks( masks_and_clusters=masks_and_clusters, @@ -446,28 +447,14 @@ def find_historical(self, match_to='piControl', look_back_extra=1): if first_try: return first_try self.log.warning('No results at first try, retying with any variant_label') - search.update( - dict( - variant_label='*', - ) - ) + search.update( dict( variant_label='*', ) ) second_try = oet.cmip_files.find_matches.find_matches(base, **search) if second_try: return second_try self.log.warning('No results at second try, retying with any version') - search.update( - dict( - version='*', - ) - ) + search.update( dict( version='*', ) ) third_try = oet.cmip_files.find_matches.find_matches(base, **search) if third_try: return third_try - raise RuntimeError(f'Looked for {search}, in {base} found nothing') - - @property - def log(self): - if self._logger is None: - self._logger = oet.config.get_logger() - return self._logger + raise RuntimeError(f'Looked for {search}, in {base} found nothing') \ No newline at end of file diff --git a/test/test_region_finding.py b/test/test_region_finding.py index ecf32b69..8f060325 100644 --- a/test/test_region_finding.py +++ b/test/test_region_finding.py @@ -28,7 +28,7 @@ def test_build_plots( ): cls = getattr(region_finding, make) print(cls) - extra_opt = dict(time_series_joined=True, scatter_medians=True) + extra_opt = dict(time_series_joined=True, scatter_medians=True, percentiles=50) with tempfile.TemporaryDirectory() as temp_dir: print(make) save_kw = dict(