In [None]:
#Loading Modules
import numpy as np
from sklearn.gaussian_process.kernels import Matern, RBF
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
from scipy.linalg import expm, cholesky
from numpy.linalg import inv, det
import scipy
import ot
import itertools
import os
from scipy.stats import multivariate_normal
from untils import *
from alg_tools import *
from evaluation_metric import *
from scipy.linalg import block_diag
from scipy.special import logsumexp as lse

from matplotlib.collections import LineCollection
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D
from matplotlib.pyplot import cm
import timeit
np.random.seed(42)
G_data_list = ['Tris_data','Nbody_data','1dGP_data','2dGP_data']
exp_index = 1 #change this index to reproduce results on different datasets

# Parameter Setting (Tunable)

In [None]:

DFPL = {'Mi_1.5':[[32],[12],[200],[32]], #Default Mi list for each experiement when nu = 1.5
        'Mi_2.5':[[8,4],[4,3],[40,5],[8,4]], #Default Mi list for each experiement when nu = 2.5
        'c':[1,4,1,1], #Default c list for each experiement when nu = 2.5
       }
max_iter = 5 #number of iterations of message passing
t_end = 2 #end time
nu = 1.5 #order of matern kernel 1.5 or 2.5
l = 3 #length parameter for matern kernel default: 3 for nu=1.5, 2 for nu=2.5
samples_pre_compute = 20 #number of samples used for approximation of the precomputation tensor \Gamma

#-------------------------parameters--for--sampling--trajectories--------------------------------#
n_trajectories = 50 #number of trajectories to sample for each starting observation
load_newest = False #load the newest message or load the message with the smallest norm changes



# Load Dataset

In [None]:
# Select the dataset to reproduce results 0-3
G_data = G_data_list[exp_index]

GD = np.load('G_data/{}.npz'.format(G_data))['arr_0']

if len(GD.shape) == 2:
    GD = np.expand_dims(GD,axis=-1)
T,N,data_D = GD.shape[0]-1,GD.shape[1], GD.shape[2]
print('Timestep:{},Number of observations:{},dimension:{}'.format(T-1,N,data_D))
#Parameter Setting
n=N

lo = False
dt_list = [0]
lot = 3
if not lo:
    dt_list = np.linspace(0,t_end,T+1)
    D = GD+0
else:
    dt_list = np.concatenate([np.linspace(0,t_end,T+1)[0:lot-1],np.linspace(0,t_end,T+1)[lot:]],axis=0)
    D = np.concatenate([GD[0:lot-1],GD[lot:]],axis=0)
    T -= 1

color_ob = cm.hot(np.linspace(0,1,T+2))

if data_D >= 2:
    for i in range(T+1):
        plt.scatter(D[i,:,0],D[i,:,1],color=color_ob[i],label=str(i))
        plt.xlabel('Dim 1')
        plt.ylabel('Dim 2')
else:
    for i in range(N):
        plt.plot([t for t in range(T+1)],D[:,i,0],label=str(i))
        plt.scatter([t for t in range(T+1)],D[:,i,0],label=str(i))
        plt.xlabel('t')
        plt.ylabel('Position X')
print('Data loaded successfully! Visualization of ground truth:')



# Precomputation Stage (no need to tune)

In [None]:
#Setting parameters
save_name = G_data[:-5]
c = DFPL['c'][exp_index]
save_dir = 'Experiments/{}_c={}_nu={}'.format(save_name,c,nu)
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
if nu == 1.5:
    Mi_list = DFPL['Mi_1.5'][exp_index]
    total_indices = [Mi_list for dim in range(data_D)]
    d = 2
    lam = np.sqrt(2*nu)/l
    A = np.zeros((d,d))
    for i in range(d-1):
        A[i,i+1] = 1
    A[d-1,0] = -lam**2
    A[d-1,1] = -2*lam
    L = np.zeros((d,1))
    q_list =[]
    for dim in range(data_D):
        
        sigma = np.sqrt(np.var(D[:,:,dim]))*c
        q_list.append(2*(sigma**2)*np.sqrt(np.pi)*scipy.special.gamma(nu+0.5)*lam**(2*nu)/scipy.special.gamma(nu))
    L[-1,0] = 1
    mu_0_list = []
    sigma_0_list = []
    for dim in range(data_D):
        mu_0 = np.zeros(d)
        sigma_0  = solve_stationary(A,q_list[dim]).reshape((d,d))
        mu_0_list.append(mu_0)
        sigma_0_list.append(sigma_0)
elif nu == 2.5:
    Mi_list = DFPL['Mi_2.5'][exp_index]
    total_indices = [Mi_list for dim in range(data_D)]
    d = 3
    lam = np.sqrt(2*nu)/l
    A = np.zeros((d,d))
    for i in range(d-1):
        A[i,i+1] = 1
    A[d-1,0] = -lam**3
    A[d-1,1] = -3*lam**2
    A[d-1,2] = -3*lam
    
    L = np.zeros((d,1))
    q_list =[]
    for dim in range(data_D):
        
        sigma = np.sqrt(np.var(D[:,:,dim]))*c
        q_list.append(2*(sigma**2)*np.sqrt(np.pi)*scipy.special.gamma(nu+0.5)*lam**(2*nu)/scipy.special.gamma(nu))
    L[-1,0] = 1
    mu_0_list = []
    sigma_0_list = []
    for dim in range(data_D):
        mu_0 = np.zeros(d)
        sigma_0  = solve_stationary(A,q_list[dim]).reshape((d,d))
        mu_0_list.append(mu_0)
        sigma_0_list.append(sigma_0)
else:
    raise Exception("Set nu to be 1.5 or 2.5!")

In [None]:
start = timeit.default_timer()
cond_mean_list = []
cond_cov_list = [[] for dim in range(data_D)]
cov_list = [[] for dim in range(data_D)]
Kp = np.prod(total_indices[0])
cond_mean_margin = np.zeros((data_D,T,N,d-1))
cond_cov_margin_list = [[] for dim in range(data_D)]
cond_mean_margin = np.zeros((data_D,T,N,d-1))
for dim in range(data_D):
    COV_X_list = []
    COV_XY_list = []
    for i in range(T):
        Cov = compute_cov_explicit(A, L, q_list[dim], dt_list[i], dt_list[i+1],sigma_0_list[dim])
        cov_list[dim].append(Cov)
        s_Cov = Cov[index_perm(d), :][:, index_perm(d)]
        COV_X = s_Cov[0:2, 0:2]
        COV_Y = s_Cov[2:2*d, 2:2*d]
        COV_XY = s_Cov[0:2, 2:2*d]
        cond_COV = COV_Y - COV_XY.T @ np.linalg.inv(COV_X) @ COV_XY
        cond_cov_list[dim].append(cond_COV)
        COV_X_list.append(COV_X+0)
        COV_XY_list.append(COV_XY+0)
        cond_cov_margin = Cov[1:d,1:d] - Cov[1:d,0:1] @ Cov[0:1,1:d] /Cov[0,0]
        cond_cov_margin_list[dim].append(cond_cov_margin)
        for j in range(N):
            x_t = D[i,j,dim]
            cond_mean_margin[dim,i,j,:] = -Cov[1:d,0]/Cov[0,0]*D[i,j,dim]
    cond_Means = compute_conditional_mean(np.stack(COV_X_list,axis=0), np.stack(COV_XY_list,axis=0), D[:,:,dim], d)
    cond_mean_list.append(cond_Means)

print('Precomputation for Gamma tensor start!')
phi_phi_pdf_list = []
Kp = [np.prod(total_indices[dim]) for dim in range(data_D)]
for dim in range(data_D):
    print('Computing for dimension {}'.format(dim))
    phi_save_dir = save_dir+'/joint_phi_phi_{}'.format(dim)
    phi_phi_pdf = wave_pdf(T,N,total_indices[dim],d,cov_list[dim],cond_cov_list[dim],cond_mean_list[dim],cond_cov_margin_list[dim],cond_mean_margin[dim],phi_save_dir,samples_pre_compute)
    phi_phi_pdf_list.append(phi_phi_pdf)


final_phi_phi_list = []
for dim in range(data_D):
    print('Saving precomputation tensor for dimension {} to {}'.format(dim,save_dir+'/final_phi_{}.npz'))
    Cond_xx = condition_xx(D[:,:,dim],cov_list[dim],d)
    final_phi_phi_list.append(phi_phi_pdf_list[dim] + Cond_xx.reshape((T,N,N,1,1)))
    np.savez(save_dir+'/final_phi_{}.npz'.format(dim),final_phi_phi_list[dim])

print('Precomputation done! Time spent: {} seconds'.format(timeit.default_timer()-start))

# Message Passing Algorithm

In [None]:
#Initialization
phi_phi_list = []
print('Loading precompuated tensor')
start = timeit.default_timer()
for i in range(data_D):
    phi_phi_list.append(np.load((save_dir+'/final_phi_{}.npz'.format(i)))['arr_0'])
print('Loading done!')
save_freq = 1
error_thre = 1e-10
errors = message_passing(phi_phi_list,data_D,T,N,total_indices,max_iter,save_dir,save_freq,error_thre)
print('Message passing done! Time spent: {} seconds'.format(timeit.default_timer()-start))

plt.plot([i+1 for i in range(max_iter)],errors)
plt.title('Norm changes in messages against iterations')
plt.yscale('log')

# Generating Trajectories
The following only matches observations, matched observations are connected by straight lines.

In [None]:
def generate_trajectories_sample(C_z_right,C_z_left,phi_phi_list,data_D,T,N,Kp_shape,sample,save_traj_dir):
    n = N 
    t=0
    C_z_right_t = C_z_right[t].reshape(N,-1)
    C_z_left_t = C_z_left[t].reshape(N,-1)
    C_z_right_left_t = C_z_right_t.reshape((N,1,-1,1)) + C_z_left_t.reshape((1,N,1,-1))
    Kpp = Kp_shape[0]
    Kp = np.prod(Kp_shape)
    if data_D > 1:
        for dim in range(data_D-1):
            Kpp *= Kp_shape[dim+1]
            if dim == 0:
                phi_phi_expand_t = (np.expand_dims(phi_phi_list[dim][t],axis=(-3,-1))+np.expand_dims(phi_phi_list[dim+1][t],axis=(-4,-2))).reshape(N,N,Kpp,Kpp)
            else:
                phi_phi_expand_t = (np.expand_dims(phi_phi_expand_t,axis=(-3,-1))+np.expand_dims(phi_phi_list[dim+1][t],axis=(-4,-2))).reshape(N,N,Kpp,Kpp)
        joint_d_t = phi_phi_expand_t+C_z_right_left_t
    else:
        joint_d_t = phi_phi_list[0][t]+C_z_right_left_t
    total_yt_index_recorder = np.zeros((sample*n,T+1)).astype('int')
    total_trajectory = np.zeros((sample*n,T+1)).astype('int')
    for tt in range(sample*n):
        total_trajectory[tt,0] = tt//sample
    xt_index = total_trajectory[:,0].astype('int')
    print('Generating velocity for initial observations')
    yt_index = sample_yt1_ini_sample_log_md_sample(joint_d_t,n,sample).astype('int')
    for t in range(T):
        print('Matching observations for time step {}'.format(t))
        phi_phi_expand_t = 0
        joint_d = 0
        C_z_right_left_t = 0
        C_z_right_t = C_z_right[t].reshape(N,-1)
        C_z_left_t = C_z_left[t].reshape(N,-1)
        C_z_right_left_t = C_z_right_t.reshape((N,1,-1,1)) + C_z_left_t.reshape((1,N,1,-1))
        Kpp = Kp_shape[0]
        if data_D > 1:
            for dim in range(data_D-1):
                Kpp *= Kp_shape[dim+1]
                if dim == 0:
                    phi_phi_expand_t = (np.expand_dims(phi_phi_list[dim][t],axis=(-3,-1))+np.expand_dims(phi_phi_list[dim+1][t],axis=(-4,-2))).reshape(N,N,Kpp,Kpp)
                else:
                    phi_phi_expand_t = (np.expand_dims(phi_phi_expand_t,axis=(-3,-1))+np.expand_dims(phi_phi_list[dim+1][t],axis=(-4,-2))).reshape(N,N,Kpp,Kpp)
            joint_d_t = phi_phi_expand_t+C_z_right_left_t
        else:
            joint_d_t = phi_phi_list[0][t]+C_z_right_left_t
        m_t1 = lse(joint_d_t,axis=-1) + 0
        yt1_index = np.zeros(sample*n).astype('int')
        xt1_index = np.zeros(sample*n).astype('int')
        for i in range(sample*n):
            B = m_t1[xt_index[i],:,yt_index[i]] + 0
            B -= lse(B)
            xt1_index[i] = np.random.choice([i for i in range(N)], p = np.exp(B))
            yt1_p = joint_d_t[xt_index[i],xt1_index[i],yt_index[i],:] + 0
            yt1_p -= lse(yt1_p)
            yt1_index[i] = np.random.choice([i for i in range(Kp)], p = np.exp(yt1_p))
        total_yt_index_recorder[:,t+1] = yt1_index + 0
        total_trajectory[:,t+1] = xt1_index + 0
        yt_index = (yt1_index+0).astype('int')
        xt_index = (xt1_index+0).astype('int')
    print('Saving results to {}'.format(save_traj_dir))
    np.savez(save_traj_dir+'/trajectories',np.array(total_trajectory))
    np.savez(save_traj_dir+'/y_trajectories',np.array(total_yt_index_recorder))
    return np.array(total_trajectory),np.array(total_yt_index_recorder)

In [None]:
#Eval Mode
save_traj_dir = 'Results/{}_c={}_nu={}'.format(save_name,c,nu)
if not os.path.exists(save_traj_dir):
    os.mkdir(save_traj_dir)
phi_phi_list =[]
for i in range(data_D):
    phi_phi_list.append(np.load(save_dir+'/final_phi_{}.npz'.format(i))['arr_0'])

if not load_newest:
    mess_n = '/message_sample_best.npz'
else:
    mess_n = '/message_sample_new.npz'

C_z_left,C_z_right = np.load(save_dir+mess_n)['arr_2'],np.load(save_dir+mess_n)['arr_3']
print('Message loaded successfully!')
n = N
total_yt_index_recorder = []
Kp_shape = [np.prod(i) for i in total_indices]

print('Start to generate matching of observations')

x_traj,y_traj = generate_trajectories_sample(C_z_right,C_z_left,phi_phi_list,data_D,T,N,Kp_shape,n_trajectories,save_traj_dir)

color = cm.rainbow(np.linspace(0, 1, N))
print('Visualizing Results:')
fig,ax = plt.subplots(nrows=1,ncols=2,sharey=True,figsize=(12,5))
if data_D > 1:
    for i in range(n_trajectories*n):
        s_trajectory = x_traj[i,:]
        ax[1].plot([D[j,:,0][s_trajectory[j]] for j in range(T+1)],[D[j,:,1][s_trajectory[j]] for j in range(T+1)],linewidth=0.05,c = color[i//n_trajectories])
    ax[1].set_title('sampled matching')
    
    for i in range(n):
        ax[0].plot([D[j,i,0] for j in range(T+1)],[D[j,i,1] for j in range(T+1)],linewidth=2,c = color[i])
    ax[0].set_title('ground truth') 
else:
    for i in range(N):
        ax[0].plot([t for t in range(T+1)],D[:,i,0],label=str(i))
        ax[0].scatter([t for t in range(T+1)],D[:,i,0],label=str(i))
        ax[0].set_xlabel('t')
        ax[0].set_ylabel('Position X')
    ax[0].set_title('ground truth') 
    for i in range(n_trajectories*n):
        s_trajectory = x_traj[i,:]
        ax[1].plot([i for i in range(T+1)],[D[j,:,0][s_trajectory[j]] for j in range(T+1)],linewidth=0.05,c = color[i//n_trajectories])
    ax[1].set_title('sampled matching')

# Quantitative Evaluation

In [None]:
D_sample = np.empty((x_traj.shape[0]//N, T+1,N))
for i in range(x_traj.shape[0]//N):
    D_sample[i] = x_traj[np.arange(N) * x_traj.shape[0]//N + i].T
D_sample = D_sample.astype(int)
eva_results = obo_evaluation(D, D_sample)
print('Showing Quantitative Results:')
for i in eva_results.keys():
    print(i+': '+("%.4f" % float(eva_results[i])))