# Create subchunked refs from original refs

In [None]:
import fsspec
import xarray as xr
import os
import ujson
from kerchunk.utils import subchunk, inline_array
from pathlib import Path

In [None]:
from dotenv import load_dotenv

In [None]:
load_dotenv('/shared/users/nebari-setup/chen_keys.env') 

In [None]:
%%time
import os
cluster_type = 'Gateway'

if cluster_type == 'Gateway':
    from dask_gateway import Gateway

    gateway = Gateway()  # instantiate Dask gateway 

    # Cluster options on Nebari 
    options = gateway.cluster_options()
    options.conda_environment='global/global-pangeo'  # comment out for Daskhub or Planetary Computer
    options.profile = 'Small Worker'   # comment out for Daskhub or Planetary Computer
    options.environment_vars = {'AWS_ACCESS_KEY_ID':os.environ['AWS_ACCESS_KEY_ID'],
                                'AWS_SECRET_ACCESS_KEY':os.environ['AWS_SECRET_ACCESS_KEY']}
    # Create a Dask Gateway cluster
    cluster = gateway.new_cluster(options)

    # Get the Dask client for the Dask Gateway cluster
    client = cluster.get_client()

    # Scale the cluster
    cluster.adapt(minimum=4, maximum=30)

In [None]:
so = dict(anon=False, skip_instance_cache=True, use_listings_cache=False)

In [None]:
fs = fsspec.filesystem('s3', **so)

In [None]:
json_dir = 's3://umassd-fvcom/gom3/hindcast/individual_jsons'

In [None]:
ref_list = fs.glob(f'{json_dir}/*.json')
print(len(ref_list))
print(ref_list[0])
print(ref_list[-1])

In [None]:
# d0 = json.load(fs.open(ref_list[0]))

In [None]:
#ds = xr.open_dataset(d0, engine="kerchunk", chunks={},
#            drop_variables= ['Itime', 'Itime2', 'Times', 'file_date', 'iint', 'nprocs'],
#            storage_options=dict(remote_protocol='s3', remote_options=so))

In [None]:
#siglev_vars = []
#for v in ds.variables.keys():
#    if 'siglev' in ds[v].dims:
#        siglev_vars.append(v)

In [None]:
#siglay_vars = []
#for v in ds.variables.keys():
#    if 'siglay' in ds[v].dims:
#        siglay_vars.append(v)

In [None]:
#nlev = len(ds['siglev'])
#nlay = len(ds['siglay'])

In [None]:
nlev = 46
nlay = 45
siglev_vars = ['kh', 'km', 'kq', 'l', 'omega', 'q2', 'q2l', 'siglev']
siglay_vars = ['salinity', 'siglay', 'temp', 'u', 'v', 'ww']

In [None]:
def subchunk_ref(ref):
    d0 = ujson.load(fs.open(ref))
    for v in siglev_vars:
        d0 = subchunk(store=d0, variable=v, factor=nlev)
    for v in siglay_vars:
        d0 = subchunk(store=d0, variable=v, factor=nlay)
    outf = f's3://umassd-fvcom/gom3/hindcast/subchunk_jsons/{Path(ref).stem}.json'
    with fs.open(outf, 'wb') as f:
        f.write(ujson.dumps(d0).encode());

In [None]:
%%time
import dask
_ = dask.compute(*[dask.delayed(subchunk_ref)(f) for f in ref_list], retries=10)

In [None]:
%%time
subchunk_ref(ref_list[0])

In [None]:
ds = xr.open_dataset(d0, engine="kerchunk", chunks={},
            drop_variables= ['Itime', 'Itime2', 'Times', 'file_date', 'iint', 'nprocs'],
            storage_options=dict(remote_protocol='s3', remote_options=so))

In [None]:
%%time
da = ds['temp'][-100:,-1,:].load()

In [None]:
client.close()

In [None]:
cluster.shutdown()