In [None]:
import sys
import numpy             as np
import tensorflow        as tf
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import gudhi.representations as gdr
import gudhi.tensorflow as gdtf

In [None]:
diagrams = [np.array([[0.,4.],[1.,2.],[3.,8.],[6.,8.]]), np.array([[3.,6.],[4.,5.]])]

In [None]:
plt.scatter(diagrams[0][:,0], diagrams[0][:,1])
plt.plot([0.,6.],[0.,6.])
plt.show()

In [None]:
diagrams = gdr.DiagramScaler(use=True, scalers=[([0,1], MinMaxScaler())]).fit_transform(diagrams)

In [None]:
plt.scatter(diagrams[1][:,0], diagrams[1][:,1])
plt.plot([0.,1.],[0.,1.])
plt.show()

In [None]:
diagrams = tf.concat([
    tf.RaggedTensor.from_tensor(tf.constant(diagrams[0][None,:], dtype=tf.float32)),
    tf.RaggedTensor.from_tensor(tf.constant(diagrams[1][None,:], dtype=tf.float32))
], axis=0)

In [None]:
print(diagrams)

In [None]:
with tf.GradientTape() as tape:
    
    rho = tf.identity 
    #rho = tf.keras.layers.Dense(10)
    #rho = tf.keras.layers.Conv2D(3,5)
    
    #phi = gdtf.GaussianPerslayPhi((100, 100), ((-.5, 1.5), (-.5, 1.5)), .1)
    phi = gdtf.TentPerslayPhi(np.array(np.arange(-1.,2.,.001), dtype=np.float32))
    #phi = gdtf.FlatPerslayPhi(np.array(np.arange(-1.,2.,.001), dtype=np.float32), 100.)
    
    #weight = gdtf.GaussianMixturePerslayWeight(np.array([[.5],[.5],[5],[5]], dtype=np.float32))
    #weight = gdtf.PowerPerslayWeight(1.,0.)
    weight = gdtf.GridPerslayWeight(np.array(np.random.uniform(size=[100,100]),dtype=np.float32),((-0.01, 1.01),(-0.01, 1.01)))
    
    #perm_op = tf.math.reduce_sum
    perm_op = 'top3'
    
    perslay = gdtf.Perslay(phi=phi, weight=weight, perm_op=perm_op, rho=rho)
    vectors = perslay(diagrams)
    
#print(tape.gradient(vectors, phi.variance))

In [None]:
# GridPerslayWeight
W = np.flip(weight.grid.numpy(), 0)
plt.figure()
plt.imshow(W, cmap="Purples", zorder=1)
((xm,xM),(ym,yM)) = weight.grid_bnds
[xs, ys] = weight.grid.shape
plt.scatter([int(xs*(x-xm)/(xM-xm)) for x in diagrams[0][:,0]], 
            [ys-int(ys*(y-ym)/(yM-ym)) for y in diagrams[0][:,1]], 
            s=10, color="red", zorder=2)
plt.show()

In [None]:
# GaussianMixturePerslayWeight
means = weight.W[:2,:].numpy()
varis = weight.W[2:,:].numpy()
x, y = np.arange(-.5, 1.5, .001), np.arange(-.5, 1.5, .001)
xx, yy = np.meshgrid(x, y)
z = np.zeros(xx.shape)
for idx_g in range(means.shape[1]):
    z += np.exp(-((xx-means[0,idx_g])**2 * (varis[0,idx_g])**2 + (yy-means[1,idx_g])**2 * (varis[1,idx_g])**2 ))
plt.contourf(xx, yy, z)
plt.scatter(diagrams[0][:,0], diagrams[0][:,1], s=50, color="red")
plt.show()

In [None]:
# TentPerslayPhi
plt.figure()
vectors = np.reshape(vectors[0,:], [-1, 3])
for k in range(3):
    plt.plot(vectors[:,k], linewidth=5.)
plt.show()

In [None]:
# FlatPerslayPhi
plt.figure()
plt.plot(vectors[0,:], linewidth=5.)
plt.show()

In [None]:
# GaussianPerslayPhi
plt.figure()
plt.imshow(np.flip(vectors[0,:,:,0],0), cmap="Purples")
cb = plt.colorbar()
cb.ax.tick_params(labelsize=14)
plt.show()