In [None]:
import geopandas as gpd
import numpy as np
import pandas as pd
import time
import re
from joblib import Parallel, delayed, parallel_config, parallel_backend

In [None]:
import os 
import sys 
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from stream_cat_config import (
    LYR_DIR,
    MASK_DIR_RP100,
    MASK_DIR_SLP10,
    MASK_DIR_SLP20,
    ACCUM_DIR,
    NHD_DIR,
    OUT_DIR,
    PCT_FULL_FILE,
    PCT_FULL_FILE_RP100
)

from StreamCat_functions import (
    PointInPoly,
    createCatStats,
    mask_points
)

In [None]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=8, threads_per_worker=2)
client = Client(cluster)
client.dashboard_link

In [None]:
ctl = pd.read_csv('../ControlTable_StreamCat.csv')
inter_vpu = pd.read_csv("../InterVPU.csv")
INPUTS = np.load(ACCUM_DIR +"/vpu_inputs.npy", allow_pickle=True).item()

already_processed = []

In [None]:
print(INPUTS)
print(len(INPUTS))
print(len(ctl.loc[ctl['run'] == 1]))

In [None]:
# with parallel_config(backend="dask"):
#     Parallel(verbose=100)(
#         delayed(long_running_function)(i) for i in range(10)
#     )
# with parallel_config(backend="dask"):
    #     zone_results = Parallel(verbose=100)(
    #         delayed(process_zone) (zone, hydroregion) for zone, hydroregion in INPUTS.items()
    #     )
# with parallel_config(backend="dask", verbose=100):

In [None]:
for _, row in ctl.query("run == 1").iterrows():
    
    apm = "" if row.AppendMetric == "none" else row.AppendMetric
    if row.use_mask == 1:
        mask_dir = MASK_DIR_RP100
    elif row.use_mask == 2:
        mask_dir = MASK_DIR_SLP10
    elif row.use_mask == 3:
        mask_dir = MASK_DIR_SLP20
    else:
        mask_dir = ""
    landscape_layer_year = re.findall(r'\d+', row.LandscapeLayer)
    actual_layer_name = f"Annual_NLCD_LndCov_{landscape_layer_year[0]}_CU_C1V0.tif"
    layer = (
        actual_layer_name
        if "/" in row.LandscapeLayer or "\\" in row.LandscapeLayer
        else (f"{LYR_DIR}/{actual_layer_name}")
    )  # use abspath
    # print(layer)
    if isinstance(row.summaryfield, str):
        summary = row.summaryfield.split(";")
    else:
        summary = None
    if row.accum_type == "Point":
        # Load in point geopandas table and Pct_Full table
        # TODO: script to create this PCT_FULL_FILE
        pct_full = pd.read_csv(
            PCT_FULL_FILE if row.use_mask == 0 else PCT_FULL_FILE_RP100
        )
        points = gpd.read_file(layer)
        if mask_dir:
            points = mask_points(points, mask_dir, INPUTS)
    # File string to store InterVPUs needed for adjustments
    Connector = f"{OUT_DIR}/{row.FullTableName}_connectors.csv"
    print(
        f"Acquiring `{row.FullTableName}` catchment statistics...",
        end="",
        flush=True,
    )
    #for zone, hydroregion in INPUTS.items():
    def process_zone(zone, hydroregion, client):
        zone_start_time = time.time()
        if not os.path.exists(f"{OUT_DIR}/{row.FullTableName}_{zone}.csv"):
            print(zone, end=", ")
            pre = f"{NHD_DIR}/NHDPlus{hydroregion}/NHDPlus{zone}"
            if not row.accum_type == "Point":
                izd = (
                    f"{mask_dir}/{zone}.tif"
                    if mask_dir
                    else f"{pre}/NHDPlusCatchment/cat"
                )
                cat = createCatStats(
                    row.accum_type,
                    layer,
                    izd,
                    OUT_DIR,
                    zone,
                    row.by_RPU,
                    mask_dir,
                    NHD_DIR,
                    hydroregion,
                    apm,
                    use_dask=True,
                    dask_client=client
                )
            if row.accum_type == "Point":
                izd = f"{pre}/NHDPlusCatchment/Catchment.shp"
                cat = PointInPoly(
                    points, zone, izd, pct_full, mask_dir, apm, summary, use_dask=True
                )
            cat.to_csv(f"{OUT_DIR}/{row.FullTableName}_{zone}.csv", index=False)
            zone_end_time = time.time()
            print(f"Time to finish processing stats for zone {zone} / region {hydroregion} {(zone_end_time - zone_start_time) / 60} minutes")
    start_time = time.time()
    
    with parallel_backend('dask'):
        zone_results = Parallel(n_jobs=8, verbose=100)(
            delayed(process_zone) (zone, hydroregion, client) for zone, hydroregion in INPUTS.items()
        )
    end_time = time.time()
    print(f"Processed {len(INPUTS)} in {end_time - start_time} seconds with {os.cpu_count()} parallel processes")
    print("done!")

In [None]:
def process_zone(zone, hydroregion, client):
    if not os.path.exists(f"{OUT_DIR}/{row.FullTableName}_{zone}.csv"):
        print(zone, end=", ")
        pre = f"{NHD_DIR}/NHDPlus{hydroregion}/NHDPlus{zone}"
        if not row.accum_type == "Point":
            izd = (
                f"{mask_dir}/{zone}.tif"
                if mask_dir
                else f"{pre}/NHDPlusCatchment/cat"
            )
            cat = createCatStats(
                row.accum_type,
                layer,
                izd,
                OUT_DIR,
                zone,
                row.by_RPU,
                mask_dir,
                NHD_DIR,
                hydroregion,
                apm,
                use_dask=True,
                dask_client=client,
            )
        if row.accum_type == "Point":
            izd = f"{pre}/NHDPlusCatchment/Catchment.shp"
            cat = PointInPoly(
                points, zone, izd, pct_full, mask_dir, apm, summary, use_dask=True, dask_client=client
            )
        cat.to_csv(f"{OUT_DIR}/{row.FullTableName}_{zone}.csv", index=False)

# Step 3: Combine Joblib and Dask
for _, row in ctl.query("run == 1").iterrows():
    apm = "" if row.AppendMetric == "none" else row.AppendMetric
    if row.use_mask == 1:
        mask_dir = MASK_DIR_RP100
    elif row.use_mask == 2:
        mask_dir = MASK_DIR_SLP10
    elif row.use_mask == 3:
        mask_dir = MASK_DIR_SLP20
    else:
        mask_dir = ""
    landscape_layer_year = re.findall(r'\d+', row.LandscapeLayer)
    actual_layer_name = f"Annual_NLCD_LndCov_{landscape_layer_year[0]}_CU_C1V0.tif"
    layer = (
        actual_layer_name
        if "/" in row.LandscapeLayer or "\\" in row.LandscapeLayer
        else (f"{LYR_DIR}/{actual_layer_name}")
    )  # use abspath
    # print(layer)
    if isinstance(row.summaryfield, str):
        summary = row.summaryfield.split(";")
    else:
        summary = None
    if row.accum_type == "Point":
        pct_full = pd.read_csv(
            PCT_FULL_FILE if row.use_mask == 0 else PCT_FULL_FILE_RP100
        )
        points = gpd.read_file(layer)
        if mask_dir:
            points = mask_points(points, mask_dir, INPUTS)
    Connector = f"{OUT_DIR}/{row.FullTableName}_connectors.csv"
    print(f"Acquiring `{row.FullTableName}` catchment statistics...", end="", flush=True)
    
    with parallel_backend('dask'):
        start_time = time.time()
        zone_results = Parallel(n_jobs=8, verbose=100)(
            delayed(process_zone)(zone, hydroregion, client) for zone, hydroregion in INPUTS.items()
        )
        end_time = time.time()
        print(f"Processed {len(INPUTS)} zones in {end_time - start_time} seconds using Joblib + Dask")
        print("done!")