In [None]:
# Import packages
import subprocess
from pathlib import Path

from hydroflows import Workflow
from hydroflows.log import setuplog
from hydroflows.methods.climate import (
    ClimateFactorsGridded,
    ClimateStatistics,
    DownscaleClimateDataset,
    MergeDatasets,
)
from hydroflows.methods.wflow import WflowConfig, WflowRun
from hydroflows.utils.example_data import fetch_data
from hydroflows.workflow.workflow_config import WorkflowConfig

logger = setuplog(level="INFO")

In [None]:
# Fetch the climate build data
cmip6_dir = fetch_data(data="cmip6-data")

In [None]:
# Set the parent directory
pwd = Path().resolve()

# Define variables
name = "climate_discharge"  # for now
model_dir = "models"
data_dir = "data"
input_dir = f"{data_dir}/input"
stats_dir = f"{input_dir}/stats"
change_dir = f"{input_dir}/change"
assemble_dir = f"{input_dir}/assemble"
output_dir = f"{data_dir}/output"
simu_dir = "simulations"

# Case directory
case_root=Path(pwd, "cases", name)

In [None]:
# Fetch a pre-build wflow-model
wflow_data_dir = fetch_data(
    data="wflow-model",
    output_dir=Path(pwd, "cases", name, model_dir, "wflow"),
    sub_dir=False,
)

# Make a relative path for the workflow
wflow_model_dir = wflow_data_dir.relative_to(Path(pwd, "cases", name))

In [None]:
# Setup the config file
conf = WorkflowConfig(
    region=Path(wflow_data_dir, "staticgeoms", "region.geojson"),
    data_libs=[Path(cmip6_dir, "data_catalog.yml")],
    cmip6_models=[
        "NOAA-GFDL_GFDL-ESM4",
        "INM_INM-CM5-0",
        "CSIRO-ARCCSS_ACCESS-CM2",
    ],
    cmip6_scenarios=["ssp245", "ssp585"],
    historical=[[2000, 2010]],
    future_horizons=[[2050, 2060], [2090, 2100]],
    plot_fig=True,
)

In [None]:
# Create a workflow
w = Workflow(config=conf, name=name, root=case_root)
# Set wildcards
w.wildcards.set("models", w.get_ref("$config.cmip6_models").value)
w.wildcards.set("scenarios", w.get_ref("$config.cmip6_scenarios").value)

In [None]:
# Derive climate data statistics
hist_stats = ClimateStatistics(
    region=w.get_ref("$config.region"),
    data_libs=w.get_ref("$config.data_libs"),
    model="{models}",
    horizon=w.get_ref("$config.historical"),
    data_root=stats_dir,
)
w.add_rule(hist_stats, rule_id="hist_stats")

fut_stats = ClimateStatistics(
    region=w.get_ref("$config.region"),
    data_libs=w.get_ref("$config.data_libs"),
    model="{models}",
    scenario="{scenarios}",
    horizon=w.get_ref("$config.future_horizons"),
    data_root=stats_dir,
)
w.add_rule(fut_stats, rule_id="fut_stats")

In [None]:
# Derive change factors from the statistics
change_factors = ClimateFactorsGridded(
    hist_stats.output.stats,
    fut_stats.output.stats,
    model="{models}",
    scenario="{scenarios}",
    horizon=w.get_ref("$config.future_horizons"),
    wildcard="horizons",
    data_root=change_dir,
)
w.add_rule(change_factors, rule_id="change_factors")

In [None]:
# Create a model ensemble of the change factors
ensemble = MergeDatasets(
    change_factors.output.change_factors,
    scenario="{scenarios}",
    horizon="{horizons}",
    data_root=Path(wflow_model_dir, simu_dir, "{scenarios}_{horizons}"),
)
w.add_rule(ensemble, rule_id="ensemble")

In [None]:
# Downscale the ensemble change factors to wflow model resolution
downscale = DownscaleClimateDataset(
    dataset=ensemble.output.merged,
    ds_like=wflow_data_dir / "staticmaps.nc",
    data_root=Path(wflow_model_dir, simu_dir, "{scenarios}_{horizons}"),
)
w.add_rule(downscale, rule_id="downscale")

In [None]:
# Prep the wflow config file
set_config = WflowConfig(
    wflow_toml=wflow_data_dir / simu_dir / "default" / "wflow_sbm.toml",
    ri_input__path_forcing_scale=downscale.output.downscaled,
    scenario="{scenarios}",
    horizon="{horizons}",
    endtime="2014-01-31T00:00:00",
    data_root=Path(wflow_model_dir, simu_dir, "{scenarios}_{horizons}"),
)
w.add_rule(set_config, rule_id="set_config")

# Run the wflow model
wflow_run = WflowRun(
    wflow_toml=set_config.output.wflow_out_toml,
    run_method="script",
    wflow_run_script="run_wflow_change_factors.jl",
)
w.add_rule(wflow_run, rule_id="wflow_run")

In [None]:
# Test the workflow
w.dryrun()

In [None]:
# Write the workflow to a Snakefile
w.to_snakemake()

# show the top 25 lines of the Snakefile
with open(w.root / "Snakefile", "r") as f:
    for _ in range(25):
        print(f.readline().strip('\n'))

In [None]:
from IPython.display import SVG

# (test) run the workflow with snakemake and visualize the directed acyclic graph
# make sure to have snakemake installed in your environment
subprocess.run('snakemake --dag | dot -Tsvg > dag.svg', cwd=w.root, shell=True).check_returncode()

# show the dag
SVG(Path(w.root, "dag.svg").as_posix())