In [None]:
import xarray as xr
import param
import numpy

from scipy.ndimage import measurements
from scipy.signal import convolve2d

import holoviews as hv
colormaps = hv.plotting.list_cmaps()

import hvplot.xarray
from holoviews.selection import link_selections
from holoviews import opts
opts.defaults(
    opts.Image(
        # Values taken from holoviews.Store.custom_options for a xarray.Dataset.hvplot()
        colorbar=True,
        height=300,
        logx=False,
        logy=False,
        responsive=True,
        aspect=2,
        shared_axes=True,
        show_grid=False,
        show_legend=True,
        tools=['hover','lasso_select', 'box_select'], # Default = hover 
    )
)

import panel as pn

import io
import os

In [None]:
class ValueChanger(param.Parameterized):
    
    # How we are going to modify the values
    # Absolute => Set to that value
    # Relatif => Base value + new value
    # Percentage => Base value + percentage
    calculation_type = pn.widgets.RadioButtonGroup(options=['Absolute', 'Relatif', 'Percentage'], align='end')
    # Replacement value
    spinner = pn.widgets.Spinner(name='Replacement Value', value=0, align='start')
    
    # Buttons
    save = pn.widgets.FileDownload(label='Save', align='end', button_type='success')
    apply = pn.widgets.Button(name='Apply', align='end', button_type='primary')
    # Mask
    mask = pn.widgets.Checkbox(name='Mask', max_width=100, align='end')
    mask_value = pn.widgets.Spinner(name='Mask Value', value=0)
    # Show extra graphs
    show_internal_oceans = pn.widgets.Checkbox(name='Show Internal Oceans', align='start')
    show_passage_problems = pn.widgets.Checkbox(name='Show Diffusion Passages ', align='start')
    # Store the variable we want to look at and modify
    attribute = param.String()
    # Load the file from disk
    file = param.Parameter()
    # Choose colormap
    colormap = param.String()
    colormap_min = pn.widgets.Spinner(name='Min Value')
    colormap_max = pn.widgets.Spinner(name='Max Value')
    nb_discrete_colormap = pn.widgets.Spinner(name='Number color intervals', start=1, value=1)
    # Holoviews.DataSet => Data
    ds = param.Parameter()
    # Link the viewing of multiple graphs together
    selection = link_selections.instance(unselected_alpha=0.4)
    
    # Used to store when inital data is loaded
    loaded = param.Parameter()
    
    # Parts of the display
    file_pane = pn.Row()
    graph_pane = pn.Column()
    options_pane = pn.Column()
    
    def __init__(self, **params):
        self.param.file.default = pn.widgets.FileInput()
        self.param.ds.default = xr.Dataset()
        self.param.loaded.default = False
        self.param.attribute.default = pn.widgets.Select(max_width=200, align='end')
        self.param.colormap.default = pn.widgets.Select(name='Colormap', options=colormaps, value='terrain', max_width=200, align='end')
        super().__init__(**params)
        self.apply.on_click(self._apply_values)
        self.save.callback = self._save
        self.file_pane.append(self.file)
        self.curvilinear_coordinates = None
        
    @pn.depends("file.value", watch=True)
    def _parse_file_input(self):
        self.loaded = False
        value = self.file.value
        # We are dealing with a h5netcdf file ->
        # The reader can't read bytes so we need to write it to a file like object
        if value.startswith(b"\211HDF\r\n\032\n"):
            value = io.BytesIO(value)
        ds = xr.open_dataset(value)
        self.attribute.options = list(ds.keys())
        self.curvilinear_coordinates = None
        
        number_coordinates_in_system = len(list(ds.coords.variables.values())[0].dims)
        # Standard Grid
        if number_coordinates_in_system == 1:
            pass
        # Curvilinear coordinates
        elif number_coordinates_in_system == 2:
            dims = list(ds[list(ds.coords)[0]].dims)
            # Store the true coordinates for export
            self.curvilinear_coordinates = list(ds.coords)
            # Add the dimension into the coordinates this results in an ij indexing
            ds.coords[dims[0]] = ds[dims[0]]
            ds.coords[dims[1]] = ds[dims[1]]
            # Remove the curvilinear coordinates from the original coordinates
            ds = ds.reset_coords()
        else:
            raise ValueError("Unknown number of Coordinates")
        self.ds = ds
        self.colormap_min.value = int(ds[self.attribute.value].min())
        self.colormap_max.value = int(ds[self.attribute.value].max())
        self.loaded = True
        return True
        
    def _set_values(self):
        hvds = hv.Dataset(self.ds.to_dataframe(dim_order=[*list(self.ds[self.attribute.value].dims)]).reset_index())
        if self.calculation_type.value == 'Absolute':
            hvds.data[self.attribute.value].loc[hvds.select(self.selection.selection_expr).data.index] = self.spinner.value
        elif self.calculation_type.value == 'Relatif':
            hvds.data[self.attribute.value].loc[hvds.select(self.selection.selection_expr).data.index] += self.spinner.value
        elif self.calculation_type.value == 'Percentage':
            hvds.data[self.attribute.value].loc[hvds.select(self.selection.selection_expr).data.index] *=  (100 + self.spinner.value) / 100.
        self.ds[self.attribute.value] = list(self.ds[self.attribute.value].dims), hvds.data[self.attribute.value].values.reshape(*self.ds[self.attribute.value].shape)
        # ds is modified inplace and not by reassignment this means that get plots isn't retriggered
        # To manually force get_plots to be triggered we change the value of loaded
        # to false -> nothing happens
        # then to true saying the plot is ready
        self.loaded = False 
        self.loaded = True
        
    def _save(self):
        filename, extension = os.path.splitext(self.file.filename) 
        self.save.filename = filename + "_netcdf-editor" + extension
        ds = self.ds
        # We need to remove the dimension coordinates and reset the curvilinear coordinates
        if self.curvilinear_coordinates is not None:
            ds = self.ds.drop([*self.ds.dims]).set_coords([*self.curvilinear_coordinates])
        return io.BytesIO(ds.to_netcdf())
    
    def _apply_values(self, event):
        self._set_values()
        self.selection.selection_expr = None
        
    def _get_ordered_coordinate_dimension_names(self):
        dimension_names = list(self.ds.coords)
        if 'lat' in dimension_names[0].lower() and 'lon' in dimension_names[1].lower():
            dimension_names = dimension_names[::-1]
        elif 'x' == dimension_names[1].lower() or 'y' == dimension_names[0].lower():
            dimension_names = dimension_names[::-1]
        return dimension_names
        
    @pn.depends("file.filename", watch=True)
    def _toggle_save(self):
        if self.file.filename and len(self.file_pane) == 1:
            self.file_pane.append(self.save)
        elif not self.file.filename and len(self.file_pane) == 2:
            self.file_pane.pop(1)
            
    def _calculate_internal_oceans(self):
        # Calculate a binary array of above and below see level
        # from scipy doc:  Any non-zero values in `input` are
        # counted as features and zero values are considered the background.
        # This is why we choose ocean = True
        ocean = self.ds[self.attribute.value] <= 0
        
        # Use scipy to calculate internal oceans
        labeled_array, num_features = measurements.label(ocean)
        
        # Replace continents with numpy.NaN
        # Originally they are ints or floats and numpy.NaN can't be set
        labeled_array = labeled_array.astype(object)
        # continents have a value of 0
        labeled_array[labeled_array==0] = numpy.NaN
        
        return labeled_array
    
    def _calculate_passage_problems(self):
        # Define template we are looking for passages
        # Where only diffusion occurs this means we are looking
        # for ocean passages one in width/height
        # 1 => Ocean
        # -1 => Land
        # 0 = Indifferent
        template = numpy.array([[0, 1, 0], 
                                [-1,1,-1], 
                                [0, 1, 0]])

        # Theoretical max value when the template is found
        # Note that 0s are considered wildcards so they are not taken into
        # Account 
        #TODO this only works on data arrays where the absolute values are 1
        perfect_match = numpy.sum(numpy.abs(template))

        # we recode the values of land to -1 as
        # we did in the template
        values = (self.ds[self.attribute.value].values <= 0).astype(int)
        values[values == 0] = -1

        # Create an empty array where we are going to stock the values
        #TODO This could potentially by a binary array??
        potential_points = numpy.empty(values.shape)
        potential_points[:] = numpy.nan

        # Mark points where there is only diffusion in longitude direction
        convolvedh = convolve2d(values, template, 'same')
        potential_points[convolvedh == perfect_match] = 1

        # Mark points where there is only diffusion in latitude direction
        convolvedv = convolve2d(values, template.T, 'same')
        potential_points[convolvedv == perfect_match] = 1
        
        return potential_points
        
    @pn.depends("file.filename", watch=True)
    def _toggle_options_pane(self):
        self.options_pane.clear()
        if self.file.filename is not None:
            self.options_pane.extend([
                pn.Row(self.attribute),
                pn.Row(self.colormap, self.colormap_min, self.colormap_max, self.nb_discrete_colormap),
                pn.Row(self.mask, self.mask_value), 
                pn.Row(self.show_internal_oceans, self.show_passage_problems),
                pn.Row(self.calculation_type, self.spinner, self.apply), 
            ])
      
    @pn.depends('ds', 
                'colormap.value',
                'mask.value', 
                'mask_value.value',
                'show_internal_oceans.value',
                'show_passage_problems.value',
                'colormap_min.value',
                'colormap_max.value',
                'nb_discrete_colormap.value',
                'loaded',
                watch=True)
    def get_plots(self):
        if not self.loaded:
            return
        attriute_image = hv.Image(
            self.ds[self.attribute.value],
            [*self._get_ordered_coordinate_dimension_names()]
        )
        # Calculate Ticks
        ydim, xdim = self.ds[self.attribute.value].dims
        xvals = self.ds[xdim].values
        yvals = self.ds[ydim].values
        x_ticks = (xvals[1:] + xvals[:-1]) / 2
        y_ticks = (yvals[1:] + yvals[:-1]) / 2
        # Setup a grid style
        grid_style = {'grid_line_color': 'black', 'grid_line_width': 1,
                'xgrid_ticker': x_ticks, 'ygrid_ticker': y_ticks}
        
        range_dict = {}
        range_dict[self.attribute.value] = (self.colormap_min.value, self.colormap_max.value)
        if self.mask.value:
            range_dict[self.attribute.value] = (self.mask_value.value, self.mask_value.value)
        attriute_image = attriute_image.opts(clipping_colors = {'min': 'lightgray', 'max': 'black'}).redim.range(**range_dict)
        
        graphs = attriute_image
        
        if self.show_internal_oceans.value:
            internal_oceans = self._calculate_internal_oceans()
            number_internal_oceans = numpy.nanmax(internal_oceans)
            internal_oceans = xr.DataArray(internal_oceans, self.ds.coords)
            internal_oceans_image = hv.Image(
                internal_oceans, 
                [*self._get_ordered_coordinate_dimension_names()],
                group="Internal_Oceans",
                label=f'Number Oceans: {number_internal_oceans}'
            ).opts(
                clipping_colors = {'NaN': (0,0,0,0.5)})
            graphs += internal_oceans_image
            
        if self.show_passage_problems.value:
            passage_problems = self._calculate_passage_problems()
            number_passage_problems = int(numpy.nansum(self._calculate_passage_problems()))
            passage_problems = xr.DataArray(passage_problems, self.ds.coords)
            passage_problems_image = hv.Image(
                passage_problems, 
                [*self._get_ordered_coordinate_dimension_names()], 
                group='Problems', 
                label =f"Number Diffusive passages {number_passage_problems}"
            )
            continent_background_image = hv.Image(
                (self.ds[self.attribute.value] <= 0).rename('continents'), 
                [*self._get_ordered_coordinate_dimension_names()], 
                group='Map',
            )
            graphs += continent_background_image * passage_problems_image
            
        self.graph_pane.clear()
        
        layout = self.selection(graphs + self.ds[self.attribute.value].hvplot.hist())
        
        layout.opts(
            hv.opts.Histogram(tools=['hover']),
            hv.opts.Image(tools=['hover'])
        ).cols(2)
        
        layout[0].opts(
            hv.opts.Image(
#                 symmetric=True, # Make variable map colormap symmetrical -> centered at 0
                cmap=self.colormap.value,
                show_grid=True,
                gridstyle=grid_style,
                alpha = 0.7
            ))
        
        if self.nb_discrete_colormap.value > 1:
            layout[0].opts(
                hv.opts.Image(
                    color_levels = int(self.nb_discrete_colormap.value)
                ))
        
        if self.show_passage_problems.value:
            layout[('Overlay', 'I')].opts(
                hv.opts.Image('Map', cmap='binary_r', color_levels=2, alpha=0.1, colorbar=False, tools=[]),
                hv.opts.Image('Problems', color_levels=3, colorbar=False, clipping_colors = {'NaN': (0,0,0,0)}, tools=[])
            )
            
        if self.show_internal_oceans.value:
            layout.opts(
                hv.opts.Image('Internal_Oceans', clipping_colors = {'NaN': '#dedede'}, colorbar=False)
            )
        
        self.graph_pane.append(
            layout
        )
    
    def plot(self):
        return pn.Column(
            self.file_pane,
            self.options_pane,
            self.graph_pane
        )

In [None]:
vc = ValueChanger()
vc.plot().servable('NetCDF Editor')