In [1]:
%matplotlib widget

In [2]:
import logging
from pathlib import Path
import warnings

import ipywidgets as widgets
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from scipy.stats import zscore

# Third party ipympl causing this in it's backend_agg startup code
warnings.filterwarnings("ignore", message="nbagg.transparent is deprecated")

plt.ion()

In [12]:
import colorsys
import itertools

def distinct_cmap(n=33):
    def infinite_hues():
        yield 0
        for k in itertools.count():
            i = 2**k # zenos_dichotomy
            for j in range(1,i,2):
                yield j/i

    def hue_to_hsvs(h):
        # tweak ratios to adjust scheme
        s = 6/10
        for v in [6/10, 9/10]: 
            yield h, s, v

    hues = infinite_hues()
    hsvs = itertools.chain.from_iterable(hue_to_hsvs(hue) for hue in hues)
    rgbs = (colorsys.hsv_to_rgb(*hsv) for hsv in hsvs)
    return matplotlib.colors.ListedColormap(list(itertools.islice(rgbs, n)))

debug_view = widgets.Output(layout={'border': '1px solid black'})

class MetaClusterData():
    def __init__(self, path, cluster_filename, pixelcount_filename, output_mapping_filename):
        self.path = Path(path)
        self.output_mapping_filename = output_mapping_filename
        clusters_raw = pd.read_csv(self.path / cluster_filename).sort_values('cluster')
        self.cluster_counts = pd.read_csv(self.path / pixelcount_filename).sort_values('cluster').set_index('cluster')
        self._clusters  = clusters_raw.set_index('cluster').drop(columns='hCluster_cap')
        self.mapping = clusters_raw[['cluster', 'hCluster_cap']].set_index('cluster')

    @property
    def clusters_with_metaclusters(self):
        return self._clusters.join(self.mapping).sort_values(by='hCluster_cap')

    @property
    def clusters(self):
        return self.clusters_with_metaclusters.drop(columns='hCluster_cap')
    
    @property
    def metaclusters(self):
        weighted_clusters = self.clusters.multiply(self.cluster_counts['count'], axis=0)
        metacluster_counts = self.cluster_counts.join(self.mapping).groupby('hCluster_cap').aggregate('sum')
        weighted_metaclusters = weighted_clusters.join(self.mapping).groupby('hCluster_cap').aggregate('sum').divide(metacluster_counts['count'], axis=0)
        return weighted_metaclusters
    
    def cluster_in_metacluster(self, metacluster):
        return list(self.mapping[self.mapping['hCluster_cap'] == metacluster].index.values)
    
    def which_metacluster(self, cluster):
        return mcd.mapping.loc[cluster]['hCluster_cap']
    
    def remap(self, cluster, metacluster):
        self.mapping.loc[cluster, 'hCluster_cap'] = metacluster
        self.mapping.to_csv(self.path / self.output_mapping_filename)

class MetaClusterGui():
    def __init__(self, metaclusterdata, heatmapcolors='seismic', width=17):
        self.width = width
        self.heatmapcolors = heatmapcolors
        self.mcd = metaclusterdata
        self.make_gui()
        
    @property
    def max_zscore(self):
        return self.zscore_clamp_slider.value

    @property
    def cmap(self):
        return distinct_cmap(len(self.mcd.metaclusters))

    def preplot(self, df):
        return df.apply(zscore).clip(upper=self.max_zscore).T
       
    def make_gui(self):
        # Overall layout
        #  |    Cluster     | Meta |
        #  |    cc          |  mc  | counts of pixels
        #  |    c           |  m   | heatmap itself
        #  |    cs          |  ms  | selection markers
        #  |    cl          |  ml  | metacluster color labels
        self.fig, ((self.ax_cc, self.ax_mc), (self.ax_c, self.ax_m), (self.ax_cs, self.ax_ms), (self.ax_cl, self.ax_ml)) = plt.subplots(4,2,
            figsize=(self.width, 4.2),
            sharey=False,
            gridspec_kw={
                'width_ratios': [len(self.mcd.clusters), len(self.mcd.metaclusters)], # cluster plot bigger than metacluster plot
                'height_ratios': [5, len(self.mcd.clusters.columns), 1, 1]},
            )

        self.fig.canvas.toolbar_visible = False
        self.fig.canvas.header_visible = False
        self.fig.canvas.footer_visible = False
        
        self.fig.canvas.mpl_connect('pick_event', self.onpick)
        
        # heatmap axis
        self.ax_c.yaxis.set_tick_params(which='major', labelsize=8)
        self.ax_c.set_yticks(np.arange(len(self.mcd.clusters.columns))+0.5)
        self.ax_c.set_yticklabels(self.mcd.clusters.columns)
        self.ax_c.set_xticks(np.arange(len(self.mcd.clusters.index))+0.5)
        self.ax_m.set_xticks(np.arange(len(self.mcd.metaclusters.index))+0.5)
        self.ax_c.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
        self.ax_m.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
        
        # heatmaps
        self.im_c = self.ax_c.imshow(self.mcd.clusters.T, cmap=self.heatmapcolors, aspect='auto', picker=True)
        self.im_m = self.ax_m.imshow(self.mcd.metaclusters.T, cmap=self.heatmapcolors, aspect='auto', picker=True)
        self.ax_m.yaxis.set_tick_params(which='both', left=True, labelleft=False)
        
        # xaxis metacluster color labels
        self.ax_cl.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
        self.ax_ml.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
        self.ax_cl.yaxis.set_tick_params(which='both', left=False, labelleft=False)
        self.ax_ml.yaxis.set_tick_params(which='both', left=False, labelleft=False)
        
        self.im_cl = self.ax_cl.imshow([self.mcd.clusters_with_metaclusters['hCluster_cap']], aspect='auto', picker=True)
        self.im_ml = self.ax_ml.imshow([self.mcd.metaclusters.index], aspect='auto', picker=True)
      
        # xaxis cluster selection labels
        self.ax_cs.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
        self.ax_ms.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
        self.ax_cs.yaxis.set_tick_params(which='both', left=False, labelleft=False)
        self.ax_ms.yaxis.set_tick_params(which='both', left=False, labelleft=False)
        self.im_cs = self.ax_cs.imshow([[0.0 for _ in self.mcd.clusters.columns]], cmap='Blues', aspect='auto', picker=True, vmin=-0.3, vmax=1)
        self.im_ms = self.ax_ms.imshow([[0.0 for _ in self.mcd.clusters.columns]], cmap='Blues', aspect='auto', picker=True, vmin=-0.3, vmax=1)
        
        # xaxis pixelcount graphs
        self.ax_cc.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
        self.ax_mc.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
        self.ax_cc.yaxis.set_tick_params(which='both', left=False, labelleft=False)
        self.ax_mc.yaxis.set_tick_params(which='both', left=False, labelleft=False)
        self.ax_cc.set_xlim(0, len(mcd.clusters))
        self.rects_cc = self.ax_cc.bar([c - .5 for c in mcd.cluster_counts.index], mcd.cluster_counts['count'])

        
        # zscore adjuster
        self.zscore_clamp_slider = widgets.FloatSlider(
            value=3,
            min=1,
            max=10.0,
            step=0.5,
            description='Max Zscore:',
            disabled=False,
            continuous_update=True,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
            tooltip='Clamp/Clip zscore to a certain max value.',
        )
        self.zscore_clamp_slider.observe(self.update_zscore)
        display(self.zscore_clamp_slider)
        
        # selection widget
        self.selected_clusters = set()
        self.selected_clusters_widget = widgets.Text(disabled=True)
        #display(self.selected_clusters_widget)
        
        # clear_selection button
        self.clear_selection_button = widgets.Button(
            description='Clear Selection',
            disabled=False,
            button_style='warning',
            tooltip='Clear currently selected clusters',
            icon='ban',
            )
        self.clear_selection_button.on_click(self.clear_selection)
        display(self.clear_selection_button)

        # new metacluster button
        self.new_metacluster_button = widgets.Button(
            description='New metacluster',
            disabled=False,
            button_style='success',
            tooltip='Create new metacluster from current selection',
            icon='plus',
            )
        self.new_metacluster_button.on_click(self.new_metacluster)
        display(self.new_metacluster_button)
        
        # initilize data, etc
        self.update_gui()
        # Tighten layout based on display
        self.fig.tight_layout()
        plt.subplots_adjust(hspace = .0) # make color labels touch heatmap

    def enable_debug_mode(self):
        self.fig.canvas.footer_visible = True
        display(debug_view)
        
    def remap_current_selection(self, metacluster):
        for cluster in self.selected_clusters:
            print('remapping', cluster, metacluster)
            self.mcd.remap(cluster, metacluster)
        
    def update_gui(self):
        # clusters heatmap
        self.im_c.set_data(self.preplot(self.mcd.clusters))
        self.im_c.set_extent((0, len(self.mcd.clusters), 0, len(self.mcd.clusters.columns)))
        self.im_c.set_clim(0, self.max_zscore)
        
        # metaclusters heatmap
        _, _, vstart, vend = self.im_m.get_extent()
        self.im_m.set_extent((0, len(self.mcd.metaclusters), vstart, vend))
        self.im_m.set_data(self.preplot(self.mcd.metaclusters))
        self.im_m.set_extent((0, len(self.mcd.metaclusters), 0, len(self.mcd.metaclusters.columns)))
        self.im_m.set_clim(0, self.max_zscore)
        
        # xaxis metacluster color labels
        metacluster_iloc = {mc:i+1 for (mc,i) in zip(self.mcd.metaclusters.index, range(len(self.mcd.metaclusters.index)))}
        self.im_cl.set_data([[metacluster_iloc[mc] for mc in self.mcd.clusters_with_metaclusters['hCluster_cap']]])
        self.im_cl.set_extent((0, len(self.mcd.clusters), 0, len(self.mcd.metaclusters)))
        self.im_cl.set_cmap(self.cmap)
        self.im_cl.set_clim(1, len(self.mcd.metaclusters))
        
        self.im_ml.set_data([[metacluster_iloc[mc] for mc in self.mcd.metaclusters.index]])
        self.im_ml.set_extent((0, len(self.mcd.metaclusters), 0, len(self.mcd.metaclusters)))
        self.im_ml.set_cmap(self.cmap)
        self.im_ml.set_clim(1,len(self.mcd.metaclusters))
        
        # xaxis cluster selection labels
        selection_mask = [[1 if c in self.selected_clusters else 0 for c in self.mcd.clusters.index]]
        self.im_cs.set_data(selection_mask)
        self.im_cs.set_extent((0, len(self.mcd.clusters), 0, len(self.mcd.metaclusters.columns)))
        #self.im_ms.set_extent((0, len(self.mcd.metaclusters), vstart, vend))
        #self.im_ms.set_data([self.mcd.metaclusters.index])
        #self.im_ms.set_extent((0, len(self.mcd.metaclusters), 0, len(self.mcd.metaclusters.columns)))
        
        # xaxis pixelcount graphs
        sorted_pixel_counts = self.mcd.clusters.join(self.mcd.cluster_counts)['count']
        for rect, h in zip(self.rects_cc, sorted_pixel_counts):
            rect.set_height(h)
        
        # selection textbox
        self.selected_clusters_widget.value = ', '.join(str(c) for c in self.selected_clusters)

    @debug_view.capture(clear_output=False)
    def update_zscore(self, e):
        self.update_gui()
        
    @debug_view.capture(clear_output=False)
    def clear_selection(self, e):
        self.selected_clusters.clear()
        self.update_gui()

    @debug_view.capture(clear_output=False)
    def new_metacluster(self, e):
        metacluster = max(self.mcd.mapping['hCluster_cap']) + 1
        print(metacluster)
        self.remap_current_selection(metacluster)
        self.update_gui()

    @debug_view.capture(clear_output=False)
    def onpick(self, e):
        self.e = e
        if e.mouseevent.name != 'button_press_event':
            return
        if e.mouseevent.button == 1:
            self.onpick_select(e)
        elif e.mouseevent.button == 3:
            self.onpick_remap(e)
        self.update_gui()
    
    def onpick_select(self, e):
        selected_ix = int(e.mouseevent.xdata)
        if e.artist in [self.im_c, self.im_cs]:
            selected_cluster = self.mcd.clusters.index[selected_ix]
            # Toggle selection
            if selected_cluster in self.selected_clusters:
                self.selected_clusters.remove(selected_cluster)
            else:
                self.selected_clusters.add(selected_cluster)
        elif e.artist in [self.im_m, self.im_ml, self.im_ms]:
            self.select_metacluster(self.mcd.metaclusters.index[selected_ix])
        elif e.artist in [self.im_cl]:
            selected_cluster = self.mcd.clusters_with_metaclusters.index[selected_ix]
            self.select_metacluster(self.mcd.which_metacluster(cluster=selected_cluster))

    def select_metacluster(self, metacluster):        
        clusters = self.mcd.cluster_in_metacluster(metacluster)
        # Toggle entire metacluster
        if all(c in self.selected_clusters for c in clusters):
            # remove whole metacluster
            self.selected_clusters.difference_update(clusters)
        else:
            # select whole metacluster
            self.selected_clusters.update(clusters)

    def onpick_remap(self, e):
        selected_ix = int(e.mouseevent.xdata)
        metacluster = None
        if e.artist in [self.im_c, self.im_cs]:
            selected_cluster = self.mcd.clusters.index[selected_ix]
            metacluster = self.mcd.which_metacluster(cluster=selected_cluster)
        elif e.artist in [self.im_m, self.im_ml, self.im_ms]:
            metacluster = self.mcd.metaclusters.index[selected_ix]
        elif e.artist in [self.im_cl]:
            selected_cluster = self.mcd.clusters_with_metaclusters.index[selected_ix]
            metacluster = self.mcd.which_metacluster(cluster=selected_cluster)
        self.remap_current_selection(metacluster)


In [13]:
mcd = MetaClusterData("../data/example_dataset/metaclustering", "ex2_clusters_nozscore.csv", "ex2_clusters_pixelcount.csv", "ex2_output_mapping.csv")
mcg = MetaClusterGui(mcd, heatmapcolors='viridis', width=17)
#mcg.enable_debug_mode()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to  previous…

FloatSlider(value=3.0, description='Max Zscore:', max=10.0, min=1.0, readout_format='.1f', step=0.5)



Button(button_style='success', description='New metacluster', icon='plus', style=ButtonStyle(), tooltip='Creat…

# Developers Maintanence Guide

This guide is outline/work in progress

High level Architecture

- MetaClusterData - Data IO and state maintence
- MetaClusterGui - Handle the 

This metacluster remapper GUI uses the ipywidgets backend for Matplotlib. Addition widgets are used for extra input/output.

How to debug callbacks
https://ipywidgets.readthedocs.io/en/latest/examples/Output%20Widget.html#Debugging-errors-in-callbacks-with-the-output-widget

In [7]:
# Save out current notebook
!cp ./example_manually_adjust_metaclusters.ipynb ../data/example_dataset/metaclustering/example_manually_adjust_metaclusters.ipynb