In [None]:
# Import packages
import subprocess
from pathlib import Path

from hydroflows import Workflow
from hydroflows.log import setuplog
from hydroflows.methods import climate, wflow
from hydroflows.utils.example_data import fetch_data
from hydroflows.workflow.workflow_config import WorkflowConfig

logger = setuplog(level="INFO")

In [None]:
# Fetch the climate build data
cmip6_dir = fetch_data(data="cmip6-data")

In [None]:
# Set the parent directory
pwd = Path().resolve()

model_dir = "models/wflow"    # wflow model directory (input)
clim_dir = "data/climatology" # climatology data (intermediate results)

# Case directory
name = "climate_discharge"  # for now
case_root=Path(pwd, "cases", name)

In [None]:
# Fetch a pre-build wflow-model
fetch_data(
    data="wflow-model",
    output_dir=Path(case_root, model_dir),
    sub_dir=False
)


In [None]:
# Setup the config file

config = WorkflowConfig(
    region=Path(model_dir, "staticgeoms", "region.geojson"),
    catalog_path=Path(cmip6_dir, "data_catalog.yml"),
    cmip6_models=[
        "NOAA-GFDL_GFDL-ESM4",
        "INM_INM-CM5-0",
        "CSIRO-ARCCSS_ACCESS-CM2",
    ],
    cmip6_scenarios=["ssp245", "ssp585"],
    historical=[[2000, 2010]],
    future_horizons=[[2050, 2060], [2090, 2100]],
    plot_fig=True,
    clim_dir = clim_dir,
)

In [None]:
# Create a workflow
w = Workflow(config=config, name=name, root=case_root)
# Set wildcards
w.wildcards.set("clim_models", config.cmip6_models)
w.wildcards.set("clim_scenarios", config.cmip6_scenarios)

In [None]:
# Derive climate data statistics
hist_climatology = climate.MonthlyClimatolgy(
    region=w.get_ref("$config.region"),
    catalog_path=w.get_ref("$config.catalog_path"),
    model="{clim_models}",
    scenario="historical",
    horizon=w.get_ref("$config.historical"),
    output_dir=w.get_ref("$config.clim_dir"),
)
w.add_rule(hist_climatology, rule_id="hist_climatology")

future_climatology = climate.MonthlyClimatolgy(
    region=w.get_ref("$config.region"),
    catalog_path=w.get_ref("$config.catalog_path"),
    model="{clim_models}",
    scenario="{clim_scenarios}",
    horizon=w.get_ref("$config.future_horizons"),
    output_dir=w.get_ref("$config.clim_dir"),
)
w.add_rule(future_climatology, rule_id="future_climatology")

In [None]:
# Derive change factors from the statistics
change_factors = climate.ClimateChangeFactors(
    hist_climatology=hist_climatology.output.climatology,
    future_climatology=future_climatology.output.climatology,
    model="{clim_models}",
    scenario="{clim_scenarios}",
    horizon=w.get_ref("$config.future_horizons"),
    wildcard="horizons",
    output_dir=w.get_ref("$config.clim_dir"),
)
w.add_rule(change_factors, rule_id="change_factors")

In [None]:
# Create a model ensemble of the change factors
change_factors_median = climate.MergeGriddedDatasets( 
    datasets=change_factors.output.change_factors,
    reduce_dim="model",
    quantile=0.5,
    output_name="change_{clim_scenarios}_{horizons}_q50.nc",
    output_dir=w.get_ref("$config.clim_dir"),
)
w.add_rule(change_factors_median, rule_id="change_factors_median")

In [None]:
# Downscale the ensemble change factors to wflow model resolution
downscale = climate.DownscaleClimateDataset(
    dataset=change_factors_median.output.merged_dataset,
    target_grid= Path(model_dir, "staticmaps.nc"),
    output_dir=Path(model_dir, "simulations", "{clim_scenarios}_{horizons}"),
)
w.add_rule(downscale, rule_id="downscale")

In [None]:
# Prep the wflow config file
set_config = wflow.WflowConfig(
    wflow_toml= Path(model_dir, "simulations", "default", "wflow_sbm.toml"),
    ri_input__path_forcing_scale=downscale.output.downscaled,
    # scenario="{clim_scenarios}",
    # horizon="{horizons}",
    endtime="2014-01-31T00:00:00",
    output_dir=Path(model_dir, "simulations", "{clim_scenarios}_{horizons}"),
)
w.add_rule(set_config, rule_id="set_config")

# Run the wflow model
wflow_run = wflow.WflowRun(
    wflow_toml=set_config.output.wflow_out_toml,
    run_method="script",
    wflow_run_script="run_wflow_change_factors.jl",
)
w.add_rule(wflow_run, rule_id="wflow_run")

In [None]:
# Test the workflow
w.dryrun()

In [None]:
# Write the workflow to a Snakefile
w.to_snakemake()

# show the top 25 lines of the Snakefile
with open(w.root / "Snakefile", "r") as f:
    for _ in range(25):
        print(f.readline().strip('\n'))

In [None]:
from IPython.display import SVG

# (test) run the workflow with snakemake and visualize the directed acyclic graph
# make sure to have snakemake installed in your environment
subprocess.run('snakemake --dag | dot -Tsvg > dag.svg', cwd=w.root, shell=True).check_returncode()

# show the dag
SVG(Path(w.root, "dag.svg").as_posix())