# Surface Freshwater Flux / Virtual Salt Flux: E - P - R

## observational datasets
- precip: Legates, ERAI
- total water flux: Large-Yeager, WHOI

In [None]:
import os
import sys
import time
sys.path.append("..")
import numpy as np
import xarray as xr
import cmocean
import cartopy
import cartopy.crs as ccrs
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
matplotlib.rc_file('../rc_file')
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%load_ext autoreload
%autoreload 2

In [None]:
from tqdm import notebook
from paths import path_results, path_prace, file_RMASK_ocn, file_RMASK_ocn_low, file_ex_ocn_ctrl, file_ex_ocn_lpd
from timeseries import IterateOutputCESM
from xr_DataArrays import xr_AREA
from xr_regression import ocn_field_regression

In [None]:
dh = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False)
dl = xr.open_dataset(file_ex_ocn_lpd , decode_times=False)
RMASK_ocn = dh.REGION_MASK
RMASK_ocn_low = dl.REGION_MASK

In [None]:
dh = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False)
RMASK_ocn = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False).REGION_MASK
RMASK_low = xr.open_dataset(file_RMASK_ocn_low, decode_times=False).REGION_MASK
Atl_MASK_ocn = xr.DataArray(np.in1d(RMASK_ocn, [6,7,8,9,12]).reshape(RMASK_ocn.shape),
                            dims=RMASK_ocn.dims, coords=RMASK_ocn.coords)
Atl_MASK_low = xr.DataArray(np.in1d(RMASK_low, [6,7,8,9,12]).reshape(RMASK_low.shape),
                            dims=RMASK_low.dims, coords=RMASK_low.coords)
AREA_ocn = xr_AREA(dhmain='ocn')
AREA_low = xr_AREA(dhmain='ocn_low')

In [None]:
f, ax = plt.subplots(4,2, figsize=(12,15))
for i, d in enumerate([dh, dl]):
    maxv = 5e-4
    kw = dict(add_colorbar=False, vmax=maxv)
    d.SFWF.plot(ax=ax[0,i], **kw)
    if i ==0:  Sum = d.PREC_F + d.EVAP_F + d.MELT_F + d.ROFF_F
    if i ==1:  Sum = d.PREC_F + d.EVAP_F + d.MELT_F + d.ROFF_F + d.IOFF_F
    Sum.plot(ax=ax[1,i], **kw)
    d.SALT_F.plot(ax=ax[2,i], **kw)
    
    kw = dict(add_colorbar=False, vmax=1e-5)
    (d.SFWF-Sum-d.SALT_F*d.sflux_factor/d.salinity_factor).plot(ax=ax[3,i], **kw)
    


the difference visible in the HIGH on the left is the IOFF_F ice runoff which is not included in the model output, while on the right it is the restoring FW flux, mainly to move freshwater between closed off Baltic/Black/Red Seas and the ocean. Not sure what happens in the Caspian Sea and the Kara(?) Sea

In [None]:
dh.ROFF_F.where(dh.REGION_MASK>0).plot(vmin=0, vmax=1e-4)
plt.xlim((300,1000))
plt.ylim((1700,2400))

In [None]:
1/35

In [None]:
from FW_budget import make_SFWF_surface_int_dict

In [None]:
make_SFWF_surface_int_dict()

In [None]:
make_SFWF_trends('lr1')

In [None]:
make_SFWF_trends('rcp')

In [None]:
for q in IterateOutputCESM(domain='ocn', run='rcp', tavg='yrly', name='SFWF'):
    print(q)

In [None]:
xr.open_dataarray(f'{path_prace}/lpd/SFWF_lpd_mean_500-529.nc').plot()


In [None]:
Et_rcp = xr.open_dataarray(f'{path_prace}/rcp/EVAP_F_yrly_trend_rcp.nc')

In [None]:
Et_rcp

In [None]:
(d.SFWF-Sum-d.SALT_F*d.sflux_factor/d.salinity_factor).plot()
plt.xlim((30,120))
plt.ylim((200,384))


In [None]:
dh.salinity_factor

In [None]:
dh.SFWF

In [None]:
ds_ctrl = xr.open_dataset(f'{path_prace}/ctrl/EVAP_F_PREC_F_ROFF_F_ctrl_mean_200-230.nc')
EVAP_mean_ctrl = ds_ctrl.EVAP_F
PREC_mean_ctrl = ds_ctrl.PREC_F
ROFF_mean_ctrl = ds_ctrl.ROFF_F

ds_lpd = xr.open_dataset(f'{path_prace}/lpd/EVAP_F_PREC_F_ROFF_F_lpd_mean_500-530.nc')
EVAP_mean_lpd  = ds_lpd.EVAP_F
PREC_mean_lpd  = ds_lpd.PREC_F
ROFF_mean_lpd  = ds_lpd.ROFF_F

In [None]:
# total freshwater flux 34S-60N
for i, run in enumerate(['ctrl', 'lpd']):
    (Em, Pm, Rm) = [(EVAP_mean_ctrl, PREC_mean_ctrl, ROFF_mean_ctrl),
                    (EVAP_mean_lpd , PREC_mean_lpd , ROFF_mean_lpd)][i]
    AREA = [AREA_ocn, AREA_low][i]
    MASK = [Atl_MASK_ocn, Atl_MASK_low][i].where(Em.TLAT>-35).where(Em.TLAT<60)
    plt.figure()
    MASK.plot()
    print((AREA*(Em+Pm+Rm)).where(MASK).sum()/1e6/1e3)

In [None]:
EVAP_trend_rcp = xr.open_dataarray(f'{path_prace}/rcp/EVAP_yrly_trend.nc')
PREC_trend_rcp = xr.open_dataarray(f'{path_prace}/rcp/PREC_yrly_trend.nc')
ROFF_trend_rcp = xr.open_dataarray(f'{path_prace}/rcp/ROFF_yrly_trend.nc')

EVAP_trend_lr1 = xr.open_dataarray(f'{path_prace}/lr1/EVAP_yrly_trend.nc')
PREC_trend_lr1 = xr.open_dataarray(f'{path_prace}/lr1/PREC_yrly_trend.nc')
ROFF_trend_lr1 = xr.open_dataarray(f'{path_prace}/lr1/ROFF_yrly_trend.nc')

In [None]:
for k, run in enumerate(['rcp','lr1']):
    (PREC_mean, EVAP_mean, ROFF_mean) = [(PREC_mean_ctrl, EVAP_mean_ctrl, ROFF_mean_ctrl),\
                                         (PREC_mean_lpd , EVAP_mean_lpd , ROFF_mean_lpd )][k]
    (PREC_trend, EVAP_trend, ROFF_trend) = [(PREC_trend_rcp, EVAP_trend_rcp, ROFF_trend_rcp),\
                                            (PREC_trend_lr1, EVAP_trend_lr1, ROFF_trend_lr1)][k]
    RMASK = [RMASK_ocn, RMASK_ocn_low][k]
    f = plt.figure(figsize=(6.4,12), constrained_layout=False)
    for i, q1 in enumerate(['Precipitation', 'Evaporation', 'P+R+E']):
        mean = [PREC_mean, EVAP_mean, PREC_mean+EVAP_mean+ROFF_mean][i]
        trend = [PREC_trend, EVAP_trend, PREC_trend+EVAP_trend+ROFF_trend][i]
        mean_min, mean_max = [0,-7.5,-7.5][i], [7.5,0,7.5][i]
        trend_min, trend_max = -2.5, 2.5
        mean_cmap, trend_cmap = ['viridis', 'viridis_r', 'cmo.tarn'][i], ['cmo.balance', 'cmo.balance_r', 'cmo.balance'][i]
        ax = f.add_subplot(3, 5, 1+i*5)
        ax.axis('off')
        ax.set_position([.04,.01+(2-i)*.32,.02,.3])
        ax.text(.5, .5, q1, transform=ax.transAxes, rotation='vertical', va='center', ha='right', fontsize=20)
        for j, q2 in tqdm_notebook(enumerate(['mean', 'trend'])):
            xa = [mean.where(RMASK>0)*24*3600, trend.where(RMASK>0)*24*3600*365*100][j]
            ax = f.add_subplot(3, 3, 3*i+2+j, projection=ccrs.LambertAzimuthalEqualArea(central_longitude=-30, central_latitude=20))
            minv, maxv, cmap = [mean_min, trend_min][j], [mean_max, trend_max][j], [mean_cmap, trend_cmap][j]
            if i==0:  ax.text(.5, 1.05, ['mean [mm/day]', 'trend [mm/day/100y]'][j], transform=ax.transAxes, fontsize=16, ha='center')
            ax.set_position([.05+j*.46,.01+(2-i)*.32,.45,.3])
            ax.set_extent((-6e6, 3.5e6, -8.5e6, 1e7), crs=ccrs.LambertAzimuthalEqualArea())
            cax, kw = matplotlib.colorbar.make_axes(ax,location='right',pad=0.01,shrink=0.9)
            im = ax.pcolormesh(xa.ULONG, xa.ULAT, xa, cmap=cmap,  # kg/m2/s to mm/d
    #         im = ax.contourf(xa.ULONG, xa.ULAT, xa, cmap=cmap, levels=14,  # kg/m2/s to mm/d
                               vmin=minv, vmax=maxv, transform=ccrs.PlateCarree())
            cbar = f.colorbar(im, cax=cax, extend='both', **kw)
            cbar.ax.tick_params(labelsize=10)

            ax.add_feature(cartopy.feature.LAND, zorder=2, edgecolor='black', facecolor='grey')
            if i==2:
                rivers = cartopy.feature.NaturalEarthFeature(
                            category='physical', name='rivers_lake_centerlines',
                            scale='50m', facecolor='none', edgecolor='lightgrey')
                ax.add_feature(rivers, linewidth=.3, zorder=3)
                ax.add_feature(cartopy.feature.RIVERS, zorder=4)
            gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False)
            gl.ylocator = matplotlib.ticker.FixedLocator([-90, -60, -30, 0, 30, 60, 90])
    plt.savefig(f'{path_results}/SFWF/SFWF_map_{run}')