# Linear Integral Sheaf Metric for MNIST classification

In this notebook, we implement the RG(B) image analysis experiment, where we classify images by feeding a $k$-NN clustering algorithm with an input distance matrix that results from computing pairwise linear integral sheaf distances between image filtrations, as well as linear integral sheaf amplitudes of image filtrations.

In [103]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [104]:
import numpy                 as np
from gtda.images import RadialFiltration
from gtda.images import HeightFiltration

from difftda                              import *
from sklearn.cluster                      import KMeans
import numpy as np
from keras.datasets import mnist

In [105]:
from helpers import *
from plot_helpers import *
from LISM import pipeline

## Loading the data

First, we load the data - these are hand-written $28\times 28$ digits of zeros and fours.

In [106]:
# loading the data
(train_X, train_y), (test_X, test_y) = mnist.load_data()

# we look at digits '0' and '8' only
train_filter = np.where((train_y == 0) | (train_y == 2) | (train_y == 8) | (train_y == 9))
test_filter = np.where((test_y == 0) | (test_y == 8))
train_X, train_y = train_X[train_filter]/255, train_y[train_filter]
test_X, test_y = test_X[test_filter]/255, test_y[test_filter]

# parameters
N = 100

# pick the first N images
images = train_X[:N]
labels = train_y[:N]

In a topological data analysis pipeline for MNIST data, it is common to binarize an image $I$ and assign a filtration $I_\mathrm{filt}$ to the binarized version $I_\mathrm{bin}$. There are various types of filtrations, each one bringing new information. We will look at *radial* and the *height* filtrations. To a given filtration $I_\mathrm{filt}$, one can associate a Wasserstein amplitude $A_W(I_\mathrm{filt})$. 

## Optimization visual summary 

We visualize the optimization process for the computation of a single LISM.

In [142]:
filtrations = [RadialFiltration(center=np.array([13,0])),
               HeightFiltration(direction=np.array([0,1]))]

In [143]:
I = multifiltration(images[0],filtrations)
J = multifiltration(images[3],filtrations)

In [146]:
data = np.array([I, J])

meta_data = [data]

pipe = pipeline(1, meta_data, 392, dims=[1])

In [147]:
opt = pipe.single_distance(1,p_init=[0.99,0.11])

In [136]:
#plot_optim_filtration(opt, I, J)

In [148]:
plot_optim(opt)

### Estimating the LISM with a grid search

In [12]:
from LISM import grid_search, fast_grid_search

filtrations = [RadialFiltration(center=np.array([13,0])),
               HeightFiltration(direction=np.array([0,1]))]

I = multifiltration(images[0],filtrations)
J = multifiltration(images[1],filtrations)

In [360]:
p_opt = opt['projections'][-1]

In [61]:
dist, angle, fig = grid_search(I,J,0.01,True,1)

In [62]:
fig.show()

In [365]:
np.array([[np.cos(angle), np.sin(angle)]])-p_opt

array([[-0.00328686, -0.12391136],
       [ 0.1245471 ,  0.0039226 ]])

## Classification of MNIST data using the LISM method

For an image $I=(I_1,...,I_d)\in\mathbb{R}^{n\times n\times d}$ and a projection $p\in S^{d-1}$, one can obtain a projected image $I_p:=\sum_{i=1}^d p_iI_i\in\mathbb{R}^{n\times n}$. 

**Definition. (LISA)** One defines the *linear integral sheaf amplitude (LISA)* of $I$ as the maximum $$A_W(I):=\max_{p\in S^{d-1}}A_W(I_p).$$ 

For example, take $d=2$, with $I_1$ the $(3,7)$-radial filtration of an orginal image $M\in\mathbb{R}^{n\times n}$ and $I_2$ the $(1,0)$-height filtration of the image $M\in\mathbb{R}^{n\times n}$. 

**Definition. (LISM)** One defines the *linear integral sheaf metric (LISM)* between two images $I, J\in\mathbb{R}^{n\times n\times d}$ as the maximum
$$d_{\mathrm{ISM}}(I,J):=\max_{p\in S^{d-1}}d_W(I_p, J_p)$$ 

**Task.** We try to classify $N$ images given by a family of multifiltrations $\{I^i\}_{i=1}^{100}\subset\mathbb{R}^{n\times n\times d}$. 


### Method 2 : LISM distance matrix

**Method 2.** We compute a distance matrix $D=(d_{ij})_{i,j=1}^{100}$ where $d_{ij}=d_{\mathrm{ISM}}(I^i,I^j)$.

For the case of $2$-filtrations, we compute the LISM distance matrix via grid search.

### Method 3 : Wasserstein distance matrix 

**Method 3.** What about $1$-filtrations and standard persistent homology? One can compute a distance matrix $D=(d_{ij})_{i,j=1}^{100}$, where $d_{ij}=d_W(I^i,I^j)$, where each of the $N$ images $\mathrm{Im}\in\mathbb{R}^{n\times n}$ is assigned a filtration $\mathrm{Im}_\mathrm{filt}\in\mathbb{R}^{n\times n}$ (and not a multifiltration in $\mathbb{R}^{n\times n\times d}$) and $d_W$ is the $2$-Wasserstein distance.

# Results

Initializing the multi-filtration :

In [58]:
filtrations = [RadialFiltration(center=np.array([6,20])),
              HeightFiltration(direction=np.array([0,1])),
              RadialFiltration(center=np.array([6,6]))]

Computing the distance matrices. If no_multi=True, the method run_experiment_cub does not compute the LISM matrix.

In [59]:
(D1, D2, D3), labels_bb = run_experiment_cub(filtrations, images, labels, no_multi=True)

0it [00:00, ?it/s]

Computing Wasserstein matrix for filtration no.1...


100it [02:49,  1.70s/it]
100it [02:40,  1.60s/it]
0it [00:00, ?it/s]

Computing Wasserstein matrix for filtration no.2...


100it [02:30,  1.50s/it]
100it [02:48,  1.68s/it]
0it [00:00, ?it/s]

Computing Wasserstein matrix for filtration no.3...


100it [02:36,  1.57s/it]
100it [02:16,  1.36s/it]


### Multi-dimensional Scaling (MDS)

In [60]:
MDS_analysis([D1, D2, D3], labels_bb)

### Heatmaps

In [61]:
heatmaps((D1,D2,D3))

### Performance

In [128]:
accuracy((D1,D2,D3), labels_bb, 4)

[0.56, 0.65, 0.72]

### LISM

In [None]:
data = np.array([multifiltration(img, filtrations) for img in images[:3]])

meta_data = [data]

pipe = pipeline(1, meta_data, 392, dims=[1])

In [None]:
D_sheaf_1 = pipe.distance_matrix(1)

In [None]:
D_sheaf_0 = pipe.distance_matrix(1)

In [None]:
D_sheaf = np.maximum(D_sheaf_1, D_sheaf_0)