In [1]:
import xarray as xr
import numpy as np
import xgcm
import gcsfs
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = (14,8)
%matplotlib inline

In [2]:
infile = '../raw_data/B.E.13.B1950C5.ne120_t12.cesm-ihesp-1950cntl.013.pop.h.nday1.0100-01-01.nc'
ds = xr.open_dataset(infile,chunks={'time':1})
ds

In [3]:
def dims_from_grid_loc(grid_loc):
    grid_loc = str(grid_loc)
    ndim = int(grid_loc[0])
    x_loc_key = int(grid_loc[1])
    y_loc_key = int(grid_loc[2])
    z_loc_key = int(grid_loc[3])
    
    x_loc = {1: 'nlon_t', 2: 'nlon_u'}[x_loc_key]
    y_loc = {1: 'nlat_t', 2: 'nlat_u'}[y_loc_key]
    z_loc = {0: 'surface', 1: 'z_t', 2: 'z_w', 3: 'z_w_bot', 4: 'z_t_150m'}[z_loc_key]
    
    if ndim == 3:
        return z_loc, y_loc, x_loc
    elif ndim == 2:
        return y_loc, x_loc

In [4]:
def label_coord_grid_locs(ds):
    grid_locs = {'ANGLE': '2220', 'ANGLET': '2110',
                 'DXT': '2110', 'DXU': '2220',
                 'DYT': '2110', 'DYU': '2220',
                 'HT': '2110', 'HU': '2220',
                 'HTE': '2210', 'HTN': '2120',
                 'HUS': '2210', 'HUW': '2120',
                 'KMT': '2110', 'KMU': '2220',
                 'REGION_MASK': '2110',
                 'TAREA': '2110', 'TLAT': '2110', 'TLONG': '2110',
                 'UAREA': '2220', 'ULAT': '2220', 'ULONG': '2220'}
    ds_new = ds.copy()
    for vname, grid_loc in grid_locs.items():
        ds_new[vname].attrs['grid_loc'] = grid_loc
    return ds_new

In [5]:
# create some actual dimension coordinates
def add_pop_dims_to_dataset(ds):
    ds_new = ds.copy()
    ds_new['nlon_u'] = xr.Variable(('nlon_u'), np.arange(len(ds.nlon)) + 1, {'axis': 'X', 'c_grid_axis_shift': 0.5})
    ds_new['nlat_u'] = xr.Variable(('nlat_u'), np.arange(len(ds.nlat)) + 1, {'axis': 'Y', 'c_grid_axis_shift': 0.5})
    ds_new['nlon_t'] = xr.Variable(('nlon_t'), np.arange(len(ds.nlon)) + 0.5, {'axis': 'X'})
    ds_new['nlat_t'] = xr.Variable(('nlat_t'), np.arange(len(ds.nlat)) + 0.5, {'axis': 'Y'})
    
    # add metadata to z grid
    ds_new['z_t'].attrs.update({'axis': 'Z'})
    ds_new['z_w'].attrs.update({'axis': 'Z', 'c_grid_axis_shift': -0.5})
    ds_new['z_w_top'].attrs.update({'axis': 'Z', 'c_grid_axis_shift': -0.5})
    ds_new['z_w_bot'].attrs.update({'axis': 'Z', 'c_grid_axis_shift': 0.5})
    
    return ds_new

In [6]:
def relabel_pop_dims(ds):
    ds_new = label_coord_grid_locs(ds)
    ds_new = add_pop_dims_to_dataset(ds_new)
    for vname in ds_new.variables:
        if 'grid_loc' in ds_new[vname].attrs:
            da = ds_new[vname]
            dims_orig = da.dims
            new_spatial_dims = dims_from_grid_loc(da.attrs['grid_loc'])
            if dims_orig[0] == 'time':
                dims = ('time',) + new_spatial_dims
            else:
                dims = new_spatial_dims
            ds_new[vname] = xr.Variable(dims, da.data, da.attrs, da.encoding, fastpath=True)
    return ds_new

In [7]:
ds_new = relabel_pop_dims(ds)
ds_new

In [8]:
grid = xgcm.Grid(ds_new, periodic=['X'])
grid

<xgcm.Grid>
Y Axis (not periodic):
  * center   nlat_t --> right
  * right    nlat_u --> center
Z Axis (not periodic):
  * center   z_t --> left
  * left     z_w_top --> center
  * right    z_w_bot --> center
X Axis (periodic):
  * center   nlon_t --> right
  * right    nlon_u --> center

In [9]:
# here we strip the coordinates out of the dataset
# this makes the calculation below work better
ds_coords = ds_new.coords.to_dataset().reset_coords().drop('time')
ds_raw = ds_new.reset_coords(drop=True)

In [10]:
ds_coords

In [11]:
assert 'DXU' in ds_coords.data_vars

AssertionError: 