In [27]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy
from scipy.spatial import cKDTree
from scipy.spatial import Delaunay
from scipy.optimize import basinhopping
import copy
from sklearn.cluster import MiniBatchKMeans
from basic_functions import *
import time

In [28]:
#REAL trajectory
new_param = change_param(10000, "length")
ys_fixed = Euler_jnp(new_param)

subset_size = 500
#CELL CENTER
Xtrain_fixed = ys_fixed.transpose()
sample = MiniBatchKMeans(n_clusters=subset_size).fit(Xtrain_fixed).cluster_centers_



In [29]:
#point indices with respect to voronoi cells

def idxs(ys):
    point_idxs = []
    Xtrain = ys.transpose()#coordinates of points
    
    for i in range(new_param["length"]):
        
        distances = np.linalg.norm(sample - Xtrain[i], axis = 1)
        #Euclidean distance from the ith point to each cell
        
        idxs = np.argmin(distances)
        point_idxs.append(idxs)
    
    point_idxs = jnp.array(point_idxs)
    
    return point_idxs

In [39]:
#cell center of barycentric simplexes
vertices_cells = Delaunay(sample).simplices

#vertex coordinates of barycentric simplexes
vertices = sample[vertices_cells]

In [37]:
def precompute_barycentric_matrices(vertices):
    matrices = []
    for simplex in vertices:
        #ri: vertices of a Barycentric simplex
        r1, r2, r3, r4 = simplex

        x1, y1, z1 = r1
        x2, y2, z2 = r2
        x3, y3, z3 = r3
        x4, y4, z4 = r4
    
        T = np.array(([1., 1., 1., 1.],
                      [x1, x2, x3, x4],
                      [y1, y2, y3, y4],
                      [z1, z2, z3, z4]))
        T_inverse = np.linalg.inv(T)
        matrices.append(T_inverse)
    return matrices

In [43]:

#I use numpy here for now so this runs fast. But we probably need to jax when taking gradient.
#input:point coordinates
#output:(weights, indices)
def car2bar_weight(pt, barycentric_matrices):
                                 
    #Cartesian coordinates
    x, y, z = pt
    vec = np.array((1,x,y,z))#vector for barycentric matrix computing
    for i, mat in enumerate(barycentric_matrices):
        #ri: vertices of a Barycentric simplex
        bar = mat.dot(vec)

        if np.all(bar>= 0.):#to make sure the point is inside or on the simplex
            return bar, vertices_cells[i]
        
    distance = np.linalg.norm(sample - pt, axis = 1)
    index = np.argmin(distance)
    return np.array([1.]), np.array([index])
        


In [46]:
#generate idxs_tilda
def idxs_tilda(ys, barycentric_matrices):
    start_time = time.time()

    point_idxs_tilda = np.zeros((subset_size, new_param["length"]))
    
    Xtrain = ys.transpose()
    
    for i in range(new_param["length"]):
        w, ind = car2bar_weight(Xtrain[i], barycentric_matrices)
        for num in range(len(w)):
            point_idxs_tilda[ind[num], i] += w[num]
        if time.time() - start_time >= 10:
                    print(f"i = {i}")
                    start_time = time.time()
            
    return point_idxs_tilda

In [48]:
barycentric_matrices = precompute_barycentric_matrices(vertices)
point_idxs_tilda = idxs_tilda(ys_fixed, barycentric_matrices)

i = 1567
i = 3100
i = 4774
i = 6373
i = 8008
i = 9622


In [49]:
#Markov matrix constructed by making the (i, j) entry the sum of the products of the kth entry in the ith row and the (k+1)th entry in the jth row of point_idxs_tilda
#Dot product of point_idxs_tilda minus the last row and point_idxs_tilda minus the first row should be the same computation
markov_matrix = np.dot(point_idxs_tilda[:, :-1], point_idxs_tilda[:, 1:].T)
row_sums = np.sum(point_idxs_tilda, axis = 1)
markov_matrix = markov_matrix / row_sums[:, None]

  markov_matrix = markov_matrix / row_sums[:, None]
