# Plot Latent Space and Expand Data Point Details
This notebook allows to plot the latent space of the VAE model in an interactive way. While mouse cursor hovers over the data point, it will pop up the box with spectrum image and associated information.

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
sns.set(style='white', context='notebook', rc={'figure.figsize':(24,20)})

import torchvision as tv
from specvae.dataset import MoNA
import specvae.vae as vae, specvae.utils as utils
import specvae.dataset as dt

import plotly.express as px
import pandas as pd

In [None]:
use_cuda = False
cpu_device = torch.device('cpu')
if torch.cuda.is_available() and use_cuda:
    device = torch.device('cuda:0')
    print('GPU device count:', torch.cuda.device_count())
else:
    device = torch.device('cpu')
print('Device in use: ', device)

In [None]:
# Processing parameters:
dataset = 'MoNA' # HMDB and MoNA
model_name = 'alt_specvae_2000-1538-30-1538-2000 (28-06-2021_14-05-29)'
datapoints = 10000

In [None]:
print("Load data")
filename = "%s-%s.npz" % (dataset, model_name)
filepath = utils.get_project_path() / '.data' / 'latent' / filename
b = np.load(filepath, allow_pickle=True)
X, mode, energy, tax, ids, classes, spectra, images = b['X'], b['mode'], b['energy'], b['tax'], b['ids'], b['classes'], b['spectra'], b['imgs']
energy[energy < 0] = 0
log_energy = energy # np.log(np.nan_to_num(energy) + 1)

In [None]:
X.shape

In [None]:
energy

In [None]:
if datapoints < X.shape[0]:
    X, mode, energy, tax, ids, classes, spectra, images, log_energy = X[:datapoints], mode[:datapoints], energy[:datapoints], \
        tax[:datapoints], ids[:datapoints], classes[:datapoints], spectra[:datapoints], images[:datapoints], log_energy[:datapoints]

In [None]:
colors = np.array(list(map(lambda x: 'negative' if x==0 else 'positive', mode)))
# colors = classes[taxrs

In [None]:
colors.shape

In [None]:
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper, LinearColorMapper, ColorBar, TapTool
from bokeh.palettes import Spectral10
from bokeh.palettes import RdBu3
from bokeh.models.callbacks import CustomJS

output_notebook()

In [None]:
def plot_hover(d, title=""):
    c1 = RdBu3[2] # red
    c2 = RdBu3[0] # blue
    ds = {
        'x': d[:, 0],
        'y': d[:, 1],
        'image': images,
        'energy': energy,
        'log_energy': log_energy,
        'mode': mode,
        'ids': ids,
        'label': colors,
        'tax': [str(t) for t in tax],
        'color': list(map(lambda x: c1 if x==0 else c2, mode))
    }
    datasource = ColumnDataSource(ds)
#     color_mapping = CategoricalColorMapper(
# #         factors=[str(x) for x in range(colors.max())], palette=Spectral10
#         factors=[str(x) for x in colors], palette=Spectral10
#     )
#     color_mapping = LinearColorMapper(
#         palette='Turbo256',
#         low=log_energy.min(),
#         high=log_energy.max()
#     )
    plot_figure = figure(
        title=title,
        plot_width=1000,
        plot_height=1000)
    plot_figure.add_tools(HoverTool(tooltips="""
    <div>
        <div>
            <img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
        </div>
        <div>
            <span style='font-size: 11px; color: #224499'>ID:</span>
            <span style='font-size: 11px'>@ids</span>
        </div>
        <div>
            <span style='font-size: 11px; color: #224499'>Mode:</span>
            <span style='font-size: 11px'>@mode</span>
        </div>
        <div>
            <span style='font-size: 11px; color: #224499'>Energy:</span>
            <span style='font-size: 11px'>@energy</span>
        </div>
    </div>
    """))
    
    cb_click = CustomJS(args=dict(source=datasource), code="""
        var data = source.data
        var selected = source.selected.indices
        var select_inds = []
        var prop_name = 'ids'
        if(selected.length == 1){
            // only consider case where one glyph is selected by user
            var selected_prop = data[prop_name][selected[0]]
            console.log(data)
            console.log(selected)
            for (var i = 0; i < data['x'].length; ++i){
                if(data[prop_name][i] == selected_prop){
                    // add all points to selected if their ids coincide with original
                    // glyph that was clicked.
                    data['color'][i] = '#ff0000'
                    select_inds.push(i)
                }
            }
        }
        source.selected.indices = select_inds 
        source.change.emit();
    """)
    plot_figure.add_tools(TapTool(callback=cb_click))
    
    plot_figure.circle(
        'x',
        'y',
        source=datasource,
        color='color',
#         color=dict(field='label', transform=color_mapping),
        legend_group='label',
        line_alpha=0.6,
        fill_alpha=0.6,
        size=6)
#     color_bar = ColorBar(color_mapper=color_mapping, label_standoff=12)
#     plot_figure.add_layout(color_bar, 'right')
    show(plot_figure)
    

In [None]:
pca_comp = 2
print("Compute PCA for n_components=%d" % pca_comp)
r = PCA(pca_comp)
pdata = r.fit_transform(X)

print("PCA:")
print("\t      explained_variance:", r.explained_variance_)
print("\texplained_variance_ratio:", r.explained_variance_ratio_)

In [None]:
plot_hover(pdata)

In [None]:
pca_comp = 10
print("Compute PCA for n_components=%d" % pca_comp)
red = PCA(pca_comp)
data = red.fit_transform(X)
# data = X

print("PCA:")
print("\t      explained_variance:", red.explained_variance_)
print("\texplained_variance_ratio:", red.explained_variance_ratio_)

In [None]:
n_components = 2
print("Compute tSNE for n_components=%d" % n_components)
r = TSNE(n_components)
tdata = r.fit_transform(data)

print("TSNE:")
print("\t      kl_divergence:", r.kl_divergence_)

In [None]:
plot_hover(tdata)

In [None]:
def draw_umap(n_neighbors=15, min_dist=0.1, n_components=2, metric='euclidean', title=''):
    fit = UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        n_components=n_components,
        metric=metric
    )
    u = fit.fit_transform(data)
    if n_components == 1:
        fig = px.scatter(u, x=0, y=1, color=colors, template='plotly_white', hover_data={'ionization mode': mode, 'collision energy': energy, 'InChIKey': ids}, title=title)
        fig.show()
        # plot_hover(u)
    if n_components == 2:
        plot_hover(u, title="UMAP n={} dist={} metric={}".format(n_neighbors, min_dist, metric))
    if n_components == 3:
        fig = px.scatter_3d(u, x=0, y=1, z=2, color=log_energy, template='plotly_white', 
                            hover_data={'ionization mode': mode, 'collision energy': energy, 'InChIKey': ids}, title=title, 
                            width=1000, height=1000, color_continuous_scale=px.colors.sequential.Turbo)
        fig.show()

In [None]:
draw_umap(min_dist=0.5, title='min_dist = {}'.format(0.25))

In [None]:
m = "correlation"
name = m if type(m) is str else m.__name__
draw_umap(n_components=2, metric=m, title='metric = {}'.format(name))

In [None]:
draw_umap(n_components=3, metric='euclidean', title='metric = euclidean')