In [None]:
"""
Created on Wed Apr 13 14:17 2022

Prepare vertical profiles of T and S

Author: @claraburgard
"""

In [None]:
import numpy as np
import xarray as xr
from tqdm.notebook import trange, tqdm
import basal_melt_neural_networks.data_formatting as dfmt

from dask import delayed

import distributed
import glob

READ IN DATA

In [None]:
nemo_run = 'OPM031'

In [None]:
inputpath_data='/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/NEMO_eORCA025.L121_'+nemo_run+'_ANT_STEREO/'
inputpath_mask = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_'+nemo_run+'/'
inputpath_profiles = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/T_S_PROF/nemo_5km_'+nemo_run+'/'
outputpath_simple = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/SIMPLE/nemo_5km_'+nemo_run+'/'
inputpath_plumes = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/PLUMES/nemo_5km_'+nemo_run+'/'
outputpath = '/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/'
plot_path = '/bettik/burgardc/PLOTS/first-look/'

In [None]:
# make the domain a little smaller to make the computation even more efficient - file isf has already been made smaller at its creation
map_lim = [-3000000,3000000]

chunk_size = 300

In [None]:
file_isf_orig = xr.open_dataset(inputpath_mask+'nemo_5km_isf_masks_and_info_and_distance_new.nc')
nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
file_isf = file_isf_nonnan.sel(Nisf=large_isf)

In [None]:
#file_TS_orig = xr.open_mfdataset(inputpath_profiles+'T_S_profiles_per_iceshelf_1980-2018.nc', chunks={'Nisf': 1})
#file_TS_orig = xr.open_mfdataset(inputpath_profiles+'T_S_mean_prof_km_1980-2018.nc', chunks={'Nisf': 1})
#file_TS_orig = xr.open_mfdataset(inputpath_profiles+'T_S_mean_prof_km_contshelf_1980-2018.nc', chunks={'Nisf': 1})
file_TS_orig = xr.open_dataset(inputpath_profiles+'T_S_mean_prof_corrected_km_contshelf_and_offshore_1980-2018.nc')
file_TS = file_TS_orig.sel(Nisf=large_isf)
file_TS_dom = file_TS.sel(profile_domain=50).drop('profile_domain')

In [None]:
#file_TS_cut_bot = file_TS_dom.where(file_TS_dom.depth < file_isf['front_bot_depth_max'].max(), drop=True) # to keep in case I go for offshore profiles at some point
file_TS_cut_bot = file_TS_dom.dropna('depth', how='all')

In [None]:
file_TS_cut_top_bot = xr.concat([file_TS_cut_bot.isel(depth=0),file_TS_cut_bot.where(file_TS_dom.depth >= 100, drop=True)], dim='depth')

In [None]:
filled_TS = file_TS_cut_top_bot.ffill('depth')

In [None]:
T_ds = None
S_ds = None

for dd, ddepth in tqdm(enumerate(filled_TS.depth)):
    T_depth = filled_TS['theta_ocean'].sel(depth=ddepth).drop('depth')
    T_da = xr.DataArray(data=T_depth).rename('T_'+str(dd+1).zfill(3))
    
    if T_ds is None:
        T_ds = T_da.to_dataset()
    else:
        T_ds = xr.merge([T_ds, T_da])
        
    S_depth = filled_TS['salinity_ocean'].sel(depth=ddepth).drop('depth')
    S_da = xr.DataArray(data=S_depth).rename('S_'+str(dd+1).zfill(3))

    if S_ds is None:
        S_ds = S_da.to_dataset()
    else:
        S_ds = xr.merge([S_ds, S_da])

T_S_prof_flat = xr.merge([T_ds, S_ds])
T_S_prof_flat.to_netcdf(inputpath_profiles+'flattened_T_S_profiles_not_mapped_yet.nc')

START DASK HERE

In [None]:
client = distributed.Client(n_workers=12, dashboard_address=':8795', local_directory='/tmp', memory_limit='4GB')

In [None]:
file_isf = file_isf.chunk(chunks={'x': chunk_size, 'y': chunk_size, 'Nisf': 1})

In [None]:
filled_TS

In [None]:
T_list = []
S_list = []
for dd, ddepth in tqdm(enumerate(filled_TS.depth)):
    T_list.append('T_'+str(dd+1).zfill(3)) 
    S_list.append('S_'+str(dd+1).zfill(3)) 

In [None]:
T_S_prof_flat = xr.open_dataset(inputpath_profiles+'flattened_T_S_profiles_not_mapped_yet.nc')
T_prof_flat_map = file_isf['ISF_mask']
S_prof_flat_map = file_isf['ISF_mask']
T_S_prof_isf_all = None

for kisf in tqdm(file_isf.Nisf):
    T_prof_map_isf = T_prof_flat_map.where(file_isf['ISF_mask']==kisf, drop=True) 
    T_prof_isf = T_prof_map_isf.where(T_prof_map_isf!=kisf, T_S_prof_flat[T_list].sel(Nisf=[kisf]))#.drop('Nisf'))

    S_prof_map_isf = S_prof_flat_map.where(file_isf['ISF_mask']==kisf, drop=True) 
    S_prof_isf = S_prof_map_isf.where(T_prof_map_isf!=kisf, T_S_prof_flat[S_list].sel(Nisf=[kisf]))#.drop('Nisf'))

    T_S_prof_isf = xr.merge([T_prof_isf, S_prof_isf])    
    T_S_prof_isf.to_netcdf(inputpath_profiles+'flattened_T_S_profiles_isf'+str(kisf.values).zfill(3)+'.nc')
    
    #T_S_prof_isf_whole_mask = T_S_prof_isf.reindex_like(file_isf['ISF_mask'])

#    if T_S_prof_isf_all is None:
#        T_S_prof_isf_all = T_S_prof_isf
#    else:
#        T_S_prof_isf_all = T_S_prof_isf_all.combine_first(T_S_prof_isf)

#T_S_prof_isf_all.to_netcdf(inputpath_profiles+'flattened_T_S_profiles.nc')        

In [None]:
T_S_prof_isf

ALL THE FOLLOWING SOLUTIONS ARE CRASHING

In [None]:
T_S_prof_flat = xr.open_dataset(inputpath_profiles+'flattened_T_S_profiles_not_mapped_yet.nc')
T_prof_flat_map = file_isf['ISF_mask'].where(file_isf['ISF_mask']>1).copy()
S_prof_flat_map = file_isf['ISF_mask'].where(file_isf['ISF_mask']>1).copy()

for kisf in tqdm(file_isf.Nisf):
    T_prof_flat_map = T_prof_flat_map.where(file_isf['ISF_mask']!=kisf, T_S_prof_flat[T_list].sel(Nisf=kisf).drop('Nisf'))
    S_prof_flat_map = S_prof_flat_map.where(file_isf['ISF_mask']!=kisf, T_S_prof_flat[S_list].sel(Nisf=kisf).drop('Nisf'))

T_S_prof_flat_map = xr.merge([T_prof_flat_map, S_prof_flat_map])    
T_S_prof_flat_map.to_netcdf(inputpath_profiles+'flattened_T_S_profiles.nc')

In [None]:
T_S_prof_flat = xr.open_dataset(inputpath_profiles+'flattened_T_S_profiles_not_mapped_yet.nc', chunks={'x': chunk_size, 'y': chunk_size, 'time': 5, 'Nisf': 1})
T_S_prof_flat_map = file_isf['ISF_mask'].where(file_isf['ISF_mask']>1).copy()
T_S_prof_flat_map = T_S_prof_flat_map.where(file_isf['ISF_mask']!=file_isf.Nisf, T_S_prof_flat.where(T_S_prof_flat.Nisf==file_isf.Nisf)).drop('Nisf')
T_S_prof_flat_map = T_S_prof_flat_map.where(file_isf['ISF_mask']==file_isf.Nisf)

T_S_prof_flat_map.sum('Nisf').to_netcdf(inputpath_profiles+'flattened_T_S_profiles.nc')