In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy
from scipy.spatial import cKDTree
from scipy.spatial import Delaunay
from scipy.optimize import basinhopping
import copy
from sklearn.cluster import MiniBatchKMeans

In [2]:
#real parameters
G_fixed = {
    "thresh": 0.00002,
    "dt": 0.005, #sampling frequency
    "sigma": 10,
    "beta": 1.5,
    "rho": 40,
    "alpha": 1e-4,
    "length": #int(1e06 + 1)
              int(1e05)
              #int(1e04)
    # Trajectory Length
    
}
G = copy.deepcopy(G_fixed)
#new params
def change_param(new, param = "rho"):
    G[param] = new
    return G

# Define the Lorenz-63 System
@jax.jit
def Lorenz(y, G):
    dx1 = G["sigma"] * (y[1] - y[0])
    dx2 = y[0] * (G["rho"] - y[2]) - y[1]
    dx3 = y[0] * y[1] - G["beta"] * y[2]
    f = jnp.array((dx1, dx2, dx3)).flatten()#.reshape(3, 1)
    return f

# Set initial values and parameters
start = jnp.array((1., 1., 1.))#.reshape(3, 1)
nt = 1 #transition step size
subset_size = 500 #cell number


def Euler(G):
    y = start
    #ys = jnp.zeros((G["length"],3))
    ys = []
    for timestep in range(G["length"]): 
        
        y = y + G["dt"] * Lorenz(y, G)

        ys.append(y)

    
    return jnp.array(ys).transpose()



In [3]:
#REAL trajectory
ys_fixed = Euler(G_fixed)



In [4]:
#CELL CENTER
Xtrain_fixed = ys_fixed.transpose()
sample = MiniBatchKMeans(n_clusters=subset_size).fit(Xtrain_fixed).cluster_centers_

In [5]:
#point indices with respect to voronoi cells

def idxs(ys):
    point_idxs = []
    Xtrain = ys.transpose()#coordinates of points
    
    for i in range(G["length"]):
        
        distances = np.linalg.norm(sample - Xtrain[i], axis = 1)
        #Euclidean distance from the ith point to each cell
        
        idxs = np.argmin(distances)
        point_idxs.append(idxs)
    
    point_idxs = jnp.array(point_idxs)
    
    return point_idxs

In [6]:
#cell center of barycentric simplexes
vertices_cells = Delaunay(sample).simplices

In [7]:
#vertex coordinates of barycentric simplexes
vertices = sample[vertices_cells]

In [8]:

#I use numpy here for now so this runs fast. But we probably need to jax when taking gradient.
#input:point coordinates
#output:(weights, indices)
def car2bar_weight(pt):
                                 
    #Cartesian coordinates
    x = pt[0]
    y = pt[1]
    z = pt[2]
    vec = np.array((1,x,y,z))#vector for barycentric matrix computing
    real_bar = 0
    for i in range(len(vertices)):
        v = vertices[i]
        
        #ri: vertices of a Barycentric simplex
        r1 = v[0]
        r2 = v[1]
        r3 = v[2]
        r4 = v[3]

        x1 = r1[0]
        y1 = r1[1]
        z1 = r1[2]
        x2 = r2[0]
        y2 = r2[1]
        z2 = r2[2]
        x3 = r3[0]
        y3 = r3[1]
        z3 = r3[2]
        x4 = r4[0]
        y4 = r4[1]
        z4 = r4[2]
    
        T = np.array(([1., 1., 1., 1.],
                      [x1, x2, x3, x4],
                      [y1, y2, y3, y4],
                      [z1, z2, z3, z4]))
        T_inverse = np.linalg.inv(T)
        bar = T_inverse.dot(vec)

        if np.min(bar)>= 0.:#to make sure the point is inside or on the simplex
            real_bar = bar

            return real_bar, vertices_cells[i]
        
    distance = np.linalg.norm(sample - pt, axis = 1)
    index = np.argmin(distance)
    return np.array([1.]), np.array([index])
        


In [9]:
#generate idxs_tilda
def idxs_tilda(ys):
    point_idxs_tilda = np.zeros((subset_size, G["length"]))
    
    Xtrain = ys.transpose()
    
    for i in range(G["length"]):
        w = car2bar_weight(Xtrain[i])[0]
        ind = car2bar_weight(Xtrain[i])[1]
        for num in range(len(w)):
            point_idxs_tilda[ind[num], i] += w[num]
            
    return point_idxs_tilda