In [None]:
import xarray as xr
import param
import numpy

from scipy.ndimage import measurements
from scipy.signal import convolve2d

import holoviews as hv
colormaps = hv.plotting.list_cmaps()

import hvplot.xarray
from holoviews.selection import link_selections
from holoviews import opts
opts.defaults(
    opts.Image(
        # Values taken from holoviews.Store.custom_options for a xarray.Dataset.hvplot()
        colorbar=True,
        height=300,
        logx=False,
        logy=False,
        responsive=True,
        aspect=2,
        shared_axes=True,
        show_grid=False,
        show_legend=True,
        tools=['hover','lasso_select', 'box_select'], # Default = hover 
    )
)

import panel as pn

from bokeh.models import FixedTicker

import io
import os

In [None]:
class ValueChanger(param.Parameterized):
    
    # How we are going to modify the values
    # Absolute => Set to that value
    # Relatif => Base value + new value
    # Percentage => Base value + percentage
    calculation_type = pn.widgets.RadioButtonGroup(options=['Absolute', 'Relatif', 'Percentage'], align='end')
    # Replacement value
    spinner = pn.widgets.IntInput(name='Replacement Value', value=0, align='start')
    
    # Buttons
    save = pn.widgets.FileDownload(label='Save', align='end', button_type='success')
    apply = pn.widgets.Button(name='Apply', align='end', button_type='primary')
    undo_button = pn.widgets.Button(name='↶ Undo', align='end', button_type='warning')
    # Mask
    mask = pn.widgets.Checkbox(name='Mask', max_width=100, align='end')
    mask_value = pn.widgets.IntInput(name='Mask Value', value=0)
    # Show extra graphs
    show_internal_oceans = pn.widgets.Checkbox(value=True, name='Show Internal Oceans', align='start')
    show_passage_problems = pn.widgets.Checkbox(value=True, name='Show Diffusion Passages ', align='start')
    # Store the variable we want to look at and modify
    attribute = param.String()
    # Load the file from disk
    file = param.Parameter()
    # Choose colormap
    colormap = pn.widgets.Select(name='Colormap', options=colormaps, value='terrain', max_width=200, align='end')
    colormap_min = pn.widgets.IntInput(name='Min Value', width=100)
    colormap_max = pn.widgets.IntInput(name='Max Value', width=100, align='end')
    colormap_range_slider = pn.widgets.RangeSlider(width=400, show_value=False)
    colormap_delta = pn.widgets.IntInput(name='Delta between values', value=0, align='end')
    # Holoviews.DataSet => Data
    ds = param.Parameter()
    # Link the viewing of multiple graphs together
    selection = link_selections.instance(unselected_alpha=0.4)
    
    # Used to store when inital data is loaded
    loaded = param.Parameter()
    
    # Parts of the display
    file_pane = pn.Row()
    graph_pane = pn.Column()
    options_pane = pn.Column()
    
    def __init__(self, **params):
        self.param.file.default = pn.widgets.FileInput()
        self.param.ds.default = xr.Dataset()
        self.param.loaded.default = False
        self.param.attribute.default = pn.widgets.Select(name='Variable', max_width=200, align='end')
        super().__init__(**params)
        self.apply.on_click(self._apply_values)
        self.undo_button.on_click(self.undo)
        self.save.callback = self._save
        self.file_pane.append(self.file)
        
        self.curvilinear_coordinates = None
        self._undo_list = []
        
        self.colormap_min.param.watch(self._colormap_callback, 'value')
        self.colormap_max.param.watch(self._colormap_callback, 'value')
        self.colormap_range_slider.param.watch(self._colormap_callback, 'value')
        
    def _colormap_callback(self, *events):
        event = events[0]
        if event.obj == self.colormap_min:
            vals = list(self.colormap_range_slider.value)
            vals[0] = int(event.new)
            self.colormap_range_slider.value = tuple(vals)
        elif event.obj == self.colormap_max:
            vals = list(self.colormap_range_slider.value)
            vals[1] = int(event.new)
            self.colormap_range_slider.value = tuple(vals)
        elif event.obj == self.colormap_range_slider:
            vals = self.colormap_range_slider.value
            self.colormap_min.value = int(vals[0])
            self.colormap_max.value = int(vals[1])
        
    @pn.depends("file.value", watch=True)
    def _parse_file_input(self):
        self.loaded = False
        value = self.file.value
        # We are dealing with a h5netcdf file ->
        # The reader can't read bytes so we need to write it to a file like object
        if value.startswith(b"\211HDF\r\n\032\n"):
            value = io.BytesIO(value)
        ds = xr.open_dataset(value)
        self.attribute.options = list(ds.keys())
        self.curvilinear_coordinates = None
        
        number_coordinates_in_system = len(list(ds.coords.variables.values())[0].dims)
        # Standard Grid
        if number_coordinates_in_system == 1:
            pass
        # Curvilinear coordinates
        elif number_coordinates_in_system == 2:
            dims = list(ds[list(ds.coords)[0]].dims)
            # Store the true coordinates for export
            self.curvilinear_coordinates = list(ds.coords)
            # Add the dimension into the coordinates this results in an ij indexing
            ds.coords[dims[0]] = ds[dims[0]]
            ds.coords[dims[1]] = ds[dims[1]]
            # Remove the curvilinear coordinates from the original coordinates
            ds = ds.reset_coords()
        else:
            raise ValueError("Unknown number of Coordinates")
        self.ds = ds
        self._original_ds = ds.copy(deep=True)
        min_value = int(ds[self.attribute.value].min())
        max_value = int(ds[self.attribute.value].max())
        self.colormap_min.value = min_value
        self.colormap_max.value = max_value
        self.colormap_range_slider.start = min_value
        self.colormap_range_slider.end = max_value
        self.loaded = True
        return True
        
    def _set_values(self, value, calculation_type, selection_expr):
        hvds = hv.Dataset(self.ds.to_dataframe(dim_order=[*list(self.ds[self.attribute.value].dims)]).reset_index())
        if calculation_type == 'Absolute':
            hvds.data[self.attribute.value].loc[hvds.select(selection_expr).data.index] = value
        elif calculation_type == 'Relatif':
            hvds.data[self.attribute.value].loc[hvds.select(selection_expr).data.index] += value
        elif calculation_type == 'Percentage':
            hvds.data[self.attribute.value].loc[hvds.select(selection_expr).data.index] *=  (100 + value) / 100.
        self.ds[self.attribute.value] = list(self.ds[self.attribute.value].dims), hvds.data[self.attribute.value].values.reshape(*self.ds[self.attribute.value].shape)
        ds = self.ds.copy(deep=True)
        self.ds = ds
        
    def _save(self):
        filename, extension = os.path.splitext(self.file.filename) 
        self.save.filename = filename + "_netcdf-editor" + extension
        ds = self.ds
        # We need to remove the dimension coordinates and reset the curvilinear coordinates
        if self.curvilinear_coordinates is not None:
            ds = self.ds.drop([*self.ds.dims]).set_coords([*self.curvilinear_coordinates])
        return io.BytesIO(ds.to_netcdf())
    
    def undo(self, event):
        # Nothing in the undo list
        if not len(self._undo_list):
            return 
        
        # Get the last action in the undo list
        undo_action = self._undo_list.pop()
        
        # If it is 'Absolute' Change we don't stock the
        # initial values so we have to run all the steps up to this one to 
        # undo this change
        if undo_action['calculation_type'] in ['Absolute'] :
            # We reset the dataset to it's initial value
            self.ds = self._original_ds.copy(deep=True)
            # We apply each step one by one
            for action in self._undo_list:
                self._apply_action(action)

        elif undo_action['calculation_type'] in ['Percentage', 'Relatif'] :
            # Apply the opposite transformation
            undo_action['value'] *= -1
            self._apply_action(undo_action)
        else:
            raise ValueError("Can not undo action, unknown calculation type {}".format(undo_action['calculation_type']))
            
    def _apply_action(self, action):
        if action['calculation_type'] in ['Absolute', 'Percentage', 'Relatif']:
            self._set_values(
                value = action['value'], 
                calculation_type = action['calculation_type'],
                selection_expr = action['selection_expr']
            )
        else:
            raise ValueError("Cannot apply step {}, unknown calculation_type {}". format(action, action['calculation_step']))
    
    def _apply_values(self, event):
        if self.selection.selection_expr is None:
            return
        action = {
            'selection_expr': self.selection.selection_expr,
            'calculation_type': self.calculation_type.value,
            'value': self.spinner.value
        }
        # Add the action to the list of undo actions
        self._undo_list.append(action)
        # Apply the action
        self._apply_action(action)
        self.selection.selection_expr = None
        
    def _get_ordered_coordinate_dimension_names(self):
        dimension_names = list(self.ds.coords)
        if 'lat' in dimension_names[0].lower() and 'lon' in dimension_names[1].lower():
            dimension_names = dimension_names[::-1]
        elif 'x' == dimension_names[1].lower() or 'y' == dimension_names[0].lower():
            dimension_names = dimension_names[::-1]
        return dimension_names
        
    @pn.depends("file.filename", watch=True)
    def _toggle_save(self):
        if self.file.filename and len(self.file_pane) == 1:
            self.file_pane.append(self.save)
        elif not self.file.filename and len(self.file_pane) == 2:
            self.file_pane.pop(1)
            
    def _calculate_internal_oceans(self):
        # Calculate a binary array of above and below see level
        # from scipy doc:  Any non-zero values in `input` are
        # counted as features and zero values are considered the background.
        # This is why we choose ocean = True
        ocean = self.ds[self.attribute.value] <= 0
        
        # Use scipy to calculate internal oceans
        labeled_array, num_features = measurements.label(ocean)
        
        # Replace continents with numpy.NaN
        # Originally they are ints or floats and numpy.NaN can't be set
        labeled_array = labeled_array.astype(object)
        # continents have a value of 0
        labeled_array[labeled_array==0] = numpy.NaN
        
        return labeled_array
    
    def _calculate_passage_problems(self):
        # Define template we are looking for passages
        # Where only diffusion occurs this means we are looking
        # for ocean passages one in width/height
        # 1 => Ocean
        # -1 => Land
        # 0 = Indifferent
        template = numpy.array([[0, 1, 0], 
                                [-1,1,-1], 
                                [0, 1, 0]])

        # Theoretical max value when the template is found
        # Note that 0s are considered wildcards so they are not taken into
        # Account 
        #TODO this only works on data arrays where the absolute values are 1
        perfect_match = numpy.sum(numpy.abs(template))

        # we recode the values of land to -1 as
        # we did in the template
        values = (self.ds[self.attribute.value].values <= 0).astype(int)
        values[values == 0] = -1

        # Create an empty array where we are going to stock the values
        #TODO This could potentially by a binary array??
        potential_points = values
#         potential_points[:] = numpy.nan

        # Mark points where there is only diffusion in longitude direction
        convolvedh = convolve2d(values, template, 'same')
        potential_points[convolvedh == perfect_match] = 2

        # Mark points where there is only diffusion in latitude direction
        convolvedv = convolve2d(values, template.T, 'same')
        potential_points[convolvedv == perfect_match] = 2
        
        potential_points = potential_points.astype(object)
        potential_points[potential_points == -1] = numpy.NaN
        
        return potential_points
        
    @pn.depends("file.filename", watch=True)
    def _toggle_options_pane(self):
        self.options_pane.clear()
        if self.file.filename is not None:
            self.options_pane.extend([
                pn.Row(self.attribute),
                pn.Row(self.colormap, pn.Column(pn.Row(self.colormap_min, pn.layout.HSpacer(), self.colormap_max), self.colormap_range_slider), self.colormap_delta),
                pn.Row(self.mask, self.mask_value), 
                pn.Row(self.show_internal_oceans, self.show_passage_problems),
                pn.Row(self.calculation_type, self.spinner, self.apply, self.undo_button), 
            ])
            
    def get_grid_style(self):
        # Calculate Ticks
        ydim, xdim = self.ds[self.attribute.value].dims
        xvals = self.ds[xdim].values
        yvals = self.ds[ydim].values
        x_ticks = (xvals[1:] + xvals[:-1]) / 2
        y_ticks = (yvals[1:] + yvals[:-1]) / 2
        # Setup a grid style
        grid_style = {
            'grid_line_color': 'black', 'grid_line_width': 1,
            'xgrid_ticker': x_ticks, 'ygrid_ticker': y_ticks
        }
        return grid_style
            
    def _clims(self):
        if self.mask.value:
            return self.mask_value.value, self.mask_value.value
        else:
            return self.colormap_min.value, self.colormap_max.value

    def _color_levels(self):
        if self.colormap_delta.value <= 0:
            return None
        return list(range(self.colormap_min.value, self.colormap_max.value, self.colormap_delta.value)) + [self.colormap_max.value]
            
    def _colorbar_opts(self):
        if self.colormap_delta.value <= 0:
            return {}
        ticks = self._color_levels()
        if len(ticks) > 8:
            ticks = ticks[::len(ticks)//8] + [ticks[-1]]
        # Add 0 to the ticks
        if self.colormap_min.value * self.colormap_max.value < 0: # Either side of 0
            ticks = numpy.insert(ticks, numpy.searchsorted(ticks, 0), 0)
        return {'ticker': FixedTicker(ticks=ticks)}
    
    @pn.depends('colormap.value', 'colormap_min.value', 'colormap_max.value', 'mask.value', 'mask_value.value', 'colormap_delta.value')
    def _opts(self, element):
        return element.opts(
            cmap=self.colormap.value,
            clim=self._clims(),
            color_levels=self._color_levels(),
            colorbar_opts=self._colorbar_opts(),
        )
    
    @pn.depends('ds')
    def load_attribute_map(self):
        return hv.Image(
                    self.ds[self.attribute.value],
                    [*self._get_ordered_coordinate_dimension_names()])
    
    @pn.depends('ds')
    def load_passage_problems(self):
        passage_problems = self._calculate_passage_problems()
        number_passage_problems = numpy.sum(passage_problems[passage_problems == 2])
        passage_problems = xr.DataArray(passage_problems, self.ds.coords)
        passage_problems_image = hv.Image(
            passage_problems, 
            [*self._get_ordered_coordinate_dimension_names()], 
            group='Passage_problems', 
            label =f"Number Diffusive Passage cells: {number_passage_problems}"
        )
        return passage_problems_image
    
    @pn.depends('ds')
    def load_internal_oceans(self):
        internal_oceans = self._calculate_internal_oceans()
        number_internal_oceans = numpy.nanmax(internal_oceans)
        internal_oceans = xr.DataArray(internal_oceans, self.ds.coords)
        internal_oceans_image = hv.Image(
            internal_oceans, 
            [*self._get_ordered_coordinate_dimension_names()],
            group="Internal_Oceans",
            label=f'Number Oceans: {number_internal_oceans}'
        )
        return internal_oceans_image
      
    @pn.depends('show_internal_oceans.value',
                'show_passage_problems.value',
                'loaded',
                watch=True)
    def get_plots(self):
        if not self.loaded:
            return

        attribute_image = hv.DynamicMap(self.load_attribute_map).apply(self._opts).opts(
                    clipping_colors={'min': 'lightgray', 'max': 'black'},
                    tools = ['hover']
                )

        graphs = attribute_image
        
        if self.show_internal_oceans.value:
            internal_oceans = hv.DynamicMap(self.load_internal_oceans).opts(
                hv.opts.Image('Internal_Oceans', clipping_colors = {'NaN': '#dedede', 'max': 'red', 'min': '#ffffff'}, clim=(1.2, 1.5), colorbar=False, tools=[])
            )
            graphs += internal_oceans
            
        if self.show_passage_problems.value:
            passage_problems = hv.DynamicMap(self.load_passage_problems).opts(
                hv.opts.Image('Passage_problems', clipping_colors = {'NaN': '#dedede', 'max': 'red', 'min': '#ffffff'}, clim=(1.2, 1.5), colorbar=False, tools=[])
            )
            graphs += passage_problems
        
        layout = self.selection(graphs + self.ds[self.attribute.value].hvplot.hist())
        
        layout.opts(
            hv.opts.Histogram(tools=['hover']),
            hv.opts.Image(
                tools=['hover', 'box_select', 'lasso_select'], 
                show_grid=True,
                gridstyle=self.get_grid_style(),
                alpha = 0.75
            )
        ).cols(2)
        
        self.graph_pane.clear()
        self.graph_pane.append(
            layout
        )
        
    def __repr__(self):
        return self.name
    
    def plot(self):
        return pn.Column(
            self.file_pane,
            self.options_pane,
            self.graph_pane
        )

In [None]:
vc = ValueChanger()
vc.plot().servable('NetCDF Editor')