Skip to content

Commit

Permalink
Merge f3fb29f into 0483b48
Browse files Browse the repository at this point in the history
  • Loading branch information
JoranAngevaare committed Jun 13, 2024
2 parents 0483b48 + f3fb29f commit 1968797
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 76 deletions.
15 changes: 4 additions & 11 deletions optim_esm_tools/analyze/cmip_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def read_ds(

if pre_process:
data_set = oet.analyze.pre_process.get_preprocessed_ds(
source=data_path,
sources=data_path,
historical_path=_historical_path,
max_time=max_time,
min_time=min_time,
Expand Down Expand Up @@ -201,17 +201,10 @@ def read_ds(

if _cache:
log.info(f'Write {res_file}')
comp_kw = {}
if oet.config.config['CMIP_files']['compress'] == 'True':
comp_kw = dict(
format='NETCDF4',
engine='netcdf4',
encoding={
k: {'zlib': True, 'complevel': 1} for k in data_set.data_vars
},
)

data_set.to_netcdf(res_file, **comp_kw)
oet.analyze.pre_process.save_nc(data_set, res_file)
else:
data_set.to_netcdf(res_file)

return data_set

Expand Down
11 changes: 7 additions & 4 deletions optim_esm_tools/analyze/globals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import typing as ty
from optim_esm_tools.config import config

_SECONDS_TO_YEAR = int(config['constants']['seconds_to_year'])
_FOLDER_FMT = config['CMIP_files']['folder_fmt'].split()
_CMIP_HANDLER_VERSION = config['versions']['cmip_handler']
_DEFAULT_MAX_TIME = tuple(int(s) for s in config['analyze']['max_time'].split())
_SECONDS_TO_YEAR: int = int(config['constants']['seconds_to_year'])
_FOLDER_FMT: ty.List[str] = config['CMIP_files']['folder_fmt'].split()
_CMIP_HANDLER_VERSION: str = config['versions']['cmip_handler']
_DEFAULT_MAX_TIME: ty.Tuple[int, ...] = tuple(
int(s) for s in config['analyze']['max_time'].split()
)
255 changes: 194 additions & 61 deletions optim_esm_tools/analyze/pre_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,21 @@
from optim_esm_tools.analyze.xarray_tools import _native_date_fmt
from optim_esm_tools.config import config
from optim_esm_tools.config import get_logger
from optim_esm_tools.utils import timed


def get_preprocessed_ds(source, **kw):
from optim_esm_tools.utils import timed, check_accepts, to_str_tuple
from pandas.util._decorators import deprecate_kwarg


@deprecate_kwarg('source', 'sources')
@check_accepts(dict(return_type=('path', 'data_set')))
def get_preprocessed_ds(
sources: ty.Union[str, tuple, list],
year_mean=False,
_check_duplicate_years=True,
return_type='data_set',
skip_compression=False,
temp_dir_location=None,
**kw,
):
"""Create a temporary working directory for pre-process and delete all
intermediate files."""
if 'working_dir' in kw: # pragma: no cover
Expand All @@ -23,7 +34,29 @@ def get_preprocessed_ds(source, **kw):
'dataset, and remove all local files.'
)
get_logger().warning(message)
with tempfile.TemporaryDirectory() as temp_dir:
if return_type == 'path':
assert 'save_as' in kw
store_final = kw.pop('save_as')

with tempfile.TemporaryDirectory(dir=temp_dir_location) as temp_dir:
if year_mean:
old_sources = to_str_tuple(sources)
new_sources = [
os.path.join(temp_dir, os.path.split(p)[1]) for p in old_sources
]
for o, n in zip(old_sources, new_sources):
_year_mon_mean(o, n)
assert os.path.exists(n)
sources = new_sources
if isinstance(sources, (list, tuple)) and len(sources) > 1:
source_tmp = os.path.join(temp_dir, 'sources_merged.nc')
_merge_sources(list(sources), source_tmp)
source = source_tmp
elif len(sources) == 1:
source = sources[0]
else:
assert isinstance(sources, str)
source = sources
defaults = dict(
source=source,
working_dir=temp_dir,
Expand All @@ -32,11 +65,48 @@ def get_preprocessed_ds(source, **kw):
)
for k, v in defaults.items():
kw.setdefault(k, v)
intermediate_file = pre_process(**kw)
# After with close this "with", we lose the file, so load it just to be sure we have all we need
ds = load_glob(intermediate_file).load() # type: ignore
sanity_check(ds)
return ds
intermediate_file = pre_process(
**kw,
_check_duplicate_years=_check_duplicate_years,
)
ds = load_glob(intermediate_file)
if return_type == 'data_set':
# After with close this "with", we lose the file, so load it just to be sure we have all we need
ds = ds.load() # type: ignore
ret = ds
elif skip_compression:
shutil.move(intermediate_file, store_final)
ret = store_final
else:
save_nc(ds, store_final)
ret = store_final

if _check_duplicate_years:
sanity_check(ds)
return ret


def _merge_sources(source_files: ty.List[str], f_tmp: str) -> None: # pragma: no cover
import cdo

cdo_int = cdo.Cdo()
cdo_int.mergetime(input=source_files, output=f_tmp)


def _year_mon_mean(input, output):
import cdo

cdo_int = cdo.Cdo()
cdo_int.yearmonmean(input=input, output=output)


def save_nc(ds, path):
comp_kw = dict(
format='NETCDF4',
engine='netcdf4',
encoding={k: {'zlib': True, 'complevel': 1} for k in ds.data_vars},
)
ds.to_netcdf(path, **comp_kw)


def sanity_check(ds):
Expand All @@ -50,18 +120,20 @@ def sanity_check(ds):
t_prev = t_cur


@timed
def pre_process(
source: str,
historical_path: ty.Optional[str] = None,
target_grid: ty.Optional[str] = None,
target_grid: ty.Union[None, str, bool] = None,
max_time: ty.Optional[ty.Tuple[int, int, int]] = _DEFAULT_MAX_TIME,
min_time: ty.Optional[ty.Tuple[int, int, int]] = None,
save_as: ty.Optional[str] = None,
clean_up: bool = True,
_ma_window: ty.Union[int, str, None] = None,
variable_id: ty.Optional[str] = None,
working_dir: ty.Optional[str] = None,
_check_duplicate_years=True,
do_detrend=True,
do_running_mean=True,
) -> str: # type: ignore
"""Apply several preprocessing steps to the file located at <source>:
Expand Down Expand Up @@ -98,33 +170,33 @@ def pre_process(
if historical_path is not None:
_remove_bad_vars(historical_path)
variable_id = variable_id or _read_variable_id(source)
max_time = max_time or (9999, 12, 30) # unreasonably far away
min_time = min_time or (0, 1, 1) # unreasonably long ago
use_max_time = max_time or (9999, 12, 30) # unreasonably far away
use_min_time = min_time or (0, 1, 1) # unreasonably long ago

do_regrid = target_grid != False
target_grid = target_grid or config['analyze']['regrid_to']

_ma_window = _ma_window or config['analyze']['moving_average_years']
_check_time_range(source, max_time, min_time, _ma_window)
_check_time_range(source, use_max_time, use_min_time, _ma_window)

cdo_int = cdo.Cdo()
head, _ = os.path.split(source)
working_dir = working_dir or head

# Several intermediate_files
f_time = os.path.join(working_dir, 'time_sel.nc')
f_det = os.path.join(working_dir, 'detrend.nc')
f_det_rm = os.path.join(working_dir, f'detrend_rm_{_ma_window}.nc')
f_rm = os.path.join(working_dir, f'rm_{_ma_window}.nc')
f_tmp = os.path.join(working_dir, 'tmp.nc')
f_regrid = os.path.join(working_dir, 'regrid.nc')
f_area = os.path.join(working_dir, 'area.nc')
f_det = os.path.join(working_dir, 'detrend.nc')
f_rm = os.path.join(working_dir, f'rm_{_ma_window}.nc')
f_det_rm = os.path.join(working_dir, f'detrend_rm_{_ma_window}.nc')
files = [f_time, f_det, f_det_rm, f_rm, f_tmp, f_regrid, f_area]

save_as = save_as or os.path.join(working_dir, 'result.nc')

# Several names:
var = variable_id
var_det = f'{var}_detrend'
var_rm = f'{var}_run_mean_{_ma_window}'
var_det_rm = f'{var_det}_run_mean_{_ma_window}'

for p in files + [save_as]:
if p == source:
Expand All @@ -143,41 +215,104 @@ def pre_process(
f_tmp,
)
source = f_tmp
_remove_duplicate_time_stamps(source)
time_range = f'{_fmt_date(min_time)},{_fmt_date(max_time)}'
cdo_int.seldate(time_range, input=source, output=f_time) # type: ignore

cdo_int.remapbil(target_grid, input=f_time, output=f_regrid) # type: ignore
cdo_int.gridarea(input=f_regrid, output=f_area) # type: ignore

cdo_int.detrend(input=f_regrid, output=f_tmp) # type: ignore
cdo_int.chname(f'{var},{var_det}', input=f_tmp, output=f_det) # type: ignore
os.remove(f_tmp)

cdo_int.runmean(_ma_window, input=f_regrid, output=f_tmp) # type: ignore
_run_mean_patch(
f_start=f_regrid,
f_rm=f_tmp,
f_out=f_rm,
ma_window=_ma_window,
var_name=var,
var_rm_name=var_rm,
)
os.remove(f_tmp)

cdo_int.detrend(input=f_rm, output=f_tmp) # type: ignore
cdo_int.chname(f'{var_rm},{var_det_rm}', input=f_tmp, output=f_det_rm) # type: ignore
# remove in cleanup
if _check_duplicate_years:
_remove_duplicate_time_stamps(source)

next_source = source
if min_time is not None or max_time is not None:
time_range = f'{_fmt_date(use_min_time)},{_fmt_date(use_max_time)}'
cdo_int.seldate(time_range, input=next_source, output=f_time) # type: ignore
next_source = f_time

if do_regrid:
cdo_int.remapbil(target_grid, input=next_source, output=f_regrid) # type: ignore
cdo_int.gridarea(input=f_regrid, output=f_area) # type: ignore
next_source = f_regrid
input_files = [next_source, f_area]
else:
input_files = [next_source]

if do_detrend:
var_det = f'{var}_detrend'
cdo_int.detrend(input=next_source, output=f_tmp) # type: ignore
cdo_int.chname(f'{var},{var_det}', input=f_tmp, output=f_det) # type: ignore
os.remove(f_tmp)
input_files += [f_det]

if do_running_mean:
var_rm = f'{var}_run_mean_{_ma_window}'
cdo_int.runmean(_ma_window, input=f_regrid, output=f_tmp) # type: ignore
_run_mean_patch(
f_start=f_regrid,
f_rm=f_tmp,
f_out=f_rm,
ma_window=_ma_window,
var_name=var,
var_rm_name=var_rm,
)
os.remove(f_tmp)
input_files += [f_rm]

input_files = ' '.join([f_regrid, f_det, f_det_rm, f_rm, f_area])
cdo_int.merge(input=input_files, output=save_as) # type: ignore
if do_running_mean and do_detrend:
var_det_rm = f'{var_det}_run_mean_{_ma_window}'
cdo_int.detrend(input=f_rm, output=f_tmp) # type: ignore
cdo_int.chname(f'{var_rm},{var_det_rm}', input=f_tmp, output=f_det_rm) # type: ignore
input_files += [f_det_rm]
get_logger().warning(f'Join {input_files} to {save_as}')
cdo_int.merge(input=' '.join(input_files), output=save_as) # type: ignore

if clean_up: # pragma: no cover
for p in files:
os.remove(p)
if os.path.exists(p):
os.remove(p)
return save_as


def _quick_drop_duplicates(ds, t_span, t_len, path):
ds = ds.drop_duplicates('time')
if (t_new_len := len(ds['time'])) > t_span + 1:
raise ValueError(f'{t_new_len} too long! Started with {t_len} and {t_span}')
get_logger().warning('Timestamp issue solved')
with tempfile.TemporaryDirectory() as temp_dir:
save_as = os.path.join(temp_dir, 'temp.nc')
ds.to_netcdf(save_as)
# move the old file
os.rename(path, os.path.join(os.path.split(path)[0], 'faulty_merged.nc'))
shutil.copy2(save_as, path)


def _drop_duplicates_carefully(ds, t_span, t_len, path):
from tqdm import tqdm

# As we only do this for huge datasets, it might be that /tmp doesn't allow storing sufficient data.
work_dir = os.path.split(path)[0]
with tempfile.TemporaryDirectory(dir=work_dir) as temp_dir:
keep_years = np.argwhere(np.diff([t.year for t in ds['time'].values]) == 1)[
:,
0,
]
keep_years = [0] + [int(i) + 1 for i in keep_years]
saves = []
for i in tqdm(keep_years):
save_as = os.path.join(temp_dir, f'temp_{i}.nc')
saves.append(save_as)
ds.isel(time=slice(i, i + 1)).load().to_netcdf(save_as)
# move the old file
_tempf = os.path.join(temp_dir, 'temp_merge.nc')
get_logger().warning(f'Merging {saves} -> {_tempf}')
_merge_sources(saves, _tempf)
ds = load_glob(_tempf)
if (t_new_len := len(ds['time'])) > t_span + 1:
raise ValueError(
f'{t_new_len} too long! Started with {t_len} and {t_span}',
)
get_logger().warning('Timestamp issue solved')

os.rename(path, os.path.join(os.path.split(path)[0], 'faulty_merged.nc'))
os.rename(_tempf, path)
get_logger().warning(f'_remove_duplicate_time_stamps - > Fixed!')


def _remove_duplicate_time_stamps(path): # pragma: no cover
ds = load_glob(path)
if (t_len := len(ds['time'])) > (
Expand All @@ -186,24 +321,18 @@ def _remove_duplicate_time_stamps(path): # pragma: no cover
get_logger().warning(
f'Finding {t_len} timestamps in {t_span} years - removing duplicates',
)
ds = ds.drop_duplicates('time')
if (t_new_len := len(ds['time'])) > t_span + 1:
raise ValueError(f'{t_new_len} too long! Started with {t_len} and {t_span}')
get_logger().warning('Timestamp issue solved')
with tempfile.TemporaryDirectory() as temp_dir:
save_as = os.path.join(temp_dir, 'temp.nc')
ds.to_netcdf(save_as)
# move the old file
os.rename(path, os.path.join(os.path.split(path)[0], 'faulty_merged.nc'))
shutil.copy2(save_as, path)
if ds.nbytes / 1e6 < 1_000:
_quick_drop_duplicates(ds, t_span, t_len, path)
else:
_drop_duplicates_carefully(ds, t_span, t_len, path)


def _remap_and_merge(
cdo_int,
cdo,
historical_path: str,
source: str,
target_grid: str,
target_grid: ty.Union[bool, str],
working_dir: str,
f_tmp: str,
) -> None: # pragma: no cover
Expand Down Expand Up @@ -236,7 +365,11 @@ def _remap_and_merge(
try:
cdo_int.mergetime(input=[historical_path, source], output=f_tmp)
except cdo.CDOException as e: # pragma: no cover
get_logger().error(f'Ran into {e}, let\'s regrid first and retry')
get_logger().error(f"Ran into {e}, let's regrid first and retry")
if target_grid == False:
raise ValueError(
f'Cannot merge {historical_path} and {source} since target grid is False and we ran into {e}',
) from e
cdo_int.remapbil(
target_grid,
input=historical_path,
Expand Down

0 comments on commit 1968797

Please sign in to comment.