In [None]:
#!/usr/bin/env python
# coding: utf-8
# example usage: python compute_nhd_routing_SingleSeg.py -v -t -w -n Mainstems_CONUS
# python compute_nhd_routing_SingleSeg_v02.py --test -t -v --debuglevel 1
# python compute_nhd_routing_SingleSeg_v02.py --test-full-pocono -t -v --debuglevel 1


# -*- coding: utf-8 -*-
"""NHD Network traversal

A demonstration version of this code is stored in this Colaboratory notebook:
    https://colab.research.google.com/drive/1ocgg1JiOGBUl3jfSUPCEVnW5WNaqLKCD

"""
## Parallel execution
import os
import sys
import time
import numpy as np
import argparse
import pathlib
import glob
import pandas as pd
from collections import defaultdict
from functools import partial
from joblib import delayed, Parallel
from itertools import chain, islice
from operator import itemgetter


ENV_IS_CL = False
if ENV_IS_CL:
    root = pathlib.Path("/", "content", "t-route")
elif not ENV_IS_CL:
    root = pathlib.Path("../..").resolve()
    # sys.path.append(r"../python_framework_v02")

    # TODO: automate compile for the package scripts
    sys.path.append("fast_reach")

## network and reach utilities
import troute.nhd_network_utilities_v02 as nnu
import mc_reach
import troute.nhd_network as nhd_network
import troute.nhd_io as nhd_io
import build_tests  # TODO: Determine whether and how to incorporate this into setup.py


def writetoFile(file, writeString):
    file.write(writeString)
    file.write("\n")


def constant_qlats(index_dataset, nsteps, qlat):
    q = np.full((len(index_dataset.index), nsteps), qlat, dtype="float32")
    ql = pd.DataFrame(q, index=index_dataset.index, columns=range(nsteps))
    return ql


def compute_nhd_routing_v02(
    connections,
    rconn,
    reaches_bytw,
    compute_func,
    parallel_compute_method,
    subnetwork_target_size,
    cpu_pool,
    nts,
    qts_subdivisions,
    independent_networks,
    param_df,
    qlats,
    q0,
    assume_short_ts,
):

    start_time = time.time()
    if parallel_compute_method == "by-subnetwork-jit-clustered":
        networks_with_subnetworks_ordered_jit = nhd_network.build_subnetworks(
            connections, rconn, subnetwork_target_size
        )
        subnetworks_only_ordered_jit = defaultdict(dict)
        subnetworks = defaultdict(dict)
        for tw, ordered_network in networks_with_subnetworks_ordered_jit.items():
            intw = independent_networks[tw]
            for order, subnet_sets in ordered_network.items():
                subnetworks_only_ordered_jit[order].update(subnet_sets)
                for subn_tw, subnetwork in subnet_sets.items():
                    subnetworks[subn_tw] = {k: intw[k] for k in subnetwork}

        reaches_ordered_bysubntw = defaultdict(dict)
        for order, ordered_subn_dict in subnetworks_only_ordered_jit.items():
            for subn_tw, subnet in ordered_subn_dict.items():
                conn_subn = {k: connections[k] for k in subnet if k in connections}
                rconn_subn = {k: rconn[k] for k in subnet if k in rconn}
                path_func = partial(nhd_network.split_at_junction, rconn_subn)
                reaches_ordered_bysubntw[order][
                    subn_tw
                ] = nhd_network.dfs_decomposition(rconn_subn, path_func)

        cluster_threshold = 0.65  # When a job has a total segment count 65% of the target size, compute it
        # Otherwise, keep adding reaches.

        reaches_ordered_bysubntw_clustered = defaultdict(dict)

        for order in subnetworks_only_ordered_jit:
            cluster = 0
            reaches_ordered_bysubntw_clustered[order][cluster] = {
                "segs": [],
                "upstreams": {},
                "tw": [],
                "subn_reach_list": [],
            }
            for twi, (subn_tw, subn_reach_list) in enumerate(
                reaches_ordered_bysubntw[order].items(), 1
            ):
                segs = list(chain.from_iterable(subn_reach_list))
                reaches_ordered_bysubntw_clustered[order][cluster]["segs"].extend(segs)
                reaches_ordered_bysubntw_clustered[order][cluster]["upstreams"].update(
                    subnetworks[subn_tw]
                )

                reaches_ordered_bysubntw_clustered[order][cluster]["tw"].append(subn_tw)
                reaches_ordered_bysubntw_clustered[order][cluster][
                    "subn_reach_list"
                ].extend(subn_reach_list)

                if (
                    len(reaches_ordered_bysubntw_clustered[order][cluster]["segs"])
                    >= cluster_threshold * subnetwork_target_size
                ) and (
                    twi
                    < len(reaches_ordered_bysubntw[order])
                    # i.e., we haven't reached the end
                    # TODO: perhaps this should be a while condition...
                ):
                    cluster += 1
                    reaches_ordered_bysubntw_clustered[order][cluster] = {
                        "segs": [],
                        "upstreams": {},
                        "tw": [],
                        "subn_reach_list": [],
                    }

        if 1 == 1:
            print("JIT Preprocessing time %s seconds." % (time.time() - start_time))
            print("starting Parallel JIT calculation")

        start_para_time = time.time()
        # if 1 == 1:
        with Parallel(n_jobs=cpu_pool, backend="threading") as parallel:
            results_subn = defaultdict(list)
            flowveldepth_interorder = {}

            for order in range(max(subnetworks_only_ordered_jit.keys()), -1, -1):
                jobs = []
                for cluster, clustered_subns in reaches_ordered_bysubntw_clustered[
                    order
                ].items():
                    segs = clustered_subns["segs"]
                    offnetwork_upstreams = set()
                    segs_set = set(segs)
                    for seg in segs:
                        for us in rconn[seg]:
                            if us not in segs_set:
                                offnetwork_upstreams.add(us)

                    segs.extend(offnetwork_upstreams)
                    param_df_sub = param_df.loc[
                        segs, ["dt", "bw", "tw", "twcc", "dx", "n", "ncc", "cs", "s0"]
                    ].sort_index()
                    if order < max(subnetworks_only_ordered_jit.keys()):
                        for us_subn_tw in offnetwork_upstreams:
                            subn_tw_sortposition = param_df_sub.index.get_loc(
                                us_subn_tw
                            )
                            flowveldepth_interorder[us_subn_tw][
                                "position_index"
                            ] = subn_tw_sortposition
                    qlat_sub = qlats.loc[segs].sort_index()
                    q0_sub = q0.loc[segs].sort_index()
                    subn_reach_list = clustered_subns["subn_reach_list"]
                    upstreams = clustered_subns["upstreams"]

                    # results_subn[order].append(
                    #     compute_func(
                    jobs.append(
                        delayed(compute_func)(
                            nts,
                            qts_subdivisions,
                            subn_reach_list,
                            upstreams,
                            param_df_sub.index.values,
                            param_df_sub.columns.values,
                            param_df_sub.values,
                            qlat_sub.values,
                            q0_sub.values,
                            # flowveldepth_interorder,  # obtain keys and values from this dataset
                            {
                                us: fvd
                                for us, fvd in flowveldepth_interorder.items()
                                if us in offnetwork_upstreams
                            },
                            assume_short_ts,
                        )
                    )
                results_subn[order] = parallel(jobs)

                if order > 0:  # This is not needed for the last rank of subnetworks
                    flowveldepth_interorder = {}
                    for ci, (cluster, clustered_subns) in enumerate(
                        reaches_ordered_bysubntw_clustered[order].items()
                    ):
                        for subn_tw in clustered_subns["tw"]:
                            # TODO: This index step is necessary because we sort the segment index
                            # TODO: I think there are a number of ways we could remove the sorting step
                            #       -- the binary search could be replaced with an index based on the known topology
                            flowveldepth_interorder[subn_tw] = {}
                            subn_tw_sortposition = (
                                results_subn[order][ci][0].tolist().index(subn_tw)
                            )
                            flowveldepth_interorder[subn_tw]["results"] = results_subn[
                                order
                            ][ci][1][subn_tw_sortposition]
                            # what will it take to get just the tw FVD values into an array to pass to the next loop?
                            # There will be an empty array initialized at the top of the loop, then re-populated here.
                            # we don't have to bother with populating it after the last group

        results = []
        for order in subnetworks_only_ordered_jit:
            results.extend(results_subn[order])

        if 1 == 1:
            print("PARALLEL TIME %s seconds." % (time.time() - start_para_time))

    elif parallel_compute_method == "by-subnetwork-jit":
        networks_with_subnetworks_ordered_jit = nhd_network.build_subnetworks(
            connections, rconn, subnetwork_target_size
        )
        subnetworks_only_ordered_jit = defaultdict(dict)
        subnetworks = defaultdict(dict)
        for tw, ordered_network in networks_with_subnetworks_ordered_jit.items():
            intw = independent_networks[tw]
            for order, subnet_sets in ordered_network.items():
                subnetworks_only_ordered_jit[order].update(subnet_sets)
                for subn_tw, subnetwork in subnet_sets.items():
                    subnetworks[subn_tw] = {k: intw[k] for k in subnetwork}

        reaches_ordered_bysubntw = defaultdict(dict)
        for order, ordered_subn_dict in subnetworks_only_ordered_jit.items():
            for subn_tw, subnet in ordered_subn_dict.items():
                conn_subn = {k: connections[k] for k in subnet if k in connections}
                rconn_subn = {k: rconn[k] for k in subnet if k in rconn}
                path_func = partial(nhd_network.split_at_junction, rconn_subn)
                reaches_ordered_bysubntw[order][
                    subn_tw
                ] = nhd_network.dfs_decomposition(rconn_subn, path_func)

        if 1 == 1:
            print("JIT Preprocessing time %s seconds." % (time.time() - start_time))
            print("starting Parallel JIT calculation")

        start_para_time = time.time()
        with Parallel(n_jobs=cpu_pool, backend="threading") as parallel:
            results_subn = defaultdict(list)
            flowveldepth_interorder = {}

            for order in range(max(subnetworks_only_ordered_jit.keys()), -1, -1):
                jobs = []
                for twi, (subn_tw, subn_reach_list) in enumerate(
                    reaches_ordered_bysubntw[order].items(), 1
                ):
                    # TODO: Confirm that a list here is best -- we are sorting,
                    # so a set might be sufficient/better
                    segs = list(chain.from_iterable(subn_reach_list))
                    offnetwork_upstreams = set()
                    segs_set = set(segs)
                    for seg in segs:
                        for us in rconn[seg]:
                            if us not in segs_set:
                                offnetwork_upstreams.add(us)

                    segs.extend(offnetwork_upstreams)
                    param_df_sub = param_df.loc[
                        segs, ["dt", "bw", "tw", "twcc", "dx", "n", "ncc", "cs", "s0"]
                    ].sort_index()
                    if order < max(subnetworks_only_ordered_jit.keys()):
                        for us_subn_tw in offnetwork_upstreams:
                            subn_tw_sortposition = param_df_sub.index.get_loc(
                                us_subn_tw
                            )
                            flowveldepth_interorder[us_subn_tw][
                                "position_index"
                            ] = subn_tw_sortposition
                    qlat_sub = qlats.loc[segs].sort_index()
                    q0_sub = q0.loc[segs].sort_index()
                    jobs.append(
                        delayed(compute_func)(
                            nts,
                            qts_subdivisions,
                            subn_reach_list,
                            subnetworks[subn_tw],
                            param_df_sub.index.values,
                            param_df_sub.columns.values,
                            param_df_sub.values,
                            qlat_sub.values,
                            q0_sub.values,
                            # flowveldepth_interorder,  # obtain keys and values from this dataset
                            {
                                us: fvd
                                for us, fvd in flowveldepth_interorder.items()
                                if us in offnetwork_upstreams
                            },
                            assume_short_ts,
                        )
                    )

                results_subn[order] = parallel(jobs)

                if order > 0:  # This is not needed for the last rank of subnetworks
                    flowveldepth_interorder = {}
                    for twi, subn_tw in enumerate(reaches_ordered_bysubntw[order]):
                        # TODO: This index step is necessary because we sort the segment index
                        # TODO: I think there are a number of ways we could remove the sorting step
                        #       -- the binary search could be replaced with an index based on the known topology
                        flowveldepth_interorder[subn_tw] = {}
                        subn_tw_sortposition = (
                            results_subn[order][twi][0].tolist().index(subn_tw)
                        )
                        flowveldepth_interorder[subn_tw]["results"] = results_subn[
                            order
                        ][twi][1][subn_tw_sortposition]
                        # what will it take to get just the tw FVD values into an array to pass to the next loop?
                        # There will be an empty array initialized at the top of the loop, then re-populated here.
                        # we don't have to bother with populating it after the last group

        results = []
        for order in subnetworks_only_ordered_jit:
            results.extend(results_subn[order])

        if 1 == 1:
            print("PARALLEL TIME %s seconds." % (time.time() - start_para_time))

    elif parallel_compute_method == "by-network":
        with Parallel(n_jobs=cpu_pool, backend="threading") as parallel:
            jobs = []
            for twi, (tw, reach_list) in enumerate(reaches_bytw.items(), 1):
                segs = list(chain.from_iterable(reach_list))
                param_df_sub = param_df.loc[
                    segs, ["dt", "bw", "tw", "twcc", "dx", "n", "ncc", "cs", "s0"]
                ].sort_index()
                qlat_sub = qlats.loc[segs].sort_index()
                q0_sub = q0.loc[segs].sort_index()
                jobs.append(
                    delayed(compute_func)(
                        nts,
                        qts_subdivisions,
                        reach_list,
                        independent_networks[tw],
                        param_df_sub.index.values,
                        param_df_sub.columns.values,
                        param_df_sub.values,
                        qlat_sub.values,
                        q0_sub.values,
                        {},
                        assume_short_ts,
                    )
                )
            results = parallel(jobs)

    else:  # Execute in serial
        results = []
        for twi, (tw, reach_list) in enumerate(reaches_bytw.items(), 1):
            segs = list(chain.from_iterable(reach_list))
            param_df_sub = param_df.loc[
                segs, ["dt", "bw", "tw", "twcc", "dx", "n", "ncc", "cs", "s0"]
            ].sort_index()
            qlat_sub = qlats.loc[segs].sort_index()
            q0_sub = q0.loc[segs].sort_index()
            
            if tw==8777215:
                print(f"qts_subdivisions:{qts_subdivisions}")
                print(f"tw:{tw} reach_list{reach_list}")
                print(f"param_df_sub.index.values: {param_df_sub.index.values}")
                print(f"param_df_sub.columns.values: {param_df_sub.columns.values}")
                print(f"param_df_sub.values: {param_df_sub.values}")
                print(f"qlat_sub.values: {qlat_sub.values}")
                print()
                print(f"independent_networks[tw]: {independent_networks[tw]}")
                print(f"param_df: {param_df}")
                print(f"qlats: {qlats}")
                
                results.append(
                    compute_func(
                        nts,
                        qts_subdivisions,
                        reach_list,
                        independent_networks[tw],
                        param_df_sub.index.values,
                        param_df_sub.columns.values,
                        param_df_sub.values,
                        qlat_sub.values,
                        q0_sub.values,
                        {},
                        assume_short_ts,
                    )
                )

    return results


def _input_handler():

    #args = _handle_args()
    #custom_input_file = args.custom_input_file
    custom_input_file="../../test/input/yaml/CustomInput_florence_933020089_dt300.yaml"
    
    supernetwork_parameters = None
    waterbody_parameters = {}
    forcing_parameters = {}
    restart_parameters = {}
    output_parameters = {}
    run_parameters = {}
    parity_parameters = {}

    if custom_input_file:
        (
            supernetwork_parameters,
            waterbody_parameters,
            forcing_parameters,
            restart_parameters,
            output_parameters,
            run_parameters,
            parity_parameters,
        ) = nhd_io.read_custom_input(custom_input_file)
        run_parameters["debuglevel"] *= -1

    return (
        supernetwork_parameters,
        waterbody_parameters,
        forcing_parameters,
        restart_parameters,
        output_parameters,
        run_parameters,
        parity_parameters,
    )

In [10]:
if 1==1:
#def main():

    (
        supernetwork_parameters,
        waterbody_parameters,
        forcing_parameters,
        restart_parameters,
        output_parameters,
        run_parameters,
        parity_parameters,
    ) = _input_handler()

    dt = run_parameters.get("dt", None)
    nts = run_parameters.get("nts", None)
    verbose = run_parameters.get("verbose", None)
    showtiming = run_parameters.get("showtiming", None)
    debuglevel = run_parameters.get("debuglevel", 0)

    if verbose:
        print("creating supernetwork connections set")
    if showtiming:
        start_time = time.time()

    # STEP 1: Build basic network connections graph
    connections, wbodies, param_df = nnu.build_connections(supernetwork_parameters, dt)
    
    #print(f"connections: {connections}")
    #print()
    #print(f"param_df: {param_df}")
    
    if verbose:
        print("supernetwork connections set complete")
    if showtiming:
        print("... in %s seconds." % (time.time() - start_time))

    # STEP 2: Identify Independent Networks and Reaches by Network
    if showtiming:
        start_time = time.time()
    if verbose:
        print("organizing connections into reaches ...")

    independent_networks, reaches_bytw, rconn = nnu.organize_independent_networks(
        connections
    )

    if verbose:
        print("reach organization complete")
    if showtiming:
        print("... in %s seconds." % (time.time() - start_time))

creating supernetwork connections set
supernetwork connections set complete
... in 0.23984527587890625 seconds.
organizing connections into reaches ...
reach organization complete
... in 0.004957675933837891 seconds.


In [17]:
#rconn
len(rconn)
rconn.items()

dict_items([(8699069, [8699159]), (8699073, [8699079]), (8699075, [8699253]), (8699077, []), (8699161, [8699077]), (8699079, [8699083]), (8699081, [8699251]), (8699163, [8699081]), (8699083, [8699247]), (8699085, [8699165]), (8699159, []), (8699165, [8699161, 8699163]), (8699245, []), (8699247, []), (8699249, []), (8699251, []), (8699253, []), (8700889, []), (8701483, [8700889]), (8700891, [8701549]), (8700899, []), (8700901, [8700899, 8700997]), (8701475, [8700901]), (8700931, [8701551]), (8700935, [8701495]), (8700943, []), (8701487, [8700943]), (8700945, []), (8700951, [8701553]), (8700967, []), (8701365, [8700967, 8700973]), (8700973, [8700981, 8700985]), (8700981, [8700989, 8701371]), (8700983, []), (8700985, [8700991, 8701043]), (8700989, [8701033, 8701071]), (8700991, []), (8700997, [8701555]), (8701033, [8701105, 8701219]), (8701037, []), (8701045, [8701037]), (8701041, [8701557]), (8701371, [8701041, 8701063]), (8701043, [8701045, 8701051]), (8701051, [8701499]), (8701063, [87

In [15]:
njt=sum[len(r) for i, r in rconn.items() if len(r) >1]

SyntaxError: invalid syntax (<ipython-input-15-82681cfeae92>, line 1)

In [9]:
    # STEP 4: Handle Channel Initial States
    if showtiming:
        start_time = time.time()
    if verbose:
        print("setting channel initial states ...")

    q0 = nnu.build_channel_initial_state(restart_parameters, param_df.index)

    if verbose:
        print("channel initial states complete")
    if showtiming:
        print("... in %s seconds." % (time.time() - start_time))
        start_time = time.time()

    # STEP 5: Read (or set) QLateral Inputs
    if showtiming:
        start_time = time.time()
    if verbose:
        print("creating qlateral array ...")

    qlats = nnu.build_qlateral_array(forcing_parameters, connections.keys(), nts)

    if verbose:
        print("qlateral array complete")
    if showtiming:
        print("... in %s seconds." % (time.time() - start_time))

    ################### Main Execution Loop across ordered networks
    if showtiming:
        main_start_time = time.time()
    if verbose:
        print(f"executing routing computation ...")

    if run_parameters.get("compute_method", None) == "standard cython compute network":
        compute_func = mc_reach.compute_network
    else:
        compute_func = mc_reach.compute_network

    results = compute_nhd_routing_v02(
        connections,
        rconn,
        reaches_bytw,
        compute_func,
        run_parameters.get("parallel_compute_method", None),
        run_parameters.get("subnetwork_target_size", 1),
        # The default here might be the whole network or some percentage...
        run_parameters.get("cpu_pool", None),
        run_parameters.get("nts", 1),
        run_parameters.get("qts_subdivisions", 1),
        independent_networks,
        param_df,
        qlats,
        q0,
        run_parameters.get("assume_short_ts", False),
    )

    csv_output_folder = output_parameters.get("csv_output_folder", None)
    if (debuglevel <= -1) or csv_output_folder:
        qvd_columns = pd.MultiIndex.from_product(
            [range(nts), ["q", "v", "d"]]
        ).to_flat_index()
        flowveldepth = pd.concat(
            [pd.DataFrame(d, index=i, columns=qvd_columns) for i, d in results],
            copy=False,
        )

        if csv_output_folder:
            flowveldepth = flowveldepth.sort_index()
            output_path = pathlib.Path(csv_output_folder).resolve()
            flowveldepth.to_csv(output_path.joinpath(f"{args.supernetwork}.csv"))

        if debuglevel <= -1:
            print(flowveldepth)

    if verbose:
        print("ordered reach computation complete")
    if showtiming:
        print("... in %s seconds." % (time.time() - start_time))

    #if "parity_check_input_folder" in parity_parameters:
    if not "parity_check_input_folder" in parity_parameters:

        if verbose:
            print(
                "conducting parity check, comparing WRF Hydro results against t-route results"
            )

        build_tests.parity_check(
            parity_parameters, run_parameters["nts"], run_parameters["dt"], results,
        )


if __name__ == "__main__":
    main()

creating supernetwork connections set
supernetwork connections set complete
... in 0.07902908325195312 seconds.
organizing connections into reaches ...
reach organization complete
... in 0.004800081253051758 seconds.
setting channel initial states ...
channel initial states complete
... in 0.02571272850036621 seconds.
creating qlateral array ...
qlateral array complete
... in 0.5028712749481201 seconds.
executing routing computation ...
qts_subdivisions:12
tw:8777215 reach_list[[8778703, 8777129], [8777131], [8777143], [8778699, 8777145], [8777159, 8777215]]
param_df_sub.index.values: [8777129 8777131 8777143 8777145 8777159 8777215 8778699 8778703]
param_df_sub.columns.values: ['dt' 'bw' 'tw' 'twcc' 'dx' 'n' 'ncc' 'cs' 's0']
param_df_sub.values: [[3.0000000e+02 1.6090509e+00 2.6817515e+00 8.0452538e+00 6.5200000e+02
  5.9999999e-02 1.2000000e-01 7.5519782e-01 1.9280000e-02]
 [3.0000000e+02 1.5855994e+00 2.6426656e+00 7.9279966e+00 1.6530000e+03
  5.9999999e-02 1.2000000e-01 7.6010537e