load data
plot data

In [1]:
import numpy as np
import pickle
import pandas as pd
import time
from umap import UMAP

In [2]:
from tfumap.paths import ensure_dir, MODEL_DIR, DATA_DIR

In [3]:
from tfumap.paths import FIGURE_DIR, save_fig

In [4]:
save_loc = DATA_DIR / 'knn_classifier' 

In [5]:
datasets = [
    'cassins_dtw',
    'cifar10',
    'fmnist',
    'macosko2015',
    'mnist'
]

In [6]:
projection_speeds = pd.concat([pd.read_pickle(save_loc / (dataset + '.pickle')) for dataset in datasets])
projection_speeds[:3]

Unnamed: 0,method_,dimensions,dataset,1NN_acc,5NN_acc
0,network,2,cassins_dtw,0.991,0.995
1,network,64,cassins_dtw,0.988,0.991
2,autoencoder,2,cassins_dtw,0.989,0.993


In [7]:
# load parametric tsne, vae, ae
for dataset in datasets:
    for n_components in ['2', '64']:
        save_loc =  DATA_DIR / 'knn_classifier' / str(n_components) / (dataset + '.pickle')
        try:
            metric_df = pd.read_pickle(save_loc)
            projection_speeds = pd.concat([projection_speeds, metric_df])
        except FileNotFoundError:
            print(save_loc)
            
        vae_save_loc = (
            DATA_DIR
            / "knn_classifier"
            / 'vae'
            / "train"
            / str(n_components)
            / (dataset + ".pickle")
        )
        try:
            metric_df = pd.read_pickle(vae_save_loc)
            # display(metric_df)
            projection_speeds = pd.concat([projection_speeds, metric_df])
        except FileNotFoundError:
            print(vae_save_loc)
            
        ae_save_loc = (
            DATA_DIR
            / "knn_classifier"
            / 'ae_only'
            / "train"
            / str(n_components)
            / (dataset + ".pickle")
        )
        try:
            metric_df = pd.read_pickle(ae_save_loc)
            # display(metric_df)
            projection_speeds = pd.concat([projection_speeds, metric_df])
        except FileNotFoundError:
            print(ae_save_loc)

In [9]:
projection_speeds

Unnamed: 0,method_,dimensions,dataset,1NN_acc,5NN_acc
0,network,2,cassins_dtw,0.9910,0.9950
1,network,64,cassins_dtw,0.9880,0.9910
2,autoencoder,2,cassins_dtw,0.9890,0.9930
3,autoencoder,64,cassins_dtw,0.9940,0.9970
4,umap-learn,2,cassins_dtw,0.9860,0.9890
...,...,...,...,...,...
0,vae,2,mnist,0.7241,0.7649
0,ae_only,2,mnist,0.7647,0.7926
0,parametric-tsne,64,mnist,0.9697,0.9734
0,vae,64,mnist,0.9785,0.9791


In [12]:
metrics_df = projection_speeds[["method_","dimensions","dataset","1NN_acc"]].set_index(['dataset', 'dimensions'])
metrics_df = metrics_df.pivot_table(
    index=["dataset", "dimensions"],
    columns="method_",
    values="1NN_acc",
    aggfunc="first",
)
metrics_df 

Unnamed: 0_level_0,method_,PCA,TSNE,ae_only,autoencoder,network,parametric-tsne,umap-learn,vae
dataset,dimensions,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
cassins_dtw,2,0.626,0.988,0.874,0.989,0.991,0.986,0.986,0.73
cassins_dtw,64,0.995,,0.995,0.994,0.988,0.995,0.985,0.98
cifar10,2,0.1436,0.2457,0.1696,0.1592,0.1512,0.1675,0.1689,0.1665
cifar10,64,0.3829,,0.379,0.2223,0.2139,0.3426,0.2375,0.3949
fmnist,2,0.4467,0.7825,0.6816,0.7083,0.6941,0.6834,0.7144,0.6646
fmnist,64,0.8398,,0.8671,0.7772,0.7431,0.83,0.7682,0.8747
macosko2015,2,0.808525,0.971658,0.94287,0.858067,0.964294,0.966079,0.966525,0.854497
macosko2015,64,0.975898,,0.975006,0.957599,0.968311,0.977237,0.972104,0.966972
mnist,2,0.3765,0.9411,0.7647,0.9403,0.9402,0.9118,0.9317,0.7241
mnist,64,0.9707,,0.9748,0.9481,0.9518,0.9697,0.9449,0.9785


In [14]:
def can_float(x):
    try:
        float(x)
        if np.isnan(float(x)):
            return False
        return True
    except:
        return False

In [15]:
metric_string = (
    metrics_df[["TSNE", 'parametric-tsne', "umap-learn", "network", "autoencoder", "ae_only", "vae", "PCA"]]
    .round(4)
    .to_latex()
    .replace("cassins\_dtw", "Cassin's")
    .replace("cifar10", "CIFAR10")
    .replace("fmnist", "FMNIST")
    .replace("mnist", "MNIST")
    .replace("macosko2015", "Retina")
    .replace("autoencoder", "UMAP/AE")
    .replace("ae\_only", "AE")
    .replace("network", "P. UMAP")
    .replace("umap-learn", "UMAP")
    .replace("vae", "VAE")
    .replace("pca", "PCA")
    .replace("parametric-tsne", "P. t-SNE")
    .replace("TSNE", "t-SNE")
    .replace("NaN", "-")
)

In [16]:
lines = metric_string.split('\n')
skip = 1
for line in lines:
    line_elements = line.split(' ')
    floatables = [can_float(le) for (le) in line_elements]
    floats = [float(j) for i, j in zip(floatables, line_elements) if i]
    if len(floats)> 1:
        best = np.argmax(floats[skip:])
        replace_element = np.where(floatables)[0][skip + best]
        line_elements[replace_element] = '\\textbf{' +line_elements[replace_element]+ '}'
    print(' '.join(line_elements))

\begin{tabular}{llrrrrrrrr}
\toprule
      & method\_ &    t-SNE &  P. t-SNE &  UMAP &  P. UMAP &  UMAP/AE &  AE &     VAE &     PCA \\
dataset & dimensions &         &                  &             &          &              &          &         &         \\
\midrule
Cassin's & 2  &  0.9880 &           0.9860 &      0.9860 &   \textbf{0.9910} &       0.9890 &   0.8740 &  0.7300 &  0.6260 \\
      & 64 &     - &           \textbf{0.9950} &      0.9850 &   0.9880 &       0.9940 &   0.9950 &  0.9800 &  0.9950 \\
CIFAR10 & 2  &  \textbf{0.2457} &           0.1675 &      0.1689 &   0.1512 &       0.1592 &   0.1696 &  0.1665 &  0.1436 \\
      & 64 &     - &           0.3426 &      0.2375 &   0.2139 &       0.2223 &   0.3790 &  \textbf{0.3949} &  0.3829 \\
FMNIST & 2  &  \textbf{0.7825} &           0.6834 &      0.7144 &   0.6941 &       0.7083 &   0.6816 &  0.6646 &  0.4467 \\
      & 64 &     - &           0.8300 &      0.7682 &   0.7431 &       0.7772 &   0.8671 &  \textbf{0.8747} &  0.8

In [18]:
metrics_df = projection_speeds[["method_","dimensions","dataset","5NN_acc"]].set_index(['dataset', 'dimensions'])
metrics_df = metrics_df.pivot_table(
    index=["dataset", "dimensions"],
    columns="method_",
    values="5NN_acc",
    aggfunc="first",
)


metric_string = (
    metrics_df[["TSNE", 'parametric-tsne', "umap-learn", "network", "autoencoder", "ae_only", "vae", "PCA"]]
    .round(4)
    .to_latex()
    .replace("cassins\_dtw", "Cassin's")
    .replace("cifar10", "CIFAR10")
    .replace("fmnist", "FMNIST")
    .replace("mnist", "MNIST")
    .replace("macosko2015", "Retina")
    .replace("autoencoder", "UMAP/AE")
    .replace("ae\_only", "AE")
    .replace("network", "P. UMAP")
    .replace("umap-learn", "UMAP")
    .replace("vae", "VAE")
    .replace("pca", "PCA")
    .replace("parametric-tsne", "P. t-SNE")
    .replace("TSNE", "t-SNE")
    .replace("NaN", "-")
)

lines = metric_string.split('\n')
skip = 1
for line in lines:
    line_elements = line.split(' ')
    floatables = [can_float(le) for (le) in line_elements]
    floats = [float(j) for i, j in zip(floatables, line_elements) if i]
    if len(floats)> 1:
        best = np.argmax(floats[skip:])
        replace_element = np.where(floatables)[0][skip + best]
        line_elements[replace_element] = '\\textbf{' +line_elements[replace_element]+ '}'
    print(' '.join(line_elements))

\begin{tabular}{llrrrrrrrr}
\toprule
      & method\_ &    t-SNE &  P. t-SNE &  UMAP &  P. UMAP &  UMAP/AE &  AE &     VAE &     PCA \\
dataset & dimensions &         &                  &             &          &              &          &         &         \\
\midrule
Cassin's & 2  &  0.9910 &           0.9930 &      0.9890 &   \textbf{0.9950} &       0.9930 &   0.9090 &  0.7740 &  0.6910 \\
      & 64 &     - &           0.9950 &      0.9860 &   0.9910 &       \textbf{0.9970} &   0.9930 &  0.9880 &  0.9920 \\
CIFAR10 & 2  &  \textbf{0.2608} &           0.2017 &      0.1936 &   0.1722 &       0.1833 &   0.2007 &  0.1941 &  0.1503 \\
      & 64 &     - &           0.3556 &      0.2694 &   0.2519 &       0.2477 &   0.3728 &  \textbf{0.3777} &  0.3769 \\
FMNIST & 2  &  \textbf{0.8039} &           0.7361 &      0.7608 &   0.7407 &       0.7561 &   0.7339 &  0.7161 &  0.5055 \\
      & 64 &     - &           0.8479 &      0.8059 &   0.7878 &       0.8028 &   0.8756 &  \textbf{0.8830} &  0.8