In [1]:
%matplotlib widget

In [2]:
import ipywidgets as widgets
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from scipy.stats import zscore


In [3]:
plt.ion()

In [6]:
import numpy as np
class MetaClusterAdjust():
    def __init__(self):
        clusters_raw = pd.read_csv("../data/example_dataset/metaclustering/ex2_clusters_nozscore.csv").sort_values('cluster')
        self.cluster_counts = pd.read_csv("../data/example_dataset/metaclustering/ex2_clusters_pixelcount.csv").sort_values('cluster').set_index('cluster')
        self.clusters  = clusters_raw.set_index('cluster').drop(columns='hCluster_cap')
        self.mapping = clusters_raw[['cluster','hCluster_cap' ]].set_index('cluster')
        self.make_gui()
        self.update_gui()
    
    @property
    def metaclusters(self):
        # could cache this and expire on mapping change
        # return self.clusters.join(self.mapping).groupby('hCluster_cap').aggregate('mean') # naive average
        weighted_clusters = self.clusters.multiply(self.cluster_counts['count'], axis=0)
        metacluster_counts = self.cluster_counts.join(self.mapping).groupby('hCluster_cap').aggregate('sum')
        weighted_metaclusters = weighted_clusters.join(self.mapping).groupby('hCluster_cap').aggregate('sum').divide(metacluster_counts['count'], axis=0)
        return weighted_metaclusters
    
    def remap(self, cluster, metacluster):
        self.mapping.loc[cluster, 'hCluster_cap'] = metacluster
        
    @staticmethod
    def preplot(df):
        return df.apply(zscore).clip(upper=4).T
       
    def make_gui(self):
        # graph
        self.fig, (self.ax_c, self.ax_m) = plt.subplots(1,2,
            figsize=(16, 4),
            sharey=True,
            gridspec_kw={'width_ratios': [7, 2]}, # cluster plot bigger than metacluster plot
            )
        self.fig.canvas.toolbar_visible = False
        self.fig.canvas.header_visible = False
        #self.fig.canvas.footer_visible = False
        self.ax_c.yaxis.set_tick_params(which='major', labelsize=8)
        #self.fig.subplots_adjust(left=0.4)

        self.fig.canvas.mpl_connect('pick_event', self.onpick)
        
        self.ax_c.set_yticks(0.5 + np.arange(len(self.clusters.columns)))
        self.ax_c.set_yticklabels(self.clusters.columns)
        
        self.fig.tight_layout()
        
        self.im_c = self.ax_c.imshow(self.preplot(self.clusters), aspect='auto', picker=True, extent=(0, len(self.clusters), 0, len(self.clusters.columns)))
        self.im_m = self.ax_m.imshow(self.preplot(self.metaclusters), aspect='auto', picker=True, extent=(0, len(self.metaclusters), 0, len(self.metaclusters.columns)))
        #self.ax_m.set(xlim=(0, len(self.metaclusters)))
        #self.ax_m.set(xlim=(-5, len(self.metaclusters)+6))
        
        # selection widget
        self.selected_clusters = []
        self.selected_clusters_widget = widgets.Text()
        display(self.selected_clusters_widget)
        
    def update_gui(self):
        _, _, vstart, vend = self.im_m.get_extent()
        self.im_m.set_extent((0, len(self.metaclusters), vstart, vend))
        self.im_m.set_data(self.preplot(self.metaclusters))
        #self.fig.canvas.draw()

    def onpick(self, e):
        if e.artist is self.im_c:
            selected_cluster = int(e.mouseevent.xdata)+1
            if selected_cluster in self.selected_clusters:
                self.selected_clusters.remove(selected_cluster)
            else:
                self.selected_clusters.append(selected_cluster)
        elif e.artist is self.im_m:
            selected_metacluster = int(e.mouseevent.xdata)+1
            if len(self.selected_clusters) == 0:
                # grab selection based on metacluster
                self.selected_clusters = list(self.mapping[self.mapping['hCluster_cap'] == selected_metacluster].index.values)
                pass
            else:
                # remap to mc  based on current selection
                for cluster in self.selected_clusters:
                    mca.remap(cluster, selected_metacluster)
                mca.update_gui()
                self.selected_clusters = []
        else:
            self.selected_clusters = ['failed assertion']
        self.selected_clusters_widget.value = ', '.join(str(c) for c in self.selected_clusters)

mca = MetaClusterAdjust()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to  previous…

Text(value='')

In [7]:
mca.im_m.get_extent()

(0, 20, 0, 22)

In [8]:
!cp ./example_manually_adjust_metaclusters.ipynb ../data/example_dataset/metaclustering/example_manually_adjust_metaclusters.ipynb

cp: cannot stat './Manually adjust metaclusters.ipynb': No such file or directory
