# Generate and Save Synthetic Data for Manifolds

In [1]:
#JAX
import jax.numpy as jnp
from jax import vmap

#numpy
import numpy as np

#argparse
import argparse

#jaxgeometry
from jaxgeometry.manifolds import Euclidean, nSphere, nEllipsoid, Cylinder, S1, Torus, \
    H2, Landmarks, Heisenberg, SPDN, Latent, HypParaboloid, Sym
from jaxgeometry.integration import dts, dWs
from jaxgeometry.autodiff import hessianx
from jaxgeometry.statistics import score_matching
from jaxgeometry.statistics.score_matching.model_loader import load_model
from jaxgeometry.stochastics import Brownian_coords, product_sde, Brownian_sR
from jaxgeometry.stochastics.product_sde import tile

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## Hyper-Parameters

In [2]:
N_sim = 1000
n_steps = 100
T = 0.5

file_path = 'Data/'

## Euclidean

### Generate Data

In [6]:
dim = [2,3,5,10,20,50]
for d in dim:
    M = Euclidean(N=d)
    Brownian_coords(M)
    
    (product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

    x0 = M.coords([0.]*d)
    x0s = tile(x0, N_sim)
    
    _dts = dts(T=T, n_steps=n_steps)
    dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
    (ts,xss,chartss,*_) = product(x0s,
                                  _dts,dW,jnp.repeat(1.,N_sim))
    
    xs = xss[-1]
    chart = chartss[-1]
    
    path = ''.join((file_path, 'R', str(d), '/'))
    np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
    np.savetxt(''.join((path, 'chart.csv')), chart, delimiter=",")
print("Done")

using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
Done


## S1

### Generate Data

In [3]:
M = S1()
Brownian_coords(M)

x0 = M.coords([0.])
x0s = tile(x0, N_sim)

(product, sde_product, chart_update_product) = product_sde(M, 
                                                            M.sde_Brownian_coords, 
                                                            M.chart_update_Brownian_coords)

_dts = dts(T=0.1, n_steps=n_steps)
dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
(ts,xss,chartss,*_) = product(x0s,
                              _dts,dW,jnp.repeat(1.,N_sim))

xs = xss[-1]
chart = chartss[-1]
chart = vmap(lambda x,y: M.F((x,y)))(xs,chart)

path = ''.join((file_path, 'S1/'))
np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
np.savetxt(''.join((path, 'chart.csv')), chart, delimiter=",")
print("Done")

using M.Exp for Logarithm
(1000, 2)
Done


## mSphere

### Generate Data

In [11]:
dim = [2,3,5,10,20]
for d in dim:
    M = nSphere(N=d)
    Brownian_coords(M)

    (product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

    x0 = M.coords([0.]*d)
    x0s = tile(x0, N_sim)
    
    _dts = dts(T=T, n_steps=n_steps)
    dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
    (ts,xss,chartss,*_) = product(x0s,
                                  _dts,dW,jnp.repeat(1.,N_sim))
    
    xs = xss[-1]
    chart = chartss[-1]
    
    path = ''.join((file_path, 'S', str(d), '/'))
    np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
    np.savetxt(''.join((path, 'chart.csv')), chart, delimiter=",")
print("Done")

using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
Done


## Ellipsoid

### Generate Data

In [12]:
## mSphere

### Generate Data

dim = [2,3,5,10,20]
for d in dim:
    M = Ellipsoid(N=d, params = jnp.linspace(0.5,1.0,d+1))
    Brownian_coords(M)

    N_dim = M.emb_dim
    x0 = M.coords([0.]*d)

    (product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

    x0s = tile(x0, N_sim)
    
    _dts = dts(T=T, n_steps=n_steps)
    dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
    (ts,xss,chartss,*_) = product(x0s,
                                  _dts,dW,jnp.repeat(1.,N_sim))
    
    xs = xss[-1]
    chart = chartss[-1]
    
    path = ''.join((file_path, 'Ellipsoid', str(d), '/'))
    np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
    np.savetxt(''.join((path, 'chart.csv')), chart, delimiter=",")
print("Done")

using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
Done


## Cylinder

### Generate Data

In [13]:
## mSphere

### Generate Data

M = Cylinder(params=(1.,jnp.array([0.,0.,1.]),jnp.pi/2.))
Brownian_coords(M)

x0 = M.coords([0.]*2)

(product, sde_product, chart_update_product) = product_sde(M, 
                                                           M.sde_Brownian_coords, 
                                                           M.chart_update_Brownian_coords)

x0s = tile(x0, N_sim)

_dts = dts(T=T, n_steps=n_steps)
dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
(ts,xss,chartss,*_) = product(x0s,
                              _dts,dW,jnp.repeat(1.,N_sim))

xs = xss[-1]
chart = chartss[-1]
path = ''.join((file_path, 'Cylinder/'))
np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
np.savetxt(''.join((path, 'chart.csv')), chart, delimiter=",")
print("Done")

using M.Exp for Logarithm
Done


## Torus

### Generate Data

In [14]:
## mSphere

### Generate Data

M = Torus()
Brownian_coords(M)

x0 = M.coords([0.]*2)

(product, sde_product, chart_update_product) = product_sde(M, 
                                                           M.sde_Brownian_coords, 
                                                           M.chart_update_Brownian_coords)

x0s = tile(x0, N_sim)

_dts = dts(T=T, n_steps=n_steps)
dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
(ts,xss,chartss,*_) = product(x0s,
                              _dts,dW,jnp.repeat(1.,N_sim))

xs = xss[-1]
chart = chartss[-1]

path = ''.join((file_path, 'Torus/'))
np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
np.savetxt(''.join((path, 'chart.csv')), chart, delimiter=",")
print("Done")

using M.Exp for Logarithm
Done


## Landmarks

### Generate Data

In [15]:
## mSphere

### Generate Data

dim = [2,5,10,20]
for d in dim:
    M = Landmarks(N=d,m=2)
    Brownian_coords(M)

    N_dim = M.dim
    x0 = M.coords(jnp.vstack((jnp.linspace(-5.0,5.0,M.N),jnp.linspace(5.0,-5.0,M.N))).T.flatten())
    x0 = M.coords(jnp.vstack((jnp.linspace(-5.0,5.0,M.N),jnp.zeros(M.N))).T.flatten())

    (product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

    #x0 = M.coords(jnp.vstack((jnp.linspace(-10.0,10.0,M.N),jnp.linspace(10.0,-10.0,M.N))).T.flatten())
    x0 = M.coords(jnp.vstack((jnp.linspace(-5.0,5.0,M.N),jnp.zeros(M.N))).T.flatten())

    if M.N >=10:
        with open('../Data/landmarks/Papilonidae/Papilionidae_landmarks.txt', 'r') as the_file:
            all_data = [line.strip() for line in the_file.readlines()]

            x1 = jnp.array([float(x) for x in all_data[0].split()[2:]])
            x2 = jnp.array([float(x) for x in all_data[1].split()[2:]])

            x0 = M.coords(jnp.vstack((x1[::len(x1)//M.N],x2[::len(x2)//M.N])).T.flatten())
    
    x0s = tile(x0, N_sim)
    _dts = dts(T=T, n_steps=n_steps)
    dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
    (ts,xss,chartss,*_) = product(x0s,
                                  _dts,dW,jnp.repeat(1.,N_sim))
    
    xs = xss[-1]
    chart = chartss[-1]
    
    path = ''.join((file_path, 'Landmarks', str(d), '/'))
    np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
    np.savetxt(''.join((path, 'chart.csv')), chart, delimiter=",")
print("Done")

using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
Done


## SPDN

### Generate Data

In [3]:
## mSphere

### Generate Data
dim = [2, 5, 10]
for d in dim:
    M = SPDN(N=d)    
    Brownian_coords(M)

    x0 = M.coords([10.]*(d*(d+1)//2))

    (product, sde_product, chart_update_product) = product_sde(M, 
                                                               M.sde_Brownian_coords, 
                                                               M.chart_update_Brownian_coords)

    x0s = tile(x0, N_sim)
    
    _dts = dts(T=T, n_steps=n_steps)
    dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
    (ts,xss,chartss,*_) = product(x0s,
                                  _dts,dW,jnp.repeat(1.,N_sim))
    #(ts,xss,chartss) = M.product_GRW(x0s,_dts,dW)
    
    xs = xss[-1]
    chart = vmap(lambda x,chart: M.F((x,chart)))(xs,chartss[-1])
    
    path = ''.join((file_path, 'SPDN', str(d), '/'))
    np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
    np.savetxt(''.join((path, 'chart.csv')), chart, delimiter=",")
print("Done")

using M.Exp for Logarithm
using M.Exp for Logarithm
using M.Exp for Logarithm
Done


## Sym

### Generate Data

In [13]:
## mSphere

### Generate Data
dim = [2]
from jaxgeometry.stochastics import GRW
from jaxgeometry.integration import integrator_ito
for d in dim:
    M = Sym(N=d)    
    Brownian_coords(M)

    N_dim = M.emb_dim
    x0 = M.coords([10.]*(d*(d+1)//2))
    
    GRW(M, f_fun = lambda x,v: M.ExpEmbedded(x[0], v))
    (product,sde_product,chart_update_product) = product_sde(M, 
                                                             M.sde_grw, 
                                                             M.chart_update_grw,
                                                             lambda a,b: integrator_ito(a,b,lambda x,v: vmap(lambda x,y,v: M.ExpEmbedded(x,v))(x[0],x[1],v)))

    x0s = tile(x0, N_sim)
    
    _dts = dts(T=T, n_steps=n_steps)
    dW = dWs(N_sim*M.emb_dim,_dts).reshape(-1,N_sim,M.emb_dim)
    print(dW.shape)
    (ts,xss,chartss,*_) = product((x0s[1], x0s[0]),
                                  _dts,
                                  dW,
                                  jnp.repeat(1.,N_sim))
    #(ts,xss,chartss) = M.product_GRW(x0s,_dts,dW)
    
    xs = chartss[-1]
    chart = xss[-1]
    #vmap(lambda x,chart: M.F((x,chart)))(xs,chartss[-1])
    
    path = ''.join((file_path, 'Sym', str(d), '/'))
    np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
    np.savetxt(''.join((path, 'chart.csv')), chart, delimiter=",")
print("Done")

using M.Exp for Logarithm
(100, 1000, 4)
(4,)


ValueError: Incompatible shapes for broadcasting: (4,) and requested shape (3,)

In [7]:
chart[1].reshape(2,2)

Array([[ 9.946781  , -0.503417  ],
       [-0.42504025, 10.025158  ]], dtype=float32)

## Hyperbolic Paraboloid

### Generate Data

In [17]:
## mSphere

### Generate Data

M = HypParaboloid()
Brownian_coords(M)

x0 = M.coords([0.]*2)

(product, sde_product, chart_update_product) = product_sde(M, 
                                                           M.sde_Brownian_coords, 
                                                           M.chart_update_Brownian_coords)

x0s = tile(x0, N_sim)

_dts = dts(T=T, n_steps=n_steps)
dW = dWs(N_sim*M.dim,_dts).reshape(-1,N_sim,M.dim)
(ts,xss,chartss,*_) = product(x0s,
                              _dts,dW,jnp.repeat(1.,N_sim))

xs = xss[-1]
chart = chartss[-1]

path = ''.join((file_path, 'HypParaboloid/'))
np.savetxt(''.join((path, 'xs.csv')), xs, delimiter=",")
np.savetxt(''.join((path, 'chart.csv')), chart, delimiter=",")
print("Done")

using M.Exp for Logarithm
Done
