*LAND* experiment of the paper: we were not able to fully reproduce this experiment.

In [1]:
import numpy as np
import torch
import torch.nn as nn
from core import VAE,RBFNN,VAE_RBF,utils,manifolds,geodesics
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets, transforms

import matplotlib as mpl
mpl.rcParams['figure.facecolor'] = 'white'

# Downloand and preprocess data

In [2]:
# # MNIST Dataset transformation
mnist_transform = transforms.Compose([
    transforms.ToTensor(), # Converts to [0, 1] interval
    transforms.Lambda(lambda x: torch.flatten(x)) # Flattens the image to a 1D vector
])

# dataset class to efficiently extract the relevant labels and not make a copy of the entire dataset
class FilteredMNIST(Dataset):
    def __init__(self, mnist_dataset, labels):
        # Extract only the data with the specified labels
        self.data = [(img, label) for img, label in mnist_dataset if label in labels]
        self.imgs = [img for img, label in self.data]
        self.targets = [label for img, label in self.data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# # Download and load the MNIST dataset (if taking too long, cancel and try again it will worky)
full_train_dataset = datasets.MNIST(root='mnist_data', train=True, transform=mnist_transform, download=True)
print("Original size of the dataset: ", len(full_train_dataset))
filtered_labels = [0, 1]
print("Filtered labels: ", filtered_labels)
train_dataset = FilteredMNIST(full_train_dataset, filtered_labels)
print("New size of the dataset: ", len(train_dataset))

batch_size = 256
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,num_workers=4)



# # MNIST Dataset transformation
# fashionmnist_transform = transforms.Compose([
#     transforms.ToTensor(), # Converts to [0, 1] interval
#     transforms.Lambda(lambda x: torch.flatten(x)) # Flattens the image to a 1D vector
# ])

# # dataset class to efficiently extract the relevant labels and not make a copy of the entire dataset
# class FilteredMNIST(Dataset):
#     def __init__(self, mnist_dataset, labels):
#         # Extract only the data with the specified labels
#         self.data = [(img, label) for img, label in mnist_dataset if label in labels]
#         self.imgs = [img for img, label in self.data]
#         self.targets = [label for img, label in self.data]

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         return self.data[idx]


# # Download and load the MNIST dataset (if taking too long, cancel and try again it will worky)
# full_train_dataset = datasets.FashionMNIST(root='fashionmnist_data', train=True, transform=fashionmnist_transform, download=True)
# print("Original size of the dataset: ", len(full_train_dataset))
# filtered_labels = [0,1,7]
# print("Filtered labels: ", filtered_labels)
# train_dataset = FilteredMNIST(full_train_dataset, filtered_labels)
# print("New size of the dataset: ", len(train_dataset))

# batch_size = 256
# train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,num_workers=4)

Original size of the dataset:  60000
Filtered labels:  [0, 1]
New size of the dataset:  12665


# Load saved model

In [3]:
# load the saved model
input_dim = train_dataset[0][0].shape[-1]
hidden_dims = [64, 32]
latent_dim = 2
hidden_activation = nn.Tanh()
encoder_output_mu_activation=nn.Identity()
encoder_output_logvar_activation=nn.Softplus()
decoder_output_mu_activation=nn.Sigmoid()
decoder_output_logvar_activation=nn.Softplus()

model_VAE = VAE.VAE(input_dim = input_dim, 
                hidden_dims = hidden_dims,
                latent_dim = latent_dim, 
                hidden_activation=hidden_activation, 
                encoder_output_mu_activation=encoder_output_mu_activation,
                encoder_output_logvar_activation=encoder_output_logvar_activation,
                decoder_output_mu_activation=decoder_output_mu_activation,
                decoder_output_logvar_activation=decoder_output_logvar_activation)

# create RBF network
k_rbf = 64
a = 2
zeta_rbf = 1e-6
model_rbfnn = RBFNN.RBFNN(a=a,k_rbf=k_rbf,zeta_rbf=zeta_rbf,W_dim=input_dim,latent_dim=latent_dim)

# create VAE-RBF model
model_VAE_RBF = VAE_RBF.VAE_RBF(model_VAE,model_rbfnn)

model_VAE_RBF.load_state_dict(torch.load('torch_models/VAE_RBFNN_MNIST.pt'))

<All keys matched successfully>

# Encode and decode the whole dataset

In [4]:
# encode and decode the data and store the outputs of the model
# computed on the whole dataset
latent_means_list = []
latent_logvars_list = []
latent_z = []
output_means_list = []
output_logvars_list = []
labels_list = []

with torch.no_grad():  
    for imgs, labels_ in train_loader: 
        mu_x, log_var_x, z_enc, mu_z, log_var_z = model_VAE_RBF.VAE(imgs)
        z_rep = VAE.VAE.reparametrization_trick(mu_z, log_var_z)
        latent_means_list.append(mu_z)
        latent_logvars_list.append(log_var_z)
        latent_z.append(z_rep)
        output_means_list.append(mu_x)
        output_logvars_list.append(log_var_x)
        labels_list.append(labels_)

# Now we concatenate the list to form a single Tensor
latent_means = torch.cat(latent_means_list, dim=0)
latent_logvars = torch.cat(latent_logvars_list, dim=0)
latent_z = torch.cat(latent_z, dim=0)
output_means = torch.cat(output_means_list, dim=0)
output_logvars = torch.cat(output_logvars_list, dim=0)
labels = torch.cat(labels_list, dim=0)

## Construct the manifold

In [5]:
latent_means_np = latent_means.detach().numpy()

# The G_mode_parameters is used by the manifold class to compute the metric tensor
G_model_parameters = {'name': 'generator'}
w_counter = 0
b_counter = 0
for name, param in model_VAE_RBF.VAE.named_parameters():
    if ('dec' or 'mu_dec') in name:
        if 'weight' in name:
            key_name = 'W' + str(w_counter)
            G_model_parameters[key_name] = param.detach().numpy()
            w_counter += 1
        elif 'bias' in name:
            key_name = 'b' + str(b_counter)
            G_model_parameters[key_name] = param.detach().numpy().reshape(-1, 1)
            b_counter += 1

G_model_parameters['activation_fun_hidden'] = str(hidden_activation).lower()  # Get the name of activFun e.g. Tanh() -> tanh
G_model_parameters['activation_fun_output'] = str(decoder_output_mu_activation).lower()
G_model_parameters['Wrbf'] = model_VAE_RBF.RBF.W_rbf.detach().numpy()  # The weights for the RBFs (D x K)
G_model_parameters['Crbf'] = model_VAE_RBF.RBF.centers_rbf.detach().numpy()  # The centers for the RBFs (K x d)
G_model_parameters['Grbf'] = model_VAE_RBF.RBF.lambdas_k.detach().numpy()  # * np.ones((K_rbf, 1))  # The precision for the RBFs (K x 1)
G_model_parameters['zeta'] = zeta_rbf  # A small value to prevent division by 0
G_model_parameters['beta'] = 1.0  # This scaling parameter of the metric is updated later

# Construct the manifold
manifold_latent = manifolds.MlpMeanInvRbfVar(G_model_parameters)
# Rescale the metric such that the maximum measure on the data to be 1
beta_rbf = 1 / (np.sqrt(np.linalg.det(manifold_latent.metric_tensor(latent_means_np.T)).max()))
G_model_parameters['beta'] = beta_rbf ** (2 / latent_dim)  # Rescale the pull-back metric
z1min, z2min = latent_means_np.min(0) - 0.5
z1max, z2max = latent_means_np.max(0) + 0.

# Land prediction

In [17]:
# Get 40 representative samples from the dataset
labels_np = labels.detach().numpy()
p0 = np.sum(labels_np == 0) / len(labels_np)
p1 = np.sum(labels_np == 1) / len(labels_np)
n0, n1 = round(p0 * 40), round(p1 * 40)
print(n0, n1)

indices_class_0 = np.where(labels_np == 0)[0]
indices_class_1 = np.where(labels_np == 1)[0]

# Stratified sample without replacement
sample_indices_class_0 = np.random.choice(indices_class_0, 20, replace=False)
sample_indices_class_1 = np.random.choice(indices_class_1, 20, replace=False)

# Combine the indices of the two classes to get the final sample
sample_indices = np.concatenate((sample_indices_class_0, sample_indices_class_1))
np.random.shuffle(sample_indices)  # Mélangez-les pour ajouter de la randomisation

# Sample the data and the labels
sample_data = latent_means_np[sample_indices]

19 21


In [21]:
from core.geometric_methods import land_predict, land_mixture_model
data = sample_data

# Compute some random curve using a heuristic graph solver 
GRAPH_DATA = KMeans(n_clusters=12, max_iter=100).fit(data).cluster_centers_
solver_graph = geodesics.SolverGraph(manifold_latent, data=GRAPH_DATA, kNN_num=5, tol=0.2)

data = sample_data
param = {}
param["means"] = KMeans(n_clusters=2).fit(data).cluster_centers_ # GRAPH_DATA = KMeans(n_clusters=64, n_init=30, max_iter=1000).fit(latent_means_np).cluster_centers_
param["K"] = 2 # number of clusters
param["S"] = 10 # We wand to use 40 samples to normalize the distribution land 
param['max_iter'] = 10
param['tol'] = 0.2 # 0.1
param['step_size'] = 0.01 # 0.1
param['mixing_param'] = 0  # [0, 1] how much between empirical covariance and identity

responsabilities = land_mixture_model(manifold_latent, solver_graph, GRAPH_DATA, param)

  super()._check_params_vs_input(X, default_n_init=10)


[Initialize Graph] [Processed point: 0/12]


  super()._check_params_vs_input(X, default_n_init=10)


[Initialize: 1/2] [Process point: 1/12]
[Initialize: 1/2] [Process point: 2/12]
[Initialize: 1/2] [Process point: 3/12]
[Initialize: 1/2] [Process point: 4/12]
[Initialize: 1/2] [Process point: 5/12]
[Initialize: 1/2] [Process point: 6/12]
[Initialize: 1/2] [Process point: 7/12]
[Initialize: 1/2] [Process point: 8/12]
[Initialize: 1/2] [Process point: 9/12]
[Initialize: 1/2] [Process point: 10/12]
[Initialize: 1/2] [Process point: 11/12]
[Initialize: 1/2] [Process point: 12/12]
[Initialize: 2/2] [Process point: 1/12]
[Initialize: 2/2] [Process point: 2/12]
[Initialize: 2/2] [Process point: 3/12]
[Initialize: 2/2] [Process point: 4/12]
[Initialize: 2/2] [Process point: 5/12]
[Initialize: 2/2] [Process point: 6/12]
[Initialize: 2/2] [Process point: 7/12]
[Initialize: 2/2] [Process point: 8/12]
[Initialize: 2/2] [Process point: 9/12]
[Initialize: 2/2] [Process point: 10/12]
[Initialize: 2/2] [Process point: 11/12]
[Initialize: 2/2] [Process point: 12/12]


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=2

Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=2

Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals


Expmap failed!
Expmap failed!
Expmap failed!
Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=2

Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=2

Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=2

Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=2

Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=2

Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=2

Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals


Expmap failed!


  the requested tolerance from being achieved.  The error may be 
  underestimated.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up

Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=2

Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  the requested tolerance from being achieved.  The error may be 
  underestimated.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals


Expmap failed!


  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=200)  # , number of subintervals
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  curve_length_eval = integrate.quad(lambda t: local_length(manifold, curve, t), a, b, epsabs=tol, limit=2

KeyboardInterrupt: 

In [None]:
# Plot the density
z1min, z2min = data.min(0) - 0.5
z1max, z2max = data.max(0) + 0.5
N_Z_grid = 20
Z_grid = utils.my_meshgrid(z1min, z1max, z2min, z2max, N=N_Z_grid)
Logmaps = np.zeros((2, Z_grid.shape[0], data.shape[1]))  # The logmaps for each center
for k in range(2):
    for n in range(Z_grid.shape[0]):
        # If the solver fails then the logmap is overestimated (straight line)
        curve_bvp, logmap_bvp, curve_length_bvp, failed_bvp, solution_bvp \
            = geodesics.compute_geodesic(solver, manifold,
                                         land_res_prior['means'][k, :].reshape(-1, 1),
                                         Z_grid[n, :].reshape(-1, 1))
        Logmaps[k, n, :] = logmap_bvp.ravel()

pdf_vals = np.zeros((Z_grid.shape[0], 1))
for k in range(2):
    pdf_vals += land_res_prior['Weights'][k, 0] * (np.exp(-0.5 * np.diag(Logmaps[k, :, :] @ np.linalg.inv(land_res_prior['Sigmas'][k, :, :]) @ Logmaps[k, :, :].T)) / land_res_prior['Consts'][k, 0]).reshape(-1, 1)
plt.imshow(pdf_vals.reshape(N_Z_grid, N_Z_grid), interpolation='bicubic',
           extent=(z1min, z1max, z2min, z2max), origin='lower')

torch.Size([12665, 2])