In [None]:
"""
Created on Mon March 06 18:36 2023

Try to apply script from Rieke from the CryoHackathon to better define the masks, but this time together with Nico's limits

@author: Clara Burgard
"""

In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import basal_melt_param.useful_functions as uf
import matplotlib as mpl
from tqdm.notebook import tqdm
import basal_melt_param.plume_functions as pf
from scipy.interpolate import griddata
import basal_melt_param.create_isf_mask_functions as isfmf
import cc3d

In [None]:
%matplotlib

In [None]:
nemo_run = 'bi646'
map_lim = [-3000000,3000000]


In [None]:
outputpath_mask='/bettik/burgardc/DATA/NN_PARAM/interim/ANTARCTICA_IS_MASKS/SMITH_'+nemo_run+'/'
inputpath_data='/bettik/burgardc/DATA/NN_PARAM/interim/SMITH_'+nemo_run+'/'
inputpath_data2='/bettik/burgardc/DATA/BASAL_MELT_PARAM/interim/NEMO_eORCA025.L121_OPM016_ANT_STEREO/'
outputpath_mask_orig='/bettik/burgardc/SCRIPTS/basal_melt_param/data/interim/ANTARCTICA_IS_MASKS/nemo_5km_OPM016/'


outputpath_boxes = '/bettik/burgardc/DATA/NN_PARAM/interim/BOXES/SMITH_'+nemo_run+'/'
inputpath_raw = '/bettik/burgardc/DATA/NN_PARAM/raw/'

In [None]:
file_isf = xr.open_dataset(outputpath_mask + 'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_2040.nc')


In [None]:
file_isf['ISF_mask'].plot()

In [None]:
file_isf = xr.open_dataset(outputpath_mask + 'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_1970.nc')
file_isf_orig = xr.open_dataset(outputpath_mask_orig + 'nemo_5km_isf_masks_and_info_and_distance_new_oneFRIS.nc')


file_mask = xr.open_dataset(inputpath_data+'custom_lsmask_Ant_stereo_clean.nc')#, chunks={'x': chunk_size, 'y': chunk_size})
file_mask_cut = uf.cut_domain_stereo(file_mask, map_lim, map_lim)

file_conc = xr.open_dataset(inputpath_data+'isfdraft_conc_Ant_stereo.nc')
file_conc_cut = uf.cut_domain_stereo(file_conc, map_lim, map_lim)

In [None]:
def def_isf_mask(arr_def_ismask, file_msk, file_conc, lon, lat, FRIS_one=True, 
                 mouginot_basins=False, connectivity = 4, threshold = 4):
    
    """
    Define a mask for the individual ice shelves. 
    
    This function defines a mask for the individual ice shelves. I think it works for both stereographic and latlon grids but I have not tried the latter.
    
    Parameters
    ----------
    arr_def_ismask : np.array
        Array containing minlon,maxlon,minlat,maxlat,is_nb or xr.Dataset with drainage basins
    file_msk : xr.DataArray
        Mask separating ocean (0), ice shelves (between 0 and 2, excluding 0 and 2), grounded ice (2) 
    file_conc : xr.DataArray
        Ice shelf concentration for each point (between 0 and 1)
    lon : xr.DataArray
        Longitude (depends on x,y for stereographic)
    lat : xr.DataArray
        Latitude (depends on x,y for stereographic)
    FRIS_one : Boolean 
        If True, Filchner-Ronne are considered as one ice-shelf
    mouginot_basins : Boolean 
        If True, arr_def_ismask
    connectivity : int
        4 or 8 for 2D, defines what is considered a "connected" point
    threshold : int
        Size of lonely pixel areas to remove
        
    Returns
    -------
    new_mask : xr.DataArray
        Array showing the coverage of each ice shelf with the respective ID, open ocean is 1, land is 0
    """    
    
    if mouginot_basins:
        
        isf_mask = file_msk.copy()
        # only ice shelves
        isf_only_mask = file_conc > 0
        
        #find connected components
        dusted = cc3d.dust(isf_only_mask.values.astype(np.int64), 
                   threshold = threshold, 
                   connectivity = connectivity, 
                   in_place = False)
        
        labels_out = cc3d.connected_components(dusted, 
                                       connectivity = connectivity)
        
        labelled = xr.DataArray(labels_out, 
                        coords = {"y": file_conc.y, "x": file_conc.x}, 
                        dims = ["y", "x"],
                        name = "labels")
        
        # assign ID for basins
        isf_mask_basins = arr_def_ismask['ID_isf'].where(isf_only_mask > 0)
        # cut connected areas to area covered by basin stuff
        labelled_isf = labelled.where(np.isfinite(isf_mask_basins))
        
        # creating the mask
        new_mask = isf_mask_basins.copy()
        
        new_mask = new_mask.where(
            new_mask != 58, 57).where(
            new_mask != 151, 99).where(
            new_mask != 109, 107).where(
            new_mask != 116, 5).where(
            new_mask != 143, 97).where(
            new_mask != 137, 99)
        
                    
        arr_def_ismask['name_isf'].loc[{'Nisf': 57}] = 'Ross'
        arr_def_ismask['name_isf'].loc[{'Nisf': 58}] = np.nan
        
        if FRIS_one:
            new_mask = new_mask.where(new_mask != 104, 103)
            arr_def_ismask['name_isf'].loc[{'Nisf': 103}] = 'Filchner-Ronne'
            arr_def_ismask['name_isf'].loc[{'Nisf': 104}] = np.nan

        arr_def_ismask['name_isf'] = arr_def_ismask['name_isf'].dropna('Nisf')
        
        # do some fine-tuning for overlapping ice shelves   
        problem_regions = [2,3,8,9,10,13,23,26,27,28,29,32,34,38,44,46,50,57,59,60,
                   63,70,71,72,73,74,76,77,78,83,84,85,89,91,96,103]
        
        for conn_label in range(1,labels_out.max()):
            basins_conn_domain = summary_mask_basins['ID_isf'].where(labelled_isf == conn_label, drop=True)
            max_label = basins_conn_domain.max().values
            min_label = basins_conn_domain.min().values
            
            # for areas with two labels in problem regions, take the one with the most points
            if max_label != min_label:
                groups_isf = basins_conn_domain.groupby(basins_conn_domain)
                groups_labels = groups_isf.groups.keys()
                if groups_isf.count().ID_isf.count() > 1:
                    if any(x in problem_regions for x in list(groups_labels)):
                        #print(conn_label)
                        #print(min_label,max_label)
                        dominant_isf = groups_isf.count().idxmax().values
                        if dominant_isf == 12:
                            dominant_isf = 14
                        #print(dominant_isf)
                        new_mask = new_mask.where(labelled_isf != conn_label, dominant_isf)
            
        # other fine-tuning: if an ice shelf is split, keep the largest connected domain
        dx = abs(file_conc.x[1] - file_conc.x[0])
        dy = abs(file_conc.y[1] - file_conc.y[0])

        split_regions = [70,77,83,89,103] 

        for rreg in split_regions:
            # look where there are the same labels in several unconnected domains
            labels_same = list(new_mask.groupby(labelled_isf).groups) * (new_mask.groupby(labelled_isf).median() == rreg)
            labels_same = labels_same[labels_same>0]

            area_before = 0
            for conn_label in labels_same:
                # compute the area of the different unconnected areas
                conc_for_area = file_conc.where(labelled_isf == conn_label, drop=True)
                area_now = (conc_for_area * dx * dy).sum()
                if area_now >= area_before:
                    area_before = area_now
                    largest_label = conn_label

            # set the smaller areas to 159
            for small_label in (labels_same.where(labels_same != largest_label).dropna('labels')):
                new_mask = new_mask.where(labelled_isf != small_label, 159)

        new_mask = new_mask + 1
        new_mask_info = arr_def_ismask.copy()
        new_mask_info['Nisf'] = new_mask_info['Nisf'] + 1
        
        new_mask = new_mask.where(file_msk != 0, 1).where(file_msk != 2, 0)
    
    else:
        
        arr_def_general = arr_def_ismask[arr_def_ismask[:, 3] == -50]
        arr_def_detail = arr_def_ismask[arr_def_ismask[:, 3] != -50]

        isf_yes = (file_msk > 0) & (file_msk < 2)
        isf_mask = file_msk.copy()
        # is_mask0.plot()
        for i, mm in enumerate(arr_def_general):
            #print('general ' + str(i))
            isf_mask = isf_mask.where(~(uf.in_range(lon, mm[0:2]) & uf.in_range(lat, mm[2:4])), int(mm[4]))
        for i, mm in enumerate(arr_def_detail):
            #print('detail ' + str(i))
            isf_mask = isf_mask.where(~(uf.in_range(lon, mm[0:2]) & uf.in_range(lat, mm[2:4])), int(mm[4]))
        isf_mask = isf_mask.where(isf_yes)

        if FRIS_one:
            isf_mask = isf_mask.where(isf_mask != 21, 11) # Filchner (21) and Ronne (11) are combined
        
        new_mask = isf_mask.where(file_msk != 0, 1).where(file_msk != 2, 0)
    
    
    if mouginot_basins:
        mask_file = xr.merge([new_mask.rename('ISF_mask'), 
                              new_mask_info['name_isf'], 
                              new_mask_info['name_reg'], 
                              new_mask_info['Nisf_orig']])
    else:
        mask_file = new_mask
    
    return mask_file

In [None]:
tt = 2040
file_msk = file_mask_cut['ls_mask012'].sel(time=tt)
file_conc = file_conc_cut['isfdraft_conc'].sel(time=tt).drop('time')
inputpath_metadata='/bettik/burgardc/SCRIPTS/basal_melt_param/data/raw/MASK_METADATA/'
arr_mask = isfmf.read_isfmask_info(inputpath_metadata+'lonlat_masks.txt')
mask_file = def_isf_mask(arr_mask, file_msk, file_conc, file_isf.longitude, file_isf.latitude, FRIS_one=True, 
                 mouginot_basins=False)

In [None]:
file_conc.where(mask_file < 10).plot()

In [None]:
scattered_reg_all_mask.plot()

In [None]:
scattered_reg_all_conc = file_conc.where(mask_file < 10) 
scattered_reg_all_mask = scattered_reg_all_conc > 0

new_mask = mask_file.copy()
isf_only_mask = new_mask.where(new_mask > 2)


In [None]:
labelled.plot()

In [None]:
new_mask = mask_file.copy()

In [None]:
new_mask = mask_file.copy()
new_mask = new_mask.where(~((new_mask > 1) & (new_mask < 10)), 4)

### SPECIAL REGIONS
new_mask = new_mask.where(new_mask != 102, 75)
new_mask = new_mask.where(new_mask != 103, 75)


###### THIS BLOCK IS TO SEPARATE SPLIT REGIONS
threshold = 1
connectivity = 4

#find connected components
dusted = cc3d.dust(new_mask.values.astype(np.int64), 
           threshold = threshold, 
           connectivity = connectivity, 
           in_place = False)

labels_out = cc3d.connected_components(dusted, 
                               connectivity = connectivity)

labelled_isf = xr.DataArray(labels_out, 
                coords = {"y": file_conc.y, "x": file_conc.x}, 
                dims = ["y", "x"],
                name = "labels")

all_isf_list = np.array(list(new_mask.groupby(new_mask).groups))
isf_labels = all_isf_list[all_isf_list>9]

for rreg in isf_labels:
    #print(rreg)
    # look is one ice shelf is present in disconnected regions
    isf_group = new_mask.where(new_mask==rreg)
    label_group = labelled_isf.where(np.isfinite(isf_group))
    label_group_list = np.array(list(label_group.groupby(label_group).groups))
    label_group_list = label_group_list[label_group_list > 1]
    if label_group_list.size > 0:
        if label_group_list.min() != label_group_list.max():
            area_before = 0
            for conn_label in label_group_list:
                # compute the area of the different unconnected areas
                conc_for_area = file_conc.where(labelled_isf == conn_label, drop=True)
                area_now = (conc_for_area * dx * dy).sum()
                if area_now >= area_before:
                    area_before = area_now
                    largest_label = conn_label
            
            # set the smaller areas to 4 (random choice)
            for llabel in label_group_list:
                if llabel != largest_label:
                    new_mask = new_mask.where(labelled_isf != llabel, 4)
                    


In [None]:
plt.figure()
new_mask.plot(vmin=101,vmax=104)

In [None]:
file_conc.where(new_mask<5).plot()

In [None]:
new_mask.plot()

In [None]:
    area_before = 0
    for conn_label in labels_same:
        # compute the area of the different unconnected areas
        conc_for_area = file_conc.where(labelled_isf == conn_label, drop=True)
        area_now = (conc_for_area * dx * dy).sum()
        if area_now >= area_before:
            area_before = area_now
            largest_label = conn_label

    # set the smaller areas to 3
    for small_label in (labels_same.where(labels_same != largest_label).dropna('labels')):
        new_mask = new_mask.where(labelled_isf != small_label, 3)

In [None]:
labelled_isf.plot()

In [None]:
new_mask = mask_file.copy()
new_mask = new_mask.where(~((new_mask > 1) & (new_mask < 10)), 4)

### SPECIAL REGIONS
new_mask = new_mask.where(new_mask != 102, 75)
new_mask = new_mask.where(new_mask != 103, 75)

###### THIS BLOCK IS TO SEPARATE SPLIT REGIONS
threshold = 1
connectivity = 4

#find connected components
dusted = cc3d.dust(new_mask.values.astype(np.int64), 
           threshold = threshold, 
           connectivity = connectivity, 
           in_place = False)

labels_out = cc3d.connected_components(dusted, 
                               connectivity = connectivity)

labelled_isf = xr.DataArray(labels_out, 
                coords = {"y": file_conc.y, "x": file_conc.x}, 
                dims = ["y", "x"],
                name = "labels")

all_isf_list = np.array(list(new_mask.groupby(new_mask).groups))
isf_labels = all_isf_list[all_isf_list>9]

for rreg in isf_labels:
    #print(rreg)
    # look is one ice shelf is present in disconnected regions
    isf_group = new_mask.where(new_mask==rreg)
    label_group = labelled_isf.where(np.isfinite(isf_group))
    label_group_list = np.array(list(label_group.groupby(label_group).groups))
    label_group_list = label_group_list[label_group_list > 1]
    if label_group_list.size > 0:
        if label_group_list.min() != label_group_list.max():
            area_before = 0
            for conn_label in label_group_list:
                # compute the area of the different unconnected areas
                conc_for_area = file_conc.where(labelled_isf == conn_label, drop=True)
                area_now = (conc_for_area * dx * dy).sum()
                if area_now >= area_before:
                    area_before = area_now
                    largest_label = conn_label
            
            # set the smaller areas to 4 (random choice)
            for llabel in label_group_list:
                if llabel != largest_label:
                    new_mask = new_mask.where(labelled_isf != llabel, 4)


###### THIS BLOCK IS TO FILL THE "NEW REGIONS"

threshold = 1
connectivity = 4

scattered_reg_all_conc = file_conc.where(new_mask == 4) 
scattered_reg_all_mask = scattered_reg_all_conc > 0

#find connected components
dusted = cc3d.dust(scattered_reg_all_mask.values.astype(np.int64), 
           threshold = threshold, 
           connectivity = connectivity, 
           in_place = False)

labels_out_conc = cc3d.connected_components(dusted, 
                               connectivity = connectivity)

labelled = xr.DataArray(labels_out_conc, 
                coords = {"y": file_conc.y, "x": file_conc.x}, 
                dims = ["y", "x"],
                name = "labels")

# filter that checks the point around
weights_filter = np.zeros((3,3))
weights_filter[0,1] = 1
weights_filter[1,0] = 1
weights_filter[1,2] = 1
weights_filter[2,1] = 1

weights_da = xr.DataArray(data=weights_filter,dims=['y0','x0'])

dx = abs(labelled.x[1] - labelled.x[0]).values.astype(int)
dy = abs(labelled.y[1] - labelled.y[0]).values.astype(int)

for conn_label in range(1,labels_out_conc.max()):
    dom_region = labelled.where(labelled == conn_label, drop=True)
    dom_bounds_plus1 = np.array([dom_region.x.min().values - dx,dom_region.x.max().values + dx,dom_region.y.min().values - dy,dom_region.y.max().values + dy]).astype(int)
    dom_plus1_mask = scattered_reg_all_mask.sel(x=range(dom_bounds_plus1[0],dom_bounds_plus1[1]+1,dx), y=range(dom_bounds_plus1[2],dom_bounds_plus1[3]+1,dy))
    corr = pf.xr_nd_corr_v2(dom_plus1_mask, weights_filter)
    only_contour = (corr ^ dom_plus1_mask)
    neighboring_pixels = new_mask.where(only_contour)
    if neighboring_pixels.max() > 9:
        neighbor_max = neighboring_pixels.where(neighboring_pixels > 9).max()
        neighbor_min = neighboring_pixels.where(neighboring_pixels > 9).min()
        if neighbor_max == neighbor_min:
            #print(neighbor_min.values)
            new_mask = new_mask.where(labelled != conn_label, neighbor_min)
    else:
        print(conn_label)
        print(neighboring_pixels.max().values)

In [None]:
labelled.plot(vmin=50,vmax=60)

In [None]:
labels_out_conc.plot()

In [None]:
file_conc.plot()

In [None]:
file_conc.where(new_mask<4).plot()

In [None]:
new_mask = new_mask.where(~((new_mask > 1) & (new_mask < 10)), 4)
#new_mask.where(new_mask == 4)

In [None]:
conn_label = 59
dom_region = labelled.where(labelled == conn_label, drop=True)
dom_bounds_plus1 = np.array([dom_region.x.min().values - dx,dom_region.x.max().values + dx,dom_region.y.min().values - dy,dom_region.y.max().values + dy]).astype(int)
dom_plus1_mask = scattered_reg_all_mask.sel(x=range(dom_bounds_plus1[0],dom_bounds_plus1[1]+1,dx), y=range(dom_bounds_plus1[2],dom_bounds_plus1[3]+1,dy))
corr = pf.xr_nd_corr_v2(dom_plus1_mask, weights_filter)
only_contour = (corr ^ dom_plus1_mask)
neighboring_pixels = new_mask.where(only_contour)

if neighboring_pixels.max() > 9:
    neighbor_max = neighboring_pixels.where(neighboring_pixels > 9).max()
    neighbor_min = neighboring_pixels.where(neighboring_pixels > 9).min()


In [None]:
file_isf['ground_mask'].plot()

In [None]:
neighboring_pixels.plot()

In [None]:
new_mask.plot(vmin=0,vmax=3)

In [None]:
plt.figure()
scattered_reg_all_mask.plot()

In [None]:
dom_region.plot()

In [None]:
new_mask.plot()

In [None]:
new_mask.where(dom_plus1_mask).plot()

In [None]:
labelled.where(labelled == conn_label).plot()

In [None]:
corr.where(dom_region != conn_label).plot()

In [None]:
labelled.sel(x=range(dom_bounds_plus1[0],dom_bounds_plus1[1]+1,dx), y=range(dom_bounds_plus1[2],dom_bounds_plus1[3]+1,dy)).plot()

In [None]:
mask_file.sel(x=range(dom_bounds_plus1[0],dom_bounds_plus1[1]+1,dx), y=range(dom_bounds_plus1[2],dom_bounds_plus1[3]+1,dy)).plot(vmax=10)

In [None]:
corr.plot()

In [None]:
mask_file.where(corr)

In [None]:
neighbouring_pixels = mask_file.where(corr)


In [None]:
neighbouring_pixels.where(neighbouring_pixels > 2).min()