### TEEHR with Spark and Iceberg

In [1]:
import os
import duckdb
import numpy as np
from pyspark.sql.functions import pandas_udf
import pandas as pd
from urllib.request import urlretrieve
import gc

In [2]:
from pyspark.sql import SparkSession
from pyspark import SparkConf

In [3]:
config = {
    "spark.kubernetes.authenticate.driver.serviceAccountName": "jupyter",
    "spark.kubernetes.namespace": "teehr-spark-default",
    "spark.kubernetes.container.image": os.environ["TEEHR_WORKER_IMAGE"],
    "spark.executor.extraJavaOptions=-Daws.region": "us-east-1",
    "spark.driver.extraJavaOptions=-Daws.region": "us-east-1",
    "spark.executor.instances": "6",
    "spark.executor.memory": "16g",
    "spark.executor.cores": "2",
    "spark.driver.blockManager.port": "7777",
    "spark.driver.port": "2222",
    "spark.driver.host": "jupyter.teehr-spark-default.svc.cluster.local",
    "spark.driver.bindAddress": "0.0.0.0",
    "spark.hadoop.fs.s3a.impl": "org.apache.hadoop.fs.s3a.S3AFileSystem",
    "spark.hadoop.fs.s3a.aws.credentials.provider": "org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider",
    "spark.sql.catalog.demo.s3.access-key-id": "minio",
    "spark.sql.catalog.demo.s3.secret-access-key": "password123",
    "spark.sql.parquet.enableVectorizedReader": "false",
    "spark.kubernetes.executor.node.selector.dedicated": "worker",
    "spark.kubernetes.executor.podTemplateFile": "/home/spark/pod-template.yaml",
}

def get_spark_session(app_name: str, conf: SparkConf):
    conf.setMaster("k8s://https://kubernetes.default.svc.cluster.local")
    for key, value in config.items():
        conf.set(key, value)    
    return SparkSession.builder.appName(app_name).config(conf=conf).getOrCreate()

In [4]:
spark = get_spark_session("teehr-workers", SparkConf())
# spark.sparkContext.getConf().getAll()

24/05/28 01:12:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [5]:
obs = spark.read.parquet("s3a://ciroh-rti-public-data/teehr/protocols/science-eval/timeseries/usgs*.parquet")

24/05/28 01:15:24 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
                                                                                

In [6]:
obs.count()

                                                                                

176063841

In [7]:
sim = spark.read.parquet("s3a://ciroh-rti-public-data/teehr/protocols/science-eval/timeseries/nwm2*.parquet")

In [8]:
sim.count()

                                                                                

399505152

In [9]:
xw = spark.read.parquet("s3a://ciroh-rti-public-data/teehr/common/geo/usgs_nwm2*_crosswalk.conus.parquet")

In [10]:
xw.count()

22835

In [11]:
obs.createTempView("obs_temp")

In [12]:
sim.createTempView("sim_temp")

In [13]:
xw.createTempView("xw_temp")

In [14]:
# Join from S3
sdf = spark.sql("""
SELECT
        sf.reference_time
        , sf.value_time as value_time
        , sf.location_id as secondary_location_id
        , pf.reference_time as reference_time
        , sf.value as secondary_value
        , sf.configuration
        , sf.measurement_unit
        , sf.variable_name
        , pf.value as primary_value
        , pf.location_id as primary_location_id
    FROM sim_temp sf
    JOIN xw_temp cf
        on cf.secondary_location_id = sf.location_id
    JOIN obs_temp pf
        on cf.primary_location_id = pf.location_id
        and sf.value_time = pf.value_time
        and sf.measurement_unit = pf.measurement_unit
        and sf.variable_name = pf.variable_name
""")

In [15]:
%%time
sdf.count()



CPU times: user 104 ms, sys: 39.9 ms, total: 144 ms
Wall time: 2min 52s


                                                                                

351877563

In [16]:
from pandas_udfs import * 
spark.udf.register("teehr_kling_gupta_efficiency", teehr_kling_gupta_efficiency)
spark.udf.register("teehr_root_mean_squared_error", teehr_root_mean_squared_error)
spark.udf.register("teehr_relative_bias", teehr_relative_bias)
spark.udf.register("teehr_r_squared", teehr_r_squared)

<pyspark.sql.udf.UserDefinedFunction at 0x7f324acabf40>

In [17]:
%%time
sdf = spark.sql("""
WITH joined as (
    SELECT
        sf.reference_time
        , sf.value_time as value_time
        , sf.location_id as secondary_location_id
        , pf.reference_time as reference_time
        , sf.value as secondary_value
        , sf.configuration
        , sf.measurement_unit
        , sf.variable_name
        , pf.value as primary_value
        , pf.location_id as primary_location_id
    FROM sim_temp sf
    JOIN xw_temp cf
        on cf.secondary_location_id = sf.location_id
    JOIN obs_temp pf
        on cf.primary_location_id = pf.location_id
        and sf.value_time = pf.value_time
        and sf.measurement_unit = pf.measurement_unit
        and sf.variable_name = pf.variable_name
)
, metrics AS (
    SELECT
        joined.primary_location_id
        , teehr_kling_gupta_efficiency(joined.primary_value, joined.secondary_value) as kling_gupta_efficiency
        , teehr_root_mean_squared_error(joined.primary_value, joined.secondary_value) as root_mean_squared_error
        , teehr_relative_bias(joined.primary_value, joined.secondary_value) as relative_bias
        , teehr_r_squared(joined.primary_value, joined.secondary_value) as r_squared
    FROM
        joined
    GROUP BY
        joined.primary_location_id
)
SELECT
    metrics.primary_location_id
    , kling_gupta_efficiency
    , root_mean_squared_error
    , relative_bias
    , r_squared
FROM metrics
ORDER BY
    metrics.primary_location_id
LIMIT 20;
""")
sdf.show()

                                                                                

+-------------------+----------------------+-----------------------+-------------+-----------+
|primary_location_id|kling_gupta_efficiency|root_mean_squared_error|relative_bias|  r_squared|
+-------------------+----------------------+-----------------------+-------------+-----------+
|      usgs-01010000|            0.42462006|               97.13046|   -0.2157014|  0.7913283|
|      usgs-01010070|            0.44942167|              14.033999|  -0.28686345|  0.6007537|
|      usgs-01010500|            0.27008584|              221.87845|   -0.3975725|  0.8074171|
|      usgs-01011000|             0.6334492|              57.011288|  -0.19466422|  0.7688364|
|      usgs-01013500|            0.54683834|              41.360138|  -0.23086037| 0.82660383|
|      usgs-01014000|            0.37036318|              354.70465|  -0.35600448| 0.86370957|
|      usgs-01015800|            0.72119784|              31.359137|  -0.14541015| 0.90401065|
|      usgs-01017000|            0.76449686|      

In [19]:
joined = spark.read.parquet("s3a://ciroh-rti-public-data/teehr/protocols/science-eval/timeseries/joined*.parquet")

In [20]:
joined.createTempView("joined_temp")

In [21]:
%%time
sdf = spark.sql("""
WITH joined as (
    SELECT
        *
    FROM joined_temp jt
)
, metrics AS (
    SELECT
        joined.primary_location_id
        , teehr_kling_gupta_efficiency(joined.primary_value, joined.secondary_value) as kling_gupta_efficiency
        , teehr_root_mean_squared_error(joined.primary_value, joined.secondary_value) as root_mean_squared_error
        , teehr_relative_bias(joined.primary_value, joined.secondary_value) as relative_bias
        , teehr_r_squared(joined.primary_value, joined.secondary_value) as r_squared
    FROM
        joined
    GROUP BY
        joined.primary_location_id
)
SELECT
    metrics.primary_location_id
    , kling_gupta_efficiency
    , root_mean_squared_error
    , relative_bias
    , r_squared
FROM metrics
ORDER BY
    metrics.primary_location_id
LIMIT 20;
""")
sdf.show()



+-------------------+----------------------+-----------------------+-------------+-----------+
|primary_location_id|kling_gupta_efficiency|root_mean_squared_error|relative_bias|  r_squared|
+-------------------+----------------------+-----------------------+-------------+-----------+
|      usgs-01010000|            0.42462006|              97.130455|  -0.21570143|  0.7913283|
|      usgs-01010070|            0.44942167|              14.033999|  -0.28686342|  0.6007537|
|      usgs-01010500|            0.27008584|              221.87843|   -0.3975725|  0.8074171|
|      usgs-01011000|             0.6334492|              57.011288|  -0.19466424|  0.7688364|
|      usgs-01013500|             0.5468384|              41.360138|  -0.23086038| 0.82660383|
|      usgs-01014000|            0.37036318|              354.70468|   -0.3560045| 0.86370957|
|      usgs-01015800|             0.7211978|              31.359137|  -0.14541014| 0.90401065|
|      usgs-01017000|             0.7644969|      

                                                                                

In [22]:
spark.stop()

24/05/28 01:24:57 WARN ExecutorPodsWatchSnapshotSource: Kubernetes client has been closed.
