# Linear Integral Sheaf Amplitudes (LISAs) 

In this notebook, we implement the RG(B) image analysis experiment, where we classify images by feeding a $k$-NN clustering algorithm with an input distance matrix that results from computing pairwise linear integral sheaf distances between image filtrations, as well as linear integral sheaf amplitudes of image filtrations.

In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [155]:
import numpy                 as np
import tensorflow            as tf
import gudhi                 as gd

from difftda                              import *
from gudhi.wasserstein                    import wasserstein_distance
from sklearn.cluster                      import KMeans
from tqdm import tqdm
import numpy as np
from keras.datasets import mnist
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [228]:
from helpers import *

### Classification of MNIST data with standard TDA

First, we load the data - these are hand-written $28\times 28$ digits of zeros and fours.

In [157]:
# loading the data
(train_X, train_y), (test_X, test_y) = mnist.load_data()

# we look at digits '0' and '8' only
train_filter = np.where((train_y == 0 ) | (train_y == 8))
test_filter = np.where((test_y == 0) | (test_y == 8))
train_X, train_y = train_X[train_filter]/255, train_y[train_filter]
test_X, test_y = test_X[test_filter]/255, test_y[test_filter]

# parameters
N = 60

# pick the first N images
images = train_X[:N]
labels = train_y[:N]

In a topological data analysis pipeline for MNIST data, it is common to binarize an image $I$ and assign a filtration $I_\mathrm{filt}$ to the binarized version $I_\mathrm{bin}$. There are various types of filtrations, each one bringing new information. We will look at *radial* and the *height* filtrations. To a given filtration $I_\mathrm{filt}$, one can associate a Wasserstein amplitude $A_W(I_\mathrm{filt})$. 

In [158]:
ind0 = np.argwhere(labels==0)
ind8 = np.argwhere(labels==8)
s1 = ind0.shape[0]
s2 = ind8.shape[0]

images0 = images[ind0]
images8 = images[ind8]

In [218]:
from LISM import *

In [160]:
img1 = images[0]
img2 = images[1]

### Illustration of the pipeline for computing a single distance $d_{\mathrm{ISM}}(I,J)$.

In [174]:
img1 = images[0]
img2 = images[1]

I, plot1 = bifiltration(img1,[13,13], [0,1])
J, plot2 = bifiltration(img2,[13,13], [0,1])

In [175]:
imgs_ = np.vstack((plot1,plot2)).reshape(8,28,28)
fig = px.imshow(imgs_, facet_col=0, facet_col_wrap=4)
fig.update_layout(height=400, width=800,title="Illustrating the filtrations")
names = ["original image", "binarized filtration", "radial filtration", "height filtration"]
for i, name in enumerate(names):
    fig.layout.annotations[i]['text'] = name
fig.show()

In [252]:
# initialiting the projection 
theta = np.random.rand()*np.pi/2
p = tf.Variable(initial_value=np.array([np.cos(theta),np.sin(theta)]).reshape(2,1), dtype=np.float32, trainable=True)

# creating the model
model = CubicalModel_ISM(p, I, J, dim=1, card=392)

In [177]:
# dist = LISM_optimization(model, fast=True)

In [254]:
optimization = LISM_optimization(model)

 22%|██▏       | 11/51 [00:05<00:20,  1.93it/s]


In [221]:
p_opt = optimization['projections'][-1]
I_opt = tf.reshape(tf.tensordot(I,p_opt,1),shape=[28,28])
J_opt = tf.reshape(tf.tensordot(J,p_opt,1),shape=[28,28])

In [222]:
imgs_opt  = np.vstack((I_opt,J_opt)).reshape(2,28,28)
fig = px.imshow(imgs_opt, facet_col=0, facet_col_wrap=2)
fig.update_layout(height=290, width=600, title="Optimal LISM arguments (optimized bifiltrations)")
fig.show()

In [255]:
plot_optim(optimization,True)

In [212]:
projections = optimization['projections']
norms = [tf.norm(p) for p in projections]
diffs = [(norms[i+1]-norms[i]).numpy() for i in range(len(norms)-1)]

### Estimating the LISM with a grid search

In [183]:
dist, angle, fig = grid_search(I,J,True,1)

In [184]:
fig.show()

In [185]:
p_opt/np.linalg.norm(p_opt)

array([[0.5538462],
       [0.8326191]], dtype=float32)

In [186]:
np.array([[np.cos(angle), np.sin(angle)]])

array([[0.55702255, 0.83049737]])

### Classification of MNIST data using the LISM method

For an image $I=(I_1,I_2)\in\mathbb{R}^{n\times n\times 2}$ and a projection $p\in S^1$, one can obtain a projected image $I_p:=p_1I_1+p_2I_2\in\mathbb{R}^{n\times n}$. 

**Definition. (LISA)** One defines the *linear integral sheaf amplitude (LISA)* of $I$ as the maximum $$A_W(I):=\max_{p\in S^1}A_W(I_p).$$ 

For example, $I_1$ can be the $(3,7)$-radial filtration of an orginal image $M\in\mathbb{R}^{n\times n}$ and $I_2$ can be the $(1,0)$-height filtration of the image $M\in\mathbb{R}^{n\times n}$. 

**Definition. (LISM)** One defines the *linear integral sheaf metric (LISM)* between two images $I, J\in\mathbb{R}^{n\times n\times 2}$ as the maximum
$$d_{\mathrm{ISM}}(I,J):=\max_{p\in S^1}d_W(I_p, J_p)$$ 

**Task.** We try to classify $N$ images given by a family of bi-filtrations $\{I^i\}_{i=1}^{100}\subset\mathbb{R}^{n\times n\times 2}$. 

**Method 1.** We compute an amplitude vector $A=(a_{i})_{i=1}^{100}$ where $a_{ij}=A_W(I^i)$.

**Method 2.** We compute a distance matrix $D=(d_{ij})_{i,j=1}^{100}$ where $d_{ij}=d_{\mathrm{ISM}}(I^i,I^j)$.

**Method 3.** One can compare the LISM method with a one-dimensional persistence method. More precisely, one can compute a distance matrix $D=(d_{ij})_{i,j=1}^{100}$, where $d_{ij}=d_W(I^i,I^j)$, where each of the $N$ images $\mathrm{Im}\in\mathbb{R}^{n\times n}$ is assigned a filtration $\mathrm{Im}_\mathrm{filt}\in\mathbb{R}^{n\times n}$ (and not a bi-filtration in $\mathbb{R}^{n\times n\times 2}$) and $d_W$ is the $2$-Wasserstein distance.


### Method 2 : LISM distance matrix

In [124]:
one_dist_time = 1.5 # 8 seconds
total_entries = N*(N-1)/2
computation_time = total_entries * one_dist_time
print("Computation of distance matrix will take approx. {} minutes.".format(computation_time/60))

Computation of distance matrix will take approx. 44.25 minutes.


In [146]:
D = np.zeros((N,N))

for i, img1 in tqdm(enumerate(images)):
    for j, img2 in enumerate(images):
        if i<j:
            I, plot1 = bifiltration(img1,[13,13],[1,0])
            J, plot2 = bifiltration(img2, [13,13],[1,0])
            D[i,j] = fast_grid_search(I.numpy(),J.numpy(),0)

D += np.transpose(D)

60it [35:00, 35.00s/it]


In [147]:
get_accuracy(D, N, train_y)

0.6333333333333333

In [148]:
dist_matrix_hist(D,N,train_y,2)

In [141]:
dist_matrix_hist(D_radial,N,train_y,1)

### Method 3 : distance matrix of pairwise $2$-Wasserstein distances (single filtration)

In [136]:
D_height = wasserstein_matrix(images, 0, "height", [1,0])

60it [00:16,  3.74it/s]


In [139]:
D_radial = wasserstein_matrix(images, 0, "radial", [13,13])

60it [00:16,  3.65it/s]


In [145]:
get_accuracy(D_height, N, train_y), get_accuracy(D_radial, N, train_y)

(0.6333333333333333, 0.65)