# Freshwater/Salt Fluxes



1. divergence
2. surface fluxes


In [None]:
import os
import sys
sys.path.append("..")
import numpy as np
import pickle
import xarray as xr
import cmocean
import cartopy
import warnings  
import cartopy.crs as ccrs
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
matplotlib.rc_file('../rc_file')
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%load_ext autoreload
%autoreload 2

In [None]:
from tqdm import notebook
from paths import file_ex_ocn_ctrl, file_ex_ocn_lpd
from paths import path_results, path_prace, file_RMASK_ocn, file_RMASK_ocn_low
from FW_plots import FW_region_plot, FW_summary_plot
from FW_budget import load_obj, lat_bands
from constants import rho_sw
from timeseries import IterateOutputCESM
from FW_transport import calc_section_transport, sections_high, sections_low
from xr_DataArrays import xr_AREA, xr_DZ
from xr_regression import ocn_field_regression, xr_linear_trend
warnings.filterwarnings('ignore')

In [None]:
dsh = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False)
dsl = xr.open_dataset(file_ex_ocn_lpd , decode_times=False)

In [None]:
do = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False)
RMASK_ocn = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False).REGION_MASK
RMASK_low = xr.open_dataset(file_RMASK_ocn_low, decode_times=False).REGION_MASK
Atl_MASK_ocn = xr.DataArray(np.in1d(RMASK_ocn, [6,8,9,10,11]).reshape(RMASK_ocn.shape),
                            dims=RMASK_ocn.dims, coords=RMASK_ocn.coords)
Atl_MASK_low = xr.DataArray(np.in1d(RMASK_low, [6,8,9,10,11]).reshape(RMASK_low.shape),
                            dims=RMASK_low.dims, coords=RMASK_low.coords)
AREA_ocn = xr_AREA(domain='ocn')
AREA_low = xr_AREA(domain='ocn_low')

## surface flux integrals

## time derivative of content: $\partial_t$ FW/SALT

In [None]:
Atl_MASK_ocn.plot()

In [None]:
Atl_MASK_low.plot()

In [None]:
f, ax = plt.subplots(1, 2, figsize=(6.4,3), sharey=True)
for i, run in notebook.tqdm(enumerate(['lpd', 'lr1', 'ctrl', 'rcp'])):
#     if i>1:  continue
    if run in ['ctrl', 'rcp']:  # HIGH
        ax_= ax[1]
        AREA = xr_AREA('ocn')
        MASK = Atl_MASK_ocn
    elif run in ['lpd', 'lr1']:  # LOW
        ax_= ax[0]
        AREA = xr_AREA('ocn_low')
        MASK = Atl_MASK_low
    da = xr.open_dataarray(f'{path_prace}/SALT/SALT_dz_0-1000m_{run}.nc') \
       + xr.open_dataarray(f'{path_prace}/SALT/SALT_dz_below_1000m_{run}.nc')
    for j, (latS, latN) in notebook.tqdm(enumerate(lat_bands)):
        MASK_ = MASK.where(MASK.TLAT<latN).where(MASK.TLAT>latS)
        salt = (da*AREA).where(MASK_==1).sum(dim=['nlat','nlon'])  # m^2
        ax_.plot(np.arange(101), salt-salt.mean(),label=latS, ls=['-','--'][i%2])
        print(f'{run:4}', f'{latS:4}', f'{latN:4}', f'{salt.mean().values:4.1e}')
for i in range(2):  ax[i].legend()

In [None]:
# d/dt SALT
for i, run in enumerate(['lpd', 'lr1', 'ctrl', 'rcp']):
    if i in [0,1]: continue
    ts, te = [500,2000,200,2000][i],[530,2101,230,2101][i]
    if run in ['ctrl', 'rcp']:  # HIGH
        AREA, DZT = xr_AREA('ocn'),xr_DZ('ocn')
        MASK = Atl_MASK_ocn
    elif run in ['lpd', 'lr1']:  # LOW
        AREA, DZT = xr_AREA('ocn_low'), xr_DZ('ocn_low')
        MASK = Atl_MASK_low
        
    for j, (latS, latN) in enumerate(lat_bands):
        if j<3: continue
        MASK_ = MASK.where(MASK.TLAT<latN).where(MASK.TLAT>latS)
        vol = (DZT*AREA).where(MASK_).sum()  # m^3
        print(latS, latN, vol)
        for y in notebook.tqdm(np.arange(ts, te)):
            SALT = xr.open_dataset(f'{path_prace}/{run}/ocn_yrly_SALT_{y:04d}.nc', decode_times=False).SALT
            S = (SALT*AREA*DZT).where(MASK_).sum()/vol
            F = ((SALT-35.)*AREA*DZT).where(MASK_).sum()/vol
            if y==ts:  S_, F_ = S.copy(), F.copy()
            else:      S_, F_ = xr.concat([S_, S], dim='time'), xr.concat([F_, F], dim='time')
        d[f'SALT_{latS}N_{latN}N'] = S_
        d[f'FW_{latS}N_{latN}N']   = F_
    fn = f'{path_results}/SFWF/Atlantic_SALT_integrals_{run}'
    save_obj(d, fn)

In [None]:
dlr1 = load_obj(f'{path_results}/SFWF/Atlantic_SALT_integrals_lr1')

In [None]:
dlr1

In [None]:
f, ax = plt.subplots(2,2, figsize=(6.4,5), sharex=True)
for i, lats in enumerate(lat_bands[1:]):
    
    ax[0,0].plot(dlr1[f'FW_{lats[0]}N_{lats[1]}N'], c=f'C{i}', label=str(lats))
    ax[1,0].plot(dlr1[f'FW_{lats[0]}N_{lats[1]}N']-dlr1[f'FW_{lats[0]}N_{lats[1]}N'].mean(), c=f'C{i}')
    ax[0,1].plot(dlr1[f'SALT_{lats[0]}N_{lats[1]}N'], c=f'C{i}')
    ax[1,1].plot(dlr1[f'SALT_{lats[0]}N_{lats[1]}N']-dlr1[f'SALT_{lats[0]}N_{lats[1]}N'].mean(), c=f'C{i}')

ax[0,0].legend()

## box plot

In [None]:
FW_summary_plot('FW')
FW_summary_plot('SALT')

In [None]:
FW_region_plot('FW')
FW_region_plot('SALT')