In [None]:
import torch
import matplotlib.pyplot as plt

from src.manifolds.deformed_gaussian_pullback_manifold.quadratic_banana_pullback_manifold import QuadraticBananaPullbackManifold
from src.riemannian_autoencoder.deformed_gaussian_riemannian_autoencoder import DeformedGaussianRiemannianAutoencoder

: 

In [None]:
banana_manifold = QuadraticBananaPullbackManifold()

xx = torch.linspace(-6.0, 6.0, 100)
yy = torch.linspace(-6.0, 6.0, 100)
x_grid, y_grid = torch.meshgrid(xx, yy)

xy_grid = torch.zeros((*x_grid.shape,2))
xy_grid[:,:,0] = x_grid
xy_grid[:,:,1] = y_grid

density_banana = torch.exp(banana_manifold.dg.log_density(xy_grid.reshape(-1,2)).reshape(x_grid.shape))
plt.contour(x_grid, y_grid, density_banana)
plt.show()

: 

In [None]:
# special points
x0 = torch.tensor([2.,4.])
x1 = torch.tensor([2.,-4.])

: 

In [None]:
# test barycentre
x=torch.zeros((2,2))
x[0] = x0
x[1] = x1

barycentre = banana_manifold.barycentre(x)

plt.contour(x_grid, y_grid, density_banana)
plt.scatter(torch.tensor([x0[0], x1[0]]), torch.tensor([x0[1], x1[1]]))
plt.scatter(barycentre[0], barycentre[1], color="orange")
plt.savefig("results/quadratic_banana/barycentre.eps")
plt.show()

: 

In [None]:
# test inner
X = torch.eye(2)

inner_0 = banana_manifold.inner(torch.zeros(2)[None], X[None], X[None])
inner_x0 = banana_manifold.inner(x0[None], X[None], X[None])

print(inner_0)
print(inner_x0)

: 

In [None]:
# test interpolation
t = torch.linspace(0.,1.,100)


geodesic = banana_manifold.geodesic(x0,x1,t)

plt.contour(x_grid, y_grid, density_banana)
plt.plot(geodesic[:,0], geodesic[:,1], color="orange")
plt.scatter(torch.tensor([x0[0], x1[0]]), torch.tensor([x0[1], x1[1]]))
plt.savefig("results/quadratic_banana/geodesic.eps")
plt.show()

: 

In [None]:
# test logarithmic mapping
logarithmic = banana_manifold.log(x0,x1[None])[0]

plt.contour(x_grid, y_grid, density_banana)
plt.arrow(x0[0], x0[1], logarithmic[0], logarithmic[1], head_width=0.2, color="orange")
plt.scatter(torch.tensor([x0[0], x1[0]]), torch.tensor([x0[1], x1[1]]))
plt.savefig("results/quadratic_banana/logarithmic.eps")
plt.show()

: 

In [None]:
# test exponential mapping
exponential = banana_manifold.exp(x0,logarithmic[None])[0]

plt.contour(x_grid, y_grid, density_banana)
plt.scatter(x0[0], x0[1])
plt.arrow(x0[0], x0[1], logarithmic[0], logarithmic[1], head_width=0.2)
plt.scatter(exponential[0], exponential[1], color="orange")
plt.savefig("results/quadratic_banana/exponential.eps")
plt.show()
print(f"The error between exp_x0(log_x0 (x1)) and x1 is {torch.norm(exponential - x1)}")

: 

In [None]:
# test distance
l2_distance = torch.norm(x0 - x1)
distance = banana_manifold.distance(x0[None,None], x1[None,None])[0,0,0]
print(l2_distance)
print(distance)

: 

In [None]:
# test parallel transport
parallel_transport = banana_manifold.parallel_transport(x0, logarithmic[None], x1)[0]

plt.contour(x_grid, y_grid, density_banana)
plt.scatter(torch.tensor([x0[0], x1[0]]), torch.tensor([x0[1], x1[1]]))
plt.arrow(x0[0], x0[1], logarithmic[0], logarithmic[1], head_width=0.2)
plt.arrow(x1[0], x1[1], parallel_transport[0], parallel_transport[1], head_width=0.2, color="orange")
plt.savefig("results/quadratic_banana/parallel-transport.eps")
plt.show()


: 

In [None]:
# Riemannian autoencoder
epsilon = 0.1
banana_rae = DeformedGaussianRiemannianAutoencoder(banana_manifold, epsilon)

p = torch.linspace(-5, 5, 100)[:,None]
rae_decode_p = banana_rae.decode(p)

plt.contour(x_grid, y_grid, density_banana)
plt.plot(rae_decode_p[:,0], rae_decode_p[:,1], color="orange")
plt.savefig("results/quadratic_banana/rae_manifold.eps")
plt.show()

: 