In [None]:
%matplotlib inline

import intake
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import scipy.stats

# TODO: update zarr version in icesat2 jhub
!pip install zarr -U

# TODO: document this notebook
# References: https://journals.ametsoc.org/doi/full/10.1175/BAMS-D-13-00255.1

In [None]:
# TODO: Add this to the intake catalog?
obs = xr.open_dataset('https://www.esrl.noaa.gov/psd/thredds/dodsC/Datasets/cru/hadcrut4/air.mon.anom.median.nc')['air'].load()
obs

In [None]:
cat = intake.Catalog('https://raw.githubusercontent.com/NCAR/cesm-lens-aws/master/intake-catalogs/atmosphere/daily.yaml')

t_20c = cat['reference_height_temperature_20C'].to_dask()['TREFHT']
t_rcp = cat['reference_height_temperature_RCP85'].to_dask()['TREFHT']

In [None]:
t_ref = t_20c.sel(time=slice('1961', '1990'))
t_ref

In [None]:
# TODO: These should all be weighted by area
t_ref_ts = t_ref.resample(time='AS').mean('time').mean(('lat', 'lon', 'time', 'member_id'))
t_20c_ts = t_20c.resample(time='AS').mean('time').mean(('lat', 'lon'))
t_rcp_ts = t_rcp.resample(time='AS').mean('time').mean(('lat', 'lon'))

In [None]:
#TODO: Figure out why dask workers are dying when using the k8s cluster
# from dask_kubernetes import KubeCluster
from dask.distributed import Client

# cluster = KubeCluster()
# cluster.adapt(minimum=20, maximum=100)
client = Client(n_workers=8)
client

In [None]:
t_ref_mean = t_ref_ts.load()
t_ref_mean

In [None]:
t_20c_ts_df = t_20c_ts.to_series().unstack().T
t_20c_ts_df.head()

In [None]:
t_rcp_ts_df = t_rcp_ts.to_series().unstack().T
t_rcp_ts_df.head()

In [None]:
# TODO: weight by days in each month
obs_s = obs.resample(time='AS').mean('time').mean(('lat', 'lon')).to_series()
obs_s.head()

In [None]:

all_ts_anom = pd.concat([t_20c_ts_df, t_rcp_ts_df]) - t_ref_mean.data
years = [val.year for val in all_ts_anom.index]

In [None]:
# Figure 2
# TODO: confirm that after using area weighted average, max temp increase is 5k

ax = plt.axes()

ax.plot(years, all_ts_anom, color='grey')
ax.plot(years, all_ts_anom[1], color='black')
ax.plot(obs_s.index.year, obs_s, color='red')

ax.set_xticks([1850, 1920, 1950, 2000, 2050, 2100])
plt.ylim(-1, 5)
plt.xlim(1850, 2100)
plt.ylabel('Global Surface\nTemperature Anomaly (K)')

In [None]:
def linear_trend(da, dim='time'):
    da_chunk = da.chunk({dim: -1})
    trend = xr.apply_ufunc(calc_slope, da_chunk,
                           vectorize=True,
                           input_core_dims=[[dim]],
                           output_core_dims=[[]],
                           output_dtypes=[np.float],
                           dask='parallelized')
    return trend
    

def calc_slope(y):
    """ufunc to be used by linear_trend"""
    x = np.arange(len(y))
    return np.polyfit(x, y, 1)[0]

In [None]:
# TODO - this should probably include only full seasons from 1979 and 2012
seasons = t_20c.sel(time=slice('1979', '2012')).resample(time='QS-DEC').mean('time').load()

In [None]:
def is_dec(date):
    if date.month == 12:
        return True
    return False

In [None]:
winter_seasons = seasons.sel(time=[is_dec(date) for date in seasons.time.data])
winter_trends = linear_trend(winter_seasons.chunk({'lat': 20, 'lon': 20, 'time': -1})).load() * len(winter_seasons)

In [None]:
assert len(winter_seasons) == 34  # TODO: this should be 34 I think, its not!

In [None]:
import cartopy.crs as ccrs
levels = [-7, -6, -5, -4, -3, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 5, 6, 7]
fg = winter_trends.isel(member_id=slice(0, 20)).plot(col='member_id', col_wrap=4, transform=ccrs.PlateCarree(),
                                                    subplot_kws={'projection': ccrs.Robinson(central_longitude=180)},
                                                    add_colorbar=False, levels=levels, cmap='RdYlBu_r',
                                                    extend='neither')

for ax in fg.axes.flat:
    ax.coastlines(color='grey')
    
# TODO: move the subplot title to lower left corners
# TODO: Add obs panel and ensemble mean at the end
    
fg.add_colorbar(orientation='horizontal')
fg.cbar.set_label('1979-2012 DJF surface air temperature trends (K/34 years)')
fg.cbar.set_ticks(levels)
fg.cbar.set_ticklabels(levels)