#### This notebook calculates mean areal precip (MAP) for the 40-year NWM v3.0 retrospective for USGS high-res basin polygons (CONUS only)
- Requires the weights (fractional pixel coverage) to be pre-calculated (see `05_calculate_nwm30_pixel_weights.ipynb`)
- Reads NWM v3.0 forcing (RAINRATE) netcdf files from s3 using Sedona: https://noaa-nwm-retrospective-3-0-pds.s3.amazonaws.com/index.html#CONUS/
- Writes output to local TEEHR Evaluation (warehouse)
- Some additional maintenance and exploratory code is included below
- PROCESSING NOTES:
    - Still working on defining the optimal balance between number of cores, executors, and memory.
    - Still seeing occasional pods die with OOM errors.
    - Processing chunks of timesteps in a loop may be resulting in memory build up, although the `del` and garbage collecting seems to have helped.
    - Definitely more to understand here in general.

In [None]:
import os
from pathlib import Path
import logging
import shutil
import time
import gc
import glob
import re

from pyspark.sql import functions as F
from pyspark.sql.types import TimestampType
import holoviews as hv
import hvplot.pandas
import xarray as xr
import fsspec
import rioxarray
import rasterio
import geopandas as gpd
import numpy as np
import pandas as pd

import teehr
from teehr.evaluation.spark_session_utils import create_spark_session

LINE_PLOT_HEIGHT = 300
LINE_PLOT_WIDTH = 600

# Set global defaults for all line plots
hv.opts.defaults(
    hv.opts.Curve(
        bgcolor="#e7e9ecb8",
        show_grid=True,
        gridstyle={'grid_line_alpha': 0.5, 'grid_line_color': 'white'},
        frame_width=LINE_PLOT_WIDTH,
        frame_height=LINE_PLOT_HEIGHT
    )
)

logger = logging.getLogger(__name__)

teehr.__version__

Configure the logger:

In [None]:
logger.setLevel(logging.DEBUG)
# Create a file handler, set its level, and define its format.
file_handler = logging.FileHandler('sedona_map_processor_logger.log')
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
# Add the file handler to the logger.
logger.addHandler(file_handler)

In [None]:
# spark.stop()

In [None]:
%%time
# # ~7 pods/node --> Resulted in OOMKilled msgs in the driver pod logs (140 / 7 = 20 nodes)
# NUM_EXECUTORS = 140
# NUM_CORES = 2
# EXECUTOR_MEMORY = "15g"

# # ~4 pods/node
# NUM_EXECUTORS = 80
# NUM_CORES = 3
# EXECUTOR_MEMORY = "26g"

# ~3 pods/node
NUM_EXECUTORS = 60
NUM_CORES = 4
EXECUTOR_MEMORY = "35g"

NUM_SHUFFLE_PARTITIONS = NUM_EXECUTORS * NUM_CORES * 2

spark = create_spark_session(
    start_spark_cluster=True,
    executor_instances=NUM_EXECUTORS,
    executor_memory=EXECUTOR_MEMORY,
    executor_cores=NUM_CORES,
    aws_region="us-east-1",
    update_configs={
        "spark.hadoop.fs.s3a.aws.credentials.provider":
        "org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider",
        "spark.sql.shuffle.partitions": f"{NUM_SHUFFLE_PARTITIONS}",
        "spark.kubernetes.executor.node.selector.teehr-hub/nodegroup-name": "spark-r5-4xlarge-spot",
        "spark.decommission.enabled": "true",
        "spark.executor.decommission.signal": "SIGTERM",
        "spark.storage.decommission.enabled": "true",
        # "spark.storage.decommission.rddBlocks.enabled": "true",  # default is true
        # "spark.storage.decommission.shuffleBlocks.enabled": "true",  # default is true
        # "spark.storage.decommission.fallbackStorage.path": "s3a://ciroh-rti-public-data/spark-fallback-storage/",
        # "spark.kubernetes.driver.ownPersistentVolumeClaim": "true",
        # "spark.kubernetes.driver.reusePersistentVolumeClaim": "true"
    }    
)

dir_path = "/data/playground/slamont/teehr/warehouse/sedona/usgs_basins_map"

# USE EXISTING:
ev = teehr.Evaluation(
    spark=spark,
    dir_path=dir_path,
    create_dir=False
)

In [None]:
DATE_PATTERN = r"(\d{12})"
CONFIGURATION_NAME = "nwm30_retrospective"
UNIT_NAME = "mm/s"
VARIABLE_NAME = "rainfall_hourly_rate"

# Create filepath generator
BASE_S3_PATH = "s3a://noaa-nwm-retrospective-3-0-pds/CONUS/netcdf/FORCING/"

START_DATE = "2013-05-24 17:00:00"
END_DATE = "2023-01-31 23:00"   # Last file: 202301312300.LDASIN_DOMAIN1

CHUNK_SIZE = 210  # Number of timesteps (hours) processed at once (1 year test used 200, but can go higher)  INCREASE for more efficiency/less IO but more memory

In [None]:
start_date = pd.Timestamp(START_DATE)
end_date = pd.Timestamp(END_DATE)
dt_rng = pd.date_range(start=START_DATE, end=END_DATE, freq="h")

NUM_SPLITS = int(len(dt_rng) / CHUNK_SIZE)
full_filepaths = [f"{BASE_S3_PATH}{dt.year}/{dt.year}{dt.month:02d}{dt.day:02d}{dt.hour:02d}00.LDASIN_DOMAIN1" for dt in dt_rng]

split_full_filepaths = np.array_split(full_filepaths, NUM_SPLITS)

len(split_full_filepaths)

In [None]:
%%time
# Create a view of fractional coverage table
spark.sql("""
    CREATE OR REPLACE TEMPORARY VIEW fractions_view AS
    SELECT fraction_covered, id AS location_id, pos AS position_index FROM local.teehr.nwm30_usgs_hires_basins_fractional_coverage
""")
spark.sql("CACHE TABLE fractions_view")

# Try this for lower memory requirements but higher cpu requirements:
# fractions_df = spark.table("fractions_view")
# fractions_df.persist(StorageLevel.MEMORY_AND_DISK_2)  # <-- Adding 2 replicates data across 2 nodes for fault tolerance
# fractions_df.count()  # Materialize

In [None]:
%%time
table_name = "primary_timeseries"

cntr = 0
for i, split_filepaths_chunk in enumerate(split_full_filepaths):

    t0 = time.time()

    filepaths = [str(fp) for fp in split_filepaths_chunk]

    nc_sdf = spark.read.format("binaryFile").load(filepaths).selectExpr("RS_FromNetCDF(content, 'RAINRATE', 'x', 'y') as raster", "path as filepath") 
    nc_sdf = nc_sdf.withColumn("value_time", F.regexp_extract(nc_sdf["filepath"], DATE_PATTERN, 1))  # partition by time?
    
    # Explode the raster values
    raster_exp_sdf = nc_sdf.selectExpr(
        "posexplode(RS_BandAsArray(raster, 1))",
        "value_time",
    ).selectExpr(
        "value_time as value_time",
        "col as value",
        "CAST(pos as BIGINT) as position_index"
    )
    raster_exp_sdf.createOrReplaceTempView("raster_values")
    
    # Calculate MAP
    map_results = spark.sql(f"""
        SELECT /*+ BROADCAST(w) */
            w.location_id,
            to_timestamp(r.value_time, 'yyyyMMddHHmm') AS value_time,
            SUM(r.value * w.fraction_covered) / SUM(w.fraction_covered) AS value,
            CAST(NULL AS TIMESTAMP) AS reference_time,
            '{UNIT_NAME}' AS unit_name,
            '{VARIABLE_NAME}' AS variable_name,
            '{CONFIGURATION_NAME}' AS configuration_name
        FROM 
            raster_values AS r
        JOIN 
             fractions_view AS w ON r.position_index = w.position_index
        GROUP BY 
            w.location_id, r.value_time;
    """)
    
    # Write to table
    ev.write.to_warehouse(
        source_data=map_results,
        table_name=table_name,  # Note. 1-year run stored in "temp_secondary_timeseries" table
        write_mode="append"
    )

    spark.catalog.dropTempView("raster_values")

    del raster_exp_sdf
    gc.collect()
    logger.info(f"Processed chunk {i}/{len(split_full_filepaths)} in {(time.time() - t0) / 60:.2f} mins")


    # # Rewrite data files to fix the small file problem every 100 days or so
    # cntr += 1
    # days_processed = (CHUNK_SIZE * cntr) / 24
    # if days_processed > 200:
    #     t1 = time.time()
    #     ev.spark.sql(f"""
    #         CALL local.system.rewrite_data_files(
    #             table => 'teehr.{table_name}',
    #             options => map('target-file-size-bytes', '134217728') -- 128 MB
    #         )
    #     """)
    #     cntr = 0
    #     print(f"Rewrote data files in {(time.time() - t1) / 60:.2f} mins")

    # break

### One time requirement: Let's load the first year we've calculated locally into the warehouse

In [None]:
dir_path = "/data/playground/slamont/teehr/warehouse/sedona/usgs_basins_map"

spark = create_spark_session()

# USE EXISTING:
ev = teehr.Evaluation(
    spark=spark,
    dir_path=dir_path,
    create_dir=False
)

In [None]:
sdf = ev.table(table_name="temp_secondary_timeseries").to_sdf()

In [None]:
sdf.show(3)

In [None]:
primary_sdf = sdf.drop("member")
primary_sdf.show(4)

In [None]:
primary_sdf.select(F.max("value_time")).show()

In [None]:
value_times = [row['value_time'] for row in primary_sdf.select('value_time').collect()]

In [None]:
# primary_sdf = primary_sdf.withColumn("reference_time", F.to_timestamp(F.col("reference_time"), "yyyy-MM-dd HH:mm:ss"))
# primary_sdf = primary_sdf.withColumn("value_time", F.to_timestamp(F.col("value_time"), "yyyyMMddHHmm"))
# primary_sdf.show(4)

primary_casted_sdf = primary_sdf.withColumn("value_time", F.to_timestamp(F.col("value_time"), "yyyyMMddHHmm")).withColumn("reference_time", F.to_timestamp(F.col("reference_time"), "yyyy-MM-dd HH:mm:ss"))

In [None]:
primary_casted_sdf.schema

In [None]:
ev.set_active_catalog("remote")
ev.active_catalog

In [None]:
ev.write.to_warehouse(
    source_data=primary_casted_sdf,
    table_name="primary_timeseries",
    write_mode="append",
)

### Alter table partitions and rewrite datafiles

In [None]:
ev.spark.sql("ALTER TABLE local.teehr.primary_timeseries ADD PARTITION FIELD years(value_time)")

In [None]:
ev.spark.sql("ALTER TABLE local.teehr.primary_timeseries ADD PARTITION FIELD months(value_time)")

In [None]:
%%time
# Set a target file size (e.g., 128 MB)
table_name = "primary_timeseries"

ev.spark.sql(f"""
    CALL local.system.rewrite_data_files(
        table => 'teehr.{table_name}',
        options => map('target-file-size-bytes', '134217728') -- 128 MB
    )
""")

In [None]:
ev.spark.sql("SELECT committed_at FROM local.teehr.primary_timeseries.snapshots ORDER BY committed_at DESC").toPandas()

In [None]:
%%time
ev.spark.sql("""
    CALL local.system.expire_snapshots(
        table => 'teehr.primary_timeseries', 
        older_than => TIMESTAMP '2026-02-11 01:14:22.558',
        retain_last => 1
    )
""").show()

In [None]:
# # Coalesce approach
# table_data_dir = "/data/playground/slamont/teehr/warehouse/sedona/usgs_basins_map/local/teehr/primary_timeseries/data"

# sdf = ev.spark.read.parquet(str(table_data_dir / "*.parquet"))
# sdf.coalesce(num_cache_files).write.mode("overwrite").parquet(str(coalesced_cache_dir))

## Exploring the output

#### I needed to re-register the tables after we switch to the JDBC local catalog instead of hadooop

In [None]:
table_names = [
    "attributes",
    "configurations",
    "grid_pixel_coverage_weights",
    "location_attributes",
    "location_crosswalks",
    "locations",
    "nwm30_usgs_hires_basins_fractional_coverage",
    "primary_timeseries",
    "secondary_timeseries",
    "temp_nwm_rainrate_rasters",
    "temp_secondary_timeseries",
    "temp_secondary_timeseries_test",
    "units",
    "variables"
]

In [None]:
# Execute the register_table procedure
for table_name in table_names:
    meta_dir = f"/data/playground/slamont/teehr/warehouse/sedona/usgs_basins_map/local/teehr/{table_name}/metadata"
    filelist = glob.glob(meta_dir + "/*.metadata.json")
    meta_df = pd.DataFrame([{"path": fullpath, "version_number": int(re.findall(r'v(\d+)', Path(fullpath).stem)[0])} for fullpath in filelist])
    latest_meta_path = meta_df.iloc[meta_df.version_number.idxmax()].path

    # NOTE: If it already exists this will raise an error
    ev.spark.sql(f"""
    CALL local.system.register_table(
        table => 'teehr.{table_name}',
        metadata_file => '{latest_meta_path}'
    )
    """).show()

In [None]:
# ev.list_tables()

#### Check out the start/end times, total timesteps, etc.

In [None]:
sdf = ev.primary_timeseries.to_sdf()

In [None]:
sdf.select(F.max("value_time")).show(), sdf.select(F.min("value_time")).show() 

In [None]:
one_hour_sdf = sdf.filter("value_time = '1984-01-01 06:00'")

In [None]:
sed_df = one_hour_sdf.toPandas()
sed_df

In [None]:
ee_df = pd.read_csv("/data/playground/slamont/teehr/warehouse/sedona/exactextract/198401010600_RAINRATE_nwm_v30_results.csv")

In [None]:
ee_df.set_index("basin_id", inplace=True)

In [None]:
ee_df.reset_index(inplace=True)

In [None]:
ee_df[ee_df.basin_id == "usgsbasin-01013500"]

In [None]:
ee_df.sort_values(by="mean")

In [None]:
sed_df.sort_values(by="value")