Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train_gt: use separate_hist_future #281

Merged
merged 2 commits into from Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -130,6 +130,13 @@ Documentation
Internal Changes
^^^^^^^^^^^^^^^^

- Refactor the mesmer internals to use the new statistical core, employ helper functions
etc.:

- Use :py:func:`mesmer.utils.separate_hist_future` in :py:func:`mesmer.calibrate_mesmer.train_gt` (`#281 <https://github.com/MESMER-group/mesmer/pull/281>`_).

By `Mathias Hauser <https://github.com/mathause>`_.

- Restore compatibility with regionmask v0.9.0 (`#136 <https://github.com/MESMER-group/mesmer/pull/136>`_).
By `Mathias Hauser <https://github.com/mathause>`_.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api.rst
Expand Up @@ -187,7 +187,7 @@ Utils
:toctree: generated/

~utils.convert.convert_dict_to_arr
~utils.convert.separate_hist_future
~utils.separate_hist_future
~utils.select.extract_land
~utils.select.extract_time_period
~utils.regionmaskcompat import mask_3D_frac_approx
72 changes: 26 additions & 46 deletions mesmer/calibrate_mesmer/train_gt.py
Expand Up @@ -15,6 +15,7 @@
from mesmer.io.save_mesmer_bundle import save_mesmer_data
from mesmer.stats.linear_regression import LinearRegression
from mesmer.stats.smoothing import lowess
from mesmer.utils import separate_hist_future


def train_gt(var, targ, esm, time, cfg, save_params=True):
Expand Down Expand Up @@ -70,80 +71,59 @@ def train_gt(var, targ, esm, time, cfg, save_params=True):

"""

# specify necessary variables from config file
gen = cfg.gen
# specify necessary variables from config
method_gt = cfg.methods[targ]["gt"]
preds_gt = cfg.preds[targ]["gt"]

scenarios_tr = list(var.keys())
scenarios = list(var.keys())

# initialize parameters dictionary and fill in the metadata which does not depend on
# the applied method
# initialize param dict and fill in the metadata which does not depend on the method
params_gt = {}
params_gt["targ"] = targ
params_gt["esm"] = esm
params_gt["method"] = method_gt
params_gt["preds"] = preds_gt
params_gt["scenarios"] = scenarios_tr # single entry in case of ic ensemble
params_gt["scenarios"] = scenarios # single entry in case of ic ensemble

# apply the chosen method to the type of ensenble
gt = {}
if "LOWESS" in params_gt["method"]:
# i.e. derive gt for each scen individually
for scen in scenarios_tr:
gt[scen], frac_lowess_name = train_gt_ic_LOWESS(var[scen])
params_gt["frac_lowess"] = frac_lowess_name
# derive gt for each scen individually
for scen in scenarios:
gt[scen], frac_lowess = train_gt_ic_LOWESS(var[scen])
params_gt["frac_lowess"] = frac_lowess
else:
raise ValueError("No alternative method to LOWESS is implemented for now.")

params_gt["time"] = {}
# if hist included
if scenarios[0].startswith("h-"):

# i.e. if hist included
if scenarios_tr[0][:2] == "h-":

if gen == 5:
start_year_fut = 2005
elif gen == 6:
start_year_fut = 2014

idx_start_year_fut = np.where(time[scen] == start_year_fut)[0][0] + 1

params_gt["time"]["hist"] = time[scen][:idx_start_year_fut]
gt_s, time_s = separate_hist_future(gt, time, cfg)

# compute median LOWESS estimate of historical part across all scenarios
gt_lowess_hist_all = np.zeros([len(gt.keys()), len(params_gt["time"]["hist"])])
for i, scen in enumerate(gt.keys()):
gt_lowess_hist_all[i] = gt[scen][:idx_start_year_fut]
gt_lowess_hist = np.median(gt_lowess_hist_all, axis=0)
gt_lowess_hist_all_new = gt_s.pop("hist")

if params_gt["method"] == "LOWESS_OLSVOLC":
scen = scenarios_tr[0]
var_all = var[scen][:, :idx_start_year_fut]
for scen in scenarios_tr[1:]:
var_tmp = var[scen][:, :idx_start_year_fut]
var_all = np.vstack([var_all, var_tmp])
gt_hist_median = np.median(gt_lowess_hist_all_new, axis=0)

# check for duplicates & exclude those runs
var_all = np.unique(var_all, axis=0)
if params_gt["method"] == "LOWESS_OLSVOLC":
var_s, time_s = separate_hist_future(var, time, cfg)

params_gt["saod"], params_gt["hist"] = train_gt_ic_OLSVOLC(
var_all, gt_lowess_hist, params_gt["time"]["hist"]
var_s["hist"], gt_hist_median, time_s["hist"]
)

elif params_gt["method"] == "LOWESS":
params_gt["hist"] = gt_lowess_hist
params_gt["hist"] = gt_hist_median

# isolate future scen names
scenarios_tr_f = [scen.replace("h-", "") for scen in scenarios_tr]
gt_to_distribute = gt_s
params_gt["time"] = time_s

else:
# because first year would be already in future
idx_start_year_fut = 0
# because only future covered anyways
scenarios_tr_f = scenarios_tr
gt_to_distribute = gt
params_gt["time"] = time

for scen_f, scen in zip(scenarios_tr_f, scenarios_tr):
params_gt["time"][scen_f] = time[scen][idx_start_year_fut:]
params_gt[scen_f] = gt[scen][idx_start_year_fut:]
for scen, data in gt_to_distribute.items():
params_gt[scen] = data.squeeze()

# save the global trend paramters if requested
if save_params:
Expand All @@ -158,7 +138,7 @@ def train_gt(var, targ, esm, time, cfg, save_params=True):
*preds_gt,
targ,
esm,
*scenarios_tr,
*scenarios,
),
)

Expand Down
7 changes: 4 additions & 3 deletions mesmer/utils/convert.py
Expand Up @@ -91,15 +91,16 @@ def separate_hist_future(var_c, time_c, cfg):
var_s, time_s = {}, {}

# gather hist
hist = [var_c[scen_c][:, :idx_start_fut] for scen_c in scens_c]
hist = [np.atleast_2d(var_c[scen_c])[:, :idx_start_fut] for scen_c in scens_c]
hist = np.vstack(hist)

# exclude duplicate historical runs that are available in several scenarios
var_s["hist"] = np.unique(np.vstack(hist), axis=0)
var_s["hist"] = np.unique(hist, axis=0)
time_s["hist"] = time[:idx_start_fut]

# gather proj
for scen_f, scen_c in zip(scens_f, scens_c):
var_s[scen_f] = var_c[scen_c][:, idx_start_fut:]
var_s[scen_f] = np.atleast_2d(var_c[scen_c])[:, idx_start_fut:]
time_s[scen_f] = time_c[scen_c][idx_start_fut:]

return var_s, time_s