# Melt pattern validation for Pine Island and Crosson-Dotson (Figs. 4/5)

In [None]:
import sys
import numpy as np
import xarray as xr
import pyproj
import cmocean as cmo
import warnings
import matplotlib as mpl
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

sys.path.append("..")
warnings.filterwarnings("ignore")

np.seterr(all='ignore')
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%load_ext autoreload
%autoreload 2

from PICO import PicoModel
from Plume import PlumeModel
from PICOP import PicoPlumeModel
from Simple import SimpleModels
from forcing import Forcing
from plotfunctions import *
from real_geometry import RealGeometry

In [None]:
mpl.rcParams['figure.subplot.bottom'] = .01
mpl.rcParams['figure.subplot.top']    = .94
mpl.rcParams['figure.subplot.left']   = .01
mpl.rcParams['figure.subplot.right']  = .99
mpl.rcParams['figure.subplot.wspace'] = .05
mpl.rcParams['figure.subplot.hspace'] = .15

In [None]:
timep_w = slice("2006-1-1","2012-1-1")  # Time slice MITgcm
timep_c = slice("2013-1-1","2017-1-1")
date = '2021_02_25/'

In [None]:
# auxiliary functions
def printvals(ax,melt):
    ax.text(plon,plat,f'{np.nanmean(melt):.1f}',transform=ccrs.PlateCarree(),c='w',ha='center')
    
def createax(n,title):
    ax = fig.add_subplot(3,3,n,projection=proj)
    makebackground(ax)
    ax.set_extent(axex,crs=ccrs.PlateCarree())
    ax.set_title(title,loc='left')
    return ax

def weightmelt(melt_w,melt_c):
    return (n_w*melt_w+n_c*melt_c)/(n_w+n_c)

In [None]:
%%time
for i, geom in enumerate(['PineIsland','CrossDots']):
    if geom=='PineIsland':
        proj = ccrs.SouthPolarStereo(true_scale_latitude=-75,central_longitude=245-360)
        axex = [257.7,260.3,-75.4,-74.3]
        y1,y2 = -75.45,-74.05     # Boundaries MITgcm
        x1,x2 = 257,262
        plon,plat = 258.1,-74.8   # Locations for text
        Tdeep, ztclw, ztclc = 1.2, -400, -600
        spinup = 180
        mobs = '13-15'            # Literature value observations
        n_w, n_c = 4, 4
        fn_obs = '../../data/davidshean/PineIsland_init.nc'
    elif geom=='CrossDots':
        proj = ccrs.SouthPolarStereo(true_scale_latitude=-75,central_longitude=245-360)
        axex = [245.3,250.8,-75.3,-74.15]
        y1,y2 = -75.45,-74.05
        x1,x2 = 245.3,251
        plon,plat = 250.5,-74.4
        Tdeep, ztclw, ztclc = 0.7, -400, -700
        spinup = 180
        mobs = '5.5-7'
        n_w, n_c = 2, 5
        fn_obs = '../../data/gourmelen/CrossonDotson.nc'
        
    fig = plt.figure(figsize=(8,[11,7][i]))

    """Geometry"""
    ax = createax(1,'a) Geometry')
    ds = xr.open_dataset('../../data/BedMachine/BedMachineAntarctica_2020-07-15_v02.nc')
    ds = ds.isel(x=slice(3000,4000),y=slice(6500,9000))
    ds = add_lonlat(ds)
    draft = xr.where(ds.mask==3,(ds.surface-ds.thickness).astype('float64'),np.nan)
    IM = plt.pcolormesh(ds.lon,ds.lat,draft,vmin=-1000,vmax=0,cmap=plt.get_cmap('cmo.rain_r'),transform=ccrs.PlateCarree(),shading='auto')

    """Colorbar"""
    ax3 = fig.add_subplot(333)
    ax3.set_visible(False)
    axins = inset_axes(ax3,width="5%",height="100%",loc='lower left',bbox_to_anchor=(.1,0, 1, 1),bbox_transform=ax3.transAxes,borderpad=0)
    cbar = plt.colorbar(IM, cax=axins,extend='min')
    cbar.set_label('Draft depth [m]', labelpad=0)

    """Observations"""
    ax = createax(2,'b) Observations')
    ds = xr.open_dataset(fn_obs)
    ds = add_lonlat(ds)
    melt = ds.Band1
    ax.text(plon,plat,mobs,transform=ccrs.PlateCarree(),c='w',ha='center')
    IM = plotmelt(ax,ds.lon,ds.lat,melt)

    """Colorbar"""
    axins = inset_axes(ax3,width="5%",height="100%",loc='lower left',bbox_to_anchor=(.6,0, 1, 1),bbox_transform=ax3.transAxes,borderpad=0)
    cbar = plt.colorbar(IM, cax=axins,extend='both')
    cbar.set_ticks([1,10,100])
    cbar.set_ticklabels([1,10,100])
    cbar.set_label('Melt rate [m/yr]', labelpad=0)
    
    """ Models """
    dp = RealGeometry(geom,n=5).create()
    dp = add_lonlat(dp)
    dw = Forcing(dp).tanh2(ztcl=ztclw, Tdeep=Tdeep)
    dc = Forcing(dp).tanh2(ztcl=ztclc, Tdeep=Tdeep)
    
    for j, model in enumerate(['Simple','Plume', 'PICO', 'PICOP', 'Layer']):
        title = [r'c) M$_+$','d) Plume','e) PICO','f) PICOP','g) Layer'][j]
        ax = createax(4+j,title)
        if model=='Simple':
            ds = SimpleModels(dw).compute()
            melt_w = ds['Mp']
            melt_c = SimpleModels(dc).compute()['Mp']
        elif model=='Plume':
            ds = PlumeModel(dw).compute_plume()
            melt_w = ds['m']
            melt_c = PlumeModel(dc).compute_plume()['m']
        elif model=='PICO':
            ds = PicoModel(ds=dw).compute_pico()[1]
            melt_w = ds['melt']
            melt_c = PicoModel(ds=dc).compute_pico()[1]['melt']
        elif model=='PICOP':
            ds = PicoPlumeModel(ds=dw).compute_picop()[2]
            melt_w = ds['m']
            melt_c = PicoPlumeModel(ds=dc).compute_picop()[2]['m']
        elif model=='Layer':
            ds = xr.open_dataset(f'../../results/Layer/{date}{geom}_tanh2_Tdeep{Tdeep}_ztcl{ztclw}_{spinup:.3f}.nc')
            ds = add_lonlat(ds)
            melt_w = np.where(ds.mask==3,ds.melt,np.nan)
            ds = xr.open_dataset(f'../../results/Layer/{date}{geom}_tanh2_Tdeep{Tdeep}_ztcl{ztclc}_{spinup:.3f}.nc')
            ds = add_lonlat(ds)
            melt_c = np.where(ds.mask==3,ds.melt,np.nan)
        
        melt = weightmelt(melt_w,melt_c)
        printvals(ax,melt)
        melt = np.where(melt<1,1,melt)
        plotmelt(ax,ds.lon,ds.lat,melt)

    """MITgcm"""
    ax = createax(9,'h) MITgcm')
    ds = xr.open_dataset('../../data/paulholland/melt.nc')
    ds = ds.sel(LONGITUDE=slice(x1,x2),LATITUDE=slice(y1,y2),TIME=timep_w)
    ds = ds.mean(dim='TIME')
    melt_w = xr.where(ds.melt==0,np.nan,ds.melt)
    ds = xr.open_dataset('../../data/paulholland/melt.nc')
    ds = ds.sel(LONGITUDE=slice(x1,x2),LATITUDE=slice(y1,y2),TIME=timep_c)
    ds = ds.mean(dim='TIME')    
    lon = ds.LONGITUDE
    lat = ds.LATITUDE-.05
    melt_c = xr.where(ds.melt==0,np.nan,ds.melt)
    melt = weightmelt(melt_w,melt_c)
    printvals(ax,melt)
    plotmelt(ax,lon,lat,melt)

    """Save figure"""
    plt.savefig(f"../../figures/Validation_{['PIG','CD'][i]}.png",dpi=300)
    plt.show()