Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update example script #80

Merged
merged 4 commits into from Jul 15, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
84 changes: 21 additions & 63 deletions examples/train_create_emus_automated.py
@@ -1,57 +1,44 @@
# add pathway to folders 1 level higher (i.e., to mesmer and configs)
import sys
import warnings

# add pathway to configs folder
sys.path.append("../")

import os.path
import warnings

# load in configurations used in this script
import configs.config_across_scen_T_cmip6ng_test as cfg

# import MESMER tools
from mesmer.calibrate_mesmer import train_gt, train_gv, train_lt, train_lv
from mesmer.create_emulations import (
create_emus_g,
create_emus_gt,
create_emus_gv,
create_emus_l,
create_emus_lt,
create_emus_lv,
)
from mesmer.io import (
load_cmipng,
load_phi_gc,
load_regs_ls_wgt_lon_lat,
save_mesmer_bundle,
)
from mesmer.io import load_cmipng, load_phi_gc, load_regs_ls_wgt_lon_lat
from mesmer.utils import convert_dict_to_arr, extract_land, separate_hist_future

# where to save the bundle
bundle_out_file = os.path.join("tests", "test-data", "test-mesmer-bundle.pkl")
os.makedirs(os.path.dirname(bundle_out_file), exist_ok=True)

# specify the target variable
targ = cfg.targs[0]
print(targ)

# load in the ESM runs
esms = cfg.esms
print(esms)
print(len(esms))

# load in tas with global coverage
tas_g_dict = {} # tas with global coverage
GSAT_dict = {} # global mean tas
GHFDS_dict = {} # global mean hfds (needed as predictor)
tas_g = {}
GSAT = {}
GHFDS = {}
time = {}

for esm in esms:
print(esm)
tas_g_dict[esm] = {}
GSAT_dict[esm] = {}
GHFDS_dict[esm] = {}
time[esm] = {}

for scen in cfg.scenarios:
Expand All @@ -62,19 +49,17 @@

if tas_g_tmp is None:
warnings.warn(f"Scenario {scen} does not exist for tas for ESM {esm}")
else: # if scen exists: save fields + load hfds fields for it too
else: # if scen exists: save fields
tas_g_dict[esm][scen], GSAT_dict[esm][scen], lon, lat, time[esm][scen] = (
tas_g_tmp,
GSAT_tmp,
lon_tmp,
lat_tmp,
time_tmp,
)
_, GHFDS_dict[esm][scen], _, _, _ = load_cmipng("hfds", esm, scen, cfg)

tas_g[esm] = convert_dict_to_arr(tas_g_dict[esm])
GSAT[esm] = convert_dict_to_arr(GSAT_dict[esm])
GHFDS[esm] = convert_dict_to_arr(GHFDS_dict[esm])

# load in the constant files
reg_dict, ls, wgt_g, lon, lat = load_regs_ls_wgt_lon_lat(cfg.reg_type, lon, lat)
Expand All @@ -89,62 +74,49 @@
print(esm)

print(esm, "Start with global trend module")

params_gt_T = train_gt(GSAT[esm], targ, esm, time[esm], cfg, save_params=True)
params_gt_hfds = train_gt(GHFDS[esm], "hfds", esm, time[esm], cfg, save_params=True)

preds_gt = {"time": time[esm]}
emus_gt_T = create_emus_gt(
params_gt_T, preds_gt, cfg, concat_h_f=True, save_emus=True
)
gt_T_s = create_emus_gt(
params_gt_T, preds_gt, cfg, concat_h_f=False, save_emus=False
)
emus_gt_T = create_emus_gt(
params_gt_T, preds_gt, cfg, concat_h_f=True, save_emus=True
)

print(
esm,
"Start preparing predictors for global variability, local trends, and local variability",
)
gt_T2_s = {}
for scen in gt_T_s.keys():
gt_T2_s[scen] = gt_T_s[scen] ** 2

gt_hfds_s = create_emus_gt(
params_gt_hfds, preds_gt, cfg, concat_h_f=False, save_emus=False
)

gv_novolc_T = {}
for scen in emus_gt_T.keys():
gv_novolc_T[scen] = GSAT[esm][scen] - emus_gt_T[scen]
gv_novolc_T_s, time_s = separate_hist_future(gv_novolc_T, time[esm], cfg)
GSAT_s, time_s = separate_hist_future(GSAT[esm], time[esm], cfg)
gv_novolc_T_s = {}
for scen in gt_T_s.keys():
gv_novolc_T_s[scen] = GSAT_s[scen] - gt_T_s[scen]

tas_s, time_s = separate_hist_future(tas[esm], time[esm], cfg)

print(esm, "Start with global variability module")

params_gv_T = train_gv(gv_novolc_T_s, targ, esm, cfg, save_params=True)

time_v = {}
scen = list(emus_gt_T.keys())[0]
time_v["all"] = time[esm][scen]
# remember: scen comes from emus_gt_T.keys() here
# (= necessary to derive compatible emus_gt & emus_gv)
preds_gv = {"time": time_v}
emus_gv_T = create_emus_gv(params_gv_T, preds_gv, cfg, save_emus=True)

print(esm, "Merge the global trend and the global variability.")
emus_g_T = create_emus_g(
emus_gt_T, emus_gv_T, params_gt_T, params_gv_T, cfg, save_emus=True
)

print(esm, "Start with local trends module")

preds = {
"gttas": gt_T_s,
"gttas2": gt_T2_s,
"gthfds": gt_hfds_s,
"gvtas": gv_novolc_T_s,
} # predictors_list
targs = {"tas": tas_s} # targets list
}
targs = {"tas": tas_s}
params_lt, params_lv = train_lt(preds, targs, esm, cfg, save_params=True)

preds_lt = {"gttas": gt_T_s, "gttas2": gt_T2_s, "gthfds": gt_hfds_s}
preds_lt = {"gttas": gt_T_s}
lt_s = create_emus_lt(params_lt, preds_lt, cfg, concat_h_f=False, save_emus=True)
emus_lt = create_emus_lt(params_lt, preds_lt, cfg, concat_h_f=True, save_emus=True)

Expand All @@ -161,9 +133,7 @@

# load in the auxiliary files
aux = {}
aux["phi_gc"] = load_phi_gc(
lon, lat, ls, cfg, L_start=1750, L_end=2000, L_interval=250
) # better results with default values L, but like this much faster + less space needed
aux["phi_gc"] = load_phi_gc(lon, lat, ls, cfg)

# train lv AR1_sci on residual variability
targs_res_lv = {"tas": res_lv_s}
Expand All @@ -178,15 +148,3 @@
# create full emulations
print(esm, "Merge the local trends and the local variability.")
emus_l = create_emus_l(emus_lt, emus_lv, params_lt, params_lv, cfg, save_emus=True)

save_mesmer_bundle(
bundle_out_file,
params_lt,
params_lv,
params_gv_T,
seeds=cfg.seed,
land_fractions=ls["grid_l_m"],
lat=lat["c"],
lon=lon["c"],
time=time_s,
)