In [None]:
import os
from pathlib import Path
import shutil

import teehr
import pandas as pd
import requests

teehr.__version__

### Setup

In [None]:
from teehr.evaluation.spark_session_utils import create_spark_session

spark = create_spark_session(
    aws_profile="default"
)

In [None]:
%%time
dir_path = "/data/temp_warehouse"

ev = teehr.Evaluation(
    spark=spark,
    dir_path=dir_path,
    create_dir=False
)

In [None]:
# define toggle to generate crosswalk from scratch if not already generated as .csv ahead of run
make_crosswalk=True

### Confirm remote

In [None]:
ev.set_active_catalog("remote")

ev.active_catalog

In [None]:
ev.locations.to_sdf().count()

In [None]:
ev.secondary_timeseries.to_sdf().count()

### Examine original tables

In [None]:
ev.variables.to_sdf().show(truncate=False)

In [None]:
ev.configurations.to_sdf().show(truncate=False)

### Add configuration

In [None]:
from teehr import Configuration

configuration = Configuration(
    name="nwpsrfc_streamflow_forecast",
    type="secondary",
    description="NWPS RFC Streamflow Forecast",
)

ev.configurations.add(configuration)

### Add variable

In [None]:
from teehr import Variable

variable = Variable(
    name="streamflow_6hr_inst",
    long_name="Instantaneous 6-hour streamflow"
)

ev.variables.add(variable)

### Add crosswalk entries

In [None]:
def get_new_crosswalks() -> pd.DataFrame:
    """Get new crosswalks for NWPS RFC Streamflow Forecast."""
    primary_location_ids_sdf = ev.location_crosswalks.to_sdf().select('primary_location_id').distinct().collect()
    primary_location_ids = [row.primary_location_id for row in primary_location_ids_sdf]
    print(f'Extracted {len(primary_location_ids)} from remote location_crosswalks table')
    
    usgs_ids = [id for id in primary_location_ids if id.startswith('usgs-')]
    usgs_ids_stripped = [s.removeprefix('usgs-') for s in usgs_ids]
    print(f'Extracted {len(usgs_ids_stripped)} USGS entries from complete list of primary_location_ids\n\n')

    print('Starting routine to obtain RFC IDs from NWPS API....')

    n = len(usgs_ids_stripped)
    count = 0

    rfc_lids = {}
    for usgs_id in usgs_ids_stripped:
        count += 1
        print(f'\tFetching metadata for {usgs_id}...({count} of {n})')
        endpoint = f"https://api.water.noaa.gov/nwps/v1/gauges/{usgs_id}" # sample: https://api.water.noaa.gov/nwps/v1/gauges/01347000
        try:
            response = requests.get(endpoint)
            response.raise_for_status()
            metadata = response.json()
        except requests.exceptions.RequestException as e:
            print(f'\t\texception: {e}')
            continue
        if 'lid' in metadata:
            lid = metadata['lid']
            rfc_lids[usgs_id] = lid
            print(f'\t\tSucessfully retrieved LID for {usgs_id}')
        else:
            print(f'\t\tLID not found for usgs_id: {usgs_id}')

    print(f'Obtained {len(rfc_lids)} usgs-id + rfc-lid pairs (%{(len(rfc_lids)/len(usgs_ids_stripped))*100})\n\n')

    print('Assembling result dataframe with appropriate prefixes....')

    rfc_lids_format = ['nwpsrfc-' + lid for lid in list(rfc_lids.values())]
    usgs_ids_format = ['usgs-' + id for id in list(rfc_lids.keys())]

    data = {
        'primary_location_id':usgs_ids_format,
        'secondary_location_id':rfc_lids_format,
    }

    df = pd.DataFrame(data)

    return df

In [None]:
# generate crosswalk if not already made
if make_crosswalk:
    df = get_new_crosswalks()
    filename = 'nwps_rfc_crosswalk.csv'
    df.to_csv(filename, index=False)

In [None]:
# load in-memory crosswalk or read-in existing and load
if make_crosswalk:
    ev.location_crosswalks.load_dataframe(df)
else:
    df = pd.read_csv('nwps_rfc_crosswalk.csv')
    ev.location_crosswalks.load_dataframe(df)

### Kill spark

In [None]:
ev.spark.stop()