# Surface Heat Flux into the ocean

In [None]:
import sys
import scipy as sp
import numpy as np
import xarray as xr
import cartopy
import cartopy.crs as ccrs
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
matplotlib.rc_file('../rc_file')
%load_ext autoreload
%autoreload 2
%aimport - numpy - scipy - matplotlib.pyplot

In [None]:
sys.path.append("..")
from tqdm import tqdm
from maps import map_robinson
from paths import path_results, path_prace, file_ex_ocn_ctrl, file_ex_atm_ctrl
from filters import lowpass
from regions import regions_dict, boolean_mask
from constants import spy, latent_heat_vapor
from timeseries import IterateOutputCESM
from xr_regression import xr_linear_trends_2D, xr_quadtrend
from xr_DataArrays import xr_AREA
from bb_analysis_timeseries import AnalyzeTimeSeries as ATS

## ocean variable `SHF`: _total surface heat flux including SW_

In [None]:
# AnalyzeBudget().surface_heat_flux(run)

In [None]:
# OHC data
ctrl = xr.open_dataset(f'{path_prace}/OHC/OHC_integrals_ctrl.nc').isel(time=np.arange(50,300))
lpd  = xr.open_dataset(f'{path_prace}/OHC/OHC_integrals_lpd.nc' ).isel(time=np.arange(0,250))

ctrl_qd = xr.open_dataset(f'{path_prace}/OHC/OHC_integrals_ctrl_qd.nc', decode_times=False)
lpd_qd  = xr.open_dataset(f'{path_prace}/OHC/OHC_integrals_lpd_qd.nc' , decode_times=False)

# top of atmosphere imbalance
TOA_ctrl = xr.open_dataarray(f'{path_prace}/TOA/TOM_ctrl.nc', decode_times=False).isel(time=slice(50,300))
TOA_lpd  = xr.open_dataarray(f'{path_prace}/TOA/TOM_lpd.nc' , decode_times=False).isel(time=slice(0,250))

# SHF
SHF_ctrl = xr.open_dataset(f'{path_prace}/OHC/SHF_ctrl.nc', decode_times=False).isel(time=slice(50,300))
SHF_lpd  = xr.open_dataset(f'{path_prace}/OHC/SHF_lpd.nc' , decode_times=False).isel(time=slice(0,250))

In [None]:
from paths import file_RMASK_ocn, file_ex_ocn_ctrl 
MASK = xr.open_dataarray(file_RMASK_ocn)

In [None]:
AREA = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False).TAREA
da = xr.open_dataarray(f'{path_prace}/ctrl/ocn_yrly_SHF_0050.nc')

In [None]:
da.plot()

In [None]:
AREA.plot()

In [None]:
(AREA*da).plot()

In [None]:
from constants import spy
print((AREA*da).sum(dim=['nlat', 'nlon'])*spy/1e4)
print((AREA*da.where(MASK>0)).sum(dim=['nlat', 'nlon'])*spy/1e4)

In [None]:
ATS(SHF_ctrl).mc_ar1_spectrum(filter_type='lowpass', filter_cutoff=13)

In [None]:
f, ax = plt.subplots(2, 2, figsize=(8,5), sharey='row', constrained_layout=True)
for i, SHF in enumerate([SHF_ctrl, SHF_lpd]):
    for j in range(2):
        ax[j,i].axhline(0, c='k')
        ax[j,i].tick_params(labelsize=12)
    for j, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
        if j==0: c = 'k'
        else: f'C{j-1}'
        SHF_ = SHF[f'{ocean}_Ocean']
        ax[0,i].plot(SHF.time/365, SHF_, label=ocean  , c=c, lw=.5)
        ax[0,i].plot(SHF.time[7:-7]/365, lowpass(SHF_, 13)[7:-7], c=c', lw=1)
        ax[1,i].plot(SHF.time[7:-7]/365, lowpass(SHF_-xr_quadtrend(SHF_), 13)[7:-7], c=c, lw=1)
    ax[1,i].set_xlabel('time [model years]', fontsize=14)
ax[0,0].legend(ncol=2, fontsize=10)
ax[0,0].set_ylabel('SHF [ZJ/yr]', fontsize=12)
ax[1,0].set_ylabel('SHF anomaly [ZJ/yr]', fontsize=12)
ax[1,0].set_ylim(-2e21,2e21)
plt.savefig(f'{path_results}/Battisti/SHF_time_series')

In [None]:
f, ax = plt.subplots(1,1, constrained_layout=True)
for i, SHF in enumerate([SHF_ctrl, SHF_lpd]):
    for j, ocean in enumerate(['Global', 'Atlantic', 'Pacific', 'Southern']):
        SHF_ = SHF[f'{ocean}_Ocean']
        spec = ATS(SHF_-xr_quadtrend(SHF_)).spectrum()
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.plot(spec[1], spec[0], c=f'C{j}', lw=1, ls=['-', '--'][i], label=ocean)
    if i==0:  ax.legend()

In [None]:
f, ax = plt.subplots(1, 2, sharey=True, figsize=(10, 5), constrained_layout=True)
for i, run in enumerate(['ctrl', 'lpd']):
    da_SHF = [SHF_ctrl, SHF_lpd][i]['Global_Ocean']
    da_OHC = [ctrl, lpd][i]['OHC_Global_Ocean']
    da_TOA = [TOA_ctrl, TOA_lpd][i]
    ax[i].axhline(0, c='grey', lw=.5)
    ax[i].plot(da_SHF.time/365, da_SHF/1e21                        , c='C0', label='SHF')
    ax[i].plot(da_SHF.time[7:-7]/365, lowpass(da_SHF,13)[7:-7]/1e21, c='C0')
    ax[i].plot(da_TOA.time+[0,154][i], da_TOA*spy/1e21             , c='C1', label='TOA')
    ax[i].plot(da_TOA.time[7:-7]+[0,154][i], lowpass(da_TOA, 13)[7:-7]*spy/1e21, c='C1')
    ax[i].set_xlabel('time [model years]')
    ax[i].text(.05,.9, ['HIGH', 'LOW'][i], transform=ax[i].transAxes)
    ax[i].tick_params(labeltop=False, labelright=True)
ax[0].legend(loc=3)

In [None]:
f, ax = plt.subplots(1, 2, sharey=True, figsize=(12, 5))
for i, run in enumerate(['ctrl', 'lpd']):
    da_SHF = [SHF_ctrl, SHF_lpd][i]['Global_Ocean']
    da_OHC = [ctrl, lpd][i]['OHC_Global_Ocean']
    da_TOA = [TOA_ctrl, TOA_lpd][i]
    ax[i].axhline(0, c='grey', lw=.5)
    ax[i].plot(da_OHC.time/365, da_TOA.values*spy/1e21 - (da_OHC-da_OHC.shift(time=1)).values/1e21, label=r'TOA-$\Delta$OHC')
    ax[i].axhline(np.nanmean(da_TOA.values*spy/1e21 - (da_OHC-da_OHC.shift(time=1)).values/1e21), lw=2, c='C0')
    ax[i].plot(da_TOA.time+[0,154][i], da_TOA.values*spy/1e21 - da_SHF.values/1e21, label='TOA-SHF')
    ax[i].axhline(np.mean((da_TOA.values*spy/1e21 - da_SHF.values/1e21)), lw=2, c='C1')
    ax[i].set_ylim((-7,7))
    ax[i].set_xlabel('time [model years]')
    ax[i].text(.05,.9, ['HIGH', 'LOW'][i], transform=ax[i].transAxes)
ax[0].legend(loc=3)

In [None]:
f, ax = plt.subplots(1, 2, sharey=True, figsize=(12, 5))
for i, run in enumerate(['ctrl', 'lpd']):
    da_SHF = [SHF_ctrl, SHF_lpd][i]['Global_Ocean']
    da_OHC = [ctrl, lpd][i]['OHC_Global_Ocean']
    da_TOA = [TOA_ctrl, TOA_lpd][i]
    ax[i].axhline(0, c='grey', lw=.5)
#     ax[i].plot(da_OHC.time/365, da_TOA.values*spy/1e21 - (da_OHC-da_OHC.shift(time=1)).values/1e21, label=r'TOA-$\Delta$OHC')
#     ax[i].axhline(np.nanmean(da_TOA.values*spy/1e21 - (da_OHC-da_OHC.shift(time=1)).values/1e21), lw=2, c='C0')
    ax[i].plot(da_TOA.time[7:-7]+[0,154][i], lowpass(da_TOA,13).values[7:-7]*spy/1e21 - lowpass(da_SHF,13)[7:-7].values/1e21 - np.mean((da_TOA.values*spy/1e21 - da_SHF.values/1e21)), label='TOA-SHF')
    ax[i].plot(da_TOA.time+[0,154][i], da_TOA.values*spy/1e21 - da_SHF.values/1e21 - np.mean((da_TOA.values*spy/1e21 - da_SHF.values/1e21)), lw=.5, alpha=.5)
#     ax[i].axhline(np.mean((da_TOA.values*spy/1e21 - da_SHF.values/1e21)), lw=2, c='C1')
    ax[i].set_ylim((-7,7))
    ax[i].set_xlabel('time [model years]')
    ax[i].text(.05,.9, ['HIGH', 'LOW'][i], transform=ax[i].transAxes)
ax[0].legend(loc=3)

In [None]:
f, ax = plt.subplots(1, 2, sharey=True, figsize=(12, 5))
for i, run in enumerate(['ctrl', 'lpd']):
    da_SHF = [SHF_ctrl, SHF_lpd][i]['Global_Ocean']
    da_OHC = [ctrl, lpd][i]['OHC_Global_Ocean']
    da_TOA = [TOA_ctrl, TOA_lpd][i]
    ax[i].axhline(0, c='grey', lw=.5)
    ax[i].axvline(0, c='grey', lw=.5)
    ax[i].plot([-10,15],[-10,15])
    ax[i].scatter(da_SHF/1e21, (da_OHC-da_OHC.shift(time=1))/1e21, label='SHF-OHC')
    ax[i].scatter(da_SHF/1e21, da_TOA*spy/1e21    , label='SHF-TOA')
#     ax[i].plot(da_OHC.time/365, , label=r'$\Delta$OHC')
    ax[i].set_xlabel('SHF [ZJ/yr]')
    ax[i].text(.05,.9, ['HIGH', 'LOW'][i], transform=ax[i].transAxes)
ax[0].legend(loc=3)

In [None]:
f, ax = plt.subplots(1, 2, sharey=True, figsize=(12, 5))
for i, run in enumerate(['ctrl', 'lpd']):
    da_SHF = [SHF_ctrl, SHF_lpd][i]['Global_Ocean']
    da_OHC = [ctrl, lpd][i]['OHC_Global_Ocean']
    da_TOA = [TOA_ctrl, TOA_lpd][i]
    ax[i].axhline(0, c='grey', lw=.5)
    ax[i].axvline(0, c='grey', lw=.5)
    ax[i].plot([-1,10],[-1,10])
    ax[i].scatter(lowpass(da_SHF[1:]/1e21, 13)[7:-7], lowpass((da_OHC-da_OHC.shift(time=1)).dropna(dim='time')/1e21,13)[7:-7], label='SHF-OHC')
    ax[i].scatter(lowpass(da_SHF/1e21    , 13)[7:-7], lowpass(da_TOA*spy/1e21, 13)[7:-7], label='SHF-TOA')
#     ax[i].plot(da_OHC.time/365, , label=r'$\Delta$OHC')
    ax[i].set_xlabel('SHF [ZJ/yr]')
    ax[i].set_ylabel('$\Delta$OHC / TOA [ZJ/yr]')
    ax[i].text(.05,.9, ['HIGH', 'LOW'][i], transform=ax[i].transAxes)
ax[0].legend(loc=3)

In [None]:
for i, run in enumerate(['ctrl', 'lpd']):
    da_SHF = [SHF_ctrl, SHF_lpd][i]['Global_Ocean']
    da_TOA = [TOA_ctrl, TOA_lpd][i]
    print(np.corrcoef(lowpass(da_TOA*spy/1e21,13)[7:-7], lowpass(da_SHF/1e21    , 13)[7:-7]))

SHF is too small by 2 ZJ/year, this is likely because of water bodies other than ocean taking up heat

# atmospheric fluxes

In [None]:
ds = xr.open_dataset(file_ex_atm_ctrl, decode_times=False)

In [None]:
ds.SRFRAD.plot()

In [None]:
ds.FLNS.plot()

In [None]:
FSDS
  Downwelling solar flux at surface
FSDSC
  Clearsky downwelling solar flux at surface
FSNS
  Net solar flux at surface
FSNSC
  Clearsky net solar flux at surface

In [None]:
ds.FSNS.plot()

In [None]:
(ds.FSNS-ds.FLNS-ds.SRFRAD).plot()

In [None]:
ds['T'].mean(dim=('lat','lon')).plot()

In [None]:
MASK = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False).REGION_MASK

In [None]:
ds = xr.open_dataset('/projects/0/samoc/andre/CESM/ctrl/ocn_yrly_SHF_0200.nc', decode_times=False)

In [None]:
np.any(ds.SHF.where(MASK>0)>1000)

In [None]:
TAREA = xr_AREA('ocn')

In [None]:
TAREA.plot()

In [None]:
(ds.SHF*TAREA).where(MASK>0).plot()

In [None]:
SHF_total = (ds.SHF*TAREA).where(MASK>0).sum(dim=('nlat', 'nlon'))*spy

In [None]:
SHF_total

In [None]:
ds_a = xr.open_dataset(file_ex_atm_ctrl, decode_times=False)

In [None]:
'long_name' in ds_a['T'].attrs

In [None]:
for var in ds_a.variables:
    print(var)
    if 'long_name' in ds_a[var].attrs:
        print(f'  {ds_a[var].long_name}')

In [None]:
ds_m = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False)

In [None]:
ds_m

In [None]:
ds_m.MELTH_F.where(MASK>0).plot(vmin=-300)

In [None]:
latent_heat = (ds_m.EVAP_F[0]*latent_heat_vapor).where(MASK>0)
sens_heat   = ds_m.SENH_F[0].where(MASK>0)
lw_net      = (ds_m.LWUP_F[0]-ds_m.LWDN_F[0]).where(MASK>0)
melt_heat   = ds_m.MELTH_F.where(MASK>0)

In [None]:
latent_heat.plot()

In [None]:
sens_heat.plot()

In [None]:
lw_net.plot()

In [None]:
(latent_heat+sens_heat+lw_net+melt_heat).plot()

In [None]:
ds_m.SHF.plot()
plt.title('SHF')
plt.savefig(f'{path_results}/SHF/ex_SHF')

In [None]:
sw_net = (ds_m.SHF-(latent_heat+sens_heat+lw_net+melt_heat))

In [None]:
sw_net.plot(vmin=0)
plt.title('SW = SHF - LATENT - SENS - LWnet - MELT')
plt.savefig(f'{path_results}/SHF/ex_SW')

In [None]:
sw_net.where(sw_net<0).plot(vmin=0)

In [None]:
# example
file_ex = next(IterateOutputCESM('ocn_rect', 'ctrl', 'yrly', name='SHF'))[2]
ds = xr.open_dataset(file_ex, decode_times=False)
ds.SHF.plot()

In [None]:
# stacking files into one xr Dataset object
for run in ['ctrl', 'rcp']:
    for i, (y,m,s) in enumerate(IterateOutputCESM('ocn_rect', run, 'yrly', name='SHF')):
        ds = xr.open_dataset(s, decode_times=False)
        if i==0:
            ds_new = ds.copy()
        else:
            ds_new = xr.concat([ds_new, ds], dim='time')
    ds_new.to_netcdf(f'{path_results}/SHF/SHF_yrly_{run}.nc')

In [None]:
SHF_ctrl = xr.open_dataset(f'{path_results}/SHF/SHF_yrly_ctrl.nc', decode_times=False)
SHF_rcp  = xr.open_dataset(f'{path_results}/SHF/SHF_yrly_rcp.nc' , decode_times=False)

In [None]:
SHF_ctrl

# Global time series

In [None]:
from xr_DataArrays import xr_AREA

In [None]:
AREA = xr_AREA('ocn_rect')

In [None]:
# surface integral
sec_per_year = 3600*24*365

SHF_imbal_ctrl = sec_per_year*(AREA.where(SHF_ctrl.SHF<500)*SHF_ctrl.SHF).sum(dim=['lat', 'lon'])  # [J/year]
SHF_imbal_rcp  = sec_per_year*(AREA.where(SHF_rcp.SHF <500)*SHF_rcp.SHF ).sum(dim=['lat', 'lon'])

In [None]:
fig = plt.figure(figsize=(8,5))
plt.tick_params(labelsize=14)
plt.plot(SHF_imbal_ctrl/1e21, lw=2, label='CTRL')
plt.plot(SHF_imbal_rcp /1e21, lw=2, label='RCP')
plt.legend(ncol=3, frameon=False, fontsize=16)
plt.ylabel('SHF imbalance [ZJ/year]', fontsize=16)
plt.xlabel('time [years]', fontsize=16)
plt.tight_layout()
plt.savefig(f'{path_results}/SHF/SHF_integrated_imbalance')

In [None]:
plt.plot(SHF_ctrl.lat[1:]-SHF_ctrl.lat[:-1])
len()
plt.figure()
plt.plot(AREA[:,0])

# trends

In [None]:
SHF_trend_ctrl = xr_linear_trends_2D(SHF_ctrl.SHF[:,:10,:10], ('lat', 'lon'))
# produces LinAlg error

In [None]:
SHF_ctrl

In [None]:
label = 'Surface heat flux [W/m$^2$]'
minv, maxv = -250, 250
cmap = 'RdBu_r'
filename = f'{path_results}/SHF/SHF_ctrl_mean'

f = map_robinson(xr_DataArray=SHF_ctrl.SHF[:,:,:].mean(dim='time'),
             cmap=cmap, minv=minv, maxv=maxv, label=label, filename=filename)

In [None]:
SHF_ctrl_diff = SHF_ctrl.SHF[-30:,:,:].mean(dim='time')-SHF_ctrl.SHF[:30,:,:].mean(dim='time')

label = 'Surface heat flux [W/m$^2$]'
minv, maxv = -25, 25
cmap = 'RdBu_r'
filename = f'{path_results}/SHF/SHF_ctrl_last_minus_first_30yrs'

f = map_robinson(xr_DataArray=SHF_ctrl_diff,
             cmap=cmap, minv=minv, maxv=maxv, label=label, filename=filename)

In [None]:
SHF_rcp_ctrl = SHF_rcp.SHF[-10:,:,:].mean(dim='time')-SHF_ctrl.SHF[:,:,:].mean(dim='time')

label = 'Surface heat flux [W/m$^2$]'
minv, maxv = -50, 50
cmap = 'RdBu_r'
filename = f'{path_results}/SHF/SHF_rcp_last_10_minus_ctrl_avg'

f = map_robinson(xr_DataArray=SHF_rcp_ctrl,
             cmap=cmap, minv=minv, maxv=maxv, label=label, filename=filename)

# Why is the SHF negative

In [None]:
SHF_rcp.SHF[0,:,:].sum()

In [None]:
SHF_rcp.SHF[0,:,:].plot()

In [None]:
from paths import file_ex_ocn_ctrl, file_ex_ocn_rect

In [None]:
ds_hr = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False)
ds_lr = xr.open_dataset(file_ex_ocn_rect, decode_times=False)

In [None]:
ds_hr

In [None]:
(ds_hr.SHF*ds_hr.TAREA).sum(dim=('nlat','nlon'))

In [None]:
ds_hr.SHF.plot()

In [None]:
ds_lr.SHF.plot()

In [None]:
ds_hr.PD