# SST detrending
This notebook visualizes results from `SST_generation.py`.

In [None]:
import os
import sys
sys.path.append("..")
import numpy as np
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%aimport - numpy - scipy - matplotlib.pyplot
matplotlib.rc_file('../rc_file')

In [None]:
from paths import path_samoc, path_results, file_ex_ocn_ctrl
from timeseries import IterateOutputCESM
from ab_derivation_SST import DeriveSST as DS
from bc_analysis_fields import AnalyzeField as AF
from bd_analysis_indices import AnalyzeIndex as AI
from xr_regression import xr_quadtrend

## How to detrend?

### Quadratic pointwise detrending
what detrending starting year to choose?

In [None]:
da = xr.open_dataarray(f'{path_samoc}/SST/SST_yrly_ctrl.nc', decode_times=False)
# dt = xr.open_dataarray(f'{path_samoc}/SST/SST_GMST_sqdt_yrly_ctrl.nc', decode_times=False)
for lat in [400, 1700, 1800, 1900]:
    plt.figure()
    da.sel({'nlon':800, 'nlat':lat}).plot()
#     dt.sel({'nlon':800, 'nlat':lat}).plot()
    xr_quadtrend(da.sel({'nlon':800, 'nlat':lat})[150:]+1.5).plot()
    xr_quadtrend(da.sel({'nlon':800, 'nlat':lat})[100:]+1).plot()
    xr_quadtrend(da.sel({'nlon':800, 'nlat':lat})[40:]+.5).plot()
    xr_quadtrend(da.sel({'nlon':800, 'nlat':lat})).plot()


In [None]:
da = xr.open_dataarray(f'{path_samoc}/SST/SST_yrly_lpd.nc', decode_times=False)

for lat in [50, 250, 300, 350]:
    plt.figure()
    da.sel({'nlon':0, 'nlat':lat}).plot()
#     dt.sel({'nlon':800, 'nlat':lat}).plot()
    xr_quadtrend(da.sel({'nlon':0, 'nlat':lat})[150:]+1.5).plot()
    xr_quadtrend(da.sel({'nlon':0, 'nlat':lat})[100:]+1).plot()
    xr_quadtrend(da.sel({'nlon':0, 'nlat':lat})[40:]+.5).plot()
    xr_quadtrend(da.sel({'nlon':0, 'nlat':lat})).plot()


### scaled GMSST detrending

In [None]:
# %%time
# 1:32 min for both
for run in ['ctrl', 'lpd']:
    try:
        fn = f'{path_samoc}/SST/GMSST_yrly_{run}.nc'
        assert os.path.exists(fn)
        print(f'file exists: {fn}')
    except:
        DS().generate_yrly_global_mean_SST(run=run)

In [None]:
for i, run in enumerate(['ctrl', 'lpd']):
    da = xr.open_dataarray(f'{path_samoc}/SST/GMSST_yrly_{run}.nc')
    plt.plot(da.time/365, da- i/2, lw=.5)
    if run=='ctrl':  x = da[40:]
    else:            x = da    
    pf = np.polynomial.polynomial.polyfit(x.time, x, 2)
    plt.plot(x.time/365, pf[2]*x.time**2 + pf[1]*x.time + pf[0] - i/2)
#     plt.axvline(40, c='grey', lw=.5)
    plt.ylabel('global mean SST [$^\circ$C]')
    plt.xlabel('time [years]')

## GMST detrending

### GMST time series

In [None]:
gmst_ctrl = xr.open_dataset(f'{path_samoc}/GMST/GMST_yrly_ctrl.nc')
gmst_rcp  = xr.open_dataset(f'{path_samoc}/GMST/GMST_yrly_rcp.nc' )
gmst_lpd  = xr.open_dataset(f'{path_samoc}/GMST/GMST_yrly_lpd.nc' )
gmst_lpi  = xr.open_dataset(f'{path_samoc}/GMST/GMST_yrly_lpi.nc' )
gmsts = [gmst_ctrl, gmst_rcp, gmst_lpd, gmst_lpi]

In [None]:
for i, beta in enumerate(betas):
    (depth, lat, lon) = dll_dims_names(domain=domains[i])
    plt.figure(figsize=(8,5))
    plt.tick_params(labelsize=14)
    plt.xlabel('time [years]', fontsize=16)
    plt.ylabel('forced GMST reponse', fontsize=16)
    if i<4:
        gmst = gmsts[i]
        plt.plot(gmst.time/365, gmst.GMST-gmst.GMST.mean())
        time = beta.time/365
    else:
        time = beta.time/365+1861
    plt.plot(time, beta.forcing, c='C1')
    plt.tight_layout()
    plt.savefig(f'{path_results}/GMST/GMST_forced_signal_{runs[i]}')

### std of detrended SST field

In [None]:
SST_dt_yrly_ctrl = xr.open_dataarray(f'{path_samoc}/SST/SST_yrly_detr_ctrl.nc', decode_times=False)
SST_dt_yrly_rcp  = xr.open_dataarray(f'{path_samoc}/SST/SST_yrly_detr_rcp.nc' , decode_times=False)
SST_dt_yrly_lpd  = xr.open_dataarray(f'{path_samoc}/SST/SST_yrly_detr_lpd.nc' , decode_times=False)
SST_dt_yrly_lpi  = xr.open_dataarray(f'{path_samoc}/SST/SST_yrly_detr_lpi.nc' , decode_times=False)
SSTs_ac = [SST_dt_yrly_ctrl, SST_dt_yrly_rcp, SST_dt_yrly_lpd, SST_dt_yrly_lpi, SST_GMST_dt_yrly_had]

In [None]:
%%time
for i, SST_ac in enumerate(SSTs_ac):
    #     if i!=2: continue
    run = runs[i]
    fn = f'{path_samoc}/SST/SST_std_{run}.nc'
    fa = FieldAnalysis(SST_ac[-100:])
#     xa = fa.make_standard_deviation_map(fn=fn)
    xa = xr.open_dataarray(fn)
    
    fn = f'{path_results}/SST/SST_std_map_{run}'
    domain = map_domains[i]
    label = 'standard deviation of SST [K]'
    cmap = 'viridis'
    txt1 = f'{run.upper()}\ndetr.'
    txt2 = 'last 100\n years'
    make_map(xa=xa, domain=domain, proj='rob', cmap=cmap, minv=0, maxv=1,
             label=label, filename=fn, text1=txt1, text2=txt2)

### scaling factors

In [None]:
beta_ctrl = xr.open_dataset(f'{path_samoc}/SST/SST_beta_GMST_yrly_ctrl.nc', decode_times=False)
beta_rcp  = xr.open_dataset(f'{path_samoc}/SST/SST_beta_GMST_yrly_rcp.nc' , decode_times=False)
beta_lpd  = xr.open_dataset(f'{path_samoc}/SST/SST_beta_GMST_yrly_lpd.nc' , decode_times=False)
beta_lpi  = xr.open_dataset(f'{path_samoc}/SST/SST_beta_GMST_yrly_lpi.nc' , decode_times=False)
beta_had  = xr.open_dataset(f'{path_samoc}/SST/SST_beta_GMST_yrly_had.nc' , decode_times=False)
betas = [beta_ctrl, beta_rcp, beta_lpd, beta_lpi, beta_had]

In [None]:
for i, beta in enumerate(betas):
#     if i!=2: continue
    fn = f'{path_results}/SST/SST_GMST_beta_{runs[i]}'
    domain = map_domains[i]
    label = 'scaling SST(GMST)'
    cmap = cmocean.cm.curl
    xa = beta.slope.where(masks[i])
    make_map(xa=xa, domain=domain, proj='rob', cmap=cmap, minv=-2, maxv=2,
             label=label, filename=fn, text1=None, text2=None, rects=None, sig=None, clon=200)

### SST autocorrelation

In [None]:
%%time
for i, SST_ac in enumerate(SSTs_ac):
    #     if i!=2: continue
    run = runs[i]
    fn = f'{path_samoc}/SST/SST_autocorrelation_{run}.nc'
    fa = FieldAnalysis(SST_ac[-100:])
    xa = fa.make_autocorrelation_map(fn=fn)
    
    fn = f'{path_results}/SST/SST_autocorrelation_map_{run}'
    domain = map_domains[i]
    label = 'autocorrelation of SST'
    cmap = cmocean.cm.curl
    txt1 = f'{run.upper()}\ndetr.'
    txt2 = '100 years'
    make_map(xa=xa, domain=domain, proj='rob', cmap=cmap, minv=-1, maxv=1,
             label=label, filename=fn, text1=txt1, text2=txt2)