In [None]:
"""
Created on Fri Jul 09 09:39 2021

This script is to transform the mean profiles to 2D fields to accelerate the tuning process

@author: Clara Burgard
"""

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import basal_melt_param.useful_functions as uf
import basal_melt_param.melt_functions as meltf
from basal_melt_param.constants import *
import basal_melt_param.T_S_profile_functions as tspf
from scipy import stats
from dask import delayed

import distributed
import glob

In [None]:
%matplotlib qt5

In [None]:
client = distributed.Client(n_workers=18,dashboard_address=':8795', local_directory='/tmp', memory_limit='6GB')

In [None]:
client

READ IN DATA

In [None]:
nemo_run = 'bi646' # 'bf663','bi646' 

In [None]:
inputpath_data='/bettik/burgardc/DATA/NN_PARAM/interim/SMITH_'+nemo_run+'/'
inputpath_profiles='/bettik/burgardc/DATA/NN_PARAM/interim/T_S_PROF/SMITH_'+nemo_run+'/'
inputpath_mask='/bettik/burgardc/DATA/NN_PARAM/interim/ANTARCTICA_IS_MASKS/SMITH_'+nemo_run+'/'

#outputpath_simple = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/SIMPLE/nemo_5km_'+nemo_run+'/'
#inputpath_plumes = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/PLUMES/nemo_5km_'+nemo_run+'/'
plot_path = '/bettik/burgardc/PLOTS/NN_plots/Smith_formatting/'



In [None]:
# make the domain a little smaller to make the computation even more efficient - file isf has already been made smaller at its creation
map_lim = [-3000000,3000000]

chunk_size = 500

Files for param

In [None]:
file_other = xr.open_mfdataset(inputpath_data+'corrected_draft_bathy_isf.nc')
file_other_cut = uf.cut_domain_stereo(file_other, map_lim, map_lim).chunk(chunks={'x': chunk_size, 'y': chunk_size})
file_other_cut = file_other_cut.assign_coords({'time': range(len(file_other_cut.time))})
#file_conc = xr.open_mfdataset(inputpath_data+'isfdraft_conc_Ant_stereo.nc')
#file_conc_cut = uf.cut_domain_stereo(file_conc, map_lim, map_lim).chunk(chunks={'x': chunk_size, 'y': chunk_size})
#file_isf_conc = file_conc_cut['isfdraft_conc']


In [None]:
for tt in tqdm(range(36)):
    
    file_isf_orig = xr.open_mfdataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_yy'+str(tt).zfill(2)+'.nc')
    nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
    file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
    large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
    file_isf = file_isf_nonnan.sel(Nisf=large_isf).chunk(chunks={'x': chunk_size, 'y': chunk_size})

    ice_draft_pos = file_other_cut['corrected_isfdraft'].sel(time=tt)
    ice_draft_isf = ice_draft_pos.where(file_isf['ISF_mask'] == file_isf.Nisf)

    depth_of_int = ice_draft_isf.where(ice_draft_isf<file_isf['front_bot_depth_max'], 
                                       file_isf['front_bot_depth_max']).chunk(chunks={'Nisf':1})  
    
    file_TS_orig = xr.open_mfdataset(inputpath_profiles+'T_S_mean_prof_corrected_km_contshelf_yy'+str(tt).zfill(2)+'.nc', chunks={'Nisf': 1})
    file_TS = file_TS_orig.sel(Nisf=large_isf).chunk(chunks={'depth': 20})
    
    filled_TS = file_TS.ffill(dim='depth')
    
    T_isf = filled_TS['theta_ocean'].interp({'depth': depth_of_int}).drop('depth')
    
    print('here1')
    T_isf = T_isf.where(file_isf['ISF_mask']==file_isf.Nisf).sum('Nisf')#.where(depth_of_int>0)
    T_isf = T_isf.to_dataset(name='theta_in')
    print('here2')
    T_isf['salinity_in'] = filled_TS['salinity_ocean'].interp({'depth': depth_of_int}).drop('depth')
    T_isf['salinity_in'] = T_isf['salinity_in'].where(file_isf['ISF_mask']==file_isf.Nisf).sum('Nisf')#.where(depth_of_int>0)
    print('here3')
    depth_of_int = depth_of_int.where(file_isf['ISF_mask'] > 1).sum('Nisf')
    depth_of_int = depth_of_int.where(depth_of_int>0)
    T_isf['freezing_T'] = meltf.freezing_temperature(T_isf['salinity_in'], -1*depth_of_int)
    print('here4')
    T_isf['thermal_forcing'] = T_isf['theta_in'] - T_isf['freezing_T']
    T_isf['depth_of_int'] = depth_of_int
    
    
    T_isf.to_netcdf(inputpath_profiles+'T_S_2D_fields_isf_draft_oneFRIS_yy'+str(tt).zfill(2)+'.nc','w')
    del T_isf
    del depth_of_int

In [None]:
T_isf = xr.open_dataset(inputpath_profiles+'T_S_2D_fields_isf_draft_oneFRIS_yy'+str(tt).zfill(2)+'.nc')

In [None]:
T_isf2 = T_isf.where(file_isf['ISF_mask'] > 1)
T_isf2['salinity_in'].plot()

In [None]:
depth_of_int = depth_of_int.where(file_isf['ISF_mask'] == file_isf.Nisf, 0).chunk({'Nisf': 1})


In [None]:
file_TS_orig = xr.open_mfdataset(inputpath_profiles+'T_S_mean_prof_corrected_km_contshelf_yy'+str(tt).zfill(2)+'.nc', chunks={'Nisf': 1})
file_TS = file_TS_orig.sel(Nisf=large_isf).chunk(chunks={'depth': 20})
filled_TS = file_TS.ffill(dim='depth')
T_isf = filled_TS['theta_ocean'].interp({'depth': depth_of_int}).drop('depth')

In [None]:
T_isf = T_isf.to_dataset(name='theta_in')
print('here2')
T_isf['salinity_in'] = filled_TS['salinity_ocean'].interp({'depth': depth_of_int}).drop('depth')
T_isf['salinity_in'] = T_isf['salinity_in'].where(file_isf['ISF_mask']==file_isf.Nisf).sum('Nisf')#.where(depth_of_int>0)
print('here3')
depth_of_int = depth_of_int.where(file_isf['ISF_mask'] > 1).sum('Nisf')
depth_of_int = depth_of_int.where(depth_of_int>0)
T_isf['freezing_T'] = meltf.freezing_temperature(T_isf['salinity_in'], -1*depth_of_int)
print('here4')
T_isf['thermal_forcing'] = T_isf['theta_in'] - T_isf['freezing_T']
T_isf['depth_of_int'] = depth_of_int


In [None]:
ref = 'isf_draft'

In [None]:
T_isf.to_netcdf(inputpath_profiles+'T_S_2D_fields_'+ref+'_oneFRIS_yy'+str(tt).zfill(2)+'.nc','w')

In [None]:
tt = 0
ice_draft_pos = file_other_cut['corrected_isfdraft'].sel(time=tt)
ice_draft_pos.load()

In [None]:
filled_TS['theta_ocean']

In [None]:
        print('here1')
        T_isf = filled_TS['theta_ocean'].interp({'depth': depth_of_int}).drop('depth')

In [None]:
depth_of_int

In [None]:
filled_TS['theta_ocean']

In [None]:
        T_isf = T_isf.where(file_isf['ISF_mask']==file_isf.Nisf).sum('Nisf')#.where(depth_of_int>0)

In [None]:
        T_isf = T_isf.to_dataset(name='theta_in')
        print('here2')
        T_isf['salinity_in'] = filled_TS['salinity_ocean'].interp({'depth': depth_of_int}).drop('depth')
        T_isf['salinity_in'] = T_isf['salinity_in'].where(file_isf['ISF_mask']==file_isf.Nisf).sum('Nisf')#.where(depth_of_int>0)
        print('here3')
        depth_of_int = depth_of_int.where(file_isf['ISF_mask'] > 1).sum('Nisf')
        depth_of_int = depth_of_int.where(depth_of_int>0)
        T_isf['freezing_T'] = meltf.freezing_temperature(T_isf['salinity_in'], -1*depth_of_int)
        print('here4')
        T_isf['thermal_forcing'] = T_isf['theta_in'] - T_isf['freezing_T']
        T_isf['depth_of_int'] = depth_of_int

        # project it on 2D
        #T_isf = T_isf.sum('Nisf')
        T_isf.to_netcdf(inputpath_profiles+'T_S_2D_fields_'+ref+'_oneFRIS_yy'+str(tt).zfill(2)+'.nc','w')
        #del T_isf
        #del depth_of_int

First 2D fields of thermal forcing

In [None]:
ice_draft_isf = ice_draft_pos.where(param_var_of_int['ISF_mask'] == file_isf.Nisf).chunk(chunks={'Nisf': 1})
#plume_isf = plume_charac.where(param_var_of_int['ISF_mask'] == file_isf.Nisf).chunk(chunks={'Nisf': 1})

In [None]:
filled_TS = file_TS.ffill(dim='depth')#chunk({'x': chunk_size, 'y': chunk_size, 'time': 5, 'profile_domain': 1})

#for ref in ['isf_draft', 'GL_depth_cavity', 'GL_depth_lazero']:#,'bottom_front']:
#for ref in ['GL_depth_cavity', 'GL_depth_lazero']:#,'bottom_front']:
for ref in ['isf_draft']:#,'bottom_front']:
    
    print(ref)
    # DOES NOT WORK YET
    if ref == 'bottom_front':
        # Entering temperature and salinity profiles
        n = 0
        for kisf in file_isf.Nisf:
            depth_of_int_kisf = param_var_of_int['front_bot_depth_max'].sel(Nisf=kisf).where(file_isf['ISF_mask']==kisf)
            if n == 0:
                depth_of_int = depth_of_int_kisf.squeeze().drop('Nisf')
            else:
                depth_of_int = depth_of_int.combine_first(depth_of_int_kisf).squeeze().drop('Nisf')
            n = n+1
            
    elif ref == 'isf_draft':
        # ice draft depth or deepest entrance depth
        depth_of_int = ice_draft_isf.where(ice_draft_isf<param_var_of_int['front_bot_depth_max'], 
                                           param_var_of_int['front_bot_depth_max']).chunk(chunks={'Nisf':1})
        depth_of_int = depth_of_int.where(file_isf['ISF_mask'] == file_isf.Nisf, 0).chunk({'Nisf': 1})
        #depth_of_int = depth_of_int.where(file_isf['ISF_mask'] > 1).sum('Nisf')
        #depth_of_int = depth_of_int.where(depth_of_int>0)
        
    elif ref == 'GL_depth_cavity':
        # deepest GL point
        depth_of_int = -1*plume_charac['zGL'].sel(option='simple').where(file_isf['ISF_mask']==file_isf.Nisf,0).chunk({'Nisf': 1})
        #depth_of_int = depth_of_int.where(depth_of_int>0)
    elif ref == 'GL_depth_lazero':
        # depth from Lazero
        depth_of_int = -1*plume_charac['zGL'].sel(option='lazero').where(file_isf['ISF_mask']==file_isf.Nisf,0).chunk({'Nisf': 1})
        #depth_of_int = depth_of_int.where(depth_of_int>0)
        
    print('here1')
    T_isf = filled_TS['theta_ocean'].interp({'depth': depth_of_int}).drop('depth')
    T_isf = T_isf.where(file_isf['ISF_mask']==file_isf.Nisf).sum('Nisf')#.where(depth_of_int>0)
    T_isf = T_isf.to_dataset(name='theta_in')
    print('here2')
    T_isf['salinity_in'] = filled_TS['salinity_ocean'].interp({'depth': depth_of_int}).drop('depth')
    T_isf['salinity_in'] = T_isf['salinity_in'].where(file_isf['ISF_mask']==file_isf.Nisf).sum('Nisf')#.where(depth_of_int>0)
    print('here3')
    depth_of_int = depth_of_int.where(file_isf['ISF_mask'] > 1).sum('Nisf')
    depth_of_int = depth_of_int.where(depth_of_int>0)
    T_isf['freezing_T'] = meltf.freezing_temperature(T_isf['salinity_in'], -1*depth_of_int)
    print('here4')
    T_isf['thermal_forcing'] = T_isf['theta_in'] - T_isf['freezing_T']
    T_isf['depth_of_int'] = depth_of_int
    
    # project it on 2D
    #T_isf = T_isf.sum('Nisf')
    T_isf.to_netcdf(inputpath_profiles+'T_S_2D_fields_'+ref+'_oneFRIS.nc','w')
    del T_isf
    del depth_of_int

In [None]:
T_S_2D_isfdraft = xr.open_mfdataset(inputpath_profiles+'T_S_2D_fields_isf_draft.nc', chunks={'x': chunk_size, 'y': chunk_size})

In [None]:
T_S_2D_isfdraft

In [None]:
weighted_mean_test = uf.weighted_mean(
        T_S_2D_isfdraft['thermal_forcing'].where(file_isf['ISF_mask'] == file_isf.Nisf).chunk({'Nisf': 1}), 
        ['x', 'y'], 
        file_isf_conc.where(file_isf['ISF_mask'] == file_isf.Nisf).chunk({'Nisf': 1}))

In [None]:
# DOES NOT WORK YET!
weighted_mean_test.isel(time=0,profile_domain=0).where(file_isf['ISF_mask'] == file_isf.Nisf).chunk({'Nisf': 1}).sum('Nisf').plot()

