# Parameter Editing
This notebook gives a quick recipe for how to do calibration or sensitivity analysis with pywatershed. 
It is a design feature that parameters, more specifically the Parameter class, are read-only because 
it should be the case that parameters supplied are used and the code is not opaquely modifying these.

As a consequence, one has to make the Parameter class editable. Below this is accomplished by doing
`the_parameters.to_dd()` which returns a DatasetDict which is editable. One has to know something about
how DatasetDicts are constructed to edit effectively, information can be found in the [documentation](https://pywatershed.readthedocs.io/en/main/api/generated/pywatershed.base.DatasetDict.html#pywatershed.base.DatasetDict). 
The edited DatasetDict can be made a Parameters object again by `Parameters(**param_dict.data)`, as shown below. 

Note this notebook needs notebooks 0-1 to have been run in advance.

In [None]:
# auto-format the code in this notebook
%load_ext jupyter_black

In [None]:
import pathlib as pl
from pprint import pprint
import shutil

import numpy as np
import pywatershed as pws
import xarray as xr

In [None]:
domain_dir = pws.constants.__pywatershed_root__ / "data/drb_2yr"
nb_output_dir = pl.Path("./param_edits")
nb_output_dir.mkdir(exist_ok=True)
(nb_output_dir / "params").mkdir(exist_ok=True)

In [None]:
# A legacy PRMS parameter file
params = pws.parameters.PrmsParameters.load(domain_dir / "myparam.param")

In [None]:
param_list = []
param_files = []
for ii in range(11):
    param_dict = params.to_dd()  # copies by default
    multiplier = ii * 0.05 + 0.75
    print("multiplier = ", multiplier)
    param_dict.data_vars["K_coef"] *= multiplier
    param_file_name = (
        nb_output_dir / f"params/perturbed_params_{str(ii).zfill(3)}.nc"
    )
    param_files += [param_file_name]
    # These could avoid export to netcdf4 if just using in memory
    # could store in a list like: param_list.append(pws.Parameters(**param_dict.data))
    pws.Parameters(**param_dict.data).to_netcdf(
        param_file_name, use_xr=True
    )  # using xarray, more work necessary for nc4 export

In [None]:
# this provides a check that the values from file are what we expect
for ff in param_files:
    # the problem arises on the read with xarray default decoding
    # but we can just open the netcdf file as Parameters
    # ds = xr.open_dataset(ff, decode_times=False, decode_timedelta=False)
    # k_coef = ds["K_coef"]
    new_params = pws.parameters.PrmsParameters.from_netcdf(ff)
    k_coef = new_params.data_vars["K_coef"]
    multipliers = k_coef / params.data_vars["K_coef"]
    assert (multipliers - multipliers[0] < 1e-15).all()
    print(multipliers[0])

## A helper function for running the parameters through the model

In [None]:
def run_channel_model(output_dir_parent, param_file):
    # for concurrent.futures we have to write this function to file/module
    # so we have to import things that wont be in scope in that case.
    import numpy as np
    import pywatershed as pws

    domain_dir = pws.constants.__pywatershed_root__ / "data/drb_2yr"

    params = pws.parameters.PrmsParameters.from_netcdf(param_file)

    param_id = param_file.with_suffix("").name.split("_")[-1]
    nc_output_dir = output_dir_parent / f"run_params_{param_id}"
    nc_output_dir.mkdir(parents=True, exist_ok=True)

    control = pws.Control.load(domain_dir / "control.test")
    control.edit_end_time(np.datetime64("1979-07-01T00:00:00"))
    control.options = control.options | {
        "input_dir": "01_multi-process_models/nhm_memory",
        "budget_type": "warn",
        "calc_method": "numba",
        "netcdf_output_dir": nc_output_dir,
    }

    model = pws.Model(
        [pws.PRMSChannel],
        control=control,
        parameters=params,
    )

    model.run(finalize=True)
    return nc_output_dir

## Serial execution of the model over the parameter files

In [None]:
%%time
serial_output_dirs = []
serial_output_parent = nb_output_dir / "serial"
if serial_output_parent.exists():
    shutil.rmtree(serial_output_parent)
serial_output_parent.mkdir()
for ff in param_files:
    serial_output_dirs += [run_channel_model(serial_output_parent, ff)]

In [None]:
serial_output_dirs

## concurrent.futures approach
For concurrent futures to work in an interactive setting, we have to import the iterated/mapped function from a module, the function can not be defined in the notebook/interactive setting. We can easily just write the function out to file (ensure above that everything is in scope when imported, as noted in the function).

In [None]:
import inspect

with open("param_edits/run_channel_model.py", "w") as the_file:
    the_file.write(inspect.getsource(run_channel_model))

In [None]:
%%time
import time
from concurrent.futures import ProcessPoolExecutor as PoolExecutor
from concurrent.futures import as_completed
from functools import partial
from param_edits.run_channel_model import run_channel_model

parallel_output_parent = nb_output_dir / "parallel"
if parallel_output_parent.exists():
    shutil.rmtree(parallel_output_parent)
parallel_output_parent.mkdir()

# you can set your choice of max_workers
with PoolExecutor(max_workers=11) as executor:
    parallel_output_dirs = executor.map(
        partial(run_channel_model, parallel_output_parent), param_files
    )

# Checks

In [None]:
# check serial == parallel
serial_runs = sorted(serial_output_parent.glob("*"))
parallel_runs = sorted(parallel_output_parent.glob("*"))

for ss, pp in zip(serial_runs, parallel_runs):
    serial_files = sorted(ss.glob("*.nc"))
    parallel_files = sorted(pp.glob("*.nc"))
    for sf, pf in zip(serial_files, parallel_files):
        s_ds = xr.open_dataset(sf)
        p_ds = xr.open_dataset(pf)
        xr.testing.assert_allclose(s_ds, p_ds)

In [None]:
# check serial 5 is the same as in notebook 02
run_005 = serial_output_parent / "run_params_005"
files_005 = sorted(run_005.glob("*.nc"))
for ff in files_005:
    if ff.name == "PRMSChannel_budget.nc":
        continue
    ds_005 = xr.open_dataset(ff)
    ds_02 = xr.open_dataset(
        pl.Path("01_multi-process_models/nhm_memory") / ff.name
    )
    xr.testing.assert_allclose(ds_005, ds_02)