In [None]:
"""
Created on Mon Jan 09 10:16 2023

This is a script to cut out the T and S in the 50 km in front of the ice front for SMITH data

@author: Clara Burgard
"""

- calculate the distance to the ice front for the small domain in front of the ice shelf
- take the ocean points at distance of ~50 km of the ice front 

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
#from tqdm import tqdm
import gsw
import matplotlib.pyplot as plt
import basal_melt_param.useful_functions as uf
import basal_melt_param.T_S_profile_functions as tspf
import basal_melt_param.melt_functions as meltf
import basal_melt_param.box_functions as bf

from scipy.spatial import cKDTree


import itertools

import distributed
import glob

In [None]:
client = distributed.Client(n_workers=8, dashboard_address=':8795', local_directory='/tmp', memory_limit='6GB')

In [None]:
client

In [None]:
%matplotlib qt5

READ IN THE DATA

In [None]:
nemo_run = 'EPM026'

if nemo_run == 'OPM006':
    yy_start = 1989
    yy_end = 2018
elif nemo_run == 'OPM021' or nemo_run == 'OPM026':
    yy_start = 1989
    yy_end = 2018
elif nemo_run == 'OPM016' or nemo_run == 'OPM018':
    yy_start = 1980
    yy_end = 2008
elif nemo_run == 'OPM027':
    yy_start = 1999
    yy_end = 2038
elif nemo_run == 'OPM031':
    yy_start = 1999
    yy_end = 2068
elif nemo_run == 'EPM031' or nemo_run=='EPM026':
    yy_start = 2049
    yy_end = 2058
elif nemo_run == 'EPM034':
    yy_start = 2119
    yy_end = 2128

In [None]:
if nemo_run[0] == 'O':
    inputpath_data='/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/NEMO_eORCA025.L121_'+nemo_run+'_ANT_STEREO/'
    inputpath_profiles='/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/T_S_PROF/nemo_5km_'+nemo_run+'/'
    inputpath_isf='/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_'+nemo_run+'/'
elif nemo_run[0] == 'E':
    inputpath_data='/bettik/burgardc/DATA/NN_PARAM/interim/NEMO_eORCA025.L121_'+nemo_run+'_ANT_STEREO/'
    inputpath_profiles='/bettik/burgardc/DATA/NN_PARAM/interim/T_S_PROF/nemo_5km_'+nemo_run+'/'
    inputpath_isf='/bettik/burgardc/DATA/NN_PARAM/interim/ANTARCTICA_IS_MASKS/nemo_5km_'+nemo_run+'/'

# make the domain a little smaller to make the computation even more efficient - file isf has already been made smaller at its creation
map_lim = [-3000000,3000000]

PREPARE MASK AROUND FRONT (TO RUN WITHOUT DASK!)

In [None]:
def distance_isf_points_from_line_small_domain(isf_points_da,line_points_da):
    
    """
    Compute the distance between ice shelf points and a line.
    
    This function computes the distance between ice shelf points and a line. This line can be the grounding
    line or the ice shelf front.
    
    Parameters
    ----------
    whole_domain : xarray.DataArray
        ice-shelf mask - all ice shelves are represented by a number, all other points (ocean, land) set to nan
    isf_points_da : xarray.DataArray
        array containing only points from one ice shelf
    line_points_da : xarray.DataArray
        mask representing the grounding line or ice shelf front mask corresponding to the ice shelf selected in ``isf_points_da``
        
    Returns
    -------
    xr_dist_to_line : xarray.DataArray
        distance of the each ice shelf point to the given line of interest
    """
    
    # add a common dimension 'grid' along which to stack
    stacked_isf_points = isf_points_da.stack(grid=['y', 'x'])
    stacked_line = line_points_da.stack(grid=['y', 'x'])
    
    # remove nans
    filtered_isf_points = stacked_isf_points[stacked_isf_points>0]
    filtered_line = stacked_line[stacked_line>0]
    
    # write out the y,x pairs behind the dimension 'grid'
    grid_isf_points = filtered_isf_points.indexes['grid'].to_frame().values.astype(float)
    grid_line = filtered_line.indexes['grid'].to_frame().values.astype(float)
    
    # create tree to line and compute distance
    tree_line = cKDTree(grid_line)
    dist_yx_to_line, _ = tree_line.query(grid_isf_points)
        
    # add the coordinates of the previous variables
    xr_dist_to_line = filtered_isf_points.copy(data=dist_yx_to_line)
    # put 1D array back into the format of the grid and put away the 'grid' dimension
    xr_dist_to_line = xr_dist_to_line.unstack('grid')
    
    return xr_dist_to_line

In [None]:
file_mask

In [None]:
for tt in tqdm(range(yy_start,yy_end+1)):
#for tt in range(2049,2059):
    print(tt)
    
    file_mask_orig = xr.open_mfdataset(inputpath_data+'other_mask_vars_Ant_stereo_'+str(tt)+'.nc')
    file_mask = uf.cut_domain_stereo(file_mask_orig, map_lim, map_lim).load() #.chunk(chunks={'x': 50, 'y': 50})

    # only points above 1500 m
    contshelf = file_mask['bathy_metry'] <= 1500 # .drop('lon').drop('lat')    
    
    T_S_ocean_file = xr.open_dataset(inputpath_profiles+'T_S_theta_ocean_corrected_'+str(tt)+'.nc').load()
    
    file_isf_orig  = xr.open_dataset(inputpath_isf+'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_'+str(tt)+'.nc')
    nonnan_Nisf = file_isf_orig['Nisf'].where(np.isfinite(file_isf_orig['front_bot_depth_max']), drop=True).astype(int)
    file_isf_nonnan = file_isf_orig.sel(Nisf=nonnan_Nisf)
    large_isf = file_isf_nonnan['Nisf'].where(file_isf_nonnan['isf_area_here'] >= 2500, drop=True)
    file_isf = file_isf_nonnan.sel(Nisf=large_isf).load() #chunk(chunks={'x': 50, 'y': 50})
    if 'labels' in file_isf.coords.keys():
        file_isf = file_isf.drop('labels')
    
    lon = file_isf['longitude']
    lat = file_isf['latitude']
    
    ocean = np.isfinite(T_S_ocean_file['theta_ocean'].isel(depth=0)).drop('depth')
     
    # NB: 5.0 x 1.75 is the effective resolution at 70S for a model of 1 degree resolution in longitude (assuming 5 delta X and a Mercator grid)
    mask_50km = (ocean & contshelf).load()
    
    lon_box = np.array([10.0])
    lat_box = np.array([3.5])

    close_region_around_isf_mask = tspf.mask_boxes_around_IF_new(lon, lat, mask_50km, 
                                    file_isf['front_min_lon'], file_isf['front_max_lon'], 
                                    file_isf['front_min_lat'], file_isf['front_max_lat'],  
                                    lon_box, lat_box, 
                                    file_isf['isf_name'])
    
    dist_list = [ ]
    for kisf in file_isf['Nisf']:

            if (file_isf['IF_mask']==kisf).sum() > 0:
                region_to_cut_out = close_region_around_isf_mask.sel(Nisf=kisf).squeeze()
                region_to_cut_out = region_to_cut_out.where(region_to_cut_out > 0, drop=True)
                IF_region = file_isf['IF_mask'].where(file_isf['IF_mask']==kisf, drop=True)

                dist_from_front = distance_isf_points_from_line_small_domain(region_to_cut_out,IF_region)
                dist_list.append(dist_from_front)

    dist_all = xr.concat(dist_list, dim='Nisf').reindex_like(file_isf)
    dist_all.to_dataset(name='dist_from_front').to_netcdf(inputpath_profiles+'dist_to_ice_front_only_contshelf_'+str(tt)+'.nc')
    
    #del dist_all

COMPUTING THE MEAN PROFILES

CONTINENTAL SHELF

(needs ~ 18 x 6 GB of memory)

In [None]:
mask_domain_distkm = 50000

In [None]:
dist_to_front_file_yy

In [None]:
for yy in tqdm(range(yy_start,yy_end+1)):
    dist_to_front_file_yy = xr.open_dataset(inputpath_profiles+'dist_to_ice_front_only_contshelf_'+str(yy)+'.nc').chunk({'Nisf': 5})
    T_S_ocean_file_yy = xr.open_dataset(inputpath_profiles+'T_S_theta_ocean_corrected_'+str(yy)+'.nc').chunk({'depth': 5})
    
    dist_to_front = dist_to_front_file_yy['dist_from_front']
    mask_km = dist_to_front <= mask_domain_distkm
    ds_sum = (T_S_ocean_file_yy * mask_km).sum(['x','y'])
    
    mask_depth = T_S_ocean_file_yy['salinity_ocean'].squeeze().drop('time') > 0
    mask_all = mask_km & mask_depth
    
    mask_sum = mask_all.sum(['x','y'])
    mask_sum = mask_sum.load()
    
    ds_mean = ds_sum/mask_sum
    ds_mean.to_netcdf(inputpath_profiles+'T_S_mean_prof_corrected_km_contshelf_'+str(yy)+'.nc')

In [None]:
TS_list = []
for tt in range(yy_start,yy_end+1):
    ds_tt = xr.open_dataset(inputpath_profiles+'T_S_mean_prof_corrected_km_contshelf_'+str(tt)+'.nc')
    TS_list.append(ds_tt)
TS_ds = xr.concat(TS_list, dim='time')
TS_ds.to_netcdf(inputpath_profiles+'T_S_mean_prof_corrected_km_contshelf_allyy.nc')

THESE ARE REMNANTS OF TRYING WITH DASK BUT SOMEHOW DID NOT WORK :/

If workers don't die (with 12 cores, took approx 1hour), if workers die, divide work by years

In [None]:
list_dist_files = []
list_temp_files = []
range_time = range(5) #41
for tt,yy in enumerate(range_time):
    dist_to_front_file_yy = xr.open_mfdataset(inputpath_profiles+'dist_to_ice_front_only_contshelf_yy'+str(yy).zfill(2)+'.nc')
    T_S_ocean_file_yy = xr.open_mfdataset(inputpath_profiles+'T_S_theta_ocean_corrected_yy'+str(yy).zfill(2)+'.nc')
    if tt > 0:
        dist_to_front_file = xr.concat([dist_to_front_file, dist_to_front_file_yy], dim='time').chunk({'x': 50, 'y': 50})
        T_S_ocean_file = xr.concat([T_S_ocean_file, T_S_ocean_file_yy], dim='time').chunk({'x': 50, 'y': 50, 'depth': 50})
    else:
        dist_to_front_file = dist_to_front_file_yy.copy()
        T_S_ocean_file = T_S_ocean_file_yy.copy()

In [None]:
all_in_one = True # False if worker die, True if workers don't die
if all_in_one:
    dist_to_front_file = xr.open_mfdataset(list_dist_files, chunks={'x': 25, 'y': 25})
    dist_to_front_file = dist_to_front_file.assign_coords({'time': range_time})
    T_S_ocean_files = xr.open_mfdataset(list_temp_files, chunks={'x': 25, 'y': 25, 'depth': 50})
    T_S_ocean_files = T_S_ocean_files.assign_coords({'time': range_time})
#else:
#    dist_to_front_file = xr.open_mfdataset(inputpath_profiles+'dist_to_ice_front_only_contshelf.nc',chunks={'x': 100, 'y': 100})
#    T_S_ocean_files = xr.open_mfdataset(inputpath_profiles+'T_S_theta_ocean_corrected_*.nc', concat_dim='time', chunks={'x': 100, 'y': 100, 'depth': 50}, parallel=True)
#    #T_S_ocean_1980 = xr.open_mfdataset(inputpath_profiles+'T_S_theta_ocean_corrected_1990.nc',chunks={'x': 100, 'y': 100, 'depth': 50})
#    T_S_ocean_1980 = xr.open_mfdataset(inputpath_profiles+'T_S_theta_ocean_corrected_2000.nc',chunks={'x': 100, 'y': 100, 'depth': 50})
dist_to_front = dist_to_front_file['dist_from_front']

Prepare sum

In [None]:
dist_to_front = dist_to_front_file['dist_from_front']

In [None]:
mask_km = dist_to_front <= mask_domain_distkm

In [None]:
T_S_ocean_file

In [None]:
ds_sum = (T_S_ocean_file * mask_km).sum(['x','y'])

In [None]:
ds_sum

In [None]:
if all_in_one:
    ds_sum = ds_sum.load()
    #ds_sum.to_netcdf(inputpath_profiles+'ds_sum_for_mean.nc')
    ds_sum.to_netcdf(inputpath_profiles+'ds_sum_for_mean_contshelf.nc')
else:
    yearly_datasets = list(tspf.split_by_chunks(ds_sum.unify_chunks(),'time'))
    paths = [tspf.create_filepath(ds, 'ds_sum_for_mean_contshelf_yy', inputpath_profiles, ds.time[0].values) for ds in yearly_datasets]
    xr.save_mfdataset(datasets=yearly_datasets, paths=paths)

Prepare number of points by which you divide

In [None]:
if all_in_one:
    ds_sum = xr.open_mfdataset(inputpath_profiles+'ds_sum_for_mean_contshelf.nc')
else:
    ds_sum = xr.open_mfdataset(inputpath_profiles+'ds_sum_for_mean_contshelf_*.nc', concat_dim='time', parallel=True).drop('profile_domain')

In [None]:
mask_depth = T_S_ocean_1980['salinity_ocean'].squeeze().drop('time') >0
mask_all = mask_km & mask_depth

In [None]:
mask_sum = mask_all.sum(['x','y'])

In [None]:
mask_sum = mask_sum.load()

Make the mean

In [None]:
ds_mean = ds_sum/mask_sum

In [None]:
ds_mean.drop('profile_domain')

In [None]:
#ds_mean = ds_mean.rename({'dist_from_front': 'profile_domain'})
ds_mean = ds_mean.drop('profile_domain').rename({'dist_from_front': 'profile_domain'})

In [None]:
#ds_mean.to_netcdf(inputpath_profiles+'T_S_mean_prof_km_1980-2018.nc')
ds_mean.to_netcdf(inputpath_profiles+'T_S_mean_prof_corrected_km_contshelf_1980-2018.nc')

FINISHED

In [None]:
kisf = 21

old_profiles['theta_ocean'].sel(Nisf=kisf).isel(dist_from_front=0).mean('time').plot(color='r',linestyle='--')
old_profiles['theta_ocean'].sel(Nisf=kisf).isel(dist_from_front=1).mean('time').plot(color='g',linestyle='--')
old_profiles['theta_ocean'].sel(Nisf=kisf).isel(dist_from_front=2).mean('time').plot(color='b',linestyle='--')
old_profiles['theta_ocean'].sel(Nisf=kisf).isel(dist_from_front=3).mean('time').plot(color='orange',linestyle='--')

new_profiles['theta_ocean'].sel(Nisf=kisf).isel(profile_domain=0).mean('time').plot(color='r')
new_profiles['theta_ocean'].sel(Nisf=kisf).isel(profile_domain=1).mean('time').plot(color='g')
new_profiles['theta_ocean'].sel(Nisf=kisf).isel(profile_domain=2).mean('time').plot(color='b')
new_profiles['theta_ocean'].sel(Nisf=kisf).isel(profile_domain=3).mean('time').plot(color='orange')
#new_profiles['theta_ocean'].sel(Nisf=kisf).isel(profile_domain=4).mean('time').plot(color='k')


In [None]:
kisf = 10

old_profiles['salinity_ocean'].sel(Nisf=kisf).isel(dist_from_front=0).mean('time').plot(color='r',linestyle='--')
old_profiles['salinity_ocean'].sel(Nisf=kisf).isel(dist_from_front=1).mean('time').plot(color='g',linestyle='--')
old_profiles['salinity_ocean'].sel(Nisf=kisf).isel(dist_from_front=2).mean('time').plot(color='b',linestyle='--')
old_profiles['salinity_ocean'].sel(Nisf=kisf).isel(dist_from_front=3).mean('time').plot(color='orange',linestyle='--')

new_profiles['salinity_ocean'].sel(Nisf=kisf).isel(profile_domain=0).mean('time').plot(color='r')
new_profiles['salinity_ocean'].sel(Nisf=kisf).isel(profile_domain=1).mean('time').plot(color='g')
new_profiles['salinity_ocean'].sel(Nisf=kisf).isel(profile_domain=2).mean('time').plot(color='b')
new_profiles['salinity_ocean'].sel(Nisf=kisf).isel(profile_domain=3).mean('time').plot(color='orange')
new_profiles['salinity_ocean'].sel(Nisf=kisf).isel(profile_domain=4).mean('time').plot(color='k')