# GMST of CESM runs

In [None]:
import os
import sys
sys.path.append("..")
import numpy as np
import scipy.stats as stats
import xarray as xr
import seaborn as sns
import cartopy
import cartopy.crs as ccrs
import datetime
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%load_ext autoreload
%autoreload 2

In [None]:
from maps import map_robinson, map_eq_earth
from GMST import GMST_timeseries, GMST_regression, atm_heat_content, GMST_GISTEMP
from paths import path_results, path_samoc, path_data
# from analysis import TimeSeriesAnalysis
from plotting import shifted_color_map, discrete_cmap
from constants import abs_zero, cp_air
from timeseries import IterateOutputCESM
from xr_integrate import xr_surf_mean, xr_zonal_mean
from xr_DataArrays import xr_AREA

In [None]:
gmst_ctrl.GMST[:-1].plot()

In [None]:
test = gmst_ctrl.GMST.isel({'time':slice(50,300)})

In [None]:
test.time/365

In [None]:
plt.plot(test.time/365, test)

In [None]:
%%time
# # ca. 30 sec per run
gmst_ctrl = GMST_timeseries('ctrl')
# gmst_rcp  = GMST_timeseries('rcp')
# gmst_lpd  = GMST_timeseries('lpd')
# gmst_lpd  = GMST_timeseries('lpi')

In [None]:
gmst_wt_ctrl = xr.open_dataset(f'{path_samoc}/GMST/GMST_with_trends_yrly_ctrl.nc')
gmst_wt_rcp  = xr.open_dataset(f'{path_samoc}/GMST/GMST_with_trends_yrly_rcp.nc' )
gmst_wt_lpd  = xr.open_dataset(f'{path_samoc}/GMST/GMST_with_trends_yrly_lpd.nc' )
gmst_wt_lpi  = xr.open_dataset(f'{path_samoc}/GMST/GMST_with_trends_yrly_lpi.nc' )

In [None]:
gmst_wt_ctrl.GMST.plot()
(gmst_wt_lpd.GMST-.5).plot()


In [None]:
from xr_regression import xr_quadtrend
from  filters import lowpass

In [None]:
(gmst_wt_ctrl.GMST-xr_quadtrend(gmst_wt_ctrl.GMST)).plot()
(gmst_wt_lpd.GMST[:200]-xr_quadtrend(gmst_wt_lpd.GMST[:200])-.5).plot()


In [None]:
lowpass((gmst_wt_ctrl.GMST-xr_quadtrend(gmst_wt_ctrl.GMST)), 10).plot()
lowpass((gmst_wt_lpd.GMST[:200]-xr_quadtrend(gmst_wt_lpd.GMST[:200])), 10).plot()


In [None]:
# gmst_wt_ctrl.GMST.to_netcdf(f'{path_samoc}/GMST/GMST_yrly_ctrl.nc')
# gmst_wt_rcp .GMST.to_netcdf(f'{path_samoc}/GMST/GMST_yrly_rcp.nc' )
# gmst_wt_lpd .GMST.to_netcdf(f'{path_samoc}/GMST/GMST_yrly_lpd.nc' )
# gmst_wt_lpi .GMST.to_netcdf(f'{path_samoc}/GMST/GMST_yrly_lpi.nc' )

In [None]:
gmst_ctrl = xr.open_dataarray(f'{path_samoc}/GMST/GMST_yrly_ctrl.nc')
gmst_rcp  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_yrly_rcp.nc' )
gmst_lpd  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_yrly_lpd.nc' )
gmst_lpi  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_yrly_lpi.nc' )
gmst_had  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_yrly_had.nc' , decode_times=False)

In [None]:
(gmst_ctrl-gmst_wt_ctrl.quad_fit).to_netcdf(f'{path_samoc}/GMST/GMST_dt_yrly_ctrl.nc')
(gmst_rcp -gmst_wt_rcp .quad_fit).to_netcdf(f'{path_samoc}/GMST/GMST_dt_yrly_rcp.nc' )
(gmst_lpd -gmst_wt_lpd .lin_fit ).to_netcdf(f'{path_samoc}/GMST/GMST_dt_yrly_lpd.nc' )
(gmst_lpi -gmst_wt_lpi .lin_fit ).to_netcdf(f'{path_samoc}/GMST/GMST_dt_yrly_lpi.nc' )

In [None]:
gmst_dt_ctrl = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_ctrl.nc')
gmst_dt_rcp  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_rcp.nc' )
gmst_dt_lpd  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_lpd.nc' )
gmst_dt_lpi  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_lpi.nc' )
gmst_dt_had  = xr.open_dataarray(f'{path_samoc}/GMST/GMST_dt_yrly_had.nc' )

In [None]:
fig = plt.figure(figsize=(8,5))
plt.tick_params(labelsize=14)
plt.plot(gmst_ctrl.time/365    , gmst_ctrl.values, lw=2, label='CTRL')
plt.plot(gmst_rcp.time/365-1800, gmst_rcp.values,  lw=2, label='RCP')
plt.xlabel('time [years]', fontsize=16)
plt.ylabel('GMST [$^\circ$C]', fontsize=16)
plt.legend(fontsize=16)
plt.tight_layout()
plt.savefig(f'{path_results}/GMST/GMST_rcp_ctrl')

In [None]:
N_ctrl = len(gmst_ctrl.coords['time'])
N_rcp  = len(gmst_rcp.coords['time'])
print(N_ctrl, N_rcp)

fig = plt.figure(figsize=(8,5))
plt.tick_params(labelsize=14)
for i in range(N_rcp):
    plt.plot(gmst_wt_rcp.coords['lat'], 
             gmst_wt_rcp.T_zonal[i,:]-gmst_wt_ctrl.T_zonal[:,:].mean(dim='time'), 
             color=plt.cm.rainbow(i/N_rcp), alpha=.4, lw=.5,
             label='')
for i in range(N_rcp):
    if (i-5)%10==0:
        plt.plot(gmst_wt_rcp.coords['lat'], 
                 gmst_wt_rcp.T_zonal[i-5:i+5,:].mean(dim='time')-gmst_wt_ctrl.T_zonal[:,:].mean(dim='time'), 
                 color=plt.cm.rainbow(i/N_rcp), alpha=1, lw=2.5, ls='--',
                 label=f'decade {int((i-5)/10)+1}')
plt.axhline(0, c='k', lw=.5)
plt.legend(ncol=3, frameon=False, fontsize=16)
plt.ylabel('zonally averaged temp. RCP - CTRL [K]', fontsize=16)
plt.xlabel('Latitude', fontsize=16)
plt.xticks(np.arange(-90,91,30))
plt.tight_layout()
plt.savefig(f'{path_results}/GMST/T_zonal_rcp-ctrl')

## low resolution runs

In [None]:
hadcrut = xr.open_dataarray(f'{path_data}/HadCRUT/ihad4_krig_v2_0-360E_-90-90N_n_mean1_anom_30.nc', decode_times=False)

In [None]:
ctrl = TimeSeriesAnalysis(gmst_dt_ctrl)
rcp  = TimeSeriesAnalysis(gmst_dt_rcp )
lpd  = TimeSeriesAnalysis(gmst_dt_lpd )
lpi  = TimeSeriesAnalysis(gmst_dt_lpi )
had  = TimeSeriesAnalysis(gmst_dt_had )

In [None]:
f, ax = plt.subplots(5,1,figsize=(12,12), sharex=True)
for i in range(5):
    ax[i].tick_params(labelsize=14)
    ax[i].axhline(0, c='k', lw=.5)
# ax.axhline(0,c='k', lw=.5)

time_had = np.arange(2350,2519)

ax[0].plot(gmst_ctrl.time/365+1850, gmst_ctrl-gmst_ctrl.mean(), c='C0')
ax[0].plot(gmst_rcp .time/365+ 200, (gmst_rcp-gmst_rcp .mean())/3, c='C1')
ax[0].plot(gmst_lpd .time/365+1350, gmst_lpd -gmst_lpd .mean(), c='C2')
ax[0].plot(gmst_lpi .time/365-1600, gmst_lpi -gmst_lpi .mean(), c='C3')
ax[0].plot(np.arange(2350,2519)   , hadcrut  -hadcrut  .mean(), c='C4')

ax[1].plot(gmst_ctrl.time/365+1850, gmst_dt_ctrl, c='C0')
ax[1].plot(gmst_rcp .time/365+ 200, gmst_dt_rcp , c='C1')
ax[1].plot(gmst_lpd .time/365+1350, gmst_dt_lpd , c='C2')
ax[1].plot(gmst_lpi .time/365-1600, gmst_dt_lpi , c='C3')
ax[1].plot(gmst_had.time[11:]+2350, gmst_dt_had , c='C4')

for i, window in enumerate([5,10,15]):
    ax[i+2].plot(gmst_ctrl.time/365+1850, ctrl.rolling_trends(window), c='C0')
    ax[i+2].plot(gmst_rcp .time/365+ 200, rcp .rolling_trends(window), c='C1')
    ax[i+2].plot(gmst_lpd .time/365+1350, lpd .rolling_trends(window), c='C2')
    ax[i+2].plot(gmst_lpi .time/365-1600, lpi .rolling_trends(window), c='C3')
    ax[i+2].plot(gmst_had.time[11:]+2350, had .rolling_trends(window), c='C4')
    ax[i+2].set_ylabel('trend [$^\circ$C/yr]', fontsize=16)

ax[0].text(1950, .75, 'CTRL'         , fontsize=16, color='C0')
ax[0].text(2200, .75, 'RCP'          , fontsize=16, color='C1')
ax[0].text(2200, .5, r'$\times\frac{1}{3}$' , fontsize=20, color='C1')
ax[0].text(1500, .75, 'pres. day low', fontsize=16, color='C2')
ax[0].text(1280, .75, 'pre-ind. low' , fontsize=16, color='C3')
ax[0].text(2320, .75, 'HadISST'      , fontsize=16, color='C4')

# ax.legend(handles=[L1, L2, L3], loc=8, ncol=3, fontsize=14, frameon=False)
ax[-1].set_xlabel('time [years]', fontsize=16)
ax[0].set_ylabel('GMST [$^\circ$C]', fontsize=16)
ax[1].set_ylabel('detrended [$^\circ$C]', fontsize=16)

ax[-1].set_xticks(np.arange(1200,2800,200))
ax[-1].set_xlim((1230,2550))
f.align_ylabels()
f.tight_layout()
plt.savefig(f'{path_results}/GMST/GMST_with_trends_timeseries')

In [None]:
mpi_ge = xr.open_dataset(f'{path_samoc}/GMST/GMST_MPI_GE.nc')

In [None]:
mpi_ge

In [None]:
mpi_ge_dt = mpi_ge-mpi_ge.tsurf_mean
mpi_ge_dt_stacked = mpi_ge_dt.tsurf.stack(z=('member','time'))

In [None]:
A = {5:[], 10:[], 15:[]}
for window in [5, 10, 15]:
    for label, arr_window in mpi_ge_dt.tsurf.rolling({'time':window}, center=True):
        for i in range(100):
             if len(arr_window[i,:])>1:
                A[window].append(np.polyfit(np.arange(len(arr_window[i,:])), arr_window[i,:], 1)[0])

In [None]:
def plot_kde(ax, x, A, label, c, ls='-', lw=1):
    ax.plot(x, stats.gaussian_kde(A)(x), label=label, c=c, ls=ls, lw=lw)
    
def plot_kde_diff(ax, x, A, B, label, c, ls='-', lw=1):
    ax.plot(x, stats.gaussian_kde(A)(x)-stats.gaussian_kde(B)(x), label=label, c=c, ls=ls, lw=lw)

gmst_dts = [gmst_dt_ctrl, gmst_dt_rcp, gmst_dt_lpd, gmst_dt_lpi, gmst_dt_had]
TSAs = [ctrl, rcp, lpd, lpi, had]
labels = ['CTRL', 'RCP', 'LPD', 'LPI', 'HAD']

f, ax = plt.subplots(2,4, figsize=(12,8))
x1=np.linspace(-.35 ,.35 ,30)
x2=np.linspace(-.15,.15,30)
x3=np.linspace(-.05,.05,30)
x4=np.linspace(-.03,.03,30)

for i in range(5):
    if i==4:  lw=2
    else:  lw=1
    if i<4:
        for j in range(2):
            ax[j,i].tick_params(labelsize=14, labelleft=False)
            ax[j,i].axhline(0, c='k', lw=.5)
            ax[j,i].axvline(0, c='k', lw=.5)
        
    if i in [0,2,3]:
        plot_kde(ax[0,0], x1, gmst_dts[i][-150:], None, f'C{i}', ls='--', lw=.8)
        plot_kde_diff(ax[1,0], x1, gmst_dts[i][-150:], gmst_dts[4], None, f'C{i}', ls='--', lw=.8)
        for j, window in enumerate([5,10,15]):
                plot_kde(ax[0,j+1], [x2, x3, x4][j], TSAs[i].rolling_trends(window)[-150:].dropna(dim='time'), None, f'C{i}', ls='--', lw=.8)
                plot_kde_diff(ax[1,j+1], [x2, x3, x4][j], TSAs[i].rolling_trends(window)[-150:].dropna(dim='time'),
                                                          TSAs[4].rolling_trends(window).dropna(dim='time'), None, f'C{i}', ls='--', lw=.8)
        
    plot_kde(ax[0,0], x1, gmst_dts[i], labels[i], f'C{i}', lw=lw)
    plot_kde_diff(ax[1,0], x1, gmst_dts[i], gmst_dts[4], labels[i], f'C{i}', lw=lw)
    for j, window in enumerate([5,10,15]):
        plot_kde(ax[0,j+1], [x2, x3, x4][j], TSAs[i].rolling_trends(window).dropna(dim='time'), labels[i], f'C{i}', lw=lw)
        plot_kde_diff(ax[1,j+1], [x2, x3, x4][j], TSAs[i].rolling_trends(window).dropna(dim='time'),
                                                  TSAs[4].rolling_trends(window).dropna(dim='time'), labels[i], f'C{i}', lw=lw)
        
plot_kde(ax[0,0], x1, mpi_ge_dt_stacked, 'MPIGE', c='C5')
plot_kde_diff(ax[1,0], x1, mpi_ge_dt_stacked, gmst_dts[4], 'MPIGE', c='C5')
for i, window in enumerate([5,10,15]):
    trends = [x for x in A[window] if np.isnan(x)==False]
    plot_kde(ax[0,i+1], [x2, x3, x4][i], trends, '', c='C5')
    plot_kde_diff(ax[1,i+1], [x2, x3, x4][i], trends, TSAs[4].rolling_trends(window).dropna(dim='time'), '', c='C5')
    
ax[0,0].legend()
ax[0,0].set_ylabel('relative abundance'        , fontsize=16)
ax[1,0].set_ylabel('difference to HAD'         , fontsize=16)
ax[1,0].set_xlabel('detr. GMST [$^\circ$C]'    , fontsize=16)
ax[1,1].set_xlabel('5 yr trend [$^\circ$C/yr]' , fontsize=16)
ax[1,2].set_xlabel('10 yr trend [$^\circ$C/yr]', fontsize=16)
ax[1,3].set_xlabel('15 yr trend [$^\circ$C/yr]', fontsize=16)
plt.tight_layout()
plt.savefig(f'{path_results}/GMST/GMST_trends_kde', dpi=600)

In [None]:
spectrum_ctrl = ctrl.spectrum()
spectrum_rcp  = rcp .spectrum()
spectrum_lpd  = lpd .spectrum()
spectrum_lpi  = lpi .spectrum()
spectrum_had  = had .spectrum()
spectra = [spectrum_ctrl, spectrum_rcp , spectrum_lpd , spectrum_lpi , spectrum_had]

In [None]:
f, ax = plt.subplots(1,1, figsize=(8,5))
ax.set_yscale('log')
for spectrum in spectra:
    ax.plot(spectrum[1], spectrum[0])

In [None]:
for gmst in [I_ctrl, I_rcp, I_lpd, I_lpi, I_had]:
    gmst.plot_all_spectra()
    gmst.plot_all_autocorrelations()
#     for run in runs:
#         gmst.plot_spectrum_ar1(run)

# linear trend maps

In [None]:
# %%time
# # ca 10:30 min
# trends_ctrl = GMST_regression('ctrl')
# trends_ctrl.to_netcdf(path=f'{path_results}/GMST/trend_ctrl.nc' , mode='w')
# trends_rcp  = GMST_regression('rcp')
# trends_rcp.to_netcdf(path=f'{path_results}/GMST/trend_rcp.nc' , mode='w')

In [None]:
trends_ctrl = xr.open_dataarray(f'{path_results}/GMST/trend_ctrl.nc')
trends_rcp  = xr.open_dataarray(f'{path_results}/GMST/trend_rcp.nc' )

In [None]:
times = ['100-299', '2000-2099']
for i, trends in enumerate([trends_ctrl, trends_rcp]):
    run = ['ctrl', 'rcp'][i]
    label = f'{times[i]} air surface temperature trend [K/century]'
    minv, maxv = -2, 6
    cmap = shifted_color_map(mpl.cm.RdBu_r, start=.33, midpoint=0.5, stop=1., name='shrunk')
    cmap = discrete_cmap(16, cmap)
    filename = f'{path_results}/GMST/T_trend_map_{run}'
    map_eq_earth(xa=trends, domain='atm', cmap=cmap, minv=minv, maxv=maxv, label=label, filename=filename)