In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy
from scipy.spatial import cKDTree
from scipy.spatial import Delaunay
from scipy.optimize import basinhopping
import copy
from sklearn.cluster import MiniBatchKMeans
from basic_functions import *
import time

In [2]:
#REAL trajectory
new_param = change_param(10000, "length")
ys_fixed = Euler_jnp(new_param)

subset_size = 500
#CELL CENTER
Xtrain_fixed = ys_fixed.transpose()
sample = MiniBatchKMeans(n_clusters=subset_size).fit(Xtrain_fixed).cluster_centers_

  super()._check_params_vs_input(X, default_n_init=3)


In [3]:
#point indices with respect to voronoi cells

def idxs(ys):
    point_idxs = []
    Xtrain = ys.transpose()#coordinates of points
    
    for i in range(new_param["length"]):
        
        distances = np.linalg.norm(sample - Xtrain[i], axis = 1)
        #Euclidean distance from the ith point to each cell
        
        idxs = np.argmin(distances)
        point_idxs.append(idxs)
    
    point_idxs = jnp.array(point_idxs)
    
    return point_idxs

In [4]:
#cell center of barycentric simplexes

Dtri = Delaunay(sample)
vertices_cells = Dtri.simplices

#vertex coordinates of barycentric simplexes
vertices = sample[vertices_cells]

In [6]:
#input:simplex coordinates; output: barycentrix matrix
def precompute_barycentric_matrix(simplex):

    r1, r2, r3, r4 = simplex

    x1, y1, z1 = r1
    x2, y2, z2 = r2
    x3, y3, z3 = r3
    x4, y4, z4 = r4
    
    T = np.array(([1., 1., 1., 1.],
                    [x1, x2, x3, x4],
                    [y1, y2, y3, y4],
                    [z1, z2, z3, z4]))
    T_inverse = np.linalg.inv(T)

    return T_inverse

In [7]:
#which simplex a point in the real trajectory is in
simplex_idxs = np.zeros(new_param["length"])
for i in range(new_param["length"]):
    simplex_idxs[i] += Delaunay.find_simplex(Dtri, Xtrain_fixed[i])
    
simplex_idxs = jnp.array(simplex_idxs).astype(int)

In [8]:

#I use numpy here for now so this runs fast. But we probably need to jax when taking gradient.
#input:trajectory
#output:jnp array of (weights, indices)
def car2bar_weight(ys):   
    weights_coords = []
    Xtrain = ys.transpose()
    for i in range(new_param["length"]):
        pt = Xtrain[i]
    #Cartesian coordinates
        x, y, z = pt
        vec = jnp.array((1,x,y,z))#vector for barycentric matrix computing
        simplex_index = simplex_idxs[i]
        
        #if not inside simplex
        if simplex_index == -1 :
            distance = jnp.linalg.norm(sample - pt, axis = 1)
            index = np.argmin(distance)
            weights_coords.append((jnp.array([1., 0., 0., 0.]), 
                                   jnp.array([index, 0, 0, 0])))
        
        #if inside or on simplex
        else:
            simplex_coords = vertices[simplex_index]
            mat = precompute_barycentric_matrix(simplex_coords)
            bar = jnp.dot(mat,vec)

            weights_coords.append((jnp.array((bar)), vertices_cells[simplex_index]))
    return jnp.array(weights_coords)

        


In [9]:
#generate idxs_tilda
def idxs_tilda(ys):
    # start_time = time.time()

    point_idxs_tilda = jnp.zeros((subset_size, new_param["length"]))
    weights_coords = car2bar_weight(ys)
    
    for i in range(new_param["length"]):
        w, ind = weights_coords[i]
        ind = ind.astype(int)
        #for num in range(len(w)):
        point_idxs_tilda = point_idxs_tilda.at[ind[0], i].add(w[0])
        point_idxs_tilda = point_idxs_tilda.at[ind[1], i].add(w[1])
        point_idxs_tilda = point_idxs_tilda.at[ind[2], i].add(w[2])
        point_idxs_tilda = point_idxs_tilda.at[ind[3], i].add(w[3])
        # if time.time() - start_time >= 10:
        #             print(f"i = {i}")
        #             start_time = time.time()
            
    return point_idxs_tilda

In [10]:
#barycentric_matrices = precompute_barycentric_matrices(vertices)
point_idxs_tilda = idxs_tilda(ys_fixed)

In [11]:
#Markov matrix constructed by making the (i, j) entry the sum of the products of the kth entry in the ith row and the (k+1)th entry in the jth row of point_idxs_tilda
#Dot product of point_idxs_tilda minus the last row and point_idxs_tilda minus the first row should be the same computation
markov_matrix = np.dot(point_idxs_tilda[:, :-1], point_idxs_tilda[:, 1:].T)
row_sums = np.sum(point_idxs_tilda, axis = 1)
markov_matrix = markov_matrix / row_sums[:, None]

In [12]:
def modified_markov(ys):
    pt_idxs_tilda = idxs_tilda(ys)
    markov = jnp.dot(pt_idxs_tilda[:, :-1], pt_idxs_tilda[:, 1:].T)
    r_sums = jnp.sum(pt_idxs_tilda, axis = 1)
    r_sums[:,None].at[r_sums[:,None] == 0].add(1.)
    return markov / r_sums[:, None]

In [13]:
def param2dist_rho(param):
    new_param = change_param(param, "rho")
    ys = Euler_jnp(new_param)
    return jnp.linalg.norm(modified_markov(ys) - markov_matrix)

In [14]:
# markov_test = jnp.dot(pt_idxs_tilda_test[:, :-1], pt_idxs_tilda_test[:, 1:].T)

In [15]:
# ys_37 = Euler_jnp(change_param(37., "rho"))

In [16]:
# pt_idxs_tilda_test = idxs_tilda(ys_37, barycentric_matrices)

In [17]:
# r_sums_test = jnp.sum(pt_idxs_tilda_test, axis = 1)

In [18]:
param2dist_rho(37.)

Array(551876.94, dtype=float32)

In [19]:
gradient = jax.grad(param2dist_rho)

In [20]:
gradient(37.)


Array(-3.348636e+27, dtype=float32, weak_type=True)

In [21]:
gradient(41.)

Array(5.4829443e+27, dtype=float32, weak_type=True)

In [22]:
gradient(36.)

Array(4.6359838e+24, dtype=float32, weak_type=True)

In [24]:
gradient(38.)

Array(-1.2811433e+23, dtype=float32, weak_type=True)

In [25]:
gradient(39.)

Array(-2.8085505e+26, dtype=float32, weak_type=True)