In [None]:
%matplotlib inline

# adding project dirs to path so code may be referenced from the notebook
import sys
sys.path.insert(0, '../../evaluation')
sys.path.insert(0, '../../evaluation/queries')
sys.path.insert(0, '../../evaluation/loading')

In [None]:
import os
import shutil
import fsspec
import ujson
from kerchunk.hdf import SingleHdf5ToZarr
from kerchunk.combine import MultiZarrToZarr
import xarray as xr
import dask
import hvplot.xarray
from datetime import datetime, timedelta
import pandas as pd
import pickle
import numpy as np
import time
import gc

In [None]:
# Query some forcast data from parquet files
import importlib
import queries
import config
import utils
importlib.reload(queries)
importlib.reload(config)
importlib.reload(utils)
import grid_to_parquet
importlib.reload(grid_to_parquet)
from datetime import datetime, timedelta

In [None]:
from dask.distributed import Client, LocalCluster, progress

cluster = LocalCluster(n_workers=16)
client = Client(cluster)
client

In [None]:
fs = fsspec.filesystem('gcs', anon=True)
fs2 = fsspec.filesystem('')

In [None]:
so = dict(mode='rb', anon=True, default_fill_cache=True, default_cache_type='first') # args to fs.open()
# default_fill_cache=False avoids caching data in between file chunks to lowers memory usage.

In [None]:
def gen_zarr_json(blob_name: str):
    """Given a blob_name path to GCS resource returns zarr JSON."""
    json_out = f"{os.path.join(utils.get_cache_dir(), blob_name)}.json"
    utils.make_parent_dir(json_out )
    blob_in = f"gcs://national-water-model/{blob_name}"
    
    with fs.open(blob_in, **so) as infile:
        h5chunks = SingleHdf5ToZarr(infile, blob_in, inline_threshold=300)
        with open(json_out, 'wb') as f:
            f.write(ujson.dumps(h5chunks.translate()).encode())
    return json_out

In [None]:
def calc_zonal_stats_weights(
    src: xr.DataArray,
    weights_filepath: str
) -> pd.DataFrame:
    """Calculates zonal stats"""

    crosswalk_dict = utils.read_weights_file(weights_filepath)
        
    r_array = src.values[0]
    r_array[r_array == src.rio.nodata] = np.nan

    mean_dict = {}
    for key, value in crosswalk_dict.items():
        mean_dict[key] = np.nanmean(r_array[value])

    df = pd.DataFrame.from_dict(mean_dict,
                                orient='index',
                                columns=['value'])

    df.reset_index(inplace=True, names="catchment_id")

    # This should not be needed, but without memory usage grows
    # del crosswalk_dict
    # del f
    # gc.collect()

    return df

In [None]:
def get_dataset(
        zarr_json: str
) -> xr.Dataset:
    """Retrieve a blob from the data service as xarray.Dataset.

    Parameters
    ----------
    blob_name: str, required
        Name of blob to retrieve.

    Returns
    -------
    ds : xarray.Dataset
        The data stored in the blob.

    """   
    backend_args = { "consolidated": False,
                     "storage_options": { "fo": zarr_json,
                                    "remote_protocol": "gcs", 
                                    "remote_options": {'anon':True} }}
    ds = xr.open_dataset(
        "reference://", engine="zarr",
        backend_kwargs=backend_args
    )
    
    return ds

In [None]:
def calculate_map_forcing(
    zarr_json: str,
    weights_filepath: str
) -> pd.DataFrame:
    """Calculate the MAP for a single NetCDF file (i.e. one timestep).

    ToDo: add way to filter which catchments are calculated
    """

#     # Get some metainfo from blob_name
#     path_split = zarr_json.split("/")
#     reference_time = datetime.strptime(
#         path_split[0].split(".")[1] + path_split[2].split(".")[1],
#         "%Y%m%dt%Hz"
#     )
#     offset_hours = int(path_split[2].split(".")[4][1:])  # f001
#     value_time = reference_time + timedelta(hours=offset_hours)
#     configuration = path_split[1]
    
    # Get some metainfo from zarr_json
    path_split = zarr_json.split("/")
    reference_time = datetime.strptime(
        path_split[5].split(".")[1] + path_split[7].split(".")[1],
        "%Y%m%dt%Hz"
    )
    offset_hours = int(path_split[7].split(".")[4][1:])  # f001
    value_time = reference_time + timedelta(hours=offset_hours)
    configuration = path_split[6]
    
    # Get xr.Dataset/xr.DataArray
    ds = get_dataset(zarr_json)
    src = ds["RAINRATE"]

    # Pull out some attributes
    measurement_unit = src.attrs["units"]
    variable_name = src.attrs["standard_name"]

    # Calculate MAP
    df = calc_zonal_stats_weights(
        src,
        weights_filepath
    )

    # Set metainfo for MAP
    df["reference_time"] = reference_time
    df["value_time"] = value_time
    df["configuration"] = configuration
    df["measurement_unit"] = measurement_unit
    df["variable_name"] = variable_name

    # Reduce memory foot print
    df['configuration'] = df['configuration'].astype("category")
    df['measurement_unit'] = df['measurement_unit'].astype("category")
    df['variable_name'] = df['variable_name'].astype("category")
    df["catchment_id"] = df["catchment_id"].astype("category")

    # print(df.info(verbose=True, memory_usage='deep'))
    # print(df.memory_usage(index=True, deep=True))
    # print(df)

    # This should not be needed, but without memory usage grows
    ds.close()
    del ds
    # gc.collect()

    return df

In [None]:
# Setup some criteria
ingest_days = 30
start_dt = datetime(2022, 12, 18, 18) # First one is at 00Z in date
td = timedelta(hours=6)
number_of_forecasts = 1 #ingest_days * 4

In [None]:
%%time
# Loop though forecasts, fetch and insert
for f in range(number_of_forecasts):
    reference_time = start_dt + td * f
    ref_time_str = reference_time.strftime("%Y%m%dT%HZ")

    print(f"Processing: {ref_time_str}")

    blob_list = grid_to_parquet.list_blobs_forcing(
        configuration = "forcing_medium_range",
        reference_time = ref_time_str,
        must_contain = "forcing"
    )
    
    
    # Generate Zarr JSONS
    time1 = time.time()
    zarr_json_list = dask.compute(*[dask.delayed(gen_zarr_json)(b) for b in blob_list], retries=10)
    time2 = time.time()
    print(f"Generate Zarr took: {time2-time1}")
    
    # Calculate MAP
    time1 = time.time()
    dfs = []
    for zarr_json in zarr_json_list:
        df = dask.delayed(calculate_map_forcing)(
            zarr_json, 
            weights_filepath=config.HUC10_MEDIUM_RANGE_WEIGHTS_FILEPATH
        )
        dfs.append(df)
    
    # Join all timesteps into single pd.DataFrame
    results = dask.compute(*dfs)
    df = pd.concat(results)

    time2 = time.time()
    print(f"Download and Calculate MAP took: {time2-time1}")
    
    # Save as parquet file
    parquet_filepath = os.path.join(config.MEDIUM_RANGE_FORCING_PARQUET, f"{ref_time_str}.parquet")
    utils.make_parent_dir(parquet_filepath)
    df.to_parquet(parquet_filepath)
    
    # del df
    # gc.collect()

    # Print out some DataFrame stats
    # print(df.info(verbose=True, memory_usage='deep'))
    # print(df.memory_usage(index=True, deep=True))

In [None]:
df

In [None]:
# Now lets try combining to multizarr and

In [None]:
%%time
mzz = MultiZarrToZarr(
    zarr_json_list,
    remote_protocol='gcs',
    remote_options={'anon':True},
    concat_dims=['time'],
    identical_dims = ['x', 'y'],
)
json_out = f"{os.path.join(utils.get_cache_dir(), ref_time_str)}.json"
mzz.translate(json_out)

In [None]:
ds = get_dataset(json_out)

In [None]:
timesteps = ds.time.data

In [None]:
%%time
dfs2 = []
for t in timesteps:
    dfs2.append(dask.delayed(calc_zonal_stats_weights)(
            src=ds.sel(time=[t])["RAINRATE"], 
            weights_filepath=config.HUC10_MEDIUM_RANGE_WEIGHTS_FILEPATH
        )
    )
results = dask.compute(*dfs2)
df2 = pd.concat(results)

In [None]:
df.loc[df["catchment_id"]=="1016000606"]

In [None]:
ds.sel(time=[t])["RAINRATE"]