# Generate and Save Synthetic Data for Manifolds

In [2]:
#JAX
import jax.numpy as jnp

#numpy
import numpy as np

#argparse
import argparse

from ManLearn.VAE.VAE_MNIST import model as mnist_model
from ManLearn.VAE.VAE_MNIST import model_encoder as mnist_encoder
from ManLearn.VAE.VAE_MNIST import model_decoder as mnist_decoder
from ManLearn.VAE.VAE_MNIST import VAEOutput as mnist_output

from ManLearn.VAE.VAE_SVHN import model as svhn_model
from ManLearn.VAE.VAE_SVHN import model_encoder as svhn_encoder
from ManLearn.VAE.VAE_SVHN import model_decoder as svhn_decoder
from ManLearn.VAE.VAE_SVHN import VAEOutput as svhn_output

from ManLearn.VAE.VAE_CelebA import model as celeba_model
from ManLearn.VAE.VAE_CelebA import model_encoder as celeba_encoder
from ManLearn.VAE.VAE_CelebA import model_decoder as celeba_decoder
from ManLearn.VAE.VAE_CelebA import VAEOutput as celeba_output

from ManLearn.train_MNIST import load_dataset as load_mnist
from ManLearn.train_SVHN import load_dataset as load_svhn
from ManLearn.train_CelebA import load_dataset as load_celeba
from ManLearn.model_loader import load_model

#jaxgeometry
from jaxgeometry.manifolds import Euclidean, nSphere, Ellipsoid, Cylinder, S1, Torus, \
    H2, Landmarks, Heisenberg, SPDN, Latent, HypParaboloid
from jaxgeometry.setup import dts, dWs, hessianx
from jaxgeometry.statistics import score_matching
from jaxgeometry.statistics.score_matching.model_loader import load_model
from jaxgeometry.stochastics import Brownian_coords, product_sde, Brownian_sR
from jaxgeometry.stochastics.product_sde import tile

## Hyper-Parameters

In [3]:
N_sim = 100
n_steps = 100
T = 1.0

file_path = 'Data/'

## Euclidean

### Generate Data

In [4]:
dim = [2,3,5,10,20,50,100,200]
for d in dim:
    M = Euclidean(N=d)
    Brownian_coords(M)
    
    (product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

    x0 = M.coords([0.]*d)
    x0s = tile(x0, N_sim)
    
    _dts = dts(T=T, n_steps=n_steps)
    dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
    (ts,xss,chartss,*_) = product(x0s,
                                  _dts,dW,jnp.repeat(1.,N_sim))
    
    xs = xss[:,-1]
    chart = chartss[:,-1]
    
    path = ''.join((file_path, 'R', str(d), '/'))
    np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
    np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
Done


## S1

### Generate Data

In [5]:
M = S1(use_spherical_coords=True)
Brownian_coords(M)

x0 = M.coords([0.])
x0s = tile(x0, N_sim)

(product, sde_product, chart_update_product) = product_sde(M, 
                                                            M.sde_Brownian_coords, 
                                                            M.chart_update_Brownian_coords)

_dts = dts(T=T, n_steps=n_steps)
dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
(ts,xss,chartss,*_) = product(x0s,
                              _dts,dW,jnp.repeat(1.,N_sim))

xs = xss[:,-1]
chart = chartss[:,-1]

path = ''.join((file_path, 'S1/'))
np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

using M.Exp for Logarithm
Done


## mSphere

### Generate Data

In [6]:
dim = [2,3,5,10,20,50,100, 200]
for d in dim:
    M = nSphere(N=d)
    Brownian_coords(M)

    (product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

    x0 = M.coords([0.]*d)
    x0s = tile(x0, N_sim)
    
    _dts = dts(T=T, n_steps=n_steps)
    dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
    (ts,xss,chartss,*_) = product(x0s,
                                  _dts,dW,jnp.repeat(1.,N_sim))
    
    xs = xss[:,-1]
    chart = chartss[:,-1]
    
    path = ''.join((file_path, 'S', str(d), '/'))
    np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
    np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
Done


## Ellipsoid

### Generate Data

In [7]:
## mSphere

### Generate Data

dim = [2,3,5,10,20,50,100, 200]
for d in dim:
    M = Ellipsoid(N=d, params = jnp.linspace(0.5,1.0,d+1))
    Brownian_coords(M)

    N_dim = M.emb_dim
    x0 = M.coords([0.]*d)

    (product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

    x0s = tile(x0, N_sim)
    
    _dts = dts(T=T, n_steps=n_steps)
    dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
    (ts,xss,chartss,*_) = product(x0s,
                                  _dts,dW,jnp.repeat(1.,N_sim))
    
    xs = xss[:,-1]
    chart = chartss[:,-1]
    
    path = ''.join((file_path, 'Ellipsoid', str(d), '/'))
    np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
    np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
Done


## Cylinder

### Generate Data

In [8]:
## mSphere

### Generate Data

M = Cylinder(params=(1.,jnp.array([0.,0.,1.]),jnp.pi/2.))
Brownian_coords(M)

x0 = M.coords([0.]*2)

(product, sde_product, chart_update_product) = product_sde(M, 
                                                           M.sde_Brownian_coords, 
                                                           M.chart_update_Brownian_coords)

x0s = tile(x0, N_sim)

_dts = dts(T=T, n_steps=n_steps)
dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
(ts,xss,chartss,*_) = product(x0s,
                              _dts,dW,jnp.repeat(1.,N_sim))

xs = xss[:,-1]
chart = chartss[:,-1]

path = ''.join((file_path, 'Cylinder/'))
np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

using M.Exp for Logarithm
Done


## Torus

### Generate Data

In [9]:
## mSphere

### Generate Data

M = Torus()
Brownian_coords(M)

x0 = M.coords([0.]*2)

(product, sde_product, chart_update_product) = product_sde(M, 
                                                           M.sde_Brownian_coords, 
                                                           M.chart_update_Brownian_coords)

x0s = tile(x0, N_sim)

_dts = dts(T=T, n_steps=n_steps)
dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
(ts,xss,chartss,*_) = product(x0s,
                              _dts,dW,jnp.repeat(1.,N_sim))

xs = xss[:,-1]
chart = chartss[:,-1]

path = ''.join((file_path, 'Torus/'))
np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

using M.Exp for Logarithm
Done


## Landmarks

### Generate Data

In [10]:
## mSphere

### Generate Data

dim = [2,5,10,66,77]
for d in dim:
    M = Landmarks(N=d,m=2)
    Brownian_coords(M)

    N_dim = M.dim
    x0 = M.coords(jnp.vstack((jnp.linspace(-5.0,5.0,M.N),jnp.zeros(M.N))).T.flatten())

    (product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

    x0s = tile(x0, N_sim)
    
    _dts = dts(T=T, n_steps=n_steps)
    dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
    (ts,xss,chartss,*_) = product(x0s,
                                  _dts,dW,jnp.repeat(1.,N_sim))
    
    xs = xss[:,-1]
    chart = chartss[:,-1]
    
    path = ''.join((file_path, 'Landmarks', str(d), '/'))
    np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
    np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
Done


## SPDN

### Generate Data

In [12]:
## mSphere

### Generate Data

dim = [2, 5, 10,20]
for d in dim:
    M = SPDN(N=d)    
    Brownian_coords(M)

    N_dim = M.emb_dim
    x0 = M.coords([0.]*(d*(d+1)//2))

    (product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

    x0s = tile(x0, N_sim)
    
    _dts = dts(T=T, n_steps=n_steps)
    dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
    (ts,xss,chartss,*_) = product(x0s,
                                  _dts,dW,jnp.repeat(1.,N_sim))
    
    xs = xss[:,-1]
    chart = chartss[:,-1]
    
    path = ''.join((file_path, 'SPDN', str(d), '/'))
    np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
    np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
Done


## Hyperbolic Paraboloid

### Generate Data

In [13]:
## mSphere

### Generate Data

M = HypParaboloid()
Brownian_coords(M)

x0 = M.coords([0.]*2)

(product, sde_product, chart_update_product) = product_sde(M, 
                                                           M.sde_Brownian_coords, 
                                                           M.chart_update_Brownian_coords)

x0s = tile(x0, N_sim)

_dts = dts(T=T, n_steps=n_steps)
dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
(ts,xss,chartss,*_) = product(x0s,
                              _dts,dW,jnp.repeat(1.,N_sim))

xs = xss[:,-1]
chart = chartss[:,-1]

path = ''.join((file_path, 'HypParaboloid/'))
np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

using M.Exp for Logarithm
Done


## VAE MNIST

### Generate Data

In [14]:
## mSphere

### Generate Data

dim = [2, 5, 10]
ds = load_mnist("train", 100, 2712)
        
state = load_model('ManLearn/models/MNIST/VAE/')
F = lambda x: mnist_decoder.apply(state.params, state.rng_key, x[0]).reshape(-1)

M = Latent(F=F,dim=2,emb_dim=28*28,invF=None)
Brownian_coords(M)

x0 = mnist_encoder.apply(state.params, state.rng_key, next(ds).image)
x0 = M.coords(x0[0])
        
(product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

x0s = tile(x0, N_sim)

_dts = dts(T=T, n_steps=n_steps)
dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
(ts,xss,chartss,*_) = product(x0s,
                              _dts,dW,jnp.repeat(1.,N_sim))

xs = xss[:,-1]
chart = chartss[:,-1]

path = ''.join((file_path, 'VAE_MNIST/'))
np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

2023-10-02 19:59:45.089566: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


using M.Exp for Logarithm
Done


## VAE SVHN

### Generate Data

In [15]:
## mSphere

### Generate Data

dim = [2, 5, 10]
ds = load_svhn()
        
state = load_model('ManLearn/models/SVHN/VAE/')
F = lambda x: svhn_decoder.apply(state.params, state.rng_key, x[0]).reshape(-1)

M = Latent(F=F,dim=32,emb_dim=32*32*3,invF=None)
Brownian_coords(M)

x0 = svhn_encoder.apply(state.params, state.rng_key, next(ds).image)
x0 = M.coords(x0[0])
        
(product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

x0s = tile(x0, N_sim)

_dts = dts(T=T, n_steps=n_steps)
dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
(ts,xss,chartss,*_) = product(x0s,
                              _dts,dW,jnp.repeat(1.,N_sim))

xs = xss[:,-1]
chart = chartss[:,-1]

path = ''.join((file_path, 'VAE_SVHN/'))
np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

2023-10-02 20:00:02.191492: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


using M.Exp for Logarithm
Done


## VAE CelebA

### Generate Data

In [None]:
## mSphere

### Generate Data

dim = [2, 5, 10]
ds = load_celeba()
        
state = load_model('ManLearn/models/CelebA/VAE/')
F = lambda x: celeba_decoder.apply(state.params, state.rng_key, x[0]).reshape(-1)

M = Latent(F=F,dim=32,emb_dim=64*64*3,invF=None)
Brownian_coords(M)

N_dim = M.dim
x0 = svhn_encoder.apply(state.params, state.rng_key, next(ds).image)
x0 = M.coords(x0[0])

(product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

x0s = tile(x0, N_sim)

_dts = dts(T=T, n_steps=n_steps)
dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
(ts,xss,chartss,*_) = product(x0s,
                              _dts,dW,jnp.repeat(1.,N_sim))

xs = xss[:,-1]
chart = chartss[:,-1]

path = ''.join((file_path, 'VAE_CelebA/'))
np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
np.savetxt(''.join((path, 'chart.csv')), xs, delimiter=",")
print("Done")

using M.Exp for Logarithm
