In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import dask
import intake
import fsspec
import os
from collections import defaultdict
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
from xmip.utils import google_cmip_col
from xmip.preprocessing import combined_preprocessing,_drop_coords
from xmip.postprocessing import merge_variables, combine_datasets, concat_experiments,_concat_sorted_time

  from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!


In [2]:
def pr_units_to_m(ddict_in):
    ddict_out = ddict_in
    for k, v in ddict_in.items():
        
        assert v.pr.units == 'kg m-2 s-1'
        
        #convert 'kg m-2 s-1' to daily accumulated 'm'
        with xr.set_options(keep_attrs=True): 
            v['pr'] = 24*3600*v['pr']/1000 #multiply by number of seconds in a day to get to kg m-2, and divide by density (kg/m3) to get to m    
        v.pr.attrs['units'] = 'm'
        
        ddict_out[k] = v
        
    return ddict_out

def preselect_years(ddict_in,start_year,end_year):
    ddict_out = defaultdict(dict)
    
    if start_year>=end_year:
        raise Exception("Start year must come before end year.")
    
    if start_year>2014: #only using SSP
        for k, v in ddict_in.items():
            if 'ssp' in k:
                ddict_out[k] = v.sel(time=slice(str(start_year), str(end_year)))
                
    elif end_year<=2014: #only using historical
        for k, v in ddict_in.items():
            if 'historical' in k:
                ddict_out[k] = v.sel(time=slice(str(start_year), str(end_year)))
                
    elif ((start_year<=2014) & (end_year>2014)): #using both
        for k, v in ddict_in.items():
            if 'ssp' in k:
                ddict_out[k] = v.sel(time=slice(None, str(end_year)))
            elif 'historical' in k:
                ddict_out[k] = v.sel(time=slice(str(start_year), None))
    return ddict_out #NB: may result in no timesteps being selected at all

def drop_duplicate_timesteps(ddict_in):
    ddict_out = ddict_in
    for k, v in ddict_in.items():
        
        unique_time, idx = np.unique(v.time,return_index=True)
        
        if len(v.time) != len(unique_time):
            ddict_out[k] = v.isel(time=idx)
            print('Dropping duplicate timesteps for:' + k)
            
    return ddict_out

def drop_coords(ddict_in,coords_to_drop):
    
    for k, v in ddict_in.items():
        
        ddict_in[k] = v.drop_dims(coords_to_drop,errors="ignore")
          
    return ddict_in

def concat_realizations_most_common_ipf(ds_list):
    '''custom function that concatenates only the realizations of the most common 'ipf' combination,
    takes the first sorted 'ipf' if multiple 'ipf' are equally common'''
    member_ids = [ds.member_id.data[0] for ds in ds_list]
    
    member_ids.sort() #often i1 is the baseline?

    ipf_ids = [s[s.find('i'):] for s in member_ids] #separate 'ipf' from 'r'
    from collections import Counter

    most_common_ipf = Counter(ipf_ids).most_common()[0][0]

    # find unique members and decide which values of 'ipf' give the most members/variants?
    # pick only the matching datasets from the list
    ds_pick = [ds for ds in ds_list if (most_common_ipf in ds.member_id.data[0])]
    
    #not ideal way to ensure coordinates are the same, otherwise differences in coordinates are padded with nans if using {join='outer'}
    for idx,ds in enumerate(ds_pick):
        if idx==0:
            lat = ds['lat']
            lon = ds['lon']
        if ((ds['lat'].shape==lat.shape) & (ds['lon'].shape==lon.shape)):
            ds['lat'] = lat
            ds['lon'] = lon
        ds_pick[idx] = ds
        
    return xr.concat(ds_pick, dim='member_id', join='outer', coords='minimal',compat='override') #return xr.concat(ds_pick, dim='member_id')

In [56]:
my_models = ['BCC-CSM2-MR','CESM2','CESM2-WACCM','CMCC-ESM2','CMCC-CM2-SR5','EC-Earth3',
                'GFDL-CM4','GFDL-ESM4','HadGEM3-GC31-MM','MIROC6','MPI-ESM1-2-HR','MRI-ESM2-0',
                'NorESM2-MM','TaiESM1']

#my_models = ['MPI-ESM1-2-HR']
col = google_cmip_col()
experiment_id='ssp585'
source_id = my_models
kwargs = {
    'zarr_kwargs':{
        'consolidated':True,
        'use_cftime':True
    },
    'aggregate':False
}

cat_data = col.search(
    source_id=source_id,
    experiment_id=['historical','ssp585'],
    table_id='day',
    variable_id=['pr'],
    require_all_on=['source_id', 'member_id','grid_label']
)
ddict = cat_data.to_dataset_dict(**kwargs)
#ddict = cat_data.to_dataset_dict(**kwargs,preprocess=combined_preprocessing) # a lot of 'renaming failed' warnings here

  ddict = cat_data.to_dataset_dict(**kwargs)



--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.zstore.dcpp_init_year.version'


**NB: I don't seem to need any preprocessing. If I turn it on I get a lot of renaming failed warnings.**

In [57]:
ddict = pr_units_to_m(ddict)
ddict = drop_duplicate_timesteps(ddict) #CESM2-WACCM has duplicate timeseries
ddict = preselect_years(ddict,1850,2100)
ddict = drop_coords(ddict,['bnds'])

Dropping duplicate timesteps for:ScenarioMIP.NCAR.CESM2-WACCM.ssp585.r1i1p1f1.day.pr.gn.gs://cmip6/CMIP6/ScenarioMIP/NCAR/CESM2-WACCM/ssp585/r1i1p1f1/day/pr/gn/v20200702/.nan.20200702


In [58]:
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ddict_concat_mem = combine_datasets(
        ddict,
        concat_realizations_most_common_ipf,
        match_attrs=['source_id', 'grid_label', 'experiment_id', 'table_id']
    )
#NB: This leaves multiple datasets for the same model with different grid labels. Probably need function to keep grid label with most variants?
# Since this occurs only for a few models it is probably OK to just save these and remove them later?

**NB: Padding with NaNs above creates large time chunks for runs with missing timesteps. When loading this data into memory the kernel seems to crash. How to solve?**

**NB: `concat_realizations_most_common_ipf()` keeps multiple datasets for the same model with different grid labels. May be useful to filter these out too later, but it only occurs for a few models**

Combine the historical and ssp data if desired (**need custom combination if quering multiple SSPs at once?**):

In [59]:
ddict_concat = combine_datasets(ddict_concat_mem,
                        _concat_sorted_time,
                       match_attrs =['source_id', 'grid_label','table_id'] ) #appends SSP to historical

Do the subsetting at grid cells nearest to the tide gauges:

In [60]:
mlrcoefs = xr.open_dataset('/home/jovyan/CMIP6cf/gssr_coefs_1degRes_forcing.nc') #contains coordinates of and MLR coefficients at TGs

ddict_near_tgs = defaultdict(dict)

for key,ds in tqdm(ddict_concat.items()):
    
    ds = ds.isel(dcpp_init_year=0,drop=True)
    
    #change longitude coordinates to -180 -> 180 (avoids getting NaNs at the 0-meridian)
    lon_name = list(k for k in ds.dims if 'lon' in k)[0] #find lon/lat coordinate names
    lat_name = list(k for k in ds.dims if 'lat' in k)[0]
    
    '''
    #I think this is probably not necessary, Haversine formula should work regardlessly?
    ds.coords[lon_name] = ((ds.coords[lon_name] + 180) % 360) - 180 #wrap around 0
    ds = ds.reindex({ lon_name : np.sort(ds[lon_name])})
    '''
    
    distances = 2*np.arcsin( np.sqrt(
        np.sin( (np.pi/180) * 0.5*(ds[lat_name]-mlrcoefs.tg.lat) )**2 +
        np.cos((np.pi/180)*mlrcoefs.tg.lat)*np.cos((np.pi/180)*ds[lat_name])*np.sin((np.pi/180)*0.5*(ds[lon_name]-mlrcoefs.tg.lon))**2) )

    idx_nearest = distances.argmin(dim=[lon_name,lat_name])
    ds_nearest = ds[idx_nearest]
    
    ds_nearest = ds_nearest.rename_vars({'lon':'gridcell_lon','lat':'gridcell_lat'}) #store nearest grid cell coordinates for later
    ds_nearest = ds_nearest.assign_coords(lon=mlrcoefs.lon,lat=mlrcoefs.lat) #replace with TG coordinates
    
    ddict_near_tgs[key] = ds_nearest
    #ds_nearest.load()
    #store at this point?

  0%|          | 0/14 [00:00<?, ?it/s]