In [None]:
from pathlib import Path

import duckdb
import xarray as xr
from dask.distributed import Client, LocalCluster
from dask_gateway import Gateway
import dask.dataframe as dd
import dask
import numpy as np
import pandas as pd
from numba import jit, njit

from utils.const import LOCAL_JOINED_FILEPATH, LOCAL_DATABASE_FILEPATH, S3_JOINED_FILEPATH, S3_DATABASE_FILEPATH
from utils.numpy_metric_funcs import r_squared, relative_bias, kling_gupta_efficiency, root_mean_squared_error

In [None]:
# This should ultimately include measurement_unit and variable_name?
query = f"""
    SELECT
        DISTINCT primary_location_id, configuration
    FROM
        read_parquet('{str(LOCAL_JOINED_FILEPATH)}')
    GROUP BY primary_location_id, configuration
    ORDER BY primary_location_id
;"""
groups_df = duckdb.sql(query).to_df()
groups_df

**Threaded DuckDB Tests**

In [None]:
import duckdb
from threading import Thread, current_thread
import random

duckdb_con = duckdb.connect(LOCAL_DATABASE_FILEPATH)

In [None]:
def read_from_thread(duckdb_con, arr):
    # Create a DuckDB connection specifically for this thread
    local_con = duckdb_con.cursor()
    
    thread_name = str(current_thread().name)

    primary_location_id = arr[0]
    configuration = arr[1]

    group_timeseries_query = f"""
    SELECT
        ? AS thread_name,
        primary_value,
        secondary_value
    FROM
        joined_timeseries
    WHERE primary_location_id = '{str(primary_location_id)}' AND configuration = '{str(configuration)}'
    ;"""    

    results_dict = local_con.execute(group_timeseries_query, (thread_name,)).fetchnumpy()

    results.append(results_dict)

In [None]:
threads = []  
results = []

# This will create as many threads as there are groups
for i, tpl in enumerate(groups_df.itertuples()):

    threads.append(Thread(target = read_from_thread,
                          args = (duckdb_con, tpl, results),
                          name = "read_thread_" + str(i)))

    if i > 100:
        break

In [None]:
# Kick off all threads in parallel
for thread in threads:
    thread.start()

**Dask-Delayed + DuckDB Tests**

In [None]:
# CON = duckdb.connect(LOCAL_DATABASE_FILEPATH, read_only=True)
# CON.query("SHOW ALL TABLES;")
# CON = duckdb.connect()

In [None]:
# LOCAL_JOINED_FILEPATH

In [None]:
@dask.delayed
def fetch_group_and_calculate_metrics(connection, primary_location_id: str, configuration: str) -> pd.DataFrame:
    """Fetch data for a single group and calculate metrics using numpy"""
    
    group_timeseries_query = f"""
    SELECT
        primary_value,
        secondary_value
    FROM
        read_parquet('{str(LOCAL_JOINED_FILEPATH)}')
    WHERE primary_location_id = '{str(primary_location_id)}' AND configuration = '{str(configuration)}'
    ;"""

    local_con = connection.cursor()
    value_arr_dict = local_con.sql(group_timeseries_query).fetchnumpy()

    # DuckDB returns masked arrays only if it contains NaN? (seems like docs say otherwise)
    if np.ma.is_masked(value_arr_dict["primary_value"]) | np.ma.is_masked(value_arr_dict["secondary_value"]):
        print("masked")
        output = {"primary_location_id": primary_location_id,
                  "configuration": configuration,
                  "r_squared": np.nan,
                  "relative_bias": np.nan,
                  "kling_gupta_efficiency": np.nan,
                  "root_mean_squared_error": np.nan
                 }        
        
        return output

    r_squared_val = r_squared(value_arr_dict["primary_value"], value_arr_dict["secondary_value"])
    relative_bias_val = relative_bias(value_arr_dict["primary_value"], value_arr_dict["secondary_value"])
    kge_val = kling_gupta_efficiency(value_arr_dict["primary_value"], value_arr_dict["secondary_value"])
    rmse_val = root_mean_squared_error(value_arr_dict["primary_value"], value_arr_dict["secondary_value"])

    output = {"primary_location_id": primary_location_id,
              "configuration": configuration,
              "r_squared": r_squared_val,
              "relative_bias": relative_bias_val,
              "kling_gupta_efficiency": kge_val,
              "root_mean_squared_error": rmse_val
             }
    
    return output

In [None]:
fetch_group_and_calculate_metrics("usgs-01010000", "nwm20_retrospective")

In [None]:
%%time
# NO DASK

results = []
for i, tpl in enumerate(groups_df.itertuples()):
    results.append(fetch_group_and_calculate_metrics(tpl.primary_location_id, tpl.configuration))

    if i == 30:
        break

In [None]:
pd.DataFrame(results)

In [None]:
cluster.close()

In [None]:
cluster = LocalCluster()
client = Client(cluster)
client

In [None]:
cluster.close()
client.close()

In [None]:
gateway = Gateway()

options = gateway.cluster_options()
options.worker_cores = 1
options.worker_memory = 4
# options  # should show interactive widget to select cores, etc

cluster = gateway.new_cluster(options)
client = cluster.get_client()
client

In [None]:
cluster.scale(8)

In [None]:
# con = duckdb.connect(LOCAL_DATABASE_FILEPATH, read_only=True)
# con.sql("SET memory_limit = '20GB';")
# con.sql("SET max_memory = '20GB';")
con = duckdb.connect()

In [None]:
%%time
# WITH DASK DELAYED

results = []
for i, tpl in enumerate(groups_df.itertuples()):
    results.append(fetch_group_and_calculate_metrics(con, tpl.primary_location_id, tpl.configuration))

    # if i == 10:
    #     break

output = dask.compute(results)
results_df = pd.DataFrame(output[0])

In [None]:
results_df

In [None]:
results_df.to_parquet("/data/benchmarks/teehr-benchmark-202404/results/dask_duckdb_local_joined_results.parquet")