# Heat and salinity fluxes between basins

In [None]:
import os
import sys
from tqdm import tqdm
import numpy as np
import xarray as xr
import cmocean
import cartopy
import cartopy.crs as ccrs
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
matplotlib.rc_file('../rc_file')
%load_ext autoreload
%autoreload 2
%aimport - numpy - scipy - matplotlib.pyplot

In [None]:
sys.path.append("..")
from paths import path_results, path_samoc, path_prace, file_ex_ocn_ctrl
from regions import boolean_mask, Atlantic_mask, regions_dict
from constants import rho_sw, cp_sw
from timeseries import IterateOutputCESM
from xr_DataArrays import xr_AREA, xr_DZ, dll_dims_names
from xr_regression import xr_quadtrend

In [None]:
ds = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False)

In [None]:
ds.VNS

In [None]:
def advection_cells(from_basin_mask, to_basin_mask):
    """ arrays with which east-/northward advection need to be multiplied 
    adv_E:   1   if to_basin to the East of from_basin
            -1   if to_basin to the West of from_basin
           nan   elsewhere
    adv_N:   1   if to_basin to the North of from_basin
            -1   if to_basin to the South of from_basin
           nan   elsewhere
           
    """
    assert np.shape(from_basin_mask)==np.shape(to_basin_mask)
    assert from_basin_mask.dims==to_basin_mask.dims
    (lat, lon) = from_basin_mask.dims
    m0, m1 = from_basin_mask, to_basin_mask
    adv_E = (m0*m1.roll(shifts={lon:-1}, roll_coords=lon) - m0*m1.roll(shifts={lon:1}, roll_coords=lon)).fillna(0)
    adv_N = (m0*m1.shift(shifts={lat:1}) - m0*m1.shift(shifts={lat:-1})).fillna(0)
    if np.all(np.isnan(adv_E)) and np.all(np.isnan(adv_N)):
        print('warning, no neighbouring cells!')
    return adv_E, adv_N

def transport_into(domain, basin, VN_adv, UE_adv):
    """ computes fluxes """
    assert domain in ['ocn', 'ocn_low']
    assert VN_adv.units==UE_adv.units
    if VN_adv.units=='degC/s':
        conversion = rho_sw*cp_sw
        unit = 'W'
    elif VN_adv.units=='gram/kilogram/s':
        conversion = rho_sw*1e-3
        unit = 'kg/s'
    else:
        raise ValueError('units need to be in "degC/s" or "gram/kilogram/s"')
    
    if basin=='Atlantic':
        basin_nr = 6
        neighbours = [1,8,9]
    elif basin=='Pacific':
        basin_nr = 2
        neighbours = [1,3,10]
    elif basin=='Southern':
        basin_nr = 1
        neighbours = [2,3,6]
    else:
        raise ValueError('basin needs to be "Atlantic", "Pacific", or "Southern"')
        
    dims = [dim for dim in dll_dims_names(domain=domain)]
    DZ = xr_DZ(domain=domain)
    AREA = xr_AREA(domain=domain)
    basin_mask = boolean_mask(domain=domain, mask_nr=basin_nr)
    
    for i, n in tqdm(enumerate(neighbours)):
#         if i>0: continue
        neighbour_mask = boolean_mask(domain=domain, mask_nr=n)
        adv_E, adv_N = advection_cells(from_basin_mask=neighbour_mask, to_basin_mask=basin_mask)
        transport = ((adv_E*UE_adv + adv_N*VN_adv)*AREA*DZ).sum(dim=dims)
        transport.name = f'{regions_dict[n]}'
        transport.attrs['units'] = unit
        if i==0: temp=transport
        else: temp = xr.merge([temp, transport])
    
    return temp

In [None]:
dat = transport_into(domain='ocn', basin='Atlantic', VN_adv=ds.VNS, UE_adv=ds.UES)

In [None]:
dat

In [None]:
m0 = Atlantic_mask(domain='ocn')  # Atlantic
m1 = boolean_mask(domain='ocn', mask_nr=10)

In [None]:
(m0*2+m1).plot()
plt.ylim((2000,2400))
plt.xlim((500,1500))

In [None]:
m0.TLAT.where(adv_N).plot()
m0.TLONG.where(adv_E).plot()
plt.ylim((2100,2300))
plt.xlim((500,1200))

In [None]:
f, axs = plt.subplots(1, 2, sharey=True)
for i, ax in enumerate(axs):
    [adv_N, adv_E][i].plot(ax=ax)
    ax.set_ylim((2215,2225))
    ax.set_xlim((925,1025))

In [None]:
ds