In [20]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import cartopy
import dask
import intake
import fsspec
import eofs
from eofs.xarray import Eof
from dask.distributed import Client
from sklearn.decomposition import PCA

In [21]:
### determine grid cells around TGs to index
mlrcoefs = xr.open_dataset('gssr_coefs_1degRes_forcing.nc')
num_degrees = 2

era5_grid = xr.Dataset(
        {
            "longitude": (["longitude"], np.arange(-40,30,1)+1/2, {"units": "degrees_east"}),
            "latitude": (["latitude"], np.arange(70,10,-1)-1/2, {"units": "degrees_north"}),
        }
    ) #grid over ERA5 domain downloaded

lat_ranges = np.zeros((len(mlrcoefs.tg),2))
lon_ranges = np.zeros((len(mlrcoefs.tg),2))
for f,file in enumerate(mlrcoefs.tg.values):
    lat_ranges[f,:] = era5_grid.latitude[((era5_grid.latitude>=(mlrcoefs.sel(tg=file).lat-num_degrees/2)) & (era5_grid.latitude<=(mlrcoefs.sel(tg=file).lat+num_degrees/2)))][0:2]
    lon_ranges[f,:] = era5_grid.longitude[((era5_grid.longitude>=(mlrcoefs.sel(tg=file).lon-num_degrees/2)) & (era5_grid.longitude<=(mlrcoefs.sel(tg=file).lon+num_degrees/2)))][0:2]

#index wind and pressure at gridded coordinates around tgs    
lons_da = xr.DataArray(lon_ranges,dims=['tg','lon_around_tg'],coords={'tg':mlrcoefs.tg,'lon_around_tg':[0,1]})
lats_da = xr.DataArray(lat_ranges,dims=['tg','lat_around_tg'],coords={'tg':mlrcoefs.tg,'lat_around_tg':[0,1]})

###

df = pd.read_csv('https://storage.googleapis.com/cmip6/cmip6-zarr-consolidated-stores.csv')
df.head()
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")

query = dict(
    experiment_id='historical', # pick the `abrupt-4xCO2` and `piControl` forcing experiments
    table_id='day',                            # choose to look at atmospheric variables (A) saved at monthly resolution (mon)
    variable_id='sfcWind',
    source_id=['BCC-CSM2-MR','CESM2','CESM2-WACCM','CMCC-ESM2','CMCC-CM2-SR5','EC-Earth3','GFDL-CM4','GFDL-ESM4','HadGEM3-GC31-MM','MIROC6','MPI-ESM1-2-HR','MRI-ESM2-0','NorESM2-MM','TaiESM1']
)

query = dict(
    experiment_id='historical', # pick the `abrupt-4xCO2` and `piControl` forcing experiments
    table_id='day',                            # choose to look at atmospheric variables (A) saved at monthly resolution (mon)
    variable_id=['sfcWind','psl'],  # choose to look at near-surface air temperature (tas) as our variable
    member_id = 'r1i1p1f1',                     # arbitrarily pick one realization for each model (i.e. just one set of initial conditions)
    source_id='MPI-ESM1-2-HR'
)

In [22]:
col_subset = col.search(**query)
test=col_subset.df
test

Unnamed: 0,activity_id,institution_id,source_id,experiment_id,member_id,table_id,variable_id,grid_label,zstore,dcpp_init_year,version
0,CMIP,MPI-M,MPI-ESM1-2-HR,historical,r1i1p1f1,day,psl,gn,gs://cmip6/CMIP6/CMIP/MPI-M/MPI-ESM1-2-HR/hist...,,20190710
1,CMIP,MPI-M,MPI-ESM1-2-HR,historical,r1i1p1f1,day,sfcWind,gn,gs://cmip6/CMIP6/CMIP/MPI-M/MPI-ESM1-2-HR/hist...,,20190710


In [23]:
def drop_all_bounds(ds):
    """Drop coordinates like 'time_bounds' from datasets,
    which can lead to issues when merging."""
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop(drop_vars)

def open_dsets(df):
    """Open datasets from cloud storage and return xarray dataset."""
    dsets = [xr.open_zarr(fsspec.get_mapper(ds_url), consolidated=True)
             .pipe(drop_all_bounds)
             for ds_url in df.zstore]
    try:
        ds = xr.merge(dsets, join='exact')
        return ds
    except ValueError:
        return None

def open_delayed(df):
    """A dask.delayed wrapper around `open_dsets`.
    Allows us to open many datasets in parallel."""
    return dask.delayed(open_dsets)(df)

def regrid_to_era5(ds,era5_grid):
    
    regridder = xe.Regridder(ds,era5_grid,'bilinear')
    
    return regridder(ds)


In [24]:
from collections import defaultdict

dsets = defaultdict(dict)
for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
    dsets[group[0]][group[1]] = open_delayed(df)
    
forcing = open_dsets(df)
forcing

Unnamed: 0,Array,Chunk
Bytes,16.55 GiB,115.88 MiB
Shape,"(60265, 192, 384)","(412, 192, 384)"
Dask graph,147 chunks in 2 graph layers,147 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 16.55 GiB 115.88 MiB Shape (60265, 192, 384) (412, 192, 384) Dask graph 147 chunks in 2 graph layers Data type float32 numpy.ndarray",384  192  60265,

Unnamed: 0,Array,Chunk
Bytes,16.55 GiB,115.88 MiB
Shape,"(60265, 192, 384)","(412, 192, 384)"
Dask graph,147 chunks in 2 graph layers,147 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,16.55 GiB,87.75 MiB
Shape,"(60265, 192, 384)","(312, 192, 384)"
Dask graph,194 chunks in 2 graph layers,194 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 16.55 GiB 87.75 MiB Shape (60265, 192, 384) (312, 192, 384) Dask graph 194 chunks in 2 graph layers Data type float32 numpy.ndarray",384  192  60265,

Unnamed: 0,Array,Chunk
Bytes,16.55 GiB,87.75 MiB
Shape,"(60265, 192, 384)","(312, 192, 384)"
Dask graph,194 chunks in 2 graph layers,194 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [25]:
lon_coord = list(k for k in forcing['sfcWind'].dims if 'lon' in k)[0] #find lon/lat coordinate names
            
forcing.coords[lon_coord] = ((forcing.coords[lon_coord] + 180) % 360) - 180 #wrap around 0
forcing = forcing.reindex({ lon_coord : np.sort(forcing[lon_coord])})

regridded_forcing = regrid_to_era5(forcing,era5_grid)
forcing_around_tgs = regridded_forcing.sel(latitude=lats_da,longitude=lons_da)


#normalized_forcing = normalized_forcing.stack(coord=['lon_around_tg','lat_around_tg'])