In [2]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy
from scipy.spatial import cKDTree
from scipy.spatial import Delaunay
from scipy.optimize import basinhopping
import copy
from sklearn.cluster import MiniBatchKMeans
from basic_functions import *
import time



In [3]:
#REAL trajectory
new_param = change_param(1000, "length")
ys_fixed = Euler_jnp(new_param)

subset_size = 20
#CELL CENTER
Xtrain_fixed = ys_fixed.transpose()
sample = MiniBatchKMeans(n_clusters=subset_size).fit(Xtrain_fixed).cluster_centers_

In [4]:


Dtri = Delaunay(sample)
vertices_cells = Dtri.simplices #indices of cell centers corresponding to barycentric vectices

#vertex coordinates of barycentric simplexes
vertices = sample[vertices_cells]

In [6]:
#input:simplex coordinates; output: barycentrix matrix
def precompute_barycentric_matrix(simplex):

    r1, r2, r3, r4 = simplex

    x1, y1, z1 = r1
    x2, y2, z2 = r2
    x3, y3, z3 = r3
    x4, y4, z4 = r4
    
    T = np.array(([1., 1., 1., 1.],
                    [x1, x2, x3, x4],
                    [y1, y2, y3, y4],
                    [z1, z2, z3, z4]))
    T_inverse = np.linalg.inv(T)

    return T_inverse

In [7]:

#which simplex a point in the real trajectory is in
simplex_idxs_fixed = np.zeros(new_param["length"])
for i in range(new_param["length"]):
    simplex_idxs_fixed[i] += Delaunay.find_simplex(Dtri, Xtrain_fixed[i])
    
simplex_idxs_fixed = jnp.array(simplex_idxs_fixed).astype(int)

In [8]:
#rho == 44.
new_param_44 = copy.deepcopy(new_param)
new_param_44["rho"] = 44.
#rho44 trajectory
ys_44 = Euler_jnp(new_param_44)
Xtrain_44 = ys_44.transpose()

#which simplex a point in the real trajectory is in
simplex_idxs_44 = np.zeros(new_param["length"])
for i in range(new_param["length"]):
    simplex_idxs_44[i] += Delaunay.find_simplex(Dtri, Xtrain_44[i])
    
simplex_idxs_44 = jnp.array(simplex_idxs_44).astype(int)

In [9]:
#rho == 41.
new_param_41 = copy.deepcopy(new_param)
new_param_41["rho"] = 41.
#rho41 trajectory
ys_41 = Euler_jnp(new_param_41)
Xtrain_41 = ys_41.transpose()

#which simplex a point in the real trajectory is in
simplex_idxs_41 = np.zeros(new_param["length"])
for i in range(new_param["length"]):
    simplex_idxs_41[i] += Delaunay.find_simplex(Dtri, Xtrain_41[i])
    
simplex_idxs_41 = jnp.array(simplex_idxs_41).astype(int)

In [10]:
#rho == 37.
new_param_37 = copy.deepcopy(new_param)
new_param_37["rho"] = 37.
#rho37 trajectory
ys_37 = Euler_jnp(new_param_37)
Xtrain_37 = ys_37.transpose()

#which simplex a point in the real trajectory is in
simplex_idxs_37 = np.zeros(new_param["length"])
for i in range(new_param["length"]):
    simplex_idxs_37[i] += Delaunay.find_simplex(Dtri, Xtrain_37[i])
    
simplex_idxs_37 = jnp.array(simplex_idxs_37).astype(int)

In [11]:

#input:trajectory
#output:jnp array of (weights, indices)
def car2bar_weight(ys, simplex_idxs = simplex_idxs_fixed):   
    weights_coords = []
    Xtrain = ys.transpose()
    for i in range(new_param["length"]):
        pt = Xtrain[i]
    #Cartesian coordinates
        x, y, z = pt
        vec = jnp.array((1,x,y,z))#vector for barycentric matrix computing
        simplex_index = simplex_idxs[i]
        
        #if not inside simplex
        if simplex_index == -1 :
            distance = jnp.linalg.norm(sample - pt, axis = 1)
            index = np.argmin(distance)
            weights_coords.append((jnp.array([1., 0., 0., 0.]), 
                                   jnp.array([index, 0, 0, 0])))
        
        #if inside or on simplex
        else:
            simplex_coords = vertices[simplex_index]
            mat = precompute_barycentric_matrix(simplex_coords)
            bar = jnp.dot(mat,vec)

            weights_coords.append((jnp.array((bar)), vertices_cells[simplex_index]))
    return jnp.array(weights_coords)

        


In [12]:
#generate idxs_tilda

def idxs_tilda(ys, simplex_idxs = simplex_idxs_fixed):
    weights_coords = car2bar_weight(ys, simplex_idxs)

    w = jnp.array([w for w, _ in weights_coords])
    ind = jnp.array([ind.astype(int) for _, ind in weights_coords])

    point_idxs_tilda = jnp.zeros((subset_size, new_param["length"]))
    for i in range(new_param["length"]):
        point_idxs_tilda = point_idxs_tilda.at[ind[i], i].add(w[i])
            
    return point_idxs_tilda
# def idxs_tilda(ys, simplex_idxs = simplex_idxs_fixed):
#     # start_time = time.time()

#     point_idxs_tilda = jnp.zeros((subset_size, new_param["length"]))
#     weights_coords = car2bar_weight(ys, simplex_idxs)
    
#     for i in range(new_param["length"]):
#         w, ind = weights_coords[i]
#         ind = ind.astype(int)
#         #for num in range(len(w)):
#         point_idxs_tilda = point_idxs_tilda.at[ind[0], i].add(w[0])
#         point_idxs_tilda = point_idxs_tilda.at[ind[1], i].add(w[1])
#         point_idxs_tilda = point_idxs_tilda.at[ind[2], i].add(w[2])
#         point_idxs_tilda = point_idxs_tilda.at[ind[3], i].add(w[3])
#         # if time.time() - start_time >= 10:
#         #             print(f"i = {i}")
#         #             start_time = time.time()
            
#     return point_idxs_tilda

In [15]:
#indices tilda of ground truth
point_idxs_tilda_fixed = idxs_tilda(ys_fixed)

In [16]:
#Markov matrix constructed by making the (i, j) entry the sum of the products of the kth entry in the ith row and the (k+1)th entry in the jth row of point_idxs_tilda
#Dot product of point_idxs_tilda minus the last row and point_idxs_tilda minus the first row should be the same computation
#markov matrix of ground truth
markov_matrix_fixed = np.dot(point_idxs_tilda_fixed[:, :-1], point_idxs_tilda_fixed[:, 1:].T)
row_sums_fixed = np.sum(point_idxs_tilda_fixed, axis = 1)
markov_matrix_fixed = markov_matrix_fixed / row_sums_fixed[:, None]

In [17]:
#construct markov matrix based on Delaunay
def modified_markov(ys, simplex_idxs = simplex_idxs_fixed):
    pt_idxs_tilda = idxs_tilda(ys, simplex_idxs)
    markov = jnp.dot(pt_idxs_tilda[:, :-1], pt_idxs_tilda[:, 1:].T)
    r_sums = jnp.sum(pt_idxs_tilda, axis = 1)
    r_sums[:,None].at[r_sums[:,None] == 0].add(1.)
    return markov / r_sums[:, None]

In [19]:
# #objective function with fixed simplices
# def param2dist_rho(param):
#     new_param = change_param(param, "rho")
#     ys = Euler_jnp(new_param)
#     return jnp.linalg.norm(modified_markov(ys) - markov_matrix_fixed)

In [20]:
#objective function for rho == 44
def param2dist_rho44(param):
    new_param = change_param(param, "rho")
    ys = Euler_jnp(new_param)
    return jnp.linalg.norm(modified_markov(ys, simplex_idxs_44) - markov_matrix_fixed)

In [21]:
#objective function for rho == 41
def param2dist_rho41(param):
    new_param = change_param(param, "rho")
    ys = Euler_jnp(new_param)
    return jnp.linalg.norm(modified_markov(ys, simplex_idxs_41) - markov_matrix_fixed)

In [22]:
#objective function for rho == 37
def param2dist_rho37(param):
    new_param = change_param(param, "rho")
    ys = Euler_jnp(new_param)
    return jnp.linalg.norm(modified_markov(ys, simplex_idxs_37) - markov_matrix_fixed)

In [23]:
gradient_fixed = jax.grad(param2dist_rho)

In [24]:
gradient_44 = jax.grad(param2dist_rho44)

In [25]:
gradient_41 = jax.grad(param2dist_rho41)

In [26]:
gradient_37 = jax.grad(param2dist_rho37)

In [None]:
gradient_44(44.)

In [24]:
gradient_41(41.)

Array(0.99724495, dtype=float32, weak_type=True)

In [39]:
gradient_37(37.)

Array(-0.01504948, dtype=float32, weak_type=True)

In [43]:
def finite_diff_approx(fun, new_params, delta_theta = 1e-5):
    theta_plus = new_params.copy()
    theta_plus += delta_theta

    theta_minus = new_params.copy()
    theta_minus -= delta_theta

    gradient = np.zeros(len(new_params))
    for i in range(len(new_params)):
        diff = fun(theta_plus[i]) - fun(theta_minus[i])
        gradient[i] = diff / (2 * delta_theta)

    return gradient

In [105]:
#finite_diff_approx(param2dist_rho, np.array([40.001]), delta_theta = 5e-6)

array([4.42429447])

In [56]:
finite_diff_approx(param2dist_rho44, np.array([44.]), delta_theta = 5e-5)

array([0.60558319])

In [54]:
finite_diff_approx(param2dist_rho41, np.array([41.]), delta_theta = 5e-5)

array([1.33514404])

In [57]:
finite_diff_approx(param2dist_rho37, np.array([37.]), delta_theta = 5e-5)

array([-0.00476837])