In [1]:
'''script to regrid CMIP6 datatsets to target grid and store them'''

import numpy as np
import xarray as xr
import dask
from dask_gateway import Gateway
from distributed import worker_client, as_completed
import intake
import pandas as pd
import os
from collections import defaultdict
from tqdm.autonotebook import tqdm
from xmip.utils import google_cmip_col
from xmip.postprocessing import combine_datasets, _match_datasets,_concat_sorted_time
from cmip_catalogue_operations import reduce_cat_to_max_num_realizations, drop_vars_from_cat, drop_older_versions
from cmip_ds_dict_operations import select_period, drop_duplicate_timesteps, drop_coords, drop_incomplete
import xesmf as xe
from typing import Dict
import gcsfs
fs = gcsfs.GCSFileSystem() #list stores, stripp zarr from filename, load 

  from tqdm.autonotebook import tqdm


In [2]:
def get_lonlat_idx_nearest_to_tgs(tg_ds,ds):
    '''tg_ds = xr.DataSet containing 'lon' and 'lat' coordinates of tide gauges
    ds    = xr.DataSet containing CMIP6 data to subset
    '''
    lon_name = list(k for k in ds.dims if 'lon' in k)[0] #find lon/lat coordinate names
    lat_name = list(k for k in ds.dims if 'lat' in k)[0]
    
    #compute distances between TG coordinates and grid cell centers
    distances = 2*np.arcsin( np.sqrt(
        np.sin( (np.pi/180) * 0.5*(ds[lat_name]-tg_ds.lat) )**2 +
        np.cos((np.pi/180)*tg_ds.lat)*np.cos((np.pi/180)*ds[lat_name])*np.sin((np.pi/180)*0.5*(ds[lon_name]-tg_ds.lon))**2) )
    
    idx_nearest = distances.argmin(dim=[lon_name,lat_name]) #find indices of nearest grid cells
    return idx_nearest

In [3]:
# create regridders per source_id (datasets of source_id should be on the same grid after passing through reduce_cat_to_max_num_realizations
def create_regridder_dict(dataset_dict: Dict[str, xr.Dataset], target_grid_ds: xr.Dataset) -> Dict[str, xe.Regridder]:
    regridders = {}
    source_ids = np.unique([ds.attrs['source_id'] for ds in dataset_dict.values()])
    for si in tqdm(source_ids):
        matching_keys = [k for k in dataset_dict.keys() if si in k]
        # take the first one (we don't really care here which one we use)
        ds = dataset_dict[matching_keys[0]]
        regridder = xe.Regridder(ds,target_grid_ds,'bilinear',ignore_degenerate=True,periodic=True) #create regridder for this source_id
        regridders[si] = regridder
    return regridders

def create_regridder_dict_per_tg(dataset_dict,tg_coords):
    regridders = {}
    
    source_ids = np.unique([ds.attrs['source_id'] for ds in dataset_dict.values()])
    for si in tqdm(source_ids):
        source_regridders = {}
        matching_keys = [k for k in dataset_dict.keys() if si in k]
        # take the first one (we don't really care here which one we use)
        ds = dataset_dict[matching_keys[0]]
        
        for tg in range(len(tg_coords.tg)):
            single_tg = tg_coords.isel(tg=[tg])
            regridder = xe.Regridder(ds,single_tg, method='nearest_s2d',periodic=True,ignore_degenerate=True)
            source_regridders[str(single_tg.tg.values)] = regridder
        regridders[si] = source_regridders
        
    return regridders


def create_nearest_idx_dict(dataset_dict,tg_coords):
    indices = {}
    
    source_ids = np.unique([ds.attrs['source_id'] for ds in dataset_dict.values()])
    for si in tqdm(source_ids):
        matching_keys = [k for k in dataset_dict.keys() if si in k]
        # take the first one (we don't really care here which one we use)
        ds = dataset_dict[matching_keys[0]]
        indices[si] = get_lonlat_idx_nearest_to_tgs(tg_coords,ds)
        
    return indices
    
def make_filepath(output_path, ds):
    key = ds.attrs["ddict_key"]
    variable = ds.attrs['variable_id']
    return os.path.join(output_path,variable,ds.source_id,key.split('.gs')[0])

def pr_flux_to_m(ds):
    if 'pr' in ds.variables: #if dataset contains precipitation
        if ds.pr.units == 'kg m-2 s-1': #if precipitation units are flux 'kg m-2 s-1', convert to daily accumulated 'm'
            with xr.set_options(keep_attrs=True):
                ds['pr'] = 24*3600*ds['pr']/1000 #multiply by number of seconds in a day to get to kg m-2, and divide by density (kg/m3) to get to m    
                ds.pr.attrs['units'] = 'm'#convert pr units to m if
        elif ds.pr.units == 'm':
            ds = ds
        else:
            raise('Variable pr has unrecognized unit.')
    return ds 

def drop_duplicate_timesteps_from_ds(ds):
    #select only unique timesteps
    unique_time, idx = np.unique(ds.time,return_index=True)

    if len(ds.time) != len(unique_time):
        ds = ds.isel(time=idx)
    return ds

In [4]:
#configure settings
# output_path = 'gs://leap-persistent/timh37/CMIP6/timeseries_eu_1p5/'
output_path = 'gs://leap-persistent/timh37/CMIP6/datasets_eu_gesla2_tgs/'
overwrite_existing = True #whether or not to process files for which output already exists (to-do: implement)

tg_coords = xr.open_dataset('/home/jovyan/CMIP6cex/cmip6_processing/gssr_mlr_coefs_1p5_9deg_gesla2.nc')

query_vars = ['sfcWind','pr'] #variables to process
required_vars = ['sfcWind','pr','psl'] #variables that includes models should provide

ssps = ['ssp245','ssp585']

In [5]:
#query simulations & manipulate data catalogue:
col = google_cmip_col() #google cloud catalog
qc_col = intake.open_esm_datastore("https://storage.googleapis.com/leap-persistent-ro/data-library/catalogs/cmip6-test/leap-pangeo-cmip6-test.json") #temporary pangeo-leap-forge catalogue
noqc_col = intake.open_esm_datastore("https://storage.googleapis.com/leap-persistent-ro/data-library/catalogs/cmip6-test/leap-pangeo-cmip6-noqc-test.json")

col_df = col.df
qc_df = qc_col.df
nonqc_df = noqc_col.df

#assign priority for keeping duplicate datasets in these catalogs
col_df['prio'] = 2
qc_df['prio'] = 3
nonqc_df['prio'] = 1

col.esmcat._df = pd.concat([col_df,qc_df,nonqc_df],ignore_index=True) #merge these catalogs
ssp_cats = defaultdict(dict)

#search catalog per ssp (need to do this for each SSP separately as availability may differ between them)
for s,ssp in enumerate(ssps):
    ssp_cat = col.search( #find instances providing all required query_vars for both historical & ssp experiments
    experiment_id=['historical',ssp],
    table_id='day',
    variable_id=required_vars,
    require_all_on=['source_id', 'member_id','grid_label'])
    ssp_cats[ssp] = ssp_cat
    
ssp_cats_merged = ssp_cats[ssp] #merge catalogues for all ssps, and drop duplicate historical simulations
ssp_cats_merged.esmcat._df = pd.concat([v.df for k,v in ssp_cats.items()],ignore_index=True).drop_duplicates(ignore_index=True)
ssp_cats_merged = drop_older_versions(ssp_cats_merged) #if google cloud and leap-pangeo catalogs provide duplicate datasets, keep the newest version, and if the versions are identical, keep the highest priority catalog
ssp_cats_merged = reduce_cat_to_max_num_realizations(ssp_cats_merged) #per model, select grid and 'ipf' combination providing most realizations (needs to be applied to both SSPs together to ensure the same variants are used under both scenarios)

In [6]:
#open datasets in dictionary
ssp_cats_merged.esmcat.aggregation_control.groupby_attrs = []
ddict_all = ssp_cats_merged.to_dataset_dict(zarr_kwargs={'use_cftime':True},aggregate=True) # single stores (Perhaps we dont need some of them, but at this point we do not really care)


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.zstore.dcpp_init_year.version.prio'


In [7]:
#create xesmf regridders for each unique CMIP6 model
regridder_dict = create_nearest_idx_dict(ddict_all, tg_coords)
#regridder_dict = create_regridder_dict_per_tg(ddict_all,tg_coords)

  0%|          | 0/28 [00:00<?, ?it/s]

In [8]:
## start new dask cluster
gateway = Gateway()

# close existing clusters (be careful if you have multiple clusters/servers open!)
open_clusters = gateway.list_clusters()
print(list(open_clusters))
if len(open_clusters)>0:
    for c in open_clusters:
        cluster = gateway.connect(c.name)
        cluster.shutdown()  


options = gateway.cluster_options()
options.worker_memory = 18
# options.worker_cores = 12

# Create a cluster with those options
cluster = gateway.new_cluster(options)
client = cluster.get_client()
cluster.adapt(20, 100)
client

[]


0,1
Connection method: Cluster object,Cluster type: dask_gateway.GatewayCluster
Dashboard: /services/dask-gateway/clusters/prod.e7faf44369df4dda994b6fa6e9ba7f59/status,


Regridding step (using xesmf):

regridded_datasets = []
for key,ds in tqdm(ddict_all.items()):
    ds.attrs["ddict_key"] = key #add current key information to attributes
    #output_fn = make_filepath(output_path, ds)
    
    ds = ds.isel(dcpp_init_year=0,drop=True) #remove this coordinate

    regridders = regridder_dict[ds.attrs['source_id']] #select regridder for this source_id
      
    datasets=[]
    for tg in range(len(tg_coords.tg)):
        single_tg = tg_coords.isel(tg=[tg])
        regridder = regridders[str(single_tg.tg.values)]
        datasets.append(regridder(ds,keep_attrs=True).squeeze())
      
    ds_at_tgs = xr.concat(datasets, dim='tg')
    ds_at_tgs['tg']=tg_coords.tg.values

    
    
    
    #regridded_ds = regridder(ds, keep_attrs=True) #do the regridding
    
    
    #idx = get_lonlat_idx_nearest_to_tgs(tg_coords,ds)
    
    #lon_name = list(k for k in ds.dims if 'lon' in k)[0]
    #lat_name = list(k for k in ds.dims if 'lat' in k)[0]
    
    #regridded_ds = ds.isel({lat_name:idx[lat_name],lon_name:idx[lon_name]})
    
    regridded_datasets.append(ds_at_tgs.unify_chunks().chunk({'time':40000,'tg':1000})) #append to list

Regridding step (using fancy indexing):

In [9]:
regridded_datasets = []
for key,ds in tqdm(ddict_all.items()):
    ds.attrs["ddict_key"] = key #add current key information to attributes
    output_fn = make_filepath(output_path, ds)
    
    ds = ds.isel(dcpp_init_year=0,drop=True) #remove this coordinate

    #regridders = regridder_dict[ds.attrs['source_id']] #select regridder for this source_id
    idx = regridder_dict[ds.attrs['source_id']]
   
    lon_name = list(k for k in ds.dims if 'lon' in k)[0]
    lat_name = list(k for k in ds.dims if 'lat' in k)[0]
    
    regridded_ds = ds.isel({lat_name:idx[lat_name],lon_name:idx[lon_name]})
    
    regridded_datasets.append(regridded_ds.unify_chunks().chunk({'time':40000,'tg':1000})) #append to list

  0%|          | 0/2373 [00:00<?, ?it/s]

Save regridder datasets to zarr in parallel in batches using dask cluster:

In [None]:
# following https://stackoverflow.com/questions/66769922/concurrently-write-xarray-datasets-to-zarr-how-to-efficiently-scale-with-dask
def write_wrapper(ds, overwrite=overwrite_existing, fs=None): #wrapper around to_zarr to submit to dask distributed cluster
    target = make_filepath(output_path, ds)
    with worker_client() as client:
        try:
            if overwrite or not fs.exists(target): # only write if store doesnt exist or overwrite is true
                ds = ds.drop_dims(['bnds','nbnd','height'],errors="ignore") #drop some auxiliary coordinates
                ds = pr_flux_to_m(ds) #convert pr units to m if pr in ds
                ds = drop_duplicate_timesteps_from_ds(ds)          
                
                ds = ds.expand_dims('member_id')
                ds.to_zarr(store=target, mode='w') #store
                return target, 'written freshly'
            else:
                return target, 'already written, skipped'
        except Exception as e:
            return target, f"Failed with: {e}"

# There is some more advanced way of doing this with the `as_completed` iterator, to achieve a 'steady' supply of submissions to the client. # (see answers in https://stackoverflow.com/questions/66769922/concurrently-write-xarray-datasets-to-zarr-how-to-efficiently-scale-with-dask), 
# but for our intents and purposes, we can just submit medium sized batches here:
# This seems to scale ok (there is still downtime between the batch submissions). For comparison, just using a big cluster and looping over the datasets achieved ~3x speed up (not bad), 
# but here we are looking at 10+x

interval = 10 # this seems to work fine, except a few warnings about a large graph... You could play with this, but higher numbers seemed to 
# make the scheduler quite unstable...
regridded_datasets_batches = [regridded_datasets[a:a+interval] for a in range(0,len(regridded_datasets), interval)]

written_stores = []

for ds_batch in tqdm(regridded_datasets_batches):
    # futures = [client.submit(write_wrapper, ds) for ds in ds_batch]
    futures = client.map(write_wrapper, ds_batch, overwrite=False, fs=fs)
    for future, result in as_completed(futures, with_results=True):
        written_stores.append(result)
        future.release()
    # do we need to deal with failed futures?
    # explicitly delete futures to ease pressure on client (JB: I do not 100% understand how this works TBH).
    del futures

  0%|          | 0/238 [00:00<?, ?it/s]

This may cause some slowdown.
Consider scattering data ahead of time and using futures.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
Exception in callback None()
handle: <Handle cancelled>
Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/tornado/iostream.py", line 1367, in _do_ssl_handshake
    self.socket.do_handshake()
  File "/srv/conda/envs/notebook/lib/python3.10/ssl.py", line 1342, in do_handshake
    self._sslobj.do_handshake()
ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate (_ssl.c:1007)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 192,