# Linear Integral Sheaf Metric for MNIST classification

In this notebook, we implement the RG(B) image analysis experiment, where we classify images by feeding a $k$-NN clustering algorithm with an input distance matrix that results from computing pairwise linear integral sheaf distances between image filtrations, as well as linear integral sheaf amplitudes of image filtrations.

In [7]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [8]:
import numpy                 as np
from gtda.images import RadialFiltration
from gtda.images import HeightFiltration

from difftda                              import *
from gudhi.wasserstein                    import wasserstein_distance
from sklearn.cluster                      import KMeans
from tqdm import tqdm
import numpy as np
from keras.datasets import mnist
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [9]:
from helpers import *
from LISM import pipeline

## Loading the data

First, we load the data - these are hand-written $28\times 28$ digits of zeros and fours.

In [72]:
# loading the data
(train_X, train_y), (test_X, test_y) = mnist.load_data()

# we look at digits '0' and '8' only
train_filter = np.where((train_y == 6) | (train_y == 9))
test_filter = np.where((test_y == 0) | (test_y == 8))
train_X, train_y = train_X[train_filter]/255, train_y[train_filter]
test_X, test_y = test_X[test_filter]/255, test_y[test_filter]

# parameters
N = 50

# pick the first N images
images = train_X[:N]
labels = train_y[:N]

In a topological data analysis pipeline for MNIST data, it is common to binarize an image $I$ and assign a filtration $I_\mathrm{filt}$ to the binarized version $I_\mathrm{bin}$. There are various types of filtrations, each one bringing new information. We will look at *radial* and the *height* filtrations. To a given filtration $I_\mathrm{filt}$, one can associate a Wasserstein amplitude $A_W(I_\mathrm{filt})$. 

## Optimization visual summary 

We visualize the optimization process for the computation of a single LISM.

In [442]:
filtrations = [RadialFiltration(center=np.array([13,0])),
               HeightFiltration(direction=np.array([0,1])),
               HeightFiltration(direction=np.array([1,0]))]

In [443]:
I = multifiltration(images[0],filtrations)
J = multifiltration(images[1],filtrations)

In [444]:
data = np.array([I, J])

meta_data = [data]

pipe = pipeline(1, meta_data, 392, dims=[1])

opt = pipe.single_distance(1)

 12%|█▏        | 6/51 [00:14<01:46,  2.36s/it]


In [357]:
#plot_optim_filtration(opt, I, J)

In [445]:
plot_optim(opt,True)

### Estimating the LISM with a grid search

In [359]:
from LISM import grid_search, fast_grid_search

filtrations = [RadialFiltration(center=np.array([13,0])),
               HeightFiltration(direction=np.array([0,1]))]

I = multifiltration(images[0],filtrations)
J = multifiltration(images[1],filtrations)

In [360]:
p_opt = opt['projections'][-1]

In [361]:
dist, angle, fig = grid_search(I,J,0.1,True,1)

In [362]:
fig.show()

In [365]:
np.array([[np.cos(angle), np.sin(angle)]])-p_opt

array([[-0.00328686, -0.12391136],
       [ 0.1245471 ,  0.0039226 ]])

## Classification of MNIST data using the LISM method

For an image $I=(I_1,...,I_d)\in\mathbb{R}^{n\times n\times d}$ and a projection $p\in S^{d-1}$, one can obtain a projected image $I_p:=\sum_{i=1}^d p_iI_i\in\mathbb{R}^{n\times n}$. 

**Definition. (LISA)** One defines the *linear integral sheaf amplitude (LISA)* of $I$ as the maximum $$A_W(I):=\max_{p\in S^{d-1}}A_W(I_p).$$ 

For example, take $d=2$, with $I_1$ the $(3,7)$-radial filtration of an orginal image $M\in\mathbb{R}^{n\times n}$ and $I_2$ the $(1,0)$-height filtration of the image $M\in\mathbb{R}^{n\times n}$. 

**Definition. (LISM)** One defines the *linear integral sheaf metric (LISM)* between two images $I, J\in\mathbb{R}^{n\times n\times d}$ as the maximum
$$d_{\mathrm{ISM}}(I,J):=\max_{p\in S^{d-1}}d_W(I_p, J_p)$$ 

**Task.** We try to classify $N$ images given by a family of multifiltrations $\{I^i\}_{i=1}^{100}\subset\mathbb{R}^{n\times n\times d}$. 


### Method 2 : LISM distance matrix

**Method 2.** We compute a distance matrix $D=(d_{ij})_{i,j=1}^{100}$ where $d_{ij}=d_{\mathrm{ISM}}(I^i,I^j)$.

For the case of $2$-filtrations, we compute the LISM distance matrix via grid search.

### Method 3 : Wasserstein distance matrix 

**Method 3.** What about $1$-filtrations and standard persistent homology? One can compute a distance matrix $D=(d_{ij})_{i,j=1}^{100}$, where $d_{ij}=d_W(I^i,I^j)$, where each of the $N$ images $\mathrm{Im}\in\mathbb{R}^{n\times n}$ is assigned a filtration $\mathrm{Im}_\mathrm{filt}\in\mathbb{R}^{n\times n}$ (and not a multifiltration in $\mathbb{R}^{n\times n\times d}$) and $d_W$ is the $2$-Wasserstein distance.

### Results

In [73]:
multifilt = [RadialFiltration(center=np.array([13,0])),
             HeightFiltration(direction=np.array([0,1]))]

In [74]:
(D1, D2), D_sheaf, labels_bb = run_experiment(multifilt, images, labels)

0it [00:00, ?it/s]

Computing Wasserstein matrix for filtration no.1...


50it [00:40,  1.23it/s]
50it [00:38,  1.31it/s]
0it [00:00, ?it/s]

Computing Wasserstein matrix for filtration no.2...


50it [00:38,  1.32it/s]
50it [00:40,  1.24it/s]
0it [00:00, ?it/s]

Computing LISM matrix...


50it [03:08,  3.77s/it]
50it [02:29,  2.99s/it]


In [46]:
labels_bb

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6,
       6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 8, 8, 8, 8, 8, 8,
       8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8], dtype=uint8)

In [64]:
kmeans = KMeans(n_clusters=4, random_state=0)
kmeans.fit(D_sheaf)
pred = kmeans.labels_
pred

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0,
       0, 0, 0, 0, 3, 1, 0, 0, 1, 1, 1, 1, 0, 1, 3, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 3, 3, 2, 2, 3, 3, 3,
       2, 3, 0, 0, 0, 0, 3, 0, 3, 3, 1, 3], dtype=int32)

In [68]:
L = np.zeros_like(labels_bb)

for i in range(N):
    if labels_bb[i] == 0:
        L[i] = 0
    elif labels_bb[i] == 3:
        L[i] = 1
    elif labels_bb[i] == 6:
        L[i] = 2
    elif labels_bb[i] == 8:
        L[i] = 3

In [69]:
L

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], dtype=uint8)

In [70]:
res = L - pred
nb = np.count_nonzero(res==0)
nb/N

0.61

### Multi-dimensional Scaling (MDS)

In [75]:
MDS_analysis([D1, D2, D_sheaf], labels_bb)

In [77]:
heatmaps((D1,D2,D_sheaf))