In [45]:
# Jupter notebook magic
%matplotlib inline

import os
from glob import glob
import xarray as xr
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import matplotlib
from datetime import datetime, timedelta
from copy import copy
import geopandas as gpd
from PyAstronomy import pyasl

import zarr
import s3fs
import os



import himatpy, himatpy.GRACE_MASCON.pygrace, himatpy.MAR.utils



from importlib import reload
reload(himatpy.MAR.utils)
reload(himatpy.GRACE_MASCON.pygrace)


from himatpy.GRACE_MASCON.pygrace import (extract_grace, get_mascon_gdf, masked_mascon_gdf,grace_data_df,\
                                          get_full_trend,trend_analysis, get_cmwe_trend_analysis, select_mascons, \
                                          aggregate_mascons)
from himatpy.MAR.utils import save_agg_mascons,  MAR_trend, subset_data, nc2zarr,save_agg_mascons_zarr

import timeit

# Introduction
  * This notebook showcases the improvements in code efficiency by using chunked arrays with dask.

  * The aggregation of model data to GRACE mascons also benefits from using the xarray properly and avoid the overhead of its where function. 
  * This notebook requires the subsetting of MAR data to netcdf first and convert to a zarr store (either on s3 or local) 

### Uncomment and run the cell below if subset data and zarr store have not been created. 

In [22]:
# DNAME    = 'MAR'
# MAR_locl = os.path.join( os.path.abspath('./Data'), DNAME ) 
# SUB_locl = os.path.join( os.path.abspath('./SUB') , DNAME ) 
# marfns   = sorted(glob(MAR_locl+'/*.nc'))
# subfns   = sorted(glob(SUB_locl+'/*.nc'))
# if not os.path.exists(SUB_locl): os.makedirs(SUB_locl)

# # --- subset MAR dataset: around 40s per file.  
# chunks = {'X11_210':100,'Y11_190':90} 
# for ifn, tfn in enumerate(marfns[:1]):
#     start_time = timeit.default_timer()
#     sdir,sfn = os.path.split(tfn)
#     print(ifn,sfn)
#     ofn = os.path.join(SUB_locl,sfn)
#     subset_data(tfn,ofn,zlib=True,chunks=chunks)
#     end_time = timeit.default_timer()
#     print('Processing time [s]:',end_time-start_time) 

# # --- convert subset nc files to zarr: around 2.5 min 
# # --- cannot overwrite existing zarr store, may need to remove first
# chunks = {'time':360,'X':100,'Y':90}
# S3_root  = 'pangeo-data-upload-oregon/icesat2/HMA_Validation/'
# ZAR_path = os.path.join(S3_root,'ZarrSUB',DNAME)
# start_time = timeit.default_timer()
# nc2zarr(subfns,ZAR_path,s3store=True,chunks=chunks,parallel=True)
# end_time = timeit.default_timer()
# print('Processing time [s]:',end_time-start_time) 
    

#### Preparation: read Grace data to select the mascons in the MAR domain

In [25]:
Grace_fn = 'Data/Grace/GSFC.glb.200301_201607_v02.4-ICE6G.h5'
# ---> use local copy
grace_file = os.path.abspath(Grace_fn)
f = extract_grace(grace_file,printGroups=False)

# --- save the GRACE mascon info into file for future access/read if it already exists
SNAME    = 'MAR'
SUB_locl = os.path.join( os.path.abspath('./SUB') , SNAME ) 
subfns = sorted(glob(SUB_locl+'/*.nc'))
ds = xr.open_dataset(subfns[0])
MAR_mascons_fn = 'MAR_mascons.geojson'
masked_gdf = masked_mascon_gdf(f,ds,mascons_fn = MAR_mascons_fn,verbose=True)
ds.close()

Data extracted: 
... read info of mascons in domain from MAR_mascons.geojson ...


  f = h5py.File(fpath)


## Test the code efficiency 

#### First, contrast between the first two options using aggregate_mascons in pygrace. 
  * read subset MAR data without chunks
  * read subset MAR data with chunks

Because aggregate_mascons in pygrace has to query through the model data using mascon geometry for every data array in the dataset and loops over all mascons, it becomes extremely slow. For a single subset data file (one year, ~280MB), it takes about 21 min without chunks and dask, and 2 minutes with parallelization with dask. This is still relatively long, we choose 10 mascons out of 370 and demonstrate the difference.

In [104]:
mascon10 = masked_gdf#.iloc[:20]
# --- This line tests option 1, without chunks. 
start_time = timeit.default_timer()
#save_agg_mascons(subfns[:1],'testagg',mascon10)
end_time = timeit.default_timer()
print('Processing time [s] for option 1:',end_time-start_time) 
# --- reset timer and run option 2. 
start_time = timeit.default_timer()
save_agg_mascons(subfns[:1],'testagg',mascon10,chunks=chunks)
#save_agg_mascons(subfns[:1],'testagg',mascon10,chunks={'Y':200})
end_time = timeit.default_timer()
print('Processing time [s] for option 2:',end_time-start_time) 

Processing time [s] for option 1: 3.184700108249672e-05
... aggregating HMA_MAR3_5_ICE.2000.01-12.h22.nc ...


  x = np.divide(x1, x2, out)
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


Processing time [s] for option 2: 132.01944375099993


Without chunks, 10 mascons take about 35s. Using the chunks defined above (4 chunks in X/Y plane), they take 9s and using chunks but only one chunk in X/Y plane by using `chunks={'Y':200}`, it's actually slightly faster. The recommended chunk size by dask is 100 MB or at least 1M elements. For MAR data, a variable in one year (float32) is only 54 MB. Further chunking likely makes it less efficient. Although with only one chunk in X/Y, plane, dask still parallelize the processing across variables and speeds up the processing.  

### Second, use the new save_agg_mascon_zarr function in MAR/utils 
Use 100 mascons. About 1 min is needed for the whole 16-year data. Less than 3 min 41s is needed for all the mascons. 
The function save_agg_mascon_zarr use indices of MAR grid points in GRACE mascons to aggregate the entire MAR dataset. 
This avoids repeated calls to `xr.DataArray.where`. Using np.where to find indices to avoid the overhead on `xr.Dataset.where` . 

In [123]:
%%time
mascon100 = masked_gdf.iloc[:100]
save_agg_mascons_zarr(ds_store,'testagg/aggmar_test.nc',mascon100)

CPU times: user 1min 10s, sys: 6.18 s, total: 1min 16s
Wall time: 56 s


In [51]:
with xr.open_dataset('testagg/aggmar_test.nc') as tds:
    print(tds)

<xarray.Dataset>
Dimensions:    (mascon: 92, time: 5844)
Coordinates:
    SECTOR     float32 ...
    SECTOR1_1  float32 ...
  * time       (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2015-12-31
  * mascon     (mascon) int64 6693 6694 6695 6696 6697 ... 7767 7777 7778 7779
Data variables:
    RF         (mascon, time) float32 ...
    RU_ice     (mascon, time) float32 ...
    RU_other   (mascon, time) float32 ...
    SF         (mascon, time) float32 ...
    SMB_ice    (mascon, time) float32 ...
    SMB_other  (mascon, time) float32 ...
    SU_ice     (mascon, time) float32 ...
    SU_other   (mascon, time) float32 ...
    SW_ice     (mascon, time) float32 ...
    lat        (mascon) float32 ...
    long       (mascon) float32 ...


### Another question
Why not use `xr.Dataset.where` on the entire dataset? It can be demonstrated by the example below. For 10 mascons, 8s/5s for two choices of chunks. For all 370 mascons and one year, it uses 102s/88s with two choices chunks, slightly faster by 20s than using where on DataArray's. 

In [124]:
%%time
#with xr.open_dataset(subfns[0],chunks={'Y':200}) as ds:
with xr.open_dataset(subfns[0],chunks=chunks) as ds:
    geos = [x.bounds for x in masked_gdf['geometry']]
    dslist = []
    # len(geos) is 370
    for i in range(10):
        geo = geos[i]
        tds = ds.where( (ds.long>=geo[0]) & (ds.long<=geo[2]) & (ds.lat>=geo[1]) & (ds.lat<= geo[3]) )
        dslist.append(tds.mean(axis=(-1,-2)))
    nds = xr.concat(dslist,'mascon').compute()

CPU times: user 12.9 s, sys: 205 ms, total: 13.1 s
Wall time: 8.47 s


On the zarr store for the entire 16 years, about 6 min 23s is needed for 100 mascons. Compare to uisng the function save_agg_mascons_zarr, which takes 1 min for 100 mascons.

In [122]:
%%time
fs       = s3fs.S3FileSystem(anon=False)
ds_store = s3fs.S3Map(root=ZAR_path,s3=fs,check=True)
with xr.open_zarr(ds_store) as ds:
    geos = [x.bounds for x in masked_gdf['geometry']]
    dslist = []
    start_time = timeit.default_timer()
    for i in range(100):
        geo = geos[i]
        tds = ds.where( (ds.long>=geo[0]) & (ds.long<=geo[2]) & (ds.lat>=geo[1]) & (ds.lat<= geo[3]) )
        dslist.append(tds.mean(axis=(-1,-2)))
    nds = xr.concat(dslist,'mascon').compute()

  x = np.divide(x1, x2, out)


CPU times: user 20min 21s, sys: 9.21 s, total: 20min 30s
Wall time: 6min 23s
