In [None]:
"""
Created on Wed Feb 22 10:37 2023

Try to apply script from Rieke from the CryoHackathon to better define the masks

@author: Clara Burgard
"""

In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import basal_melt_param.useful_functions as uf
import matplotlib as mpl
from tqdm.notebook import tqdm
import basal_melt_param.plume_functions as pf
from scipy.interpolate import griddata
import basal_melt_param.create_isf_mask_functions as isfmf
import cc3d


In [None]:
%matplotlib qt5

READ IN DATA

In [None]:
nemo_run = 'bi646'
map_lim = [-3000000,3000000]


In [None]:
outputpath_mask='/bettik/burgardc/DATA/NN_PARAM/interim/ANTARCTICA_IS_MASKS/SMITH_'+nemo_run+'/'
inputpath_data='/bettik/burgardc/DATA/NN_PARAM/interim/SMITH_'+nemo_run+'/'
inputpath_data2='/bettik/burgardc/DATA/BASAL_MELT_PARAM/interim/NEMO_eORCA025.L121_OPM016_ANT_STEREO/'

outputpath_boxes = '/bettik/burgardc/DATA/NN_PARAM/interim/BOXES/SMITH_'+nemo_run+'/'
inputpath_raw = '/bettik/burgardc/DATA/NN_PARAM/raw/'

In [None]:
file_isf = xr.open_dataset(outputpath_mask + 'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_yy54.nc')

file_mask = xr.open_dataset(inputpath_data+'custom_lsmask_Ant_stereo_clean.nc')#, chunks={'x': chunk_size, 'y': chunk_size})
file_mask_cut = uf.cut_domain_stereo(file_mask, map_lim, map_lim)

file_TS_orig = xr.open_dataset(inputpath_data + '3D_variables_of_interest_allyy_Ant_stereo_2000.nc')
file_TS_cut = uf.cut_domain_stereo(file_TS_orig, map_lim, map_lim)

In [None]:
file_conc = xr.open_dataset(inputpath_data+'isfdraft_conc_Ant_stereo.nc')
file_conc_cut = uf.cut_domain_stereo(file_conc, map_lim, map_lim)

In [None]:
isf_only_mask = file_conc_cut['isfdraft_conc'] > 0

In [None]:
isf_only_mask_00 = isf_only_mask.isel(time=0).drop('time')

In [None]:
#level of connectivity (4 or 8 for 2D)
connectivity = 8
#size of lonely pixel areas to remove # Rieke put 25
threshold = 5

In [None]:
#remove lonely pixels
#data is sliced to exclude time, data type needs to be int!!
dusted = cc3d.dust(isf_only_mask_00.values.astype(np.int64), 
                   threshold = threshold, 
                   connectivity = connectivity, 
                   in_place = False)
#find connected components
labels_out = cc3d.connected_components(dusted, 
                                       connectivity = connectivity)


In [None]:
labelled = xr.DataArray(labels_out, 
                        coords = {"y": file_conc_cut.y, "x": file_conc_cut.x}, 
                        dims = ["y", "x"],
                        name = "labels")

In [None]:
labelled.plot()

In [None]:
summary_mask_basins = xr.open_dataset('/bettik/burgardc/DATA/NN_PARAM/interim/basins_mask_extrap_50km.nc')


In [None]:
isf_mask_basins = summary_mask_basins['ID_isf'].where(isf_only_mask_00 > 0)

In [None]:
labelled_isf = labelled.where(np.isfinite(isf_mask_basins))

In [None]:
labelled_isf.plot()

In [None]:
isf_only_mask_00.plot()

In [None]:
combined_test_isf_mask = isf_mask_basins.copy()

In [None]:
labels_out.max()

In [None]:
isf_mask_basins.plot()

In [None]:
groups= basins_conn_area.groupby(basins_conn_area)#.count()

In [None]:
groups.count().ID_isf.count()

MERGE ROSS AND FRIS

In [None]:
new_mask = isf_mask_basins.copy()

In [None]:
new_mask = new_mask.where(new_mask != 58, 57).where(new_mask != 104, 103)

In [None]:
summary_mask_basins['name_isf'].loc[{'Nisf': 57}] = 'Ross'
summary_mask_basins['name_isf'].loc[{'Nisf': 58}] = np.nan
summary_mask_basins['name_isf'].loc[{'Nisf': 103}] = 'Filchner-Ronne'
summary_mask_basins['name_isf'].loc[{'Nisf': 104}] = np.nan
summary_mask_basins['name_isf'] = summary_mask_basins['name_isf'].dropna('Nisf')

In [None]:
problem_regions = [70,71,72,73,74,76,77,78,83,84,85,89,91]

In [None]:
for conn_label in range(1,labels_out.max()):
    basins_conn_domain = summary_mask_basins['ID_isf'].where(labelled_isf == conn_label, drop=True)
    max_label = basins_conn_domain.max().values
    min_label = basins_conn_domain.min().values

    if max_label != min_label:
        groups_isf = basins_conn_domain.groupby(basins_conn_domain)
        groups_labels = groups_isf.groups.keys()
        if groups_isf.count().ID_isf.count() > 1:
            if any(x in problem_regions for x in list(groups_labels)):
                #print(conn_label)
                #print(min_label,max_label)
                dominant_isf = groups_isf.count().idxmax().values
                #print(dominant_isf)
                new_mask = new_mask.where(labelled_isf != conn_label, dominant_isf)


In [None]:
new_mask.plot()

In [None]:
new_ground_mask = isfmf.def_ground_mask(file_mask_cut['ls_mask012'].isel(time=0), 40, 120)

In [None]:
plt.contour(new_ground_mask.x,new_ground_mask.y,new_ground_mask,levels=[0,1],linewidths=0.5,colors='black',zorder=10)
new_mask.where(new_mask == 99).plot()

In [None]:
all_isf_areas = new_mask.groupby(new_mask).count()
large_isf_Nisf = all_isf_areas.where(all_isf_areas > 100, drop=True).ID_isf

In [None]:
for idx in large_isf_Nisf:
    if idx not in [12, 96, 158]:
        print(idx.values, summary_mask_basins['name_isf'].sel(Nisf=idx).values)

In [None]:
summary_mask_basins['ID_isf'].where(labelled_isf == 33, drop=True).plot()

In [None]:
conn_label = 47
new_mask = isf_mask_basins.copy()
basins_conn_domain = summary_mask_basins['ID_isf'].where(labelled_isf == conn_label, drop=True)
max_label = basins_conn_domain.max().values
min_label = basins_conn_domain.min().values
    
if max_label != min_label:
    groups_isf = basins_conn_domain.groupby(basins_conn_domain)
    groups_labels = groups_isf.groups.keys()
    if groups_isf.count().ID_isf.count() > 1:
        if any(x in problem_regions for x in list(groups_labels)):
            print(conn_label)
            print(min_label,max_label)
            dominant_isf = groups_isf.count().idxmax().values
            print(dominant_isf)
            new_mask = new_mask.where(labelled_isf != conn_label, dominant_isf)
            


In [None]:
new_mask.where(labelled_isf == conn_label).plot()

In [None]:
groups_isf.count().idxmax()

In [None]:
basins_conn_domain = summary_mask_basins['ID_isf'].where(labelled_isf == 47, drop=True)

In [None]:
basins_conn_domain.plot()

In [None]:
max_area != min_area

In [None]:
groups_isf.count().ID_isf.count()

In [None]:
for ii in groups.groups:
    print(ii)

In [None]:
if 4 in groups:
    print('yes')

In [None]:
basins_conn_area = summary_mask_basins['ID_isf'].where(labelled_isf == 0, drop=True)

In [None]:
groups = basins_conn_area.groupby(basins_conn_area)

In [None]:
groups

In [None]:
file_conc_cutbasins_conn_area.plot()


In [None]:
file_conc_cut['isfdraft_conc'].isel(time=0).where(basins_conn_domain, drop=True).plot()

In [None]:
isf_only_mask_00.where(basins_conn_domain, drop=True).plot()

In [None]:
basins_conn_area.plot()

In [None]:
file_conc_cut['isfdraft_conc'].isel(time=0).where(summary_mask_basins['ID_isf'] < 2).plot()

In [None]:
for idx in summary_mask_basins['name_isf'].Nisf:
    print(idx.values, summary_mask_basins['name_isf'].sel(Nisf=idx).values)

In [None]:
isf_mask_basins.plot()