# Parameter Ensemble
This notebook shows how to edit and work with parameters in pywatershed. First we look at the data model used by the `Parameter` class to build a small ensemble of parameters for the `PRMSChannel` hydrologic process. Then we do a little bit of (embarassingly) parallel programming using Python's [concurrent.futures](https://docs.python.org/3/library/concurrent.futures.html) to run this ensemble in parallel (in addition to serial). This provides a skeleton recipe for how to do calibration or sensitivity analysis with pywatershed. 

It is a design feature that the `Parameter` class is read-only. This is because we dont want the code and developers modifying parameters opaquely under the hood. While this practice is commonplace, it undermines the idea of reproducible research and causes more headaches than it sovles. So we guard against this with software design. The trick is that we need to make the `Parameter` object editable, but that means we have to change it to another class first. 

Let's get started. 

Note this notebook needs notebooks 0-1 to have been run in advance.

## Preliminaries

In [None]:
# auto-format the code in this notebook
%load_ext jupyter_black

In [None]:
import os
import pathlib as pl
from pprint import pprint
import shutil

import numpy as np
import pywatershed as pws
import xarray as xr

We'll use a PRMS-native parameter file from one of the domains supplied with pywatershed on install.

In [None]:
domain_dir = pws.constants.__pywatershed_root__ / "data/drb_2yr"
nb_output_dir = pl.Path("./05_parameter_ensemble")
nb_output_dir.mkdir(exist_ok=True)
(nb_output_dir / "params").mkdir(exist_ok=True)

In [None]:
params = pws.parameters.PrmsParameters.load(domain_dir / "myparam.param")

In [None]:
print(params)
isinstance(params, pws.Parameters)

## Create an ensemble or parameters
Now that we have the PRMS parameters as a `pws.Parameters` object, actually as its subclass `PrmsParameters`, we'll conduct a simple demonstration of how to generate an ensemble of values for `K_coef`, the Muskingum storage coefficient which affects the travel time of waves in the `PRMSChannel` representation.

As mentioned above, we have to get the parameter data in to a different class to be able to edit. Here we have two options: 1) an `xarray.Dataset` 2) a `pywatershed.DatasetDict`. These two options have invertible mappings provided by `pywatershed.DatasetDict`. 

First we'll deomonstrate the approach with `xarray.Dataset`. We'll create an ensemble with 11 members and we'll write the new parameter datasets, including all the variables, out to disk as separate NetCDF files. Note we could do this in memory and not write to disk, but generally it is favorable to have a record of inputs. This also demonstrates how to how one can quite easily convert a native PRMS parameter file to a NetCDF file. 

We'll just multiply the `K_coef` the coefficient by the 11 numbers in 0.75, 0.8, ... , 1.2, 1.25 to get our 11 realizations.

In [None]:
param_files = []  # get a list of written NetCDF files back at the end
n_members = 11
for ii in range(n_members):
    param_ds = params.to_xr_ds()  # copies by default
    multiplier = ii * 0.05 + 0.75
    print("multiplier = ", multiplier)
    param_ds["K_coef"] *= multiplier
    param_file_name = (
        nb_output_dir / f"params/perturbed_params_xr_{str(ii).zfill(3)}.nc"
    )
    param_files += [param_file_name]
    param_ds.to_netcdf(param_file_name)

For the final `param_ds` still in memory, we can look at it... it has 144 variables, so you'll need to click the triangle to see the list. The little papert with bent corner icon provides metadata and the stacked disks give a python `repr`.

In [None]:
param_ds

Do a check that the values in the file divided by the original values reproduce the factors in order.

In [None]:
for ff in param_files:
    new_params = xr.open_dataset(
        ff, decode_times=False, decode_timedelta=False
    )
    k_coef = new_params["K_coef"]
    # new_params = pws.parameters.PrmsParameters.from_netcdf(ff)
    # k_coef = new_params.data_vars["K_coef"]
    multipliers = k_coef / params.data_vars["K_coef"]
    assert (multipliers - multipliers[0] < 1e-15).all()
    print(multipliers[0].values)

In [None]:
del param_files

Now to demonstrate the use of a `pywatershed.DatasetDict` which you can read about in the [documentation](https://pywatershed.readthedocs.io/en/main/api/generated/pywatershed.base.DatasetDict.html#pywatershed.base.DatasetDict). Note that the edited `DatasetDict` can be made a `Parameters` object again by `Parameters(**param_dict.data)`, but we'll just write directly to file and then load as a `Parameters` object. These are slightly different choices from above, show additional flexibility. We still choose to write the parameter ensemble to disk, however.


In [None]:
param_files = []
for ii in range(11):
    param_dict = params.to_dd()  # copies by default
    multiplier = ii * 0.05 + 0.75
    print("multiplier = ", multiplier)
    param_dict.data_vars["K_coef"] *= multiplier
    param_file_name = (
        nb_output_dir / f"params/perturbed_params_{str(ii).zfill(3)}.nc"
    )
    param_files += [param_file_name]
    param_dict.to_netcdf(
        param_file_name, use_xr=True
    )  # using xarray, more work necessary for nc4 export

Same check as above, but this time we read the NetCDF file into a `PrmsParameters` object rather than an `xarray.Dataset`.

In [None]:
for ff in param_files:
    # the problem arises on the read with xarray default decoding
    # but we can just open the netcdf file as Parameters
    # ds = xr.open_dataset(ff, decode_times=False, decode_timedelta=False)
    # k_coef = ds["K_coef"]
    new_params = pws.parameters.PrmsParameters.from_netcdf(ff)
    k_coef = new_params.data_vars["K_coef"]
    multipliers = k_coef / params.data_vars["K_coef"]
    assert (multipliers - multipliers[0] < 1e-15).all()
    print(multipliers[0])

## Run the parameter ensemble
We'll write a helper function for running the parameters through the model. Note comments on details around concurrent.futures.

In [None]:
def run_channel_model(output_dir_parent, param_file):
    # for concurrent.futures we have to write this function to file/module
    # so we have to import things that wont be in scope in that case.
    import numpy as np
    import pywatershed as pws

    domain_dir = pws.constants.__pywatershed_root__ / "data/drb_2yr"

    params = pws.parameters.PrmsParameters.from_netcdf(param_file)

    param_id = param_file.with_suffix("").name.split("_")[-1]
    nc_output_dir = output_dir_parent / f"run_params_{param_id}"
    nc_output_dir.mkdir(parents=True, exist_ok=True)

    control = pws.Control.load_prms(
        domain_dir / "nhm.control", warn_unused_options=False
    )
    control.edit_end_time(np.datetime64("1979-07-01T00:00:00"))
    control.options = control.options | {
        "input_dir": "01_multi-process_models/nhm_memory",
        "budget_type": "warn",
        "calc_method": "numba",
        "netcdf_output_dir": nc_output_dir,
    }

    model = pws.Model(
        [pws.PRMSChannel],
        control=control,
        parameters=params,
    )

    model.run(finalize=True)
    return nc_output_dir

### Serial run
We'll perform serial execution of the model over the parameter files.

In [None]:
%%time
serial_output_dirs = []
serial_output_parent = nb_output_dir / "serial"
if serial_output_parent.exists():
    shutil.rmtree(serial_output_parent)
serial_output_parent.mkdir()
for ff in param_files:
    serial_output_dirs += [run_channel_model(serial_output_parent, ff)]

In [None]:
serial_output_dirs

### concurrent.futures run
For [concurrent futures](https://docs.python.org/3/library/concurrent.futures.html) to work in an interactive setting, we have to import the iterated/mapped function from a module, the function can not be defined in the notebook/interactive setting. We can easily just write the function out to file (ensure above that everything is in scope when imported, as noted in the function).

In [None]:
# the name of the nb_output_dir can not be imported from so
# we'll create another directory to import from and delete it later
import inspect

import_dir = pl.Path("param_ensemble_tmp")
import_dir.mkdir(exist_ok=True)
with open(import_dir / "run_channel_model_mod.py", "w") as the_file:
    the_file.write(inspect.getsource(run_channel_model))

In [None]:
%%time
import time
from concurrent.futures import ProcessPoolExecutor as PoolExecutor
from concurrent.futures import as_completed
from functools import partial
from param_ensemble_tmp.run_channel_model_mod import run_channel_model

parallel_output_parent = pl.Path("parallel")
if parallel_output_parent.exists():
    shutil.rmtree(parallel_output_parent)
parallel_output_parent.mkdir()

# you can set your choice of max_workers
with PoolExecutor(max_workers=4) as executor:
    parallel_output_dirs = executor.map(
        partial(run_channel_model, parallel_output_parent), param_files
    )

In [None]:
shutil.rmtree(import_dir)

### Check serial and parallel 
See that these runs gave the same results.

In [None]:
# check serial == parallel
serial_runs = sorted(serial_output_parent.glob("*"))
parallel_runs = sorted(parallel_output_parent.glob("*"))

for ss, pp in zip(serial_runs, parallel_runs):
    serial_files = sorted(ss.glob("*.nc"))
    parallel_files = sorted(pp.glob("*.nc"))
    for sf, pf in zip(serial_files, parallel_files):
        s_ds = xr.open_dataset(sf)
        p_ds = xr.open_dataset(pf)
        xr.testing.assert_allclose(s_ds, p_ds)

Can also check that the original parameters give the same results as in notebook `01_multi-process_models.ipynb`.

In [None]:
run_005 = serial_output_parent / "run_params_005"
files_005 = sorted(run_005.glob("*.nc"))
for ff in files_005:
    if ff.name == "PRMSChannel_budget.nc":
        continue
    ds_005 = xr.open_dataset(ff)
    ds_01 = xr.open_dataset(
        pl.Path("01_multi-process_models/nhm_memory") / ff.name
    )
    xr.testing.assert_allclose(ds_005, ds_01)