Skip to content

Commit

Permalink
Merge 00e31fd into 9b7ab8c
Browse files Browse the repository at this point in the history
  • Loading branch information
JoranAngevaare committed Aug 24, 2023
2 parents 9b7ab8c + 00e31fd commit 6854152
Show file tree
Hide file tree
Showing 7 changed files with 348 additions and 2 deletions.
1 change: 1 addition & 0 deletions optim_esm_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from .analyze.cmip_handler import read_ds
from .analyze.io import load_glob
from .plotting.map_maker import MapMaker
from .utils import print_versions
1 change: 1 addition & 0 deletions optim_esm_tools/analyze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from . import region_finding
from . import time_statistics
from . import concise_dataframe
from . import combine_variables
9 changes: 8 additions & 1 deletion optim_esm_tools/analyze/cmip_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def read_ds(
_ma_window: ty.Optional[int] = None,
_cache: bool = True,
_file_name: str = None,
_skip_folder_info: bool = False,
**kwargs,
) -> xr.Dataset:
"""Read a dataset from a folder called "base".
Expand All @@ -105,6 +106,8 @@ def read_ds(
_cache (bool, optional): cache the dataset with it's extra fields to alow faster
(re)loading. Defaults to True.
_file_name (str, optional): name to match. Defaults to configs settings.
_skip_folder_info (bool, optional): if set to True, do not infer the properties from the
(synda) path of the file
kwargs:
any kwargs are passed onto transform_ds.
Expand Down Expand Up @@ -173,7 +176,11 @@ def read_ds(
folders = base.split(os.sep)

# start with -1 (for i==0)
metadata = {k: folders[-i - 1] for i, k in enumerate(_FOLDER_FMT[::-1])}
metadata = (
{}
if _skip_folder_info
else {k: folders[-i - 1] for i, k in enumerate(_FOLDER_FMT[::-1])}
)
metadata.update(dict(path=base, file=res_file, running_mean_period=_ma_window))

data_set.attrs.update(metadata)
Expand Down
227 changes: 227 additions & 0 deletions optim_esm_tools/analyze/combine_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import os
import typing as ty
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

import optim_esm_tools as oet
from optim_esm_tools.analyze.time_statistics import default_thresholds


class VariableMerger:
full_paths = None
source_files: ty.Mapping
common_mask: xr.DataArray

def __init__(self, paths, other_paths=None, merge_method='logical_or'):
self.mask_paths = paths
self.other_paths = other_paths or []
self.merge_method = merge_method
source_files, common_mask = self.process_masks()
self.source_files = source_files
self.common_mask = common_mask

def squash_sources(self) -> xr.Dataset:
common_mask = (
self.common_mask
if self.common_mask.dims == ('lat', 'lon')
else oet.analyze.xarray_tools.reverse_name_mask_coords(self.common_mask)
)

new_ds = defaultdict(dict)
new_ds['data_vars']['global_mask'] = common_mask
for var, path in self.source_files.items():
_ds = oet.load_glob(path)
new_ds['data_vars'][var] = (
_ds[var]
.where(common_mask)
.mean(oet.config.config['analyze']['lon_lat_dim'].split(','))
)
new_ds['data_vars'][var].attrs = _ds[var].attrs

# Make one copy - just use the last dataset
new_ds['data_vars']['cell_area'] = _ds['cell_area']
keys = sorted(list(self.source_files.keys()))
new_ds['attrs'] = dict(
variables=keys,
source_files=[self.source_files[k] for k in keys],
mask_files=sorted(self.mask_paths),
)
try:
new_ds = xr.Dataset(**new_ds)
except TypeError as e:
print(f'Ran into {e} fallback method because of cftime')
# Stupid cftime can't compare it's own formats
data_vars = new_ds.pop('data_vars')
new_ds = xr.Dataset(**new_ds)

# But xarray can fudge something along the way!
for k, v in data_vars.items():
new_ds[k] = v
return new_ds

def make_fig(self, new_ds=None, fig_kw=None):
new_ds = new_ds or self.squash_sources()
variables = list(new_ds.attrs['variables'])
mapping = {str(i): v for i, v in enumerate(variables)}
keys = list(mapping.keys()) + ['t']

fig_kw = fig_kw or dict(
mosaic=''.join(f'{k}.\n' for k in keys),
figsize=(17, 4 * ((1 + len(keys)) / 3)),
gridspec_kw=dict(width_ratios=[1, 1], wspace=0.1, hspace=0.05),
)

fig, axes = plt.subplot_mosaic(**fig_kw)

if len(keys) > 1:
for k in keys[1:]:
axes[k].sharex(axes[keys[0]])

for key, var in mapping.items():
plt.sca(axes[key])
plot_kw = dict(label=var)
oet.plotting.map_maker.plot_simple(new_ds, var, **plot_kw)
plt.legend(loc='center left')

ax = plt.gcf().add_subplot(
1, 2, 2, projection=oet.plotting.plot.get_cartopy_projection()
)
oet.plotting.map_maker.overlay_area_mask(
new_ds.where(new_ds['global_mask']).copy(), ax=ax
)
res_f, tips = result_table(new_ds)
add_table(res_f=res_f, tips=tips, ax=axes['t'])

def process_masks(self) -> ty.Tuple[dict, xr.DataArray]:
source_files = {}
common_mask = None
for path in self.mask_paths:
ds = oet.load_glob(path)
# Source files may be non-unique!
source_files[ds.attrs['variable_id']] = ds.attrs['file']
common_mask = self.combine_masks(common_mask, ds)
for other_path in self.other_paths:
if other_path == '':
continue
ds = oet.load_glob(other_path)
# Source files may be non-unique!
var = ds.attrs['variable_id']
if var not in source_files:
source_files[var] = ds.attrs['file']
return source_files, common_mask

def combine_masks(
self,
common_mask: ty.Optional[xr.DataArray],
other_dataset: xr.Dataset,
field: ty.Optional[str] = None,
) -> xr.DataArray:
field = field or (
'global_mask' if 'global_mask' in other_dataset else 'cell_area'
)
is_the_first_instance = common_mask is None
if is_the_first_instance:
return other_dataset[field]
if self.merge_method == 'logical_or':
return common_mask.astype(np.bool_) | other_dataset[field].astype(np.bool_)
else:
raise NotImplementedError


def change_plt_table_height():
"""Increase the height of rows in plt.table
Unfortunately, the options that you can pass to plt.table are insufficient to render a table
that has rows with sufficient heights that work with a font that is not the default. From the
plt.table implementation, I figured I could change these (rather patchy) lines in the source
code:
https://github.com/matplotlib/matplotlib/blob/b7dfdc5c97510733770429f38870a623426d0cdc/lib/matplotlib/table.py#L391
Matplotlib version matplotlib==3.7.2
"""
import matplotlib

print('Change default plt.table row height')

def _approx_text_height(self):
return 1.5 * (
self.FONTSIZE / 72.0 * self.figure.dpi / self._axes.bbox.height * 1.2
)

matplotlib.table.Table._approx_text_height = _approx_text_height


def add_table(res_f, tips, ax=None, fontsize=16):
ax = ax or plt.gcf().add_subplot(2, 2, 4)
ax.axis('off')
ax.axis('tight')

table = ax.table(
cellText=res_f.values,
rowLabels=res_f.index,
colLabels=res_f.columns,
cellColours=[
[([0.75, 1, 0.75] if v else [1, 1, 1]) for v in row] for row in tips.values
],
loc='bottom',
colLoc='center',
rowLoc='center',
cellLoc='center',
)
table.set_fontsize(fontsize)


def result_table(ds, formats=None):
res = {
field: summarize_stats(ds, field, path)
for field, path in zip(ds.attrs['variables'], ds.attrs['source_files'])
}
thrs = default_thresholds()
is_tip = pd.DataFrame(
{
k: {
t: (thrs[t][0](v, thrs[t][1]) if v is not None else False)
for t, v in d.items()
}
for k, d in res.items()
}
).T

formats = formats or dict(
n_breaks='.0f',
p_symmetry='.3f',
p_dip='.3f',
max_jump='.1f',
n_std_global='.1f',
)
res_f = pd.DataFrame(res).T
for k, f in formats.items():
res_f[k] = res_f[k].map(f'{{:,{f}}}'.format)

order = list(formats.keys())
return res_f[order], is_tip[order]


def summarize_stats(ds, field, path):
path = path
return {
'n_breaks': oet.analyze.time_statistics.calculate_n_breaks(ds, field=field),
'p_symmetry': oet.analyze.time_statistics.calculate_symmetry_test(
ds, field=field
),
'p_dip': oet.analyze.time_statistics.calculate_dip_test(ds, field=field),
'n_std_global': oet.analyze.time_statistics.n_times_global_std(
ds=oet.load_glob(path).where(ds['global_mask'])
),
'max_jump': oet.analyze.time_statistics.calculate_max_jump_in_std_history(
ds=oet.load_glob(path).where(ds['global_mask']), mask=ds['global_mask']
),
}


if __name__ == '__main__':
change_plt_table_height()
33 changes: 33 additions & 0 deletions optim_esm_tools/analyze/time_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing as ty
from functools import partial
import os
import operator


class TimeStatistics:
Expand Down Expand Up @@ -77,6 +78,38 @@ def calculate_statistics(self) -> ty.Dict[str, ty.Optional[float]]:
}


def default_thresholds(
max_jump=None,
p_dip=None,
p_symmetry=None,
n_breaks=None,
n_std_global=None,
):
return dict(
max_jump=(
operator.ge,
max_jump or float(oet.config.config['tipping_thresholds']['max_jump']),
),
p_dip=(
operator.le,
p_dip or float(oet.config.config['tipping_thresholds']['p_dip']),
),
p_symmetry=(
operator.le,
p_symmetry or float(oet.config.config['tipping_thresholds']['p_symmetry']),
),
n_breaks=(
operator.ge,
n_breaks or float(oet.config.config['tipping_thresholds']['n_breaks']),
),
n_std_global=(
operator.ge,
n_std_global
or float(oet.config.config['tipping_thresholds']['n_std_global']),
),
)


def _get_ds_global(ds, **read_kw):
path = ds.attrs['file']
if os.path.exists(path):
Expand Down
7 changes: 6 additions & 1 deletion optim_esm_tools/optim_esm_conf.ini
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ excluded =
; # Projection fails
; DKRZ MPI-ESM1-2-LR ssp119 r1i1p1f1 siconc * * v20210901


[tipping_thresholds]
max_jump=4
p_dip=0.01
p_symmetry=0.001
n_breaks=1
n_std_global=3

[log]
logging_level = WARNING
Expand Down
Loading

0 comments on commit 6854152

Please sign in to comment.