In [None]:
import os
import sys
sys.path.append("..")
import numpy as np
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%load_ext autoreload
%autoreload 2
%aimport - numpy - scipy - matplotlib.pyplot
matplotlib.rc_file('rc_file')

In [None]:
from paths import path_samoc, path_results, file_ex_ocn_ctrl
from timeseries import IterateOutputCESM
from ab_derivation_SST import DeriveSST as DS
from bc_analysis_fields import AnalyzeField as AF
from bd_analysis_indices import AnalyzeIndex as AI
from xr_regression import xr_quadtrend

## 1. generating full SST fields
### annual: from yearly TEMP_PD averaged data

In [None]:
%%time
# 10 mins for ctrl
for run in ['ctrl', 'lpd', 'had']:
    try:
        fn = f'{path_samoc}/SST/SST_yrly_{run}.nc'
        assert os.path.exists(fn)
        print(f'file exists: {fn}')
    except:
        DS().generate_yrly_SST_files(run)  

In [None]:
for run in ['ctrl', 'lpd']:
    plt.figure()
    da = xr.open_dataarray(f'{path_samoc}/SST/SST_yrly_{run}.nc', decode_times=False)
    da[0,:,:].plot(vmin=-1.8, vmax=30)

### monthly: from monthly averaged model output

In [None]:
# %%time
# DS.generate_monthly_SST_files('ctrl')

In [None]:
# %%time
# # 22 sec for ctrl, 11 sec for lpd
# DS.generate_monthly_mock_linear_GMST_files('lpd')

## 2. detrending/deseasonalizing

In [None]:
# %%time
# 1:32 min for both
for run in ['ctrl', 'lpd']:
    try:
        fn = f'{path_samoc}/SST/GMSST_yrly_{run}.nc'
        assert os.path.exists(fn)
        print(f'file exists: {fn}')
    except:
        DS().generate_yrly_global_mean_SST(run=run)

In [None]:
for run in ['ctrl', 'lpd']:
    da = xr.open_dataarray(f'{path_samoc}/SST/GMSST_yrly_{run}.nc')
    plt.plot(da.time/365, da, lw=.5)
    if run=='ctrl':  x = da[40:]
    else:            x = da    
    pf = np.polynomial.polynomial.polyfit(x.time, x, 2)
    plt.plot(x.time/365, pf[2]*x.time**2 + pf[1]*x.time + pf[0])
#     plt.axvline(40, c='grey', lw=.5)
    plt.ylabel('global mean SST [$^\circ$C]')
    plt.xlabel('time [years]')

In [None]:
%%time
# 6:16 for both ctrl and lpd
# ctrl: 2:40 mins for 149 years
# lpd: < 1 min for 149 years
# had: 4:38 mins for both single and two factor detrending
for run in ['ctrl', 'lpd']:
    print(run)
    try:
        fn = f'{path_samoc}/SST/SST_GMST_sqdt_yrly_{run}.nc'
        assert os.path.exists(fn)
        print(f'  file exists: {fn}')
    except:
        DS().SST_remove_forced_signal(run=run, tres='yrly', detrend_signal='GMST', time_slice='full')
    
print('had')
try:
    for dt in ['sfdt', 'tfdt']:
        fn = f'{path_samoc}/SST/SST_GMST_{dt}_yrly_had.nc'
        assert os.path.exists(fn)
        print(f'  file exists: {fn}')
except:
    DS().SST_remove_forced_signal(run='had', tres='yrly', detrend_signal='GMST', time_slice='full')

In [None]:
# for i, run in enumerate(['ctrl', 'lpd', 'had']):
#     dt = ['sqdt', 'sqdt', 'tfdt'][i]
#     fn = f'{path_samoc}/SST/SST_GMST_{dt}_yrly_{run}.nc'
#     assert os.path.exists(fn)
#     if run=='ctrl':
#         AREA = xr_AREA(domain='ocn')
#         REGION_MASK = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False).REGION_MASK
#     elif run=='lpd':
#         AREA = xr_AREA(domain='ocn_low')
#         REGION_MASK = xr.open_dataset(file_ex_ocn_lpd, decode_times=False).REGION_MASK
#     elif run=='had':
#         AREA = xr_AREA(domain='ocn_low')
#         REGION_MASK = xr.open_dataset(file_ex_ocn_lpd, decode_times=False).REGION_MASK
#     da = xr.open_dataarray(fn, decode_times=False)
#     da.mean()

### loop over 149 year long segments 10 years apart

In [None]:
# starting years of 149 year long segments
ctrl_starts = np.arange(1, 152, 10)
lpd_starts = np.arange(154, 415, 10)

## 3. derive raw SST indices

### full time series

In [None]:
%%time
# ca. 1:10 min for single ctrl run, 8:45 for all
# 11 seconds for lpd and had combined
for run in ['had',  'lpd', 'ctrl']:
    if run=='had':
        dt = 'tfdt'  # two-factor detrending, or 'sfdt' single-factor detrending
    elif run in ['ctrl', 'lpd', 'rcp']:
        dt = 'sqdt'  # scaled quadratic detrending
    
    try:
        for idx in ['AMO', 'SOM', 'TPI1', 'TPI2', 'TPI3']:
            fn = f'{path_samoc}/SST/{idx}_GMST_{dt}_raw_{run}.nc'
            assert os.path.exists(fn)
        print(f'raw index files for {run} exist')
    except:
        AI().derive_all_SST_avg_indices(run, 'full')

## 4. filter SST indices

In [None]:
%%time
# 1 sec
for run in ['ctrl', 'lpd', 'had']:
    AI().derive_final_SST_indices(run=run, tslice='full')

In [None]:
%%time
# 7 sec
for i, run in enumerate(['ctrl', 'lpd']):
    for t in [ctrl_starts, lpd_starts][i]:
        tslice = (t, t+148)
        try:
            for idx in ['AMO', 'SOM', 'TPI']:
                fn = f'{path_samoc}/SST/{idx}_{run}_{tslice[0]}_{tslice[1]}.nc'
                assert os.path.exists(fn)
            print(f'filtered index files for {run} of segment {tslice} exist')
        except:
            AI().derive_final_SST_indices(run, tslice)

In [None]:
for i, run in enumerate(['ctrl', 'lpd', 'had']):
    f, ax = plt.subplots(3, 1, figsize=(8,8), sharex=True)
    ax[2].set_xlabel('time [year]', fontsize=14)
    for j, idx in enumerate(['AMO', 'SOM', 'TPI']):
        if run=='had':
            fn = f'{path_samoc}/SST/{idx}_{run}.nc'
            da = xr.open_dataarray(fn)
            ax[j].plot(da[7:-7].time/365+1870, da[7:-7], c=f'C{k%10}')
            ax[j].plot(da[:7].time/365+1870, da[:7]  , ls=':', c=f'C{k%10}')
            ax[j].plot(da[-7:].time/365+1870, da[-7:], ls=':', c=f'C{k%10}')
        elif run in ['ctrl', 'lpd']:
            for k, t in enumerate([ctrl_starts, lpd_starts][i]):
                tslice = (t, t+148)
                fn = f'{path_samoc}/SST/{idx}_{run}_{tslice[0]}_{tslice[1]}.nc'
                da = xr.open_dataarray(fn)
                ax[j].plot(da[7:-7].time/365, da[7:-7], c=f'C{k%10}')
                ax[j].plot(da[:7].time/365, da[:7]  , ls=':', c=f'C{k%10}')
                ax[j].plot(da[-7:].time/365, da[-7:], ls=':', c=f'C{k%10}')
        ax[j].set_ylabel(idx, fontsize=14)
        ax[j].axhline(0, c='k', lw=.5)
        ax[j].tick_params(labelsize=14)
    plt.tight_layout()
    plt.savefig(f'{path_results}/SST/SST_indices_segments_lowpass13_{run}')

## autocorrelation fields

In [None]:
%%time
# 5:40 mins for all
# 2:17 for all lpd and had
for run in ['ctrl']:  # 'lpd', 'had'
    try:
        fn = f'{path_samoc}/SST/SST_autocorrelation_{run}.nc'
        assert os.path.exists(fn)
        print(f'file exists: {fn}')
    except:
        AI().derive_yrly_autocorrelations(run, 'full')

In [None]:
%%time
# 3 mins for 149 year segment of ctrl
for i, run in enumerate(['ctrl', 'lpd']): #
    for t in [ctrl_starts, lpd_starts][i]:
        tslice = (t, t+148)
        try:
            fn = f'{path_samoc}/SST/SST_autocorrelation_{run}_{tslice[0]}_{tslice[1]}.nc'
            assert os.path.exists(fn)
            print(f'file exists: {fn}')
        except:
            AI().derive_yrly_autocorrelations(run, tslice)

## regression files

In [None]:
%%time
# 10 sec for  lpd
# 4 sec for had
for run in ['lpd', 'had']:  # ctrl autocorrelation file does not exist for full
    try:
        for idx in ['AMO', 'SOM', 'TPI']:
            fn = f'{path_samoc}/SST/{idx}_regr_{run}.nc'
            assert os.path.exists(fn)
        print(f'regression files for {run} exist')
    except:
        AI().make_yrly_regression_files(run, 'full')

In [None]:
from maps import regr_map

In [None]:
ds

In [None]:
%%time
for i, run in enumerate([ 'lpd', 'had']):  # 'ctrl',
    if run=='ctrl':
        TLAT = xr.open_dataset(file_ex_ocn_ctrl,\
                                       decode_times=False).TLAT.coords['TLAT']
    for idx in ['AMO', 'SOM', 'TPI']:
        print(run, idx)
        if run in ['lpd', 'had']:  # full run
            fn = f'{path_samoc}/SST/{idx}_regr_{run}.nc'
            ds = xr.open_dataset(fn)
            fn_new = f'{path_results}/SST/{idx}_regr_map_{run}'
            regr_map(ds, index=idx, run='had', fn=fn_new)
        if run in ['ctrl', 'lpd']:  # segments
            for t in [ctrl_starts, lpd_starts][i]:
                tslice = (t, t+148)
                fn = f'{path_samoc}/SST/{idx}_regr_{run}_{tslice[0]}_{tslice[1]}.nc'
                ds = xr.open_dataset(fn)
                if run=='ctrl':
                    ds.coords['TLAT'] = TLAT
                fn_new = f'{path_results}/SST/{idx}_regr_map_{run}_{tslice[0]}_{tslice[1]}'
                regr_map(ds, index=idx, run='had', fn=fn_new)

In [None]:
%%time
# ca. 5 mins mins for 149 year segment of ctrl
# 3 sec for lpd, 1:30 mins for all
for i, run in enumerate(['ctrl', 'lpd']): #
    for t in [ctrl_starts, lpd_starts][i]:
        tslice = (t, t+148)
        try:
            for idx in ['AMO', 'SOM', 'TPI']:
                fn = f'{path_samoc}/SST/{idx}_regr_{run}_{tslice[0]}_{tslice[1]}.nc'
                assert os.path.exists(fn)
            print(f'regression files for {run} of segment {tslice} exist')
        except:
            AI().make_yrly_regression_files(run, tslice)

### pattern correlation

In [None]:
%%time
# ? for 149 year segment of ctrl
# 1 mins for all lpd

for i, idx in enumerate(['AMO', 'SOM', 'TPI']):
    region = [{'longitude':slice(-80,0), 'latitude':slice(60,0)}, 1, 2][i]
    had   = xr.open_dataset(f'{path_samoc}/SST/{idx}_regr_had.nc').slope
    for j, run in enumerate(['ctrl', 'lpd']):  
        starts = [ctrl_starts, lpd_starts][j]
        fn_new = f'{path_samoc}/SST/{idx}_spatial_correlations_{run}.nc'
        
        try:
            assert os.path.exists(fn_new)
            print(f'file exists: {fn_new}')
        except:
            da = xr.DataArray(data=np.zeros(len(starts)),
                          coords={'time': starts},
                          dims=('time'))
            if run=='ctrl':
                TLAT = xr.open_dataset(file_ex_ocn_ctrl,\
                                       decode_times=False).TLAT.coords['TLAT']
            for k, t in enumerate(starts):
                tslice = (t, t+148)
                fn = f'{path_samoc}/SST/{idx}_regr_{run}_{tslice[0]}_{tslice[1]}.nc'
                segment = xr.open_dataset(fn).slope
                if run=='ctrl':
                    segment.coords['TLAT'] = TLAT
                da.values[k] = AF().spatial_correlation(field_A=had, field_B=segment,
                                                        selection=region)
                da.to_netcdf(fn_new)
            


In [None]:
f, ax = plt.subplots(1, 2, figsize=(10,5), sharey=True)
for i, idx in enumerate(['AMO', 'SOM', 'TPI']):
    for j, run in enumerate(['ctrl', 'lpd']):  # , 
        fn = f'{path_samoc}/SST/{idx}_spatial_correlations_{run}.nc'
        da = xr.open_dataarray(fn)
        da.plot(label=idx, ax=ax[j])
        ax[j].set_xlabel('starting year of segment', fontsize=14)
        ax[j].text(da.time[0], .7, run.upper(), fontsize=14)
    for j in range(2):
        ax[j].tick_params(labelsize=14)
        ax[j].axhline(0, c='k', lw=.5)
    
    
ax[0].axvline(100, c='grey', lw=.5)
ax[0].axvline(151, c='grey', lw=.5)
ax[1].axvline(268, c='grey', lw=.5)

ax[0].legend(fontsize=14, ncol=3, loc=8)
ax[0].set_ylabel('spatial correlation coefficient', fontsize=14)
plt.tight_layout()
plt.savefig(f'{path_results}/SST/spatial_correlation(t)_ctrl_lpd')