In [None]:
%load_ext autoreload
%autoreload 2

import json
import os
import pprint
import random
import shutil
from functools import reduce, partial
from operator import mul
import yaml

import xarray as xr
import yaml
from distributed import Client
from distributed.utils import format_bytes
from tqdm.auto import tqdm
import pandas as pd
from collections import Counter

import dask
import intake
from ncar_jobqueue import NCARCluster
from helpers import (create_grid_dataset, enforce_chunking, get_grid_vars,
                     print_ds_info, process_variables, save_data, zarr_store, fix_time, inspect_written_stores)

#dask.config.set({"distributed.dashboard.link": "/proxy/{port}/status"})
xr.set_options(keep_attrs=True)
import numpy as np

In [None]:
cluster = NCARCluster(memory="10GB")

In [None]:
cluster.scale(20)

In [None]:
client = Client(cluster)
client

In [None]:
cluster

In [None]:
col = intake.open_esm_datastore(
    "/glade/work/mgrover/intake-esm-catalogs/new-cesm2-le.json",
)
col

In [None]:
cluster

In [None]:
dirout = "/glade/scratch/abanihi/lens2-aws"

In [None]:
def _preprocess(ds, variables):
    """Drop all unnecessary variables and coordinates"""

    vars_to_drop = [vname for vname in ds.data_vars if vname not in variables]
    coord_vars = [
        vname
        for vname in ds.data_vars
        if "time" not in ds[vname].dims or "bound" in vname or "bnds" in vname
    ]
    ds_fixed = ds.set_coords(coord_vars)
    data_vars_dims = []
    for data_var in ds_fixed.data_vars:
        data_vars_dims.extend(list(ds_fixed[data_var].dims))
    coords_to_drop = [
        coord for coord in ds_fixed.coords if coord not in data_vars_dims
    ]
    grid_vars = list(
        set(vars_to_drop + coords_to_drop)
        - set(["time", "time_bound", "time_bnds", "time_bounds"])
    )
    ds_fixed = ds_fixed.drop(grid_vars).reset_coords()
    if "history" in ds_fixed.attrs:
        del ds_fixed.attrs["history"]
    return ds_fixed

In [None]:
variables=['T']

In [None]:
with open("test_config.yaml") as f:
    config = yaml.safe_load(f)
    
print(config)

In [None]:
run_config = []
variables = []
for component, stream_val in config.items():
    for stream, v in stream_val.items():
        frequency = v["frequency"]
        freq = v["freq"]
        time_bounds_dim = v["time_bounds_dim"]
        variable_categories = list(v["variable_category"].keys())
        
        for v_cat in variable_categories:
            experiments = list(
                v["variable_category"][v_cat]["experiment"].keys()
            )
            for exp in experiments:
                chunks = v["variable_category"][v_cat]["experiment"][exp][
                        "chunks"
                    ]
                variable = v["variable_category"][v_cat]["variable"]
                variables.extend(variable)
                
                col_subset, query = process_variables(
                        col, variable, component, stream, exp
                    )
                
                if not col_subset.df.empty:
                        d = {
                            "query": query,
                            "col": col_subset,
                            "chunks": chunks,
                            "frequency": frequency,
                            "freq": freq,
                            "time_bounds_dim": time_bounds_dim,
                        }
                run_config.append(d)

In [None]:
run_config

In [None]:
def determine_chunk_size(ds):
    ntime = len(ds.time)       # the number of time slices
    chunksize_optimal = 100e6  # desired chunk size in bytes
    ncfile_size = ds.nbytes    # the netcdf file size
    chunksize = max(int(ntime* chunksize_optimal/ ncfile_size),1)

    target_chunks = ds.dims.mapping
    target_chunks['time'] = chunksize 
    
    return target_chunks # a dictionary giving the chunk sizes in each dimension

In [None]:
field_separator = '.'
for run in tqdm(run_config, desc="runs"):
    print("*" * 120)
    query = run["query"]
    print(f"query = {query}")
    frequency = run["frequency"]
    chunks = run["chunks"]
    cftime_freq = run["freq"]
    time_bounds_dim = run["time_bounds_dim"]
    
    #if query["experiment"] == "20C" and query["stream"] == "cice.h1":
    #    if query["component"] == "ice_sh":
    #        preprocess = _preprocess_ice_sh
    #    elif query["component"] == "ice_nh":
    #        preprocess = _preprocess_ice_nh
    #elif query["component"] == "lnd":
    #    preprocess = _preprocess_lnd
    #elif query["component"] == "atm":
    #    preprocess = _preprocess_atm
        
    #print(preprocess.__name__)
    
    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        dsets = run["col"].to_dataset_dict(
            cdf_kwargs={"chunks": chunks, "decode_times": True, "use_cftime": True},
            progressbar=True,
        )
    
    dsets = enforce_chunking(dsets, chunks, field_separator)
    
    for key, ds in tqdm(dsets.items(), desc="Saving zarr store"):
        ds = ds.sortby('time')
        ds = _preprocess(ds, query['variable'])
        chunks = determine_chunk_size(ds)
        print(ds.get_index("time").is_monotonic_increasing)
        key = key.split(field_separator)
        component = query['component']
        experiment = query['experiment']
        stream = query['stream']
        forcing_variant = key[-2]
        variable = key[-1]
        
        if frequency != "hourly6":
        
            if experiment == 'historical':
            
                start_time = "1850-01"
                end_time = "2015-01"
                ds = fix_time(
                            ds,
                            start=start_time,
                            end=end_time,
                            freq=cftime_freq,
                            time_bounds_dim=time_bounds_dim,
                        )
            
                store = zarr_store(experiment,
                                   component,
                                   frequency, 
                                   forcing_variant,
                                   variable,
                                   write=False,
                                   dirout=dirout
                                  )
            
                save_data(ds, store)
                
            elif experiment == 'ssp370':
                start_time = "2015-01"
                end_time = "2101-01"
                ds = fix_time(
                            ds,
                            start=start_time,
                            end=end_time,
                            freq=cftime_freq,
                            time_bounds_dim=time_bounds_dim,
                        )
            
                store = zarr_store(experiment,
                                   component,
                                   frequency, 
                                   forcing_variant,
                                   variable,
                                   write=False,
                                   dirout=dirout
                                  )
            
                save_data(ds, store)
                