In [2]:
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import geopandas as gpd
import xarray as xr

from glob import glob
from tqdm import tqdm
from pathlib import Path
from datetime import timedelta
from shapely import wkt
import ast
from shapely.geometry import Point, LineString

datasets = Path("/nas/cee-water/cjgleason/data")
era5_dir = datasets / "ERA5-Land/sub_basin_timeseries"
swot_lake_dir = datasets / 'hydrocron' / 'lake'

save_dir = Path("/nas/cee-water/cjgleason/ted/swot-ml/data/distributed")
preprocess_dir = save_dir / "preprocess"
metadata_dir = save_dir / "metadata"

ts_dir = save_dir / "time_series"
ts_dir.mkdir(exist_ok=True)

# basin_name = 'Ohio'
basin_name = 'Upper_Miss'


matchups = gpd.read_file(metadata_dir / f'{basin_name}_matchups.geojson').set_index("HYBAS_ID")
matchups.index = matchups.index.astype(str)


# Safely convert stringified lists/dicts back to Python objects
def safe_literal_eval(df, col):
    df[col] = df[col].apply(lambda x: ast.literal_eval(x))
    return df

for col in ["mb_values", "lake_reach_ids", "lake_pld_ids"]:
    matchups = safe_literal_eval(matchups, col)


In [3]:
import global_gauges as gg

facade = gg.GaugeDataFacade(providers='usgs')
sites = facade.get_stations_n_days(90)

site_ids = matchups[~matchups['site_id'].isna()]['site_id'].unique()
sites = sites.loc[site_ids]



In [6]:
import pyarrow.dataset as ds
import itertools 

lake_ids = list(itertools.chain.from_iterable(matchups['lake_reach_ids']))
river_reaches = matchups['reach_id'].dropna().astype(int).to_list()
all_reaches = lake_ids + river_reaches

fields=[
    'reach_id', 'time','wse', 'wse_u', 'wse_r_u',
    'slope', 'slope_u', 'slope_r_u', 'slope2', 'slope2_u', 'slope2_r_u',
    'width', 'width_u',
    'area_total', 'area_tot_u', 'area_detct', 'area_det_u', 'area_wse',
    'layovr_val', 'node_dist', 'loc_offset', 'xtrk_dist',
    'reach_q', 'reach_q_b', 'dark_frac', 'ice_clim_f', 'partial_f',
    'obs_frac_n', 'xovr_cal_q', 'dry_trop_c', 'wet_trop_c', 'iono_c', 'xovr_cal_c'
]

continents = ['na']
swot = []
for con in continents:
    con_file = datasets / 'hydrocron' / 'reach' / (con + '_hydrocron_reach.parquet')
    dataset = ds.dataset(con_file, format="parquet")
    table = dataset.to_table(
        columns=fields,
        filter=(ds.field("reach_id").isin(all_reaches))
    )
    swot.append(table.to_pandas())
    
all_swot = pd.concat(swot)
all_swot = all_swot[all_swot['wse'] != -999999999999.0]
all_swot['d_wse'] = all_swot['wse'] - all_swot.groupby('reach_id')['wse'].transform('median')
all_swot['d_width'] = all_swot['width'] - all_swot.groupby('reach_id')['width'].transform('median')

all_swot = all_swot.set_index(['reach_id','time'])

In [5]:
glow_dir = datasets / "GLOW-S" / "daily_reach_aggregated"

glow = pd.read_parquet(glow_dir / "region_7_daily_median.parquet")
glow

Unnamed: 0_level_0,Unnamed: 1_level_0,width
COMID,date,Unnamed: 2_level_1
71000001,2019-01-15,854.720882
71000001,2019-02-04,1.663859
71000001,2019-02-09,1075.322720
71000001,2019-05-23,1527.924744
71000001,2019-06-11,1645.340463
...,...,...
78028470,2020-06-22,115.782824
78028470,2020-12-24,540.976377
78028470,2021-03-29,251.693162
78028470,2022-08-21,304.667251


In [7]:
# Filter out reach_ids that already have a processed file
def get_reaches_to_process():
    to_process = matchups.copy()
    print(f"Total matched reach_ids: {len(to_process.index)}")

    processed_stems = [f.stem for f in list(ts_dir.glob('*.nc'))]
    processed_mask = to_process.index.astype(str).isin(processed_stems)
    to_process = to_process[~processed_mask]
    print(f"Number of unprocessed basins: {len(to_process.index)}")
    
    all_swot_ids = all_swot.index.get_level_values('reach_id').unique()
    processed_mask = to_process.index.isin(all_swot_ids)
    to_process = to_process[~processed_mask]
    print(f"Number of unprocessed basins with SWOT data: {len(to_process.index)}")
    
    era5_stems = [p.stem for p in (era5_dir / 'basin_timeseries').glob('*.parquet')]
    processed_mask = to_process.index.astype(str).isin(era5_stems)
    to_process = to_process[~processed_mask]
    print(f"Number of unprocessed basins with era5 file: {len(to_process.index)}")
    
    return to_process.index.unique()

to_process = get_reaches_to_process()

Total matched reach_ids: 1712
Number of unprocessed basins: 1712
Number of unprocessed basins with SWOT data: 1712
Number of unprocessed basins with era5 file: 1712


In [8]:
t1 = "1980-01-01"
t2 = "2024-12-31"

pld_fields = [
    "n_overlap", "wse", "wse_u", "wse_r_u", "wse_std",
    "area_total", "area_tot_u", "area_detct", "area_det_u",
    "dark_frac", "xovr_cal_q",  "layovr_val", "xtrk_dist",
    "quality_f",  "partial_f", "ice_clim_f",
    "dry_trop_c", "wet_trop_c", "iono_c", "xovr_cal_c"
]

def get_glow_s(merit_reaches):
    glow_mask = glow.index.get_level_values('COMID').isin(merit_reaches)
    if glow_mask.any():
        glow_ix = glow[glow_mask].groupby('date').median()
        glow_ix.index = pd.to_datetime(glow_ix.index).tz_localize('UTC')
        return glow_ix.rename(columns={'width':'glow_width'})
    return pd.DataFrame(columns = ['glow_width'])


def get_swot_r(reach_id):
    if reach_id in all_swot.index.get_level_values('reach_id'):
        swot = all_swot.xs(reach_id, level='reach_id')
        swot.index = pd.to_datetime(swot.index).tz_localize('UTC')
        return swot
    return pd.DataFrame(columns = all_swot.columns)

def get_swot_l(pld_ids: list):
    lake_dfs = []
    for pld_id in pld_ids:
        path = Path(swot_lake_dir / f"{pld_id}.parquet")
        if path.is_file():
            swot = pd.read_parquet(path)[pld_fields]
            swot = swot.replace(-999999999999.0, np.nan)
            swot.dropna(subset=['wse', 'area_total'])
    
            wse = swot['wse']
            area = swot['area_total']
            swot['d_wse'] = wse - wse.median()
            swot['d_area'] = area - area.median()
            swot['d_volume'] = swot['d_wse'] * (0.5*(area + area.median()))
            lake_dfs.append(swot)
        else:
            lake_dfs.append([])

    df_lens = [len(d) for d in lake_dfs]
    if any(l>0 for l in df_lens):
        swot = lake_dfs[np.argmax(df_lens)]
        swot.index = swot.index.normalize().tz_convert('UTC')
        return swot
    else:
        new_fields = ['d_wse', 'd_area', 'd_volume']
        return pd.DataFrame(columns = pld_fields + new_fields)
    

def get_gauge(site_id):
    if site_id is not None:
        gauge = facade.get_daily_values(site_id, t1, t2).droplevel('site_id')
        gauge.index = gauge.index.tz_localize('UTC')
        return gauge[['discharge']]
    return pd.DataFrame(columns = ['discharge'])


def get_era5(reach_id):
    path = Path(era5_dir / f"{reach_id}.parquet")
    
    era5 = pd.read_parquet(path)
    era5.index = era5.index.tz_localize('UTC')
    era5.index.name = 'datetime'
    era5.fillna(0, inplace=True)
    return era5

In [9]:
# %%
ts_dir.mkdir(exist_ok=True)
date_range = pd.date_range(start=t1, end=t2, freq='D', tz='UTC')
to_process = get_reaches_to_process()
# to_process = matchups.index.unique()

for hybas_id in tqdm(to_process, total=len(to_process), desc="Writing files"):
    nc_file_path = ts_dir / f"{hybas_id}.nc"
    
    matchups_ix = matchups.loc[hybas_id]

  
    era5_df = get_era5(hybas_id)
    swot_r_df = get_swot_r(matchups_ix['reach_id']).add_suffix('_reach')
    swot_l_df = get_swot_l(matchups_ix['lake_pld_ids']).add_suffix('_lake')
    glow_df = get_glow_s(matchups_ix['mb_values'])
    gauge_df = get_gauge(matchups_ix['site_id'])
         
    # Merge all of the filtered datasets together.
    dataframes = [era5_df, swot_r_df, swot_l_df, glow_df, gauge_df]
    df = pd.DataFrame(index=date_range)
    for dataframe in dataframes:
        df = df.join(dataframe, how='left')
    
    df = df[~df.index.duplicated(keep='first')]
    df.index.name = 'date'
    
    if pd.infer_freq(df.index) !=  "D":
        raise RuntimeError("Non-daily time freq found")

    # Save netcdf
    dates = df.index.values
    ds = xr.Dataset.from_dataframe(df)
    ds = ds.assign_coords(date=(dates))
    ds.to_netcdf(nc_file_path)

Total matched reach_ids: 1712
Number of unprocessed basins: 1712
Number of unprocessed basins with SWOT data: 1712
Number of unprocessed basins with era5 file: 1712


Writing files: 100%|██████████| 1712/1712 [12:49<00:00,  2.22it/s]


In [None]:
list(glow_df)

In [None]:
plt.scatter(df.index, df['d_volume_lake'])
plt.scatter(df.index, df['d_wse_reach'])

In [None]:
plt.scatter(df.index, df['discharge'])

In [None]:
nc_file_path

In [None]:
for p in tqdm(ts_dir.glob("*.nc")):
    tmp_ds = xr.open_dataset(p)
    tmp_df = tmp_ds.to_dataframe()
    if (~tmp_df['glow_width'].isna()).sum() > 0:
        break
tmp_df

In [None]:
plt.scatter(tmp_df.index, tmp_df['glow_width'])

In [None]:
sample = xr.open_dataset(next(ts_dir.glob('*.nc')))
list(sample.variables)

In [None]:
for fp in ts_dir.glob('*.nc'):
    df = xr.open_dataset(fp).to_dataframe()
    
    if df['discharge'].isna().mean()<1:
        break
        
df

In [None]:
%matplotlib widget

plt.close('all')
# df['discharge'].plot()
# plt.errorbar(df['discharge'], df['d_wse'], yerr=df['wse_u'], fmt="o")
plt.scatter(df['discharge'], df['d_wse'], c=df['wse_u'])