In [None]:
%matplotlib inline
from ncar_jobqueue import NCARCluster
from dask.distributed import Client
from distributed.utils import format_bytes
import esmlab
import xarray as xr
import numpy as np
import dask

In [None]:
dask.config.set({'distributed.dashboard.link': '/proxy/{port}/status'})

In [None]:
cluster = NCARCluster(memory="100GB")
cluster.adapt(minimum=1, maximum=10,  wait_count=30) # Adaptively scale between 1 and 10 dask workers
client = Client(cluster) # Connect this local process to remote workers
cluster

In [None]:
import intake
from intake_esm import config

In [None]:
with config.set({'database-directory': '/glade/u/home/abanihi/.intake_esm/collections'}):
    col = intake.open_esm_metadatastore(collection_name='CESM1-LE')
    
col.df.head() 

In [None]:
keep_vars = ['O2', 'z_t', 'KMT', 'TAREA', 'TLONG', 'TLAT', 'time', 'time_bound', 'member_id']
yr0, yr1 = 2006, 2055
time_slice = slice(f'{yr0:04d}', f'{yr1:04d}')
ctrl_time_slice = slice(f'{yr0-1448:04d}', f'{yr1-1448:04d}')

In [None]:
def clean_ds(ds):
    ds.attrs = {}
    ds.time.attrs['units'] = 'days since 0000-01-01 00:00:00'
    ds.time.attrs['calendar'] = 'noleap'
    
    non_dim_coords = set(ds.coords) - set(ds.dims)
    if non_dim_coords:
        ds = ds.reset_coords(non_dim_coords)

    ds = ds.drop([v for v in ds.variables if v not in keep_vars])
    ds = ds.sel(z_t=200e2, method='nearest')
    return ds


def sel_time(ds, indexer_val, time_coord_name=None, year_offset=None):
    esmlabacc = ds.esmlab.set_time(time_coord_name=time_coord_name)
    time_coord_name = esmlabacc.time_coord_name
    dso = esmlabacc.compute_time_var(year_offset=year_offset).sel(**{time_coord_name: indexer_val})
    esmlabacc = dso.esmlab.set_time(time_coord_name=time_coord_name)
    return esmlabacc.uncompute_time_var()

In [None]:
%%time
dsets = col.search(experiment=['CTRL'], variable=['O2'])\
        .to_xarray(decode_times=False, chunks={'time': 240})
_, dd = dsets.popitem()
ctrl = clean_ds(dd)
ctrl = sel_time(ctrl, ctrl_time_slice)
ctrl = esmlab.resample(ctrl, freq='ann')
print(ctrl)

In [None]:
%%time
dsets = col.search(experiment=['20C', 'RCP85'], variable=['O2'])\
        .to_xarray(decode_times=False, chunks={'time': 240})
print(dsets.keys())
ds_20c = dsets['ocn.20C.pop.h']
ds_rcp85 = dsets['ocn.RCP85.pop.h']
dd = xr.concat([ds_20c, ds_rcp85], dim='time')
ds = clean_ds(dd)
ds = sel_time(ds, time_slice)
ds = esmlab.resample(ds, freq='ann')
print(ds)

In [None]:
format_bytes(ctrl.nbytes)

In [None]:
format_bytes(ds.nbytes)

In [None]:
def linear_trend(da, dim='time'):
    da_chunk = da.chunk({dim: -1})
    trend = xr.apply_ufunc(calc_slope, da_chunk,
                           vectorize=True,
                           input_core_dims=[[dim]],
                           output_core_dims=[[]],
                           output_dtypes=[np.float],
                           dask='parallelized')
    return trend
    

def calc_slope(y):
    """ufunc to be used by linear_trend"""
    x = np.arange(len(y))
    return np.polyfit(x, y, 1)[0]

In [None]:
%%time
ctrl_trend  = linear_trend(ctrl.O2.chunk({'nlat': 40, 'nlon': 40, 'time': -1})).load()
ctrl_trend = ctrl_trend * len(ctrl.time)
ctrl_trend.attrs['units'] = f'mmol m$^{-3}$ ({len(ctrl.time)} yr)$^{-1}$'
ctrl_trend.plot()

In [None]:
%%time 
ds_trend = linear_trend(ds.O2.chunk({'nlat': 40, 'nlon': 40, 'time': -1})).load()
ds_trend = ds_trend * len(ds.time)
ds_trend.attrs['units'] = f'mmol m$^{{-3}}$ ({len(ds.time)} yr)$^{{-1}}$'
ds_trend.isel(member_id=3).plot()

In [None]:
npac_trend = ds_trend.where((10 < ds.TLAT) & (ds.TLAT < 65) & (120 < ds.TLONG) & (ds.TLONG < 260))
npac_trend = esmlab.weighted_mean(npac_trend, dim=('nlat', 'nlon'), weights=ds.TAREA).load()

In [None]:
npac_trend

In [None]:
npac_trend.argmax()

In [None]:
npac_trend.max()

In [None]:
a = npac_trend.where(npac_trend==npac_trend.max(), drop=True).squeeze().member_id.values
a

In [None]:
ds_trend.sel(member_id=a)

In [None]:
ds_trend.where(ds_trend==npac_trend.max(), drop=True)

In [None]:
member_id_pick = [npac_trend.where(npac_trend==npac_trend.max(), drop=True).member_id.values.astype('int')[0],
                  npac_trend.where(npac_trend==npac_trend.min(), drop=True).member_id.values.astype('int')[0]]

In [None]:
member_id_pick

In [None]:
ds.TLAT.load()
ds.TLONG.load()

In [None]:
import cartopy

In [None]:
def pop_add_cyclic(ds):
    
    nj = ds.TLAT.shape[0]
    ni = ds.TLONG.shape[1]

    xL = int(ni/2 - 1)
    xR = int(xL + ni)

    tlon = ds.TLONG.data
    tlat = ds.TLAT.data
    
    tlon = np.where(np.greater_equal(tlon, np.min(tlon[:,0])), tlon-360., tlon)  
    lon  = np.concatenate((tlon, tlon + 360.), 1)
    lon = lon[:, xL:xR]

    if ni == 320:
        lon[367:-3, 0] = lon[367:-3, 0] + 360.        
    lon = lon - 360.
    
    lon = np.hstack((lon, lon[:, 0:1] + 360.))
    if ni == 320:
        lon[367:, -1] = lon[367:, -1] - 360.

    #-- trick cartopy into doing the right thing:
    #   it gets confused when the cyclic coords are identical
    lon[:, 0] = lon[:, 0] - 1e-8

    #-- periodicity
    lat = np.concatenate((tlat, tlat), 1)
    lat = lat[:, xL:xR]
    lat = np.hstack((lat, lat[:,0:1]))

    TLAT = xr.DataArray(lat, dims=('nlat', 'nlon'))
    TLONG = xr.DataArray(lon, dims=('nlat', 'nlon'))
    
    dso = xr.Dataset({'TLAT': TLAT, 'TLONG': TLONG})

    # copy vars
    varlist = [v for v in ds.data_vars if v not in ['TLAT', 'TLONG']]
    for v in varlist:
        v_dims = ds[v].dims
        if not ('nlat' in v_dims and 'nlon' in v_dims):
            dso[v] = ds[v]
        else:
            # determine and sort other dimensions
            other_dims = set(v_dims) - {'nlat', 'nlon'}
            other_dims = tuple([d for d in v_dims if d in other_dims])
            lon_dim = ds[v].dims.index('nlon')
            field = ds[v].data
            field = np.concatenate((field, field), lon_dim)
            field = field[..., :, xL:xR]
            field = np.concatenate((field, field[..., :, 0:1]), lon_dim)       
            dso[v] = xr.DataArray(field, dims=other_dims+('nlat', 'nlon'), 
                                  attrs=ds[v].attrs)


    # copy coords
    for v, da in ds.coords.items():
        if not ('nlat' in da.dims and 'nlon' in da.dims):
            dso = dso.assign_coords(**{v: da})
                
            
    return dso

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import cmocean
import cartopy.crs as ccrs

In [None]:
levels = np.arange(-100, 45, 5)
norm = colors.DivergingNorm(vmin=levels[0], vmax=levels[-1], vcenter=0.)

dsp = pop_add_cyclic(xr.Dataset({'trend': ds_trend, 'TLAT': ds.TLAT.isel(time=0), 'TLONG': ds.TLONG.isel(time=0)}))

forced = ds_trend.mean('member_id').compute()
internal = [(ds_trend.sel(member_id=i) - forced).compute() for i in member_id_pick]

extent = [120, 260, 10, 65]
prj = ccrs.Mercator(central_longitude=np.mean(extent[0:2]),
                    min_latitude=extent[2],
                    max_latitude=extent[3])

In [None]:
def one_plot(da, lines=True):
    # filled contours
    cf = ax.contourf(ds.TLONG.isel(time=0), ds.TLAT.isel(time=0), da, 
                     levels=levels, norm=norm, 
                     cmap=cmocean.cm.curl_r,
                     extend='both',
                     transform=ccrs.PlateCarree());

    # contour lines
    cs = ax.contour(ds.TLONG.isel(time=0), ds.TLAT.isel(time=0), da,
                    colors='k', levels=levels, linewidths=0.5,
                    transform=ccrs.PlateCarree())

    if lines:
        # add contour labels
        lb = plt.clabel(cs, fontsize=6, inline=True, fmt='%r');

    # land
    land = ax.add_feature(
        cartopy.feature.NaturalEarthFeature('physical', 'land', '110m', facecolor='lightgray'))

    ax.coastlines(linewidth=0.5)
    
    return cf

In [None]:
fig = plt.figure(figsize=(12, 6))

axs = []

# plot total
ax = fig.add_subplot(2, 3, 1, projection=prj)
ax.set_extent(extent)    
one_plot(ds_trend.sel(member_id=member_id_pick[0]))
ax.text(235., 60, f'{member_id_pick[0]:03d}', 
        transform=ccrs.PlateCarree())
ax.set_title('Total')
axs.append(ax)


ax = fig.add_subplot(2, 3, 4, projection=prj)
ax.set_extent(extent)    
one_plot(ds_trend.sel(member_id=member_id_pick[1]))
ax.text(235., 60, f'{member_id_pick[1]:03d}', 
        transform=ccrs.PlateCarree())
axs.append(ax)


# plot internal variability 
ax = fig.add_subplot(2, 3, 2, projection=prj)
ax.set_extent(extent)    
one_plot(internal[0])
ax.text(235., 60, f'{member_id_pick[0]:03d}', 
        transform=ccrs.PlateCarree())
ax.set_title('Internal')
axs.append(ax)


ax = fig.add_subplot(2, 3, 5, projection=prj)
ax.set_extent(extent)    
one_plot(internal[1])
ax.text(235., 60, f'{member_id_pick[1]:03d}', 
        transform=ccrs.PlateCarree())
axs.append(ax)


# plot forced
ax = fig.add_subplot(2, 3, 3, projection=prj)
ax.set_extent(extent)    
one_plot(forced)
ax.set_title('Forced')
axs.append(ax)


ax = fig.add_subplot(2, 3, 6, projection=prj)
ax.set_extent(extent)    
cf = one_plot(forced)
axs.append(ax)

# add colorbar
plt.subplots_adjust(hspace=0.02, wspace=0.02)

# colorbar and labels
cb = plt.colorbar(cf, shrink=0.5, orientation='horizontal', ax=axs, pad=0.075)
cb.ax.set_title(dsp.trend.units);

plt.savefig('trend-decomp-O2-200m-NPac.png', dpi=300, bbox_inches='tight')