In [None]:
import os
import shutil
import fsspec
import ujson
from kerchunk.hdf import SingleHdf5ToZarr
from kerchunk.combine import MultiZarrToZarr
import xarray as xr
import dask
import hvplot.xarray
from datetime import datetime, timedelta
import pandas as pd

In [None]:
from dask.distributed import Client, LocalCluster, progress

cluster = LocalCluster()
client = Client(cluster)
client

In [None]:
# adding project dirs to path so code may be referenced from the notebook
import sys
sys.path.insert(0, '../../evaluation')
sys.path.insert(0, '../../evaluation/queries')
sys.path.insert(0, '../../evaluation/loading')


In [None]:
# Query some forcast data from parquet files
import importlib
import queries
import config
import utils as hu
importlib.reload(queries)
importlib.reload(config)
importlib.reload(hu)
import grid_to_parquet
importlib.reload(grid_to_parquet)
from datetime import datetime, timedelta

In [None]:
# # Setup some criteria
# ingest_days = 30
# start_dt = datetime(2022, 12, 18, 6) # First one is at 00Z in date
# td = timedelta(hours=6)
# number_of_forecasts = 1 #ingest_days * 4

In [None]:
fs = fsspec.filesystem('gcs', anon=True)
fs2 = fsspec.filesystem('')

In [None]:
# json_dir = 'forcing_jsons/'

# if not os.path.exists(json_dir):
#     os.makedirs(json_dir)

In [None]:
so = dict(mode='rb', anon=True, default_fill_cache=False, default_cache_type='first') # args to fs.open()
# default_fill_cache=False avoids caching data in between file chunks to lowers memory usage.

In [None]:
# def gen_json(u):
#     with fs.open(u, **so) as infile:
#         h5chunks = SingleHdf5ToZarr(infile, u, inline_threshold=300)
#         p = u.split('/')
#         date = p[3]
#         fname = p[5]
#         outf = f'{json_dir}{date}.{fname}.json'
#         with open(outf, 'wb') as f:
#             f.write(ujson.dumps(h5chunks.translate()).encode());

In [None]:
# %%time
# print(datetime.now())
# # Loop though forecasts, fetch and insert
# for f in range(number_of_forecasts):
#     reference_time = start_dt + td * f
#     ref_time_str = reference_time.strftime("%Y%m%dT%HZ")
#     configuration = "forcing_medium_range"

#     print(f"Start download of {ref_time_str}")

#     blob_list = grid_to_parquet.list_blobs_forcing(
#         configuration=configuration,
#         reference_time = ref_time_str,
#         must_contain = "forcing"
#     )
    
#     blob_list = [f"gcs://national-water-model/{b}" for b in blob_list]
    
#     results = dask.compute(*[dask.delayed(gen_json)(u) for u in blob_list], retries=10)
    

In [None]:
# json_list = fs2.glob(f'{json_dir}/nwm.20221218.nwm.t06z*.json')
# json_list = sorted(json_list)

In [None]:
# mzz = MultiZarrToZarr(json_list,
#         remote_protocol='gcs',
#         remote_options={'anon':True},
#         concat_dims=['time'],
#         identical_dims = ['x', 'y'],
#     )

In [None]:
# %%time
# mzz.translate('nwm.json')

In [None]:
# backend_args = { "consolidated": False,
#                  "storage_options": { "fo": 'forcing_jsons/nwm.20221218.nwm.t06z.medium_range.forcing.f001.conus.nc.json',
#                                 "remote_protocol": "gcs", 
#                                 "remote_options": {'anon':True} }}
# ds = xr.open_dataset(
#     "reference://", engine="zarr",
#     backend_kwargs=backend_args
# )

In [None]:
# %%time
# src = ds["RAINRATE"].persist()

In [None]:
# %%time
# df = grid_to_parquet.calc_zonal_stats_weights(
#     src=src,
#     weights_filepath=config.HUC10_MEDIUM_RANGE_WEIGHTS_FILEPATH
# )
# df

In [None]:
def gen_json2(blob_in, json_out):
    print(f"gen_json: {json_out}")
    with fs.open(blob_in, **so) as infile:
        h5chunks = SingleHdf5ToZarr(infile, blob_in, inline_threshold=300)
        with open(json_out, 'wb') as f:
            f.write(ujson.dumps(h5chunks.translate()).encode())

In [None]:
def calc_zonal_stats_weights(
    src: xr.DataArray,
    weights_filepath: str,
) -> pd.DataFrame:
    """Calculates zonal stats"""

    # Open weights dict from pickle
    # This could probably be done once and passed as a reference.
    with open(weights_filepath, 'rb') as f:
        crosswalk_dict = pickle.load(f)

    r_array = src.values[0]
    r_array[r_array == src.rio.nodata] = np.nan

    mean_dict = {}
    for key, value in crosswalk_dict.items():
        mean_dict[key] = np.nanmean(r_array[value])

    df = pd.DataFrame.from_dict(mean_dict,
                                orient='index',
                                columns=['value'])

    df.reset_index(inplace=True, names="catchment_id")

    # This should not be needed, but without memory usage grows
    # del crosswalk_dict
    # del f
    # gc.collect()

    return df

In [None]:
def get_dataset(
        blob_name: str,
) -> xr.Dataset:
    """Retrieve a blob from the data service as xarray.Dataset.

    Based largely on OWP HydroTools.

    Parameters
    ----------
    blob_name: str, required
        Name of blob to retrieve.

    Returns
    -------
    ds : xarray.Dataset
        The data stored in the blob.

    """
    print(f"get_dataset: {blob_name}")
    
    json_out = f"{os.path.join(hu.get_cache_dir(), blob_name)}.json"
    hu.make_parent_dir(json_out )
    blob_in = f"gcs://national-water-model/{blob_name}"
    
    gen_json2(blob_in, json_out)
   
    backend_args = { "consolidated": False,
                     "storage_options": { "fo": json_out,
                                    "remote_protocol": "gcs", 
                                    "remote_options": {'anon':True} }}
    ds = xr.open_dataset(
        "reference://", engine="zarr",
        backend_kwargs=backend_args
    )
    
    return ds

In [None]:
def calculate_map_forcing(
    blob_name: str,
    weights_filepath: str,
) -> pd.DataFrame:
    """Calculate the MAP for a single NetCDF file (i.e. one timestep).

    ToDo: add way to filter which catchments are calculated
    """
    # print(f"Processing {blob_name}, {datetime.now()}")

    # Get some metainfo from blob_name
    path_split = blob_name.split("/")
    reference_time = datetime.strptime(
        path_split[0].split(".")[1] + path_split[2].split(".")[1],
        "%Y%m%dt%Hz"
    )
    offset_hours = int(path_split[2].split(".")[4][1:])  # f001
    value_time = reference_time + timedelta(hours=offset_hours)
    configuration = path_split[1]

    # Get xr.Dataset/xr.DataArray
    ds = get_dataset(blob_name)
    src = ds["RAINRATE"]

    # Pull out some attributes
    measurement_unit = src.attrs["units"]
    variable_name = src.attrs["standard_name"]

    # Calculate MAP
    df = calc_zonal_stats_weights(src, weights_filepath)

    # Set metainfo for MAP
    df["reference_time"] = reference_time
    df["value_time"] = value_time
    df["configuration"] = configuration
    df["measurement_unit"] = measurement_unit
    df["variable_name"] = variable_name

    # Reduce memory foot print
    df['configuration'] = df['configuration'].astype("category")
    df['measurement_unit'] = df['measurement_unit'].astype("category")
    df['variable_name'] = df['variable_name'].astype("category")
    df["catchment_id"] = df["catchment_id"].astype("category")

    # print(df.info(verbose=True, memory_usage='deep'))
    # print(df.memory_usage(index=True, deep=True))
    # print(df)

    # This should not be needed, but without memory usage grows
    # ds.close()
    # del ds
    # gc.collect()

    return df

In [None]:
# Setup some criteria
ingest_days = 30
start_dt = datetime(2022, 12, 18) # First one is at 00Z in date
td = timedelta(hours=6)
number_of_forecasts = 1 #ingest_days * 4

In [None]:
print(datetime.now())

# Loop though forecasts, fetch and insert
for f in range(number_of_forecasts):
    reference_time = start_dt + td * f
    ref_time_str = reference_time.strftime("%Y%m%dT%HZ")

    print(f"Processing: {ref_time_str}")

    blob_list = grid_to_parquet.list_blobs_forcing(
        configuration = "forcing_medium_range",
        reference_time = ref_time_str,
        must_contain = "forcing"
    )[:2]
    
    dfs = []
    for blob_name in blob_list:
        # df = calculate_map_forcing(
        #     blob_name, 
        #     weights_filepath=config.HUC10_MEDIUM_RANGE_WEIGHTS_FILEPATH
        # )
        print(blob_name)
        df = dask.delayed(calculate_map_forcing)(
            blob_name, 
            weights_filepath=config.HUC10_MEDIUM_RANGE_WEIGHTS_FILEPATH
        )
        dfs.append(df)
    
    # Join all timesteps into single pd.DataFrame
    results = dask.compute(*dfs)
    df = pd.concat(results)
    # df = pd.concat(dfs)

    # Save as parquet file
    parquet_filepath = os.path.join(config.MEDIUM_RANGE_FORCING_PARQUET, f"{ref_time_str}.parquet")
    utils.make_parent_dir(parquet_filepath)
    df.to_parquet(parquet_filepath)
    
    # del df
    # gc.collect()

    # Print out some DataFrame stats
    # print(df.info(verbose=True, memory_usage='deep'))
    # print(df.memory_usage(index=True, deep=True))
print(datetime.now())