# Linear Integral Sheaf Metric for MNIST classification

In this notebook, we implement the image analysis experiment, where we classify $N$ images with $k$ labels by feeding a LISM distance matrix of size $N\times N$ to a $k$-NN clustering algorithm (the distance matrix results from computing pairwise linear integral sheaf distances between image filtrations).

In [146]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [147]:
import numpy                 as np
from gtda.images import RadialFiltration
from gtda.images import HeightFiltration

from difftda                              import *
from sklearn.cluster                      import KMeans
import numpy as np
from keras.datasets import mnist
import n_sphere

In [233]:
from helpers import *
from plot_helpers import *
from LISM import pipeline

## Loading the data

First, we load the data - these are hand-written $28\times 28$ digits of zeros and fours.

In [346]:
# loading the data
(train_X, train_y), (test_X, test_y) = mnist.load_data()

# we look at digits '0' and '8' only
train_filter = np.where((train_y == 9) | (train_y == 3) | (train_y == 6) | (train_y == 8))
test_filter = np.where((test_y == 0) | (test_y == 8))
train_X, train_y = train_X[train_filter]/255, train_y[train_filter]
test_X, test_y = test_X[test_filter]/255, test_y[test_filter]

# parameters
N = 100

# pick the first N images
images = train_X[:N]
labels = train_y[:N]

In a topological data analysis pipeline for MNIST data, it is common to binarize an image $I$ and assign a filtration $I_\mathrm{filt}$ to the binarized version $I_\mathrm{bin}$. There are various types of filtrations, each one bringing new information. We will look at *radial* and the *height* filtrations. To a given filtration $I_\mathrm{filt}$, one can associate a Wasserstein amplitude $A_W(I_\mathrm{filt})$. 

## Optimization visual summary 

We visualize the optimization process for the computation of a single LISM.

In [268]:
filtrations = [HeightFiltration(direction=np.array([0,1])),
              RadialFiltration(center=np.array([6,6]))]

In [283]:
I = multifiltration(images[7],filtrations)
J = multifiltration(images[16],filtrations)

In [284]:
pipe = pipeline(object=1, meta_data=[np.array((I,J))], card=392)

opt = pipe.single_distance(dim=1, p_init=[0.2,0.8])

50


In [360]:
plot_optim(opt)

In [286]:
opt['projections'][-1]

array([[0.7086699],
       [0.3745885]], dtype=float32)

In [237]:
compute_distance_cub(I,J,dim=1)

0.06549221944293494

### Estimating the LISM with a grid search

In [207]:
from LISM import grid_search, fast_grid_search

In [208]:
p_opt = opt['projections'][-1]

In [282]:
I = multifiltration(images[7],filtrations)
J = multifiltration(images[16],filtrations)

dist, angle, fig = grid_search(I,J,0.05,True,1)

fig.show()

## Classification of MNIST data using the LISM method

For an image $I=(I_1,...,I_d)\in\mathbb{R}^{n\times n\times d}$ and a projection $p\in S^{d-1}$, one can obtain a projected image $I_p:=\sum_{i=1}^d p_iI_i\in\mathbb{R}^{n\times n}$. 

**Definition. (LISA)** One defines the *linear integral sheaf amplitude (LISA)* of $I$ as the maximum $$A_W(I):=\max_{p\in S^{d-1}}A_W(I_p).$$ 

For example, take $d=2$, with $I_1$ the $(3,7)$-radial filtration of an orginal image $M\in\mathbb{R}^{n\times n}$ and $I_2$ the $(1,0)$-height filtration of the image $M\in\mathbb{R}^{n\times n}$. 

**Definition. (LISM)** One defines the *linear integral sheaf metric (LISM)* between two images $I, J\in\mathbb{R}^{n\times n\times d}$ as the maximum
$$d_{\mathrm{ISM}}(I,J):=\max_{p\in S^{d-1}}d_W(I_p, J_p)$$ 

**Task.** We try to classify $N$ images given by a family of multifiltrations $\{I^i\}_{i=1}^{100}\subset\mathbb{R}^{n\times n\times d}$. 


### Method 2 : LISM distance matrix

**Method 2.** We compute a distance matrix $D=(d_{ij})_{i,j=1}^{100}$ where $d_{ij}=d_{\mathrm{ISM}}(I^i,I^j)$.

For the case of $2$-filtrations, we compute the LISM distance matrix via grid search.

### Method 3 : Wasserstein distance matrix 

**Method 3.** What about $1$-filtrations and standard persistent homology? One can compute a distance matrix $D=(d_{ij})_{i,j=1}^{100}$, where $d_{ij}=d_W(I^i,I^j)$, where each of the $N$ images $\mathrm{Im}\in\mathbb{R}^{n\times n}$ is assigned a filtration $\mathrm{Im}_\mathrm{filt}\in\mathbb{R}^{n\times n}$ (and not a multifiltration in $\mathbb{R}^{n\times n\times d}$) and $d_W$ is the $2$-Wasserstein distance.

# Results

Initializing the multi-filtration :

In [347]:
filtrations = [RadialFiltration(center=np.array([13,13])),
              HeightFiltration(direction=np.array([0,1]))]

Computing the distance matrices. If no_multi=True, the method run_experiment_cub does not compute the LISM matrix.

In [348]:
D1_0, fig, images_bb, labels_bb = distance_matrix_cub(images, train_y, dim=0, filtrations=[filtrations[0]], step=0.05, suivi=False)

100it [02:55,  1.76s/it]


In [349]:
D2_0, fig, images_bb, labels_bb = distance_matrix_cub(images, train_y, dim=0, filtrations=[filtrations[1]], step=0.05, suivi=False)

100it [02:21,  1.42s/it]


In [350]:
D1_1, fig, images_bb, labels_bb = distance_matrix_cub(images, train_y, dim=1, filtrations=[filtrations[0]], step=0.05, suivi=False)

100it [05:52,  3.53s/it]


In [351]:
D2_1, fig, images_bb, labels_bb = distance_matrix_cub(images, train_y, dim=1, filtrations=[filtrations[1]], step=0.05, suivi=False)

100it [02:32,  1.53s/it]


In [352]:
D1 = np.maximum(D1_0,D1_1)
D2 = np.maximum(D2_0,D2_1)

In [330]:
# D_sliced_0, fig, images_bb, labels_bb = matrix_sliced_conv_dist_cub(images, labels, dim=0, filtrations=filtrations, n=100)

In [353]:
D_sheaf_0, fig, images_bb, labels_bb = distance_matrix_cub(images, train_y, dim=0, filtrations=filtrations, step=0.05, suivi=False)

100it [28:19, 17.00s/it]


In [354]:
D_sheaf_1, fig, images_bb, labels_bb = distance_matrix_cub(images, train_y, dim=1, filtrations=filtrations, step=0.05, suivi=False)

100it [42:43, 25.64s/it]


In [355]:
D_sheaf = np.maximum(D_sheaf_0,D_sheaf_1)

### Multi-dimensional Scaling (MDS)

In [356]:
MDS_analysis((D1, D2, D_sheaf), labels_bb)

### Heatmaps

In [359]:
heatmaps((D1, D2, D_sheaf))

In [358]:
accuracy((D1, D2,D_sheaf), labels_bb, 4)

[0.6, 0.73, 0.81]

In [262]:
inf_vals = np.count_nonzero(D_sheaf_0==np.inf)

if inf_vals>0:

    print("Re-assigning {} values in LISM matrix.".format(inf_vals))

    inf_indices = np.where(D_sheaf_0==np.inf)

    D_max = np.maximum(D1_0,D2_0,D3_0)

    D_sheaf_0[inf_indices] = D_max[inf_indices]

------------------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------------------------------

## Sliced Convolution Distance

In [311]:
D_sliced_0, fig, images_bb, labels_bb = matrix_sliced_conv_dist_cub(images, labels, 0, filtrations, n=100)

100it [58:26, 35.07s/it]


In [316]:
with open('data/distance_3_filtration_MNIST_dim1_sliced', 'rb') as f:
    
    D_sliced_1 = np.load(f)

In [317]:
D_sliced = np.maximum(D_sliced_0, D_sliced_1)