# Surface Freshwater Flux / Virtual Salt Flux: E - P - R

## observational datasets
- precip: Legates, ERAI
- total water flux: Large-Yeager, WHOI

In [None]:
import os
import sys
sys.path.append("..")
import numpy as np
import xarray as xr
import cmocean
import cartopy
import cartopy.crs as ccrs
import matplotlib
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
matplotlib.rc_file('../rc_file')
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%load_ext autoreload
%autoreload 2

In [None]:
from tqdm import tqdm_notebook
from paths import path_results, path_prace, file_RMASK_ocn, file_RMASK_ocn_low, file_ex_ocn_ctrl
from timeseries import IterateOutputCESM
from xr_regression import ocn_field_regression

In [None]:
do = xr.open_dataset(file_ex_ocn_ctrl, decode_times=False)
RMASK_ocn = do.REGION_MASK
RMASK_ocn_low = xr.open_dataset(file_RMASK_ocn_low).REGION_MASK

In [None]:
%%time
# # ctrl: 2min 14s, lpd: 3 sec
# run='lpd'
# if run=='ctrl':  yy = 200
# elif run=='lpd':  yy = 500
# xr.concat([xr.open_dataset(f'{path_prace}/{run}/ocn_yrly_EVAP_F_PREC_F_ROFF_F_0{y}.nc') for y in np.arange(yy,yy+30)],
#           dim='time').mean(dim='time').to_netcdf(f'{path_prace}/{run}/EVAP_PREC_ROFF_{run}_mean_{yy}-{yy+30}.nc')
EVAP_mean_ctrl = xr.open_dataset(f'{path_prace}/ctrl/EVAP_PREC_ROFF_ctrl_mean_200-230.nc').EVAP_F
PREC_mean_ctrl = xr.open_dataset(f'{path_prace}/ctrl/EVAP_PREC_ROFF_ctrl_mean_200-230.nc').PREC_F
ROFF_mean_ctrl = xr.open_dataset(f'{path_prace}/ctrl/EVAP_PREC_ROFF_ctrl_mean_200-230.nc').ROFF_F

EVAP_mean_lpd  = xr.open_dataset(f'{path_prace}/lpd/EVAP_PREC_ROFF_lpd_mean_500-530.nc').EVAP_F
PREC_mean_lpd  = xr.open_dataset(f'{path_prace}/lpd/EVAP_PREC_ROFF_lpd_mean_500-530.nc').PREC_F
ROFF_mean_lpd  = xr.open_dataset(f'{path_prace}/lpd/EVAP_PREC_ROFF_lpd_mean_500-530.nc').ROFF_F

In [None]:
%%time
# concat: lr1: 11 sec
# trend: lr1: 3 sec
# run = 'lr1'
# for i, (y,m,f) in tqdm_notebook(enumerate(IterateOutputCESM(domain='ocn', run=run, tavg='yrly', name='EVAP_F_PREC_F_ROFF_F'))):
#     ds = xr.open_dataset(f, decode_times=False)
#     if i==0:  EVAP, PREC, ROFF = [], [], []
#     EVAP.append(ds.EVAP_F)
#     PREC.append(ds.PREC_F)
#     ROFF.append(ds.ROFF_F)
# xr.concat(EVAP, dim='time').to_netcdf(f'{path_prace}/{run}/EVAP_yrly.nc')
# xr.concat(PREC, dim='time').to_netcdf(f'{path_prace}/{run}/PREC_yrly.nc')
# xr.concat(ROFF, dim='time').to_netcdf(f'{path_prace}/{run}/ROFF_yrly.nc')
# EVAP = xr.open_dataarray(f'{path_prace}/{run}/EVAP_yrly.nc', decode_times=False)
# PREC = xr.open_dataarray(f'{path_prace}/{run}/PREC_yrly.nc', decode_times=False)
# ROFF = xr.open_dataarray(f'{path_prace}/{run}/ROFF_yrly.nc', decode_times=False)

# EVAP_trend = ocn_field_regression(xa=EVAP, run=run)
# PREC_trend = ocn_field_regression(xa=PREC, run=run)
# ROFF_trend = ocn_field_regression(xa=ROFF, run=run)
# EVAP_trend[0].to_netcdf(f'{path_prace}/{run}/EVAP_yrly_trend.nc')
# PREC_trend[0].to_netcdf(f'{path_prace}/{run}/PREC_yrly_trend.nc')
# ROFF_trend[0].to_netcdf(f'{path_prace}/{run}/ROFF_yrly_trend.nc')
EVAP_trend_rcp = xr.open_dataarray(f'{path_prace}/rcp/EVAP_yrly_trend.nc')
PREC_trend_rcp = xr.open_dataarray(f'{path_prace}/rcp/PREC_yrly_trend.nc')
ROFF_trend_rcp = xr.open_dataarray(f'{path_prace}/rcp/ROFF_yrly_trend.nc')

EVAP_trend_lr1 = xr.open_dataarray(f'{path_prace}/lr1/EVAP_yrly_trend.nc')
PREC_trend_lr1 = xr.open_dataarray(f'{path_prace}/lr1/PREC_yrly_trend.nc')
ROFF_trend_lr1 = xr.open_dataarray(f'{path_prace}/lr1/ROFF_yrly_trend.nc')

In [None]:
for k, run in enumerate(['rcp','lr1']):
    (PREC_mean, EVAP_mean, ROFF_mean) = [(PREC_mean_ctrl, EVAP_mean_ctrl, ROFF_mean_ctrl),\
                                         (PREC_mean_lpd , EVAP_mean_lpd , ROFF_mean_lpd )][k]
    (PREC_trend, EVAP_trend, ROFF_trend) = [(PREC_trend_rcp, EVAP_trend_rcp, ROFF_trend_rcp),\
                                            (PREC_trend_lr1, EVAP_trend_lr1, ROFF_trend_lr1)][k]
    RMASK = [RMASK_ocn, RMASK_ocn_low][k]
    f = plt.figure(figsize=(6.4,12), constrained_layout=False)
    for i, q1 in enumerate(['Precipitation', 'Evaporation', 'P+R+E']):
        mean = [PREC_mean, EVAP_mean, PREC_mean+EVAP_mean+ROFF_mean][i]
        trend = [PREC_trend, EVAP_trend, PREC_trend+EVAP_trend+ROFF_trend][i]
        mean_min, mean_max = [0,-7.5,-7.5][i], [7.5,0,7.5][i]
        trend_min, trend_max = -2.5, 2.5
        mean_cmap, trend_cmap = ['viridis', 'viridis_r', 'cmo.tarn'][i], ['cmo.balance', 'cmo.balance_r', 'cmo.balance'][i]
        ax = f.add_subplot(3, 5, 1+i*5)
        ax.axis('off')
        ax.set_position([.04,.01+(2-i)*.32,.02,.3])
        ax.text(.5, .5, q1, transform=ax.transAxes, rotation='vertical', va='center', ha='right', fontsize=20)
        for j, q2 in tqdm_notebook(enumerate(['mean', 'trend'])):
            xa = [mean.where(RMASK>0)*24*3600, trend.where(RMASK>0)*24*3600*365*100][j]
            ax = f.add_subplot(3, 3, 3*i+2+j, projection=ccrs.LambertAzimuthalEqualArea(central_longitude=-30, central_latitude=20))
            minv, maxv, cmap = [mean_min, trend_min][j], [mean_max, trend_max][j], [mean_cmap, trend_cmap][j]
            if i==0:  ax.text(.5, 1.05, ['mean [mm/day]', 'trend [mm/day/100y]'][j], transform=ax.transAxes, fontsize=16, ha='center')
            ax.set_position([.05+j*.46,.01+(2-i)*.32,.45,.3])
            ax.set_extent((-6e6, 3.5e6, -8.5e6, 1e7), crs=ccrs.LambertAzimuthalEqualArea())
            cax, kw = matplotlib.colorbar.make_axes(ax,location='right',pad=0.01,shrink=0.9)
            im = ax.pcolormesh(xa.ULONG, xa.ULAT, xa, cmap=cmap,  # kg/m2/s to mm/d
    #         im = ax.contourf(xa.ULONG, xa.ULAT, xa, cmap=cmap, levels=14,  # kg/m2/s to mm/d
                               vmin=minv, vmax=maxv, transform=ccrs.PlateCarree())
            cbar = f.colorbar(im, cax=cax, extend='both', **kw)
            cbar.ax.tick_params(labelsize=10)

            ax.add_feature(cartopy.feature.LAND, zorder=2, edgecolor='black', facecolor='grey')
            if i==2:
                rivers = cartopy.feature.NaturalEarthFeature(
                            category='physical', name='rivers_lake_centerlines',
                            scale='50m', facecolor='none', edgecolor='lightgrey')
                ax.add_feature(rivers, linewidth=.3, zorder=3)
                ax.add_feature(cartopy.feature.RIVERS, zorder=4)
            gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False)
            gl.ylocator = matplotlib.ticker.FixedLocator([-90, -60, -30, 0, 30, 60, 90])
    plt.savefig(f'{path_results}/SFWF/SFWF_map_{run}')