Skip to content

Commit

Permalink
train_gt: use separate_hist_future (#281)
Browse files Browse the repository at this point in the history
* train_gt: use separate_hist_future

* add changelog
  • Loading branch information
mathause committed Sep 5, 2023
1 parent fc43f37 commit cd79b13
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 50 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -130,6 +130,13 @@ Documentation
Internal Changes
^^^^^^^^^^^^^^^^

- Refactor the mesmer internals to use the new statistical core, employ helper functions
etc.:

- Use :py:func:`mesmer.utils.separate_hist_future` in :py:func:`mesmer.calibrate_mesmer.train_gt` (`#281 <https://github.com/MESMER-group/mesmer/pull/281>`_).

By `Mathias Hauser <https://github.com/mathause>`_.

- Restore compatibility with regionmask v0.9.0 (`#136 <https://github.com/MESMER-group/mesmer/pull/136>`_).
By `Mathias Hauser <https://github.com/mathause>`_.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api.rst
Expand Up @@ -187,7 +187,7 @@ Utils
:toctree: generated/

~utils.convert.convert_dict_to_arr
~utils.convert.separate_hist_future
~utils.separate_hist_future
~utils.select.extract_land
~utils.select.extract_time_period
~utils.regionmaskcompat import mask_3D_frac_approx
72 changes: 26 additions & 46 deletions mesmer/calibrate_mesmer/train_gt.py
Expand Up @@ -15,6 +15,7 @@
from mesmer.io.save_mesmer_bundle import save_mesmer_data
from mesmer.stats.linear_regression import LinearRegression
from mesmer.stats.smoothing import lowess
from mesmer.utils import separate_hist_future


def train_gt(var, targ, esm, time, cfg, save_params=True):
Expand Down Expand Up @@ -70,80 +71,59 @@ def train_gt(var, targ, esm, time, cfg, save_params=True):
"""

# specify necessary variables from config file
gen = cfg.gen
# specify necessary variables from config
method_gt = cfg.methods[targ]["gt"]
preds_gt = cfg.preds[targ]["gt"]

scenarios_tr = list(var.keys())
scenarios = list(var.keys())

# initialize parameters dictionary and fill in the metadata which does not depend on
# the applied method
# initialize param dict and fill in the metadata which does not depend on the method
params_gt = {}
params_gt["targ"] = targ
params_gt["esm"] = esm
params_gt["method"] = method_gt
params_gt["preds"] = preds_gt
params_gt["scenarios"] = scenarios_tr # single entry in case of ic ensemble
params_gt["scenarios"] = scenarios # single entry in case of ic ensemble

# apply the chosen method to the type of ensenble
gt = {}
if "LOWESS" in params_gt["method"]:
# i.e. derive gt for each scen individually
for scen in scenarios_tr:
gt[scen], frac_lowess_name = train_gt_ic_LOWESS(var[scen])
params_gt["frac_lowess"] = frac_lowess_name
# derive gt for each scen individually
for scen in scenarios:
gt[scen], frac_lowess = train_gt_ic_LOWESS(var[scen])
params_gt["frac_lowess"] = frac_lowess
else:
raise ValueError("No alternative method to LOWESS is implemented for now.")

params_gt["time"] = {}
# if hist included
if scenarios[0].startswith("h-"):

# i.e. if hist included
if scenarios_tr[0][:2] == "h-":

if gen == 5:
start_year_fut = 2005
elif gen == 6:
start_year_fut = 2014

idx_start_year_fut = np.where(time[scen] == start_year_fut)[0][0] + 1

params_gt["time"]["hist"] = time[scen][:idx_start_year_fut]
gt_s, time_s = separate_hist_future(gt, time, cfg)

# compute median LOWESS estimate of historical part across all scenarios
gt_lowess_hist_all = np.zeros([len(gt.keys()), len(params_gt["time"]["hist"])])
for i, scen in enumerate(gt.keys()):
gt_lowess_hist_all[i] = gt[scen][:idx_start_year_fut]
gt_lowess_hist = np.median(gt_lowess_hist_all, axis=0)
gt_lowess_hist_all_new = gt_s.pop("hist")

if params_gt["method"] == "LOWESS_OLSVOLC":
scen = scenarios_tr[0]
var_all = var[scen][:, :idx_start_year_fut]
for scen in scenarios_tr[1:]:
var_tmp = var[scen][:, :idx_start_year_fut]
var_all = np.vstack([var_all, var_tmp])
gt_hist_median = np.median(gt_lowess_hist_all_new, axis=0)

# check for duplicates & exclude those runs
var_all = np.unique(var_all, axis=0)
if params_gt["method"] == "LOWESS_OLSVOLC":
var_s, time_s = separate_hist_future(var, time, cfg)

params_gt["saod"], params_gt["hist"] = train_gt_ic_OLSVOLC(
var_all, gt_lowess_hist, params_gt["time"]["hist"]
var_s["hist"], gt_hist_median, time_s["hist"]
)

elif params_gt["method"] == "LOWESS":
params_gt["hist"] = gt_lowess_hist
params_gt["hist"] = gt_hist_median

# isolate future scen names
scenarios_tr_f = [scen.replace("h-", "") for scen in scenarios_tr]
gt_to_distribute = gt_s
params_gt["time"] = time_s

else:
# because first year would be already in future
idx_start_year_fut = 0
# because only future covered anyways
scenarios_tr_f = scenarios_tr
gt_to_distribute = gt
params_gt["time"] = time

for scen_f, scen in zip(scenarios_tr_f, scenarios_tr):
params_gt["time"][scen_f] = time[scen][idx_start_year_fut:]
params_gt[scen_f] = gt[scen][idx_start_year_fut:]
for scen, data in gt_to_distribute.items():
params_gt[scen] = data.squeeze()

# save the global trend paramters if requested
if save_params:
Expand All @@ -158,7 +138,7 @@ def train_gt(var, targ, esm, time, cfg, save_params=True):
*preds_gt,
targ,
esm,
*scenarios_tr,
*scenarios,
),
)

Expand Down
7 changes: 4 additions & 3 deletions mesmer/utils/convert.py
Expand Up @@ -91,15 +91,16 @@ def separate_hist_future(var_c, time_c, cfg):
var_s, time_s = {}, {}

# gather hist
hist = [var_c[scen_c][:, :idx_start_fut] for scen_c in scens_c]
hist = [np.atleast_2d(var_c[scen_c])[:, :idx_start_fut] for scen_c in scens_c]
hist = np.vstack(hist)

# exclude duplicate historical runs that are available in several scenarios
var_s["hist"] = np.unique(np.vstack(hist), axis=0)
var_s["hist"] = np.unique(hist, axis=0)
time_s["hist"] = time[:idx_start_fut]

# gather proj
for scen_f, scen_c in zip(scens_f, scens_c):
var_s[scen_f] = var_c[scen_c][:, idx_start_fut:]
var_s[scen_f] = np.atleast_2d(var_c[scen_c])[:, idx_start_fut:]
time_s[scen_f] = time_c[scen_c][idx_start_fut:]

return var_s, time_s

0 comments on commit cd79b13

Please sign in to comment.