In [6]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

# Linear Integral Sheaf Amplitudes (LISAs) 

In this notebook, we implement the RG(B) image analysis experiment, where we classify images by feeding a $k$-NN clustering algorithm with an input distance matrix that results from computing pairwise linear integral sheaf distances between image filtrations, as well as linear integral sheaf amplitudes of image filtrations.

In [7]:
import numpy                 as np
import tensorflow            as tf
import gudhi                 as gd
import gudhi.representations as sktda

from difftda                              import *
from gudhi.representations.kernel_methods import SlicedWassersteinKernel as swk
from gudhi.wasserstein                    import wasserstein_distance
from gudhi.representations                import pairwise_persistence_diagram_distances as ppdd
from sklearn.cluster                      import KMeans
from tqdm import tqdm
import plotly.graph_objects as go
import numpy as np
from keras.datasets import mnist
from scipy.cluster.hierarchy import dendrogram, linkage
from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import binarize
import plotly.express as px
from skimage import io

In [8]:
from helpers import *
from plotly.subplots import make_subplots

### Classification of MNIST data with cubical Homology 

First, we load the data - these are hand-written $28\times 28$ digits of zeros and fours.

In [9]:
# loading the data
(train_X, train_y), (test_X, test_y) = mnist.load_data()

# we look at digits '0' and '8' only
train_filter = np.where((train_y == 0 ) | (train_y == 8))
test_filter = np.where((test_y == 0) | (test_y == 8))
train_X, train_y = train_X[train_filter]/255, train_y[train_filter]
test_X, test_y = test_X[test_filter]/255, test_y[test_filter]

# parameters
N = 1000

# pick the first N images
images = train_X[:N]
labels = train_y[:N]

# initialize distance matrix
D = np.zeros((N,N))
# initialize amplitude lists
Ag = []
Ar = []
Ai = []

# hyper-parameter
alpha=0.01

In a topological data analysis pipeline for MNIST data, it is common to binarize an image $I$ and assign a filtration $I_\mathrm{filt}$ to the binarized version $I_\mathrm{bin}$. There are various types of filtrations, each one bringing new information. We will look at *radial* and the *height* filtrations. To a given filtration $I_\mathrm{filt}$, one can associate a Wasserstein amplitude $A_W(I_\mathrm{filt})$. 

### Computing amplitude vectors

We consider the gray-scale, $(3,7)$-radial and $(0,1)$-height filtrations.

In [11]:
# wasserstein amplitudes with gray-scale filtration
Ag = wass_amplitudes(images, alpha, dim=1)
Ag = Ag/max(Ag)

# wasserstein amplitudes with radial filtration
images_r = list(map(rad_filt,images))
Ar = wass_amplitudes(images_r, alpha, dim=1)
Ar = Ar/max(Ar)

# wasserstein amplitudes with height filtration
images_h = list(map(hei_filt,images))
Ah = wass_amplitudes(images_h, alpha, dim=1)
Ah = Ah/max(Ah)

100%|██████████| 1000/1000 [00:04<00:00, 226.68it/s]
100%|██████████| 1000/1000 [00:02<00:00, 378.90it/s]
100%|██████████| 1000/1000 [00:03<00:00, 322.56it/s]


### Point-clouds of amplitude vectors

First, we plot a 3D scatterplot of the three amplitude vectors.

In [26]:
fig = go.Figure(data=[go.Scatter3d(x=Ag, y=Ar, z=Ah,
                                   mode='markers')])
fig.update_layout(
    title="3D scatterplot of amplitude vectors",
    height=350,
    width=700
)

fig.show()

Second, we look at 2D scatter-plots of pairwise amplitude vectors.

In [12]:
fig = make_subplots(rows=1, cols=3)
fig.add_trace(go.Scatter(x=Ag,y=Ar,mode="markers"),row=1,col=1)
fig.add_trace(go.Scatter(x=Ag,y=Ah,mode="markers"),row=1,col=2)
fig.add_trace(go.Scatter(x=Ar,y=Ah,mode="markers"),row=1,col=3)
fig.update_xaxes(title_text="Gray-scale amplitudes", row=1, col=1)
fig.update_xaxes(title_text="Gray-scale amplitudes", row=1, col=2)
fig.update_xaxes(title_text="Radial amplitudes", row=1, col=3)
fig.update_yaxes(title_text="Radial amplitudes", row=1, col=1)
fig.update_yaxes(title_text="Height amplitudes", row=1, col=2)
fig.update_yaxes(title_text="Height amplitudes", row=1, col=3)
fig.update_layout(
    title="Point-clouds of pairwise amplitude vectors",
    height=400,
    width=1000
)
fig.show()

One can also compute the pairwise $2$-Wasserstein distances.

In [311]:
for i,I in tqdm(enumerate(images_r)):
    for j,J in enumerate(images_h):

        if i<=j :

            cc1 = gd.CubicalComplex(dimensions=I.shape, top_dimensional_cells=I.flatten())
            cc2 = gd.CubicalComplex(dimensions=J.shape, top_dimensional_cells=J.flatten())
            dgm1 = np.array([p[1] for p in cc1.persistence() if p[0]==1])
            dgm2 = np.array([p[1] for p in cc2.persistence() if p[0]==1])

            d = alpha * np.square(wasserstein_distance(dgm1, dgm2, order=2))

            D[i,j] = d
        
D = D + D.T - np.diag(D.diagonal())

200it [02:17,  1.46it/s]


Now, we feed the amplitude vector or the distance matrix to a clustering algorithm.

In [314]:
clustering_accuracy(D,2,True)

0.64

In [13]:
Ag = np.array(Ag).reshape(N,1)
Ar = np.array(Ar).reshape(N,1)
Ah = np.array(Ah).reshape(N,1)

A = np.vstack((Ar,Ah)).reshape(2,N).T

In [14]:
kmeans = KMeans(n_clusters=2, random_state=0).fit(A)

In [15]:
pred = kmeans.labels_
truth = train_y[:N]
res = pred - truth

nb = np.count_nonzero(res==0) + np.count_nonzero(res==-7)
nb_correct = max(nb, N-nb)

nb_correct/N

0.821

*Nota bene.* Let $N=1000$ (number of images to classify). By feeding only two amplitude vectors to the $k$-means clustering algorithm, we obtain an accuracy of $82,1\%$. This is for the radial and $(1,0)$-height filtrations.

### Classification of MNIST data using the LISA method

In [119]:
from LISM import *

In [159]:
use_reg = True

In [160]:
I1_ = rad_filt(images[0]).reshape(28,28,1)
I2_ = hei_filt(images[0]).reshape(28,28,1)
    
#I_ = np.array([M.reshape(2,28).T for M in np.split(np.stack((I1,I2),axis=1),28)])

I1 = tf.Variable(initial_value=np.array(I1_, dtype=np.float32), trainable=False)
I2 = tf.Variable(initial_value=np.array(I2_, dtype=np.float32), trainable=False)

p_ = tf.constant([0.5, 0.5], shape=[2, 1])
p = tf.Variable(initial_value=np.array(p_, dtype=np.float32), trainable=True)
#p = p/tf.norm(p)

In [161]:
model = CubicalModel_ISM(p, I1, I2, dim=0, card=256)

lr = tf.keras.optimizers.schedules.InverseTimeDecay(initial_learning_rate=1e-3, decay_steps=10, decay_rate=.01)
optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
sigma = 0.001
lambda_ = 100

alpha = 10.

optimization = {'losses':[],
                            'amplitudes':[],
                            'projections':[],
                            'diagrams':[]
}

In [162]:
# reduce the nb of epochs to 30+1 for faster computations 
for epoch in tqdm(range(50+1)):
    
    with tf.GradientTape() as tape:
        
        dgm1, dgm2 = model.call()
    
        if use_reg:
            amplitude = alpha * tf.math.sqrt(wasserstein_distance(dgm1, dgm2, order=2, enable_autodiff=True))
            loss = - amplitude + lambda_*(tf.abs(tf.norm(p)-1))
        else:
            amplitude = alpha * tf.math.sqrt(wasserstein_distance(dgm1, dgm2, order=2, enable_autodiff=True))
            loss = - amplitude
        #model.p = model.p/tf.norm(model.p)
        
        
    gradients = tape.gradient(loss, model.trainable_variables)
    
    np.random.seed(epoch)
    gradients = [tf.convert_to_tensor(gradients[0])]
    gradients[0] = gradients[0] + np.random.normal(loc=0., scale=sigma, size=gradients[0].shape)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    optimization['losses'].append(loss.numpy())
    optimization['amplitudes'].append(amplitude.numpy())
    optimization['projections'].append(p.numpy())
    optimization['diagrams'].append((dgm1,dgm2))

100%|██████████| 51/51 [00:14<00:00,  3.51it/s]


In [163]:
amplitudes = optimization['amplitudes']
losses = optimization['losses']

fig = go.Figure()
fig.add_trace(go.Scatter(x=list(range(51)),y=losses,mode="lines",name="losses"))
fig.add_trace(go.Scatter(x=list(range(51)),y=amplitudes,mode="lines",name="amplitudes"))
fig.update_layout(
    title="Loss curves for optimizing over the projection",
    xaxis_title="Epochs",
    yaxis_title="Validation loss",
    height=350,
    width=500,
    legend_title="Curves"
)
fig.show()

In [164]:
projections = optimization['projections']
proj1 = [float(p[0]) for p in projections]
proj2 = [float(p[1]) for p in projections]

fig = go.Figure()
fig.add_trace(go.Scatter(x=list(range(51)),y=proj1,mode="lines",name="p1"))
fig.add_trace(go.Scatter(x=list(range(51)),y=proj2,mode="lines",name="p2"))
fig.update_layout(
    title="Projection evolution during optimization",
    xaxis_title="Epochs",
    yaxis_title="Projection coordinates",
    height=350,
    width=500,
    legend_title="Proj. coord."
)
fig.show()

In [165]:
norms = [tf.norm(p) for p in projections]

fig = go.Figure()
fig.add_trace(go.Scatter(x=list(range(51)),y=norms,mode="lines",name="norm"))
fig.update_layout(
    title="Control over the projection norm",
    xaxis_title="Epochs",
    yaxis_title="Projection norm",
    height=350,
    width=500,
    legend_title="Proj. norm"
)
fig.show()

For an image $I$ and a projection $p$, one can obtain a projected image $I_p$. Now, define $$A_W(I):=\max_{p}\left(A_W(p_1I_1+p_2I_2)\right),$$ where $I_1$ is the $(3,7)$-radial filtration of the image $I$ and $I_2$ is the $(1,0)$-height filtration of the image $I$. One can gather all such amplitudes, which we call *linear integral sheaf amplitudes (LISAs)*, and feed them as a single row to the $k$-means clustering algorithm.

In [None]:
# wasserstein ISM amplitudes 
Ai = []

# radial filtration
images_r = list(map(rad_filt,images))

# with height filtration
images_h = list(map(hei_filt,images))

for i,img in tqdm(enumerate(images_h)):

    I1 = images_r[i]
    
    I_ = np.array([M.reshape(2,28).T for M in np.split(np.stack((I1,img),axis=1),28)])

    I = tf.Variable(initial_value=np.array(I_, dtype=np.float32), trainable=False)
    
    amplitude, optimization = ISM_norm(I)
    
    Ai.append((amplitude, optimization))

In [13]:
amplitudes_ism = [p[0] for p in Ai]

In [16]:
Al = np.array(amplitudes_ism).reshape(N,1) 

kmeans = KMeans(n_clusters=2, random_state=0).fit(Al)

pred = kmeans.labels_
truth = train_y[:N]
res = pred - truth

nb = np.count_nonzero(res==0) + np.count_nonzero(res==-7)
nb_correct = max(nb, N-nb)

nb_correct/N

0.608

In [39]:
ind0 = np.argwhere(labels==0)
ind8 = np.argwhere(labels==8)
s1 = ind0.shape[0]
s2 = ind8.shape[0]

AISM0 = np.array(amplitudes_ism).reshape(N,1)[ind0]
AISM8 = np.array(amplitudes_ism).reshape(N,1)[ind8]


NameError: name 'amplitudes_ism' is not defined

In [38]:
df = pd.DataFrame()
df['AISM']=AISM0
#fig = px.histogram(df, x="total_bill", color="sex")
#fig.show()

NameError: name 'AISM0' is not defined

In [102]:
fig = make_subplots(rows=2, cols=1)

fig.add_trace(go.Scatter(
    x=AISM0.reshape(s1), y=np.zeros(s1), mode='markers', name = 'zeros', marker=dict(
            color='LightSkyBlue',
            size=7,
        ),
), row=1, col=1)

fig.add_trace(go.Scatter(
    x=AISM8.reshape(s2), y=np.zeros(s2), mode='markers', name = 'eights', marker=dict(
            color='MediumPurple',
            size=7,
        ),
), row=2, col=1)

fig.update_xaxes(showgrid=False)
fig.update_yaxes(showgrid=False, 
                 zeroline=True, zerolinecolor='black', zerolinewidth=1,
                 showticklabels=False)
fig.update_layout(height=400, plot_bgcolor='white', title="LISA distribution (optimal combination of radial and height filtrations)")

fig.show()

### Illustrating the pipeline for a single image

In [21]:
from plotly.subplots import make_subplots

In [158]:
img = 1-images[1] # then 0=black, 1=white   
I_bin = binarize(img,threshold=0.4)
I_rad = radial_filt(I_bin)
I_height = height_filt(I_bin)

In [162]:
imgs = np.vstack((img, I_bin)).reshape(2,28,28)
fig = px.imshow(imgs, facet_col=0, facet_col_wrap=2)
fig.update_layout(height=300, width=600, title="Gray-scale (left) and binarized (right) filtrations")
fig.show()

In [161]:
imgs = np.vstack((I_rad, I_height)).reshape(2,28,28)
fig = px.imshow(imgs, facet_col=0, facet_col_wrap=2)
fig.update_layout(height=300, width=600, title="Radial (left) and height (right) filtrations")
fig.show()

In [212]:
Igr_ = np.array([M.reshape(2,28).T for M in np.split(np.stack((img,I_rad),axis=1),28)])
Ihr_ = np.array([M.reshape(2,28).T for M in np.split(np.stack((I_height,I_rad),axis=1),28)])

Igr = tf.Variable(initial_value=np.array(Igr_, dtype=np.float32), trainable=False)
Ihr = tf.Variable(initial_value=np.array(Ihr_, dtype=np.float32), trainable=False)
amplitude, optimization = ISM_norm(Igr,more_info=True)
amplitude_, optimization_ = ISM_norm(Ihr,more_info=True)

In [215]:
p_opt = optimization['projections'][-1]
I_opt = tf.reshape(tf.tensordot(Igr,p_opt,1),shape=[28,28])

p_opt_ = optimization_['projections'][-1]
I_opt_ = tf.reshape(tf.tensordot(Ihr,p_opt_,1),shape=[28,28])

In [216]:
I_opts = np.vstack((I_opt, I_opt_)).reshape(2,28,28)
fig = px.imshow(I_opts, facet_col=0, facet_col_wrap=2)
fig.update_layout(height=300, width=600, title="Optimal combinations (left : gray&rad, right : height&rad)")
fig.show()

In [183]:
amplitudes = optimization['amplitudes']
losses = optimization['losses']

fig = go.Figure()
fig.add_trace(go.Scatter(x=list(range(51)),y=losses,mode="lines",name="losses"))
fig.add_trace(go.Scatter(x=list(range(51)),y=amplitudes,mode="lines",name="amplitudes"))
fig.update_layout(
    title="Loss curves for optimizing over the projection",
    xaxis_title="Epochs",
    yaxis_title="Validation loss",
    height=350,
    width=500,
    legend_title="Curves"
)
fig.show()



In [184]:
projections = optimization['projections']
proj1 = [float(p[0]) for p in projections]
proj2 = [float(p[1]) for p in projections]

fig = go.Figure()
fig.add_trace(go.Scatter(x=list(range(51)),y=proj1,mode="lines",name="p1"))
fig.add_trace(go.Scatter(x=list(range(51)),y=proj2,mode="lines",name="p2"))
fig.update_layout(
    title="Projection evolution during optimization",
    xaxis_title="Epochs",
    yaxis_title="Projection coordinates",
    height=350,
    width=500,
    legend_title="Proj. coord."
)
fig.show()

In [185]:
norms = [tf.norm(p) for p in projections]

fig = go.Figure()
fig.add_trace(go.Scatter(x=list(range(51)),y=norms,mode="lines",name="norm"))
fig.update_layout(
    title="Control over the projection norm",
    xaxis_title="Epochs",
    yaxis_title="Projection norm",
    height=350,
    width=500,
    legend_title="Proj. norm"
)
fig.show()

In [202]:
dgm = optimization['diagrams'][-1]
b = min(dgm[:,0])
d = max(dgm[:,1])
e = (d-b)/10

In [203]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=[b-e,d+e],y=[b-e,d+e],mode="lines",name="diagonal"))
fig.add_trace(go.Scatter(x=dgm[:,0],y=dgm[:,1],mode="markers",name="(b,d)"))
fig.update_layout(
    title="Optimized persistence diagram",
    xaxis_title="Birth",
    yaxis_title="Death",
    height=400,
    width=400,
    legend_title="coordinates"
)
fig.show()