### TEEHR with Spark and Iceberg

In [None]:
import os
import duckdb
import numpy as np
from pyspark.sql.functions import pandas_udf
import pandas as pd
from urllib.request import urlretrieve
import gc

In [None]:
from pyspark.sql import SparkSession
from pyspark import SparkConf

In [None]:
config = {
    "spark.kubernetes.authenticate.driver.serviceAccountName": "jupyter",
    "spark.kubernetes.namespace": "teehr-spark-default",
    "spark.kubernetes.container.image": os.environ["TEEHR_WORKER_IMAGE"],
    "spark.executor.extraJavaOptions=-Daws.region": "us-east-1",
    "spark.driver.extraJavaOptions=-Daws.region": "us-east-1",
    "spark.executor.instances": "6",
    "spark.executor.memory": "16g",
    "spark.executor.cores": "2",
    "spark.driver.blockManager.port": "7777",
    "spark.driver.port": "2222",
    "spark.driver.host": "jupyter.teehr-spark-default.svc.cluster.local",
    "spark.driver.bindAddress": "0.0.0.0",
    "spark.hadoop.fs.s3a.impl": "org.apache.hadoop.fs.s3a.S3AFileSystem",
    "spark.hadoop.fs.s3a.aws.credentials.provider": "org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider",
    "spark.sql.catalog.demo.s3.access-key-id": "minio",
    "spark.sql.catalog.demo.s3.secret-access-key": "password123",
    "spark.sql.parquet.enableVectorizedReader": "false",
    "spark.kubernetes.executor.node.selector.dedicated": "worker",
    "spark.kubernetes.executor.podTemplateFile": "/home/spark/pod-template.yaml",
}

def get_spark_session(app_name: str, conf: SparkConf):
    conf.setMaster("k8s://https://kubernetes.default.svc.cluster.local")
    for key, value in config.items():
        conf.set(key, value)    
    return SparkSession.builder.appName(app_name).config(conf=conf).getOrCreate()

In [None]:
spark = get_spark_session("teehr-workers", SparkConf())
# spark.sparkContext.getConf().getAll()

In [None]:
DATABASE = "streamflow"
PRIMARY_TABLE = "primary"
SECONDARY_TABLE = "secondary"
CROSSWALK_TABLE = "crosswalk"
JOINED_TABLE = "joined"

In [None]:
%%time
spark.sql(f"SELECT value_time, value FROM {DATABASE}.{PRIMARY_TABLE} WHERE location_id = 'usgs-01116905' ORDER BY value_time;").show()

In [None]:
from pandas_udfs import * 
spark.udf.register("teehr_kling_gupta_efficiency", teehr_kling_gupta_efficiency)
spark.udf.register("teehr_root_mean_squared_error", teehr_root_mean_squared_error)
spark.udf.register("teehr_relative_bias", teehr_relative_bias)
spark.udf.register("teehr_r_squared", teehr_r_squared)

In [None]:
%%time
sdf = spark.sql(f"""
WITH metrics AS (
    SELECT
        joined.primary_location_id
        , teehr_kling_gupta_efficiency(joined.primary_value, joined.secondary_value) as kling_gupta_efficiency
        , teehr_root_mean_squared_error(joined.primary_value, joined.secondary_value) as root_mean_squared_error
        , teehr_relative_bias(joined.primary_value, joined.secondary_value) as relative_bias
        , teehr_r_squared(joined.primary_value, joined.secondary_value) as r_squared
    FROM
        {DATABASE}.{JOINED_TABLE} as joined
    GROUP BY
        joined.primary_location_id
)
SELECT
    metrics.primary_location_id
    , kling_gupta_efficiency
    , root_mean_squared_error
    , relative_bias
    , r_squared
FROM metrics
ORDER BY
    metrics.primary_location_id;
""")
sdf.show()

In [None]:
%%time
sdf = spark.sql(f"""
WITH metrics AS (
    SELECT
        joined.primary_location_id
        , teehr_kling_gupta_efficiency(joined.primary_value, joined.secondary_value) as kling_gupta_efficiency
        , teehr_root_mean_squared_error(joined.primary_value, joined.secondary_value) as root_mean_squared_error
        , teehr_relative_bias(joined.primary_value, joined.secondary_value) as relative_bias
        , teehr_r_squared(joined.primary_value, joined.secondary_value) as r_squared
    FROM
        {DATABASE}.{JOINED_TABLE} as joined
    GROUP BY
        joined.primary_location_id
)
SELECT
    metrics.primary_location_id
    , kling_gupta_efficiency
    , root_mean_squared_error
    , relative_bias
    , r_squared
FROM metrics
WHERE
    primary_location_id='usgs-01021480'
ORDER BY
    metrics.primary_location_id;
""")
sdf.show()

In [None]:
spark.stop()