In [None]:
import os
import sys
import dask
import typer
import numpy
import shutil
import pyproj
import xarray
import pandas
import logging
import rioxarray
import geopandas
from pathlib import Path
from dask.distributed import Client
from geocube.api.core import make_geocube

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
logger.addHandler(handler)

In [None]:
client = Client(n_workers=4, memory_limit='4GB')
client

In [None]:
def load_data(forcing, geopackage):  
    ds = load_ds(forcing)
    gdf = load_gdf(geopackage)
    return ds, gdf

def load_ds(forcing):  
    
    # load forcing data
    ds = xarray.open_dataset(forcing)
    
    return ds
    
def load_gdf(geopackage):    
    
    
    # load hydrofabric
    gdf = geopandas.read_file(geopackage, layer='divides')
    
    # convert these data into the projection of our forcing data
    # this assumes that we're using AORC forcing.
    # TODO: generalize this to use whatever projection is defined in the 
    # forcing dataset
    target_crs = pyproj.Proj(proj='lcc',
                             lat_1=30.,
                             lat_2=60., 
                             lat_0=40.0000076293945, lon_0=-97.,
                             a=6370000, b=6370000)
    gdf = gdf.to_crs(target_crs.crs)

    return gdf

@dask.delayed
def prepare_zonal(in_ds, gdf):

    # create zonal id column
    gdf['cat'] = gdf.id.str.split('-').str[-1].astype(int)

    # set the aorc crs.
    # TODO: This should be set when the dataset is saved, not here.
    in_ds =  in_ds.rio.write_crs('EPSG:4326', inplace=True)
    
    # create a grid for the geocube
    out_grid = make_geocube(
        vector_data=gdf,
        measurements=["cat"],
        like=in_ds # ensure the data are on the same grid
    )

    # add the catchment variable to the original dataset
    in_ds = in_ds.assign_coords(cat = (['latitude','longitude'], out_grid.cat.data))

    return in_ds

@dask.delayed 
def delayed_zonal_computation(ds):
    return ds.groupby(ds.cat).mean()
    
    #d = ds.where(ds.cat==cat_id, drop=True)
    #res =  {variable: ds.mean(dim=['x','y']).values}
    #return d.mean(dim=['x','y']).resample(time="1h").sum()


In [None]:
geopackage = 'input-data/wb-2917533_upstream_subset.gpkg'
forcing = 'input-data/results.nc'
output_data = 'output-data'

results = []

ds, gdf = load_data(forcing, geopackage)
#ds = ds.isel(time=range(0,100))
scattered_ds = client.scatter(ds, broadcast=True)
scattered_gdf = client.scatter(gdf, broadcast=True)

zonal_ds = prepare_zonal(scattered_ds, scattered_gdf).compute()
scattered_zonal_ds = client.scatter(zonal_ds, broadcast=True)

# clean up
del scattered_ds
del scattered_gdf

In [None]:
r = delayed_zonal_computation(scattered_zonal_ds)

In [None]:
%%time
results = r.compute()

In [None]:
@dask.delayed
def save_to_csv(results, cat_id, output_dir):
    fname = f'cat-{int(cat_id)}'
    with open(f'{output_dir}/{fname}.csv', 'w') as f:
        df = results.sel(dict(cat=cat_id)).to_dataframe()
        df.fillna(0., inplace=True)
        df['APCP_surface'] = df.APCP_surface * 3600
        df.to_csv(f, columns = ['APCP_surface',
                                'DLWRF_surface',
                                'DSWRF_surface',
                                'PRES_surface',
                                'SPFH_2maboveground',
                                'TMP_2maboveground',
                                'UGRD_10maboveground',
                                'VGRD_10maboveground'])

In [None]:
# convert cat from float to string
results = results.assign_coords({'cat': results.cat.astype(int).astype(str)})
results_scattered = client.scatter(results, broadcast=True)

delayed_write = []
for cat in results.cat.values:
    delayed = save_to_csv(results_scattered, cat, output_data)
    delayed_write.append(delayed)

In [None]:
_ = dask.compute(delayed_write)

In [None]:
ds = results

In [None]:
st = ds.time.values.min()
et = ds.time.values.max()

In [None]:
((et-st).item() * 10**-9) / (3600*24)

In [None]:
from datetime import datetime

In [None]:
st = pandas.to_datetime(ds.time.values.min())
et = pandas.to_datetime(ds.time.values.max())

In [None]:
(et-st).total_seconds()

In [None]:
computed_catchments = list(ds.cat.values)

In [None]:
known_catchments = gdf.id.values

In [None]:
diff =  len(known_catchments) - len(computed_catchments)
if diff > 0:
    print(f'{diff} catchments missing from NGen Subset.\nComputing synthetic data for these')

In [None]:
for known_id in known_catchments:
    _id = known_id.split('-')[-1]
    if _id not in computed_catchments:
        print(f'missing {_id}')


In [None]:
known_catchments[0].split('-')[-1]

In [None]:
ds