In [None]:
"""
Created on Wed Feb 22 10:37 2023

Try to apply script from Rieke from the CryoHackathon to better define the masks

@author: Clara Burgard
"""

In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import basal_melt_param.useful_functions as uf
import matplotlib as mpl
from tqdm.notebook import tqdm
import basal_melt_param.plume_functions as pf
from scipy.interpolate import griddata
import basal_melt_param.create_isf_mask_functions as isfmf
import cc3d


In [None]:
%matplotlib qt5

READ IN DATA

In [None]:
nemo_run = 'bi646'
map_lim = [-3000000,3000000]


In [None]:
outputpath_mask='/bettik/burgardc/DATA/NN_PARAM/interim/ANTARCTICA_IS_MASKS/SMITH_'+nemo_run+'/'
inputpath_data='/bettik/burgardc/DATA/NN_PARAM/interim/SMITH_'+nemo_run+'/'
inputpath_data2='/bettik/burgardc/DATA/BASAL_MELT_PARAM/interim/NEMO_eORCA025.L121_OPM016_ANT_STEREO/'

outputpath_boxes = '/bettik/burgardc/DATA/NN_PARAM/interim/BOXES/SMITH_'+nemo_run+'/'
inputpath_raw = '/bettik/burgardc/DATA/NN_PARAM/raw/'

In [None]:
file_isf = xr.open_dataset(outputpath_mask + 'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_1970.nc')

file_mask = xr.open_dataset(inputpath_data+'custom_lsmask_Ant_stereo_clean.nc')#, chunks={'x': chunk_size, 'y': chunk_size})
file_mask_cut = uf.cut_domain_stereo(file_mask, map_lim, map_lim)

file_TS_orig = xr.open_dataset(inputpath_data + '3D_variables_of_interest_allyy_Ant_stereo_2000.nc')
file_TS_cut = uf.cut_domain_stereo(file_TS_orig, map_lim, map_lim)

In [None]:
file_isf.Nisf.where(file_isf['isf_name'] == 'Bach', drop=True).values[0]

In [None]:
file_conc = xr.open_dataset(inputpath_data+'isfdraft_conc_Ant_stereo.nc')
file_conc_cut = uf.cut_domain_stereo(file_conc, map_lim, map_lim)

In [None]:
def def_isf_mask(arr_def_ismask, file_msk, file_conc, lon, lat, FRIS_one=True, 
                 mouginot_basins=False, connectivity = 4, threshold = 4):
    
    """
    Define a mask for the individual ice shelves. 
    
    This function defines a mask for the individual ice shelves. I think it works for both stereographic and latlon grids but I have not tried the latter.
    
    Parameters
    ----------
    arr_def_ismask : np.array
        Array containing minlon,maxlon,minlat,maxlat,is_nb or xr.Dataset with drainage basins
    file_msk : xr.DataArray
        Mask separating ocean (0), ice shelves (between 0 and 2, excluding 0 and 2), grounded ice (2) 
    file_conc : xr.DataArray
        Ice shelf concentration for each point (between 0 and 1)
    lon : xr.DataArray
        Longitude (depends on x,y for stereographic)
    lat : xr.DataArray
        Latitude (depends on x,y for stereographic)
    FRIS_one : Boolean 
        If True, Filchner-Ronne are considered as one ice-shelf
    mouginot_basins : Boolean 
        If True, arr_def_ismask
    connectivity : int
        4 or 8 for 2D, defines what is considered a "connected" point
    threshold : int
        Size of lonely pixel areas to remove
        
    Returns
    -------
    new_mask : xr.DataArray
        Array showing the coverage of each ice shelf with the respective ID, open ocean is 1, land is 0
    """    
    
    if mouginot_basins:
        
        isf_mask = file_msk.copy()
        # only ice shelves
        isf_only_mask = file_conc > 0
        
        #find connected components
        dusted = cc3d.dust(isf_only_mask.values.astype(np.int64), 
                   threshold = threshold, 
                   connectivity = connectivity, 
                   in_place = False)
        
        labels_out = cc3d.connected_components(dusted, 
                                       connectivity = connectivity)
        
        labelled = xr.DataArray(labels_out, 
                        coords = {"y": file_conc.y, "x": file_conc.x}, 
                        dims = ["y", "x"],
                        name = "labels")
        
        # assign ID for basins
        isf_mask_basins = arr_def_ismask['ID_isf'].where(isf_only_mask > 0)
        # cut connected areas to area covered by basin stuff
        labelled_isf = labelled.where(np.isfinite(isf_mask_basins))
        
        # creating the mask
        new_mask = isf_mask_basins.copy()
        
        new_mask = new_mask.where(
            new_mask != 58, 57).where(
            new_mask != 151, 99).where(
            new_mask != 109, 107).where(
            new_mask != 116, 5).where(
            new_mask != 143, 97).where(
            new_mask != 137, 99)
        
                    
        arr_def_ismask['name_isf'].loc[{'Nisf': 57}] = 'Ross'
        arr_def_ismask['name_isf'].loc[{'Nisf': 58}] = np.nan
        
        if FRIS_one:
            new_mask = new_mask.where(new_mask != 104, 103)
            arr_def_ismask['name_isf'].loc[{'Nisf': 103}] = 'Filchner-Ronne'
            arr_def_ismask['name_isf'].loc[{'Nisf': 104}] = np.nan

        arr_def_ismask['name_isf'] = arr_def_ismask['name_isf'].dropna('Nisf')
        
        # do some fine-tuning for overlapping ice shelves   
        problem_regions = [2,3,8,9,10,13,23,26,27,28,29,32,34,38,44,46,50,57,59,60,
                   63,70,71,72,73,74,76,77,78,83,84,85,89,91,96,103]
        
        for conn_label in range(1,labels_out.max()):
            basins_conn_domain = summary_mask_basins['ID_isf'].where(labelled_isf == conn_label, drop=True)
            max_label = basins_conn_domain.max().values
            min_label = basins_conn_domain.min().values
            
            # for areas with two labels in problem regions, take the one with the most points
            if max_label != min_label:
                groups_isf = basins_conn_domain.groupby(basins_conn_domain)
                groups_labels = groups_isf.groups.keys()
                if groups_isf.count().ID_isf.count() > 1:
                    if any(x in problem_regions for x in list(groups_labels)):
                        #print(conn_label)
                        #print(min_label,max_label)
                        dominant_isf = groups_isf.count().idxmax().values
                        if dominant_isf == 12:
                            dominant_isf = 14
                        #print(dominant_isf)
                        new_mask = new_mask.where(labelled_isf != conn_label, dominant_isf)
            
        # other fine-tuning: if an ice shelf is split, keep the largest connected domain
        dx = abs(file_conc.x[1] - file_conc.x[0])
        dy = abs(file_conc.y[1] - file_conc.y[0])

        split_regions = [70,77,83,89,103] 

        for rreg in split_regions:
            # look where there are the same labels in several unconnected domains
            labels_same = list(new_mask.groupby(labelled_isf).groups) * (new_mask.groupby(labelled_isf).median() == rreg)
            labels_same = labels_same[labels_same>0]

            area_before = 0
            for conn_label in labels_same:
                # compute the area of the different unconnected areas
                conc_for_area = file_conc.where(labelled_isf == conn_label, drop=True)
                area_now = (conc_for_area * dx * dy).sum()
                if area_now >= area_before:
                    area_before = area_now
                    largest_label = conn_label

            # set the smaller areas to 159
            for small_label in (labels_same.where(labels_same != largest_label).dropna('labels')):
                new_mask = new_mask.where(labelled_isf != small_label, 159)

        new_mask = new_mask + 1
        new_mask_info = arr_def_ismask.copy()
        new_mask_info['Nisf'] = new_mask_info['Nisf'] + 1
    
    else:
        
        arr_def_general = arr_def_ismask[arr_def_ismask[:, 3] == -50]
        arr_def_detail = arr_def_ismask[arr_def_ismask[:, 3] != -50]

        isf_yes = (file_msk > 0) & (file_msk < 2)
        isf_mask = file_msk.copy()
        # is_mask0.plot()
        for i, mm in enumerate(arr_def_general):
            #print('general ' + str(i))
            isf_mask = isf_mask.where(~(uf.in_range(lon, mm[0:2]) & uf.in_range(lat, mm[2:4])), int(mm[4]))
        for i, mm in enumerate(arr_def_detail):
            #print('detail ' + str(i))
            isf_mask = isf_mask.where(~(uf.in_range(lon, mm[0:2]) & uf.in_range(lat, mm[2:4])), int(mm[4]))
        isf_mask = isf_mask.where(isf_yes)

        if FRIS_one:
            isf_mask = isf_mask.where(isf_mask != 21, 11) # Filchner (21) and Ronne (11) are combined
    
    new_mask = new_mask.where(file_msk != 0, 1).where(file_msk != 2, 0)
    
    if mouginot_basins:
        mask_file = xr.merge([new_mask.rename('ISF_mask'), 
                              new_mask_info['name_isf'], 
                              new_mask_info['name_reg'], 
                              new_mask_info['Nisf_orig']])
    else:
        mask_file = new_mask
    
    return mask_file

In [None]:
tt = 1970
file_msk = file_mask_cut['ls_mask012'].sel(time=tt)
file_conc = file_conc_cut['isfdraft_conc'].sel(time=tt).drop('time')
summary_mask_basins = xr.open_dataset('/bettik/burgardc/DATA/NN_PARAM/interim/basins_mask_extrap_50km.nc')


In [None]:
summary_mask_basins['name_isf'].sel(Nisf=23)

In [None]:
mask_file = def_isf_mask(summary_mask_basins, file_msk, file_conc, file_isf.longitude, file_isf.latitude, FRIS_one=True, 
                 mouginot_basins=True, connectivity = 4, threshold = 4)

In [None]:
if not mask_file.Nisf.where(mask_file['Nisf_orig'] == 103, drop=True):
    print('lol')#.values[0].astype(int)

In [None]:
mask.where(file_msk != 0, 1).where(file_msk != 2, 0).plot()

In [None]:
isf_only_mask = file_conc_cut['isfdraft_conc'] > 0

In [None]:
isf_only_mask_00 = isf_only_mask.isel(time=30).drop('time')

In [None]:
#level of connectivity (4 or 8 for 2D)
connectivity = 4
#size of lonely pixel areas to remove # Rieke put 25
threshold = 4

In [None]:
#remove lonely pixels
#data is sliced to exclude time, data type needs to be int!!
dusted = cc3d.dust(isf_only_mask_00.values.astype(np.int64), 
                   threshold = threshold, 
                   connectivity = connectivity, 
                   in_place = False)
#find connected components
labels_out = cc3d.connected_components(dusted, 
                                       connectivity = connectivity)


In [None]:
labelled = xr.DataArray(labels_out, 
                        coords = {"y": file_conc_cut.y, "x": file_conc_cut.x}, 
                        dims = ["y", "x"],
                        name = "labels")

In [None]:
labelled.plot()

In [None]:
summary_mask_basins = xr.open_dataset('/bettik/burgardc/DATA/NN_PARAM/interim/basins_mask_extrap_50km.nc')


In [None]:
isf_mask_basins = summary_mask_basins['ID_isf'].where(isf_only_mask_00 > 0)

In [None]:
labelled_isf = labelled.where(np.isfinite(isf_mask_basins))

In [None]:
labelled_isf.plot()

In [None]:
isf_only_mask_00.plot()

In [None]:
combined_test_isf_mask = isf_mask_basins.copy()

In [None]:
labels_out.max()

In [None]:
isf_mask_basins.plot()

In [None]:
groups= basins_conn_area.groupby(basins_conn_area)#.count()

In [None]:
groups.count().ID_isf.count()

MERGE ROSS AND FRIS

In [None]:
new_mask = isf_mask_basins.copy()

In [None]:
new_mask = new_mask.where(new_mask != 58, 57).where(
    new_mask != 104, 103).where(
    new_mask != 151, 99).where(
    new_mask != 109, 107).where(
    new_mask != 116, 5).where(
    new_mask != 143, 97).where(
    new_mask != 137, 99)

In [None]:
summary_mask_basins['name_isf'].loc[{'Nisf': 57}] = 'Ross'
summary_mask_basins['name_isf'].loc[{'Nisf': 58}] = np.nan
summary_mask_basins['name_isf'].loc[{'Nisf': 103}] = 'Filchner-Ronne'
summary_mask_basins['name_isf'].loc[{'Nisf': 104}] = np.nan
summary_mask_basins['name_isf'] = summary_mask_basins['name_isf'].dropna('Nisf')

In [None]:
problem_regions = [2,3,8,9,10,13,23,26,27,28,29,32,34,38,44,46,50,57,59,60,
                   63,70,71,72,73,74,76,77,78,83,84,85,89,91,96,103]

In [None]:
for conn_label in range(1,labels_out.max()):
    basins_conn_domain = summary_mask_basins['ID_isf'].where(labelled_isf == conn_label, drop=True)
    max_label = basins_conn_domain.max().values
    min_label = basins_conn_domain.min().values

    if max_label != min_label:
        groups_isf = basins_conn_domain.groupby(basins_conn_domain)
        groups_labels = groups_isf.groups.keys()
        if groups_isf.count().ID_isf.count() > 1:
            if any(x in problem_regions for x in list(groups_labels)):
                #print(conn_label)
                #print(min_label,max_label)
                dominant_isf = groups_isf.count().idxmax().values
                if dominant_isf == 12:
                    dominant_isf = 14
                #print(dominant_isf)
                new_mask = new_mask.where(labelled_isf != conn_label, dominant_isf)


In [None]:
conc_00 = file_conc_cut['isfdraft_conc'].isel(time=30)
dx = abs(conc_00.x[1] - conc_00.x[0])
dy = abs(conc_00.y[1] - conc_00.y[0])

In [None]:
split_regions = [70,77,83,89] 

for rreg in split_regions:
    labels_same = list(new_mask.groupby(labelled_isf).groups) * (new_mask.groupby(labelled_isf).median() == rreg)
    labels_same = labels_same[labels_same>0]

    area_before = 0
    for conn_label in labels_same:
        conc_for_area = conc_00.where(labelled_isf == conn_label, drop=True)
        area_now = (conc_for_area * dx * dy).sum()
        if area_now >= area_before:
            area_before = area_now
            largest_label = conn_label

    for small_label in (labels_same.where(labels_same != largest_label).dropna('labels')):
        new_mask = new_mask.where(labelled_isf != small_label, 159)

In [None]:
for kisf in summary_mask_basins['Nisf']:
    if (new_mask.groupby(labelled_isf).median() == 70).sum() > 1:
        new_mask.groupby(labelled_isf)

In [None]:
(new_mask.groupby(labelled_isf).median() == 70)

In [None]:
(new_mask.groupby(labelled_isf).where(new_mask.groupby(labelled_isf).max() == 70)).plot()

In [None]:
labels_same = list(new_mask.groupby(labelled_isf).groups) * (new_mask.groupby(labelled_isf).median() == 70)
labels_same = labels_same[labels_same>0]


In [None]:
new_mask.plot()

In [None]:
dx = abs(conc_00.x[1] - conc_00.x[0])
dy = abs(conc_00.y[1] - conc_00.y[0])

area_before = 0
for conn_label in labels_same:
    conc_for_area = conc_00.where(labelled_isf == conn_label, drop=True)
    area_now = (conc_for_area * dx * dy).sum()
    if area_now >= area_before:
        area_before = area_now
        largest_label = conn_label

test_mask = new_mask.copy()
for small_label in (labels_same.where(labels_same != largest_label).dropna('labels')):
    print(small_label.values)
    test_mask = test_mask.where(labelled_isf != small_label, 159)

In [None]:
test_mask.plot()

In [None]:
conc_00.where(labelled_isf == 80).plot()

In [None]:
largest_label

In [None]:
new_ground_mask = isfmf.def_ground_mask(file_mask_cut['ls_mask012'].isel(time=30), 40, 120)

In [None]:
plt.contour(new_ground_mask.x,new_ground_mask.y,new_ground_mask,levels=[0,1],linewidths=0.5,colors='black',zorder=10)
file_conc_cut['isfdraft_conc'].isel(time=0).where(new_mask != 99).plot()

In [None]:
plt.contour(new_ground_mask.x,new_ground_mask.y,new_ground_mask,levels=[0,1],linewidths=0.5,colors='black',zorder=10)
summary_mask_basins['ID_isf'].plot(vmin=26,vmax=28)

In [None]:
plt.contour(new_ground_mask.x,new_ground_mask.y,new_ground_mask,levels=[0,1],linewidths=0.5,colors='black',zorder=10)
new_mask.plot(vmin=26,vmax=28)

In [None]:
plt.contour(new_ground_mask.x,new_ground_mask.y,new_ground_mask,levels=[0,1],linewidths=0.5,colors='black',zorder=10)
conc_00.where(labelled_isf == 80).plot()

In [None]:
file_conc_cut['isfdraft_conc'].isel(time=0).where(new_mask != 151).plot()

In [None]:
all_isf_areas = new_mask.groupby(new_mask).count()
large_isf_Nisf = all_isf_areas.where(all_isf_areas > 100, drop=True).ID_isf

In [None]:
for idx in large_isf_Nisf:
    if idx not in [12, 96, 158]:
        print(idx.values, summary_mask_basins['name_isf'].sel(Nisf=idx).values)

In [None]:
summary_mask_basins['ID_isf'].where(labelled_isf == 33, drop=True).plot()

In [None]:
conn_label = 47
new_mask = isf_mask_basins.copy()
basins_conn_domain = summary_mask_basins['ID_isf'].where(labelled_isf == conn_label, drop=True)
max_label = basins_conn_domain.max().values
min_label = basins_conn_domain.min().values
    
if max_label != min_label:
    groups_isf = basins_conn_domain.groupby(basins_conn_domain)
    groups_labels = groups_isf.groups.keys()
    if groups_isf.count().ID_isf.count() > 1:
        if any(x in problem_regions for x in list(groups_labels)):
            print(conn_label)
            print(min_label,max_label)
            dominant_isf = groups_isf.count().idxmax().values
            print(dominant_isf)
            new_mask = new_mask.where(labelled_isf != conn_label, dominant_isf)
            


In [None]:
new_mask.where(labelled_isf == conn_label).plot()

In [None]:
groups_isf.count().idxmax()

In [None]:
basins_conn_domain = summary_mask_basins['ID_isf'].where(labelled_isf == 47, drop=True)

In [None]:
basins_conn_domain.plot()

In [None]:
max_area != min_area

In [None]:
groups_isf.count().ID_isf.count()

In [None]:
for ii in groups.groups:
    print(ii)

In [None]:
if 4 in groups:
    print('yes')

In [None]:
basins_conn_area = summary_mask_basins['ID_isf'].where(labelled_isf == 0, drop=True)

In [None]:
groups = basins_conn_area.groupby(basins_conn_area)

In [None]:
groups

In [None]:
file_conc_cutbasins_conn_area.plot()


In [None]:
file_conc_cut['isfdraft_conc'].isel(time=0).where(basins_conn_domain, drop=True).plot()

In [None]:
isf_only_mask_00.where(basins_conn_domain, drop=True).plot()

In [None]:
basins_conn_area.plot()

In [None]:
file_conc_cut['isfdraft_conc'].isel(time=0).where(summary_mask_basins['ID_isf'] < 2).plot()

In [None]:
for idx in summary_mask_basins['name_isf'].Nisf:
    print(idx.values, summary_mask_basins['name_isf'].sel(Nisf=idx).values)

In [None]:
isf_mask_basins.plot()