# Model Intercomparison: melt patterns
## (Figs. 2 (idealized) and 4 (real)) 

In [None]:
import sys
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
import geopandas
import matplotlib
import matplotlib.pyplot as plt

sys.path.append('..')

from forcing import Forcing
from real_geometry import RealGeometry, glaciers
from ideal_geometry import IdealGeometry
from Plume import PlumeModel
from PICO import PicoModel
from PICOP import PicoPlumeModel
from Simple import SimpleModels

%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%load_ext autoreload
%autoreload 2

## Experiment 1: melt rates for reference setups

In [None]:
def plot_mask(ax, ds):
    kw_mask = dict(norm=matplotlib.colors.Normalize(vmin=-4, vmax=3), cmap=plt.get_cmap('ocean'), shading='nearest')
    new_mask = xr.where(ds.mask==1,2,ds.mask).where(ds.mask!=3)
    ax.pcolormesh(ds.x/1e3, ds.y/1e3, new_mask, **kw_mask)
    return

In [None]:
models = ['Simple', 'Plume', 'PICO', 'PICOP','Layer']

In [None]:
kw = dict(shading='auto', norm=matplotlib.colors.LogNorm(vmin=.1, vmax=10**2.5), cmap='inferno')
kw_grid = {'width_ratios': [1,1,1,.01]}
kw_subplots = dict(constrained_layout=True)
f, axs = plt.subplots(3, 4, figsize=(8,8), **kw_subplots, gridspec_kw=kw_grid)

titles = ['a) Geometry',r'b) M$_+$', 'c) Plume', 'd) PICO', 'e) PICOP','f) Layer']
for i in range(6):  axs[int(i/3),i%3].set_title(titles[i], loc='left')
    
ds = IdealGeometry('Ocean1').create()
ds = Forcing(ds).isomip('WARM')
depths = axs[0,0].pcolormesh(ds.x/1e3, ds.y/1e3, ds.draft, shading='nearest')


models = ['Simple', 'Plume', 'PICO', 'PICOP','Layer']
for j, model in enumerate(models):
    ax = axs[int((j+1)/3),(j+1)%3]
    if model=='Simple':
        results = SimpleModels(ds).compute()
        x, y, m, melt = results.x/1e3, results.y/1e3, results.Mp, results.Mp.mean().values
    elif model=='Plume':
        results = PlumeModel(ds).compute_plume()
        x, y, m, melt = results.x/1e3, results.y/1e3, xr.where(results.m>0, results.m, 1e-9).where(results.mask==3), results.m.mean().values
    elif model=='PICO':
        _, results = PicoModel(ds).compute_pico()
        x, y, m, melt = results.x/1e3, results.y/1e3, results.melt, results.melt.mean().values
    elif model=='PICOP':
        _, _, results = PicoPlumeModel(ds).compute_picop()
        x, y, m, melt = results.x/1e3, results.y/1e3, xr.where(results.m>0, results.m, 1e-9).where(results.mask==3), results.m.mean().values
    elif model=='Layer':
        dl = xr.open_dataset(f'../../results/Layer/Layer_Ocean1_5_ISOMIP_0.nc')
        x, y, m, melt = dl.x/1e3, dl.y/1e3, dl.melt.where(dl.mask==3), dl.melt.where(dl.mask==3).mean().values
        
    rates = ax.pcolormesh(x, y, m, **kw)
    ax.text(.95,.5, f'{melt:.2f} m/yr', transform=ax.transAxes, ha='right', va='center', color='w', fontsize=12)

        
axs[1,1].set_xlabel('x  [km]')
for i in range(2):
    plt.colorbar([depths, rates][i], ax=axs[i,-1], label=['shelf base depth [m]','melt rate [m/yr]'][i])
    axs[i,0].set_ylabel('y  [km]')
    for j in range(3):
        if i==0:  
            axs[j,-1].axis('off')
            axs[i,j].set_xticklabels([])
        plot_mask(ax=axs[i,j], ds=ds)
        axs[i,j].set_xlim((450,None))
        if j>0:   axs[i,j].set_yticklabels([])
            
""" sensitivities """
ds_MIP = xr.open_dataset(f'../../results/MIP/ISOMIP_melt_rates.nc')  # created in MIP_sensitivity.ipynb
lm = [r'$M_{+}$','Plume','PICO','PICOP','Layer']

# axs[2,1].set_yticklabels([])
axs[2,0].axhline(0, c='k', lw=.5)
axs[2,0].plot(ds_MIP.T_bottom,  7/2*(ds_MIP.T_bottom+2), c='lightgrey', label='7 m/yr/K')
axs[2,0].plot(ds_MIP.T_bottom, 16/2*(ds_MIP.T_bottom+2), c='darkgrey', label='16 m/yr/K')
lb = axs[2,2].axhline( 7/2, c='lightgrey', label=r'$M_{Lev}^-$')
ub = axs[2,2].axhline(16/2, c='darkgrey', label=r'$M_{Lev}^+$')
for i in range(3):
    cold = axs[2,i].axvspan(-1.95,-1.85, alpha=.1, color='blue', label='COLD')
    warm = axs[2,i].axvspan(.95,1.05, alpha=.1, color='red', label='WARM')
    axs[2,i].set_title(['g) mean melt', 'h) ground. zone melt', 'i) melt sensitivity'][i], loc='left')
handles = [lb,ub]
for j, model in enumerate(['Mp']+models[1:]):
    ds = ds_MIP.sel({'geometry':'Ocean1', 'model':model}).copy()
    axs[2,0].plot(ds.T_bottom, ds.melt_avg)
    axs[2,1].plot(ds.T_bottom, ds.melt_grl)
    s, = axs[2,2].plot(ds.T_bottom, -np.gradient(ds.melt_avg, .29), label=lm[j])  # step size in K
    handles.append(s)
axs[2,0].legend(handles=[cold,warm], loc='upper center', frameon=False)
axs[2,3].legend(handles=handles, handlelength=1.5, bbox_to_anchor=(-.2, 0.5), loc='center left')
axs[2,1].set_xlabel(r'bottom potential temperature [$^\circ\!$C]')
axs[2,0].set_ylabel('melt rate [m/yr]')
axs[2,2].yaxis.set_label_position("right")
axs[2,2].yaxis.tick_right()
axs[2,2].set_ylabel('sensitivity [m/yr/K]')
axs[2,2].set_ylim((0, None))
# f.align_ylabels()

plt.savefig('../../figures/ISOMIP', dpi=300)