In [1]:
%matplotlib notebook
import pandas as pd
import numpy as np
import scipy as sp
import plotly.plotly as py
import plotly.graph_objs as go
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import PIL.Image
import PIL
import ipywidgets
from ipywidgets import Image

import phate
import io
import sys
sys.path.append("..")
from blog_tools import data

In [2]:
dataset = data.digits()

In [3]:
def image_to_bytes(im, mode="L"):
    im = PIL.Image.fromarray(im, mode="L")
    imgByteArr = io.BytesIO()
    im.save(imgByteArr, format='PNG')
    imgByteArr = imgByteArr.getvalue()
    return imgByteArr

In [4]:
knn = [2, 3, 5, 10, 20]
decay = [5, 10, 25, 40, 100, 'inf']
gamma = [-1, 0, 1]
embeddings = {}
for k in knn:
    for d in decay:
        phate_op = phate.PHATE(knn=k, decay=None if d == 'inf' else d, verbose=0)
        for g in gamma:
            phate_op.set_params(gamma=g)
            embeddings[(k,d,g)] = phate_op.fit_transform(dataset.X)

In [5]:
Y = list(embeddings.values())[0]

In [6]:
axis_layout = dict(
        autorange=True,
        showgrid=False,
        zeroline=False,
        showline=True,
        ticks='',
        showticklabels=False
    )
fig = go.FigureWidget(
    data=[
        dict(
            type='scattergl',
            x=Y[:,0],
            y=Y[:,1],
            mode='markers',
        )
    ],
    layout=go.Layout(
        xaxis=axis_layout,
        yaxis=axis_layout
    )
)

In [7]:
fig.layout.title = dataset.name
fig.layout.titlefont.size
fig.layout.titlefont.size = 22
fig.layout.titlefont.family = 'Rockwell'
fig.layout.xaxis.title = 'PHATE1'
fig.layout.yaxis.title = 'PHATE2'
scatter = fig.data[0]
scatter.marker.opacity = 0.7
scatter.marker.size = 10
fig.layout.hovermode = 'closest'

In [8]:
contour = fig.add_histogram2dcontour(
    x=scatter.x, y=scatter.y, contours={'coloring':'lines'}, showscale=False)
contour.colorscale = [[i, "rgb({}, {}, {})".format(round(r*255), round(g*255), round(b*255))]
                        for i, (r,g,b) in zip(
                            np.linspace(0, 1, len(plt.cm.inferno.colors)), 
                            plt.cm.inferno.colors)]
contour.reversescale = True
contour.hoverinfo = 'skip'

In [9]:
def set_params(knn, decay, gamma):
    Y = embeddings[(knn, decay, gamma)]
    scatter.x = contour.x = Y[:,0]
    scatter.y = contour.y = Y[:,1]

def slider(values, name, default=None):
    if default is None:
        default=values[0]
    return ipywidgets.SelectionSlider(
        options=values,
        value=default,
        description=name,
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True
    )
    
from ipywidgets import interactive
params_slider = interactive(set_params,
                             knn=slider(knn, 'knn', default=5), 
                             decay=slider(decay, 'decay', default=40), 
                             gamma=slider(gamma, 'gamma', default=1))
params_slider

interactive(children=(SelectionSlider(continuous_update=False, description='knn', index=2, options=(2, 3, 5, 1…

In [10]:
params_slider.children[0].layout.width = '290px'
params_slider.children[1].layout.width = '290px'
params_slider.children[2].layout.width = '290px'

In [11]:
from ipywidgets import Image, Layout
image_widget = Image(
    value=image_to_bytes(dataset.X_raw[0]),
    layout=Layout(height='250px', width='250px')
)
display(image_widget)

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08\x08\x00\x00\x00\x00\xe1d\xe1…

In [12]:
def hover_fn(trace, points, state):

    ind = points.point_inds[0]
    
    # Update image widget
    image_widget.value = image_to_bytes(dataset.X_raw[ind])

scatter.on_hover(hover_fn)

In [13]:
from ipywidgets import HBox, VBox
dash = HBox([fig,
      VBox([image_widget, params_slider])])
display(dash)

HBox(children=(FigureWidget({
    'data': [{'marker': {'opacity': 0.7, 'size': 10},
              'mode': 'mar…