Randomised Riemannian Hamiltonian Monte Carlo for Bayesian Estimation as a Constrained Distribution Problem


1) Dynamics

In [1]:
import numba
import numpy as np
from numpy.linalg import inv
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import time
import random

# Data Extraction

In [2]:
import os
from os import sys

In [3]:
data_matrix = np.zeros((120,2000))


for k in range(1,2001):
    file_location = './cov_shrink_simulations/'
    if k < 10:
        #print(k)
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_01_01_000' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp1             =  np.append(xi_p,xi_m, axis = 0)
        
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_01_02_000' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp2             =  np.append(xi_p,xi_m, axis = 0)
        
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_02_02_000' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp3             =  np.append(xi_p,xi_m, axis = 0)   
    
    if 10 <= k < 100:
        #print(k)
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_01_01_00' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp1             =  np.append(xi_p,xi_m, axis = 0)
        
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_01_02_00' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp2              =  np.append(xi_p,xi_m, axis = 0)
        
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_02_02_00' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp3             =  np.append(xi_p,xi_m, axis = 0)   
        
    if 100 <= k < 1000:
        #print(k)
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_01_01_0' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp1             =  np.append(xi_p,xi_m, axis = 0)
        
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_01_02_0' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp2             =  np.append(xi_p,xi_m, axis = 0)
        
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_02_02_0' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp3             =  np.append(xi_p,xi_m, axis = 0)  
        
    if 1000 <= k:
        #print(k)
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_01_01_' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp1             =  np.append(xi_p,xi_m, axis = 0)
        
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_01_02_' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp2             =  np.append(xi_p,xi_m, axis = 0)
        
        file_name         =  'xi_athena_final_shrinkage_lognormal_2zbin_shapenoise_02_02_' + str(k) + '.dat'
        file              =  os.path.join(file_location,file_name)
        dat_file          =  np.genfromtxt(fname=file,  skip_header =1, dtype='unicode', invalid_raise = False)
        xi_p              =  dat_file[:,1]
        xi_m              =  dat_file[:,2]
        xi_p              =  np.reshape(xi_p,(len(xi_p),1))
        xi_m              =  np.reshape(xi_m,(len(xi_m),1))
        temp3             =  np.append(xi_p,xi_m, axis = 0)  
        
    temp12  = np.append(temp1, temp2, axis = 0)
    temp123 = np.append(temp12,temp3, axis = 0)
    #print(np.shape(temp123))
    data_matrix[:,k-1]  = np.reshape(temp123,(len(temp123),))
        


In [4]:
# Astronomical Data
#n = 2000
N = 120
N = 60

p =  N//6
d =  N
m =  p

In [5]:
n = N*2//3
Data = data_matrix[:N,:2000]

mean = np.mean(Data[:,:n],axis = 1)
var = np.var(Data[:,:n],axis = 1)

    
    
for i in range(2000):
    Data[:,i] -= mean
    Data[:,i] /= np.sqrt(var)

mean = np.mean(Data[:,:n],axis = 1)

# Preconditioning

In [6]:
# sample variance
mean2 = np.zeros(d)
for i in range(n):
    mean2 += Data[:,i]/n
    
S_small = np.zeros((d,d))
for i in range(n):
    x = Data[:,i] - mean2
    S_small += np.outer(x,x)/n



In [7]:
D_2 = np.diag(np.diag(S_small))
Pmat = (S_small - D_2)


EIG = np.linalg.eig(Pmat)
A  = np.zeros((d,p))
for i in range(p):
    A[:,i] = EIG[1][:,i]



D_1 = np.diag(EIG[0][:p])

Approx = np.matmul(A,np.matmul(D_1,np.transpose(A)))
D_2 = np.diag(np.diag(S_small)) - np.diag(np.diag(Approx))
Approx += D_2

In [8]:
#p<=d



p =  N//6
d =  N
m =  p


dimension = int(d*p+p+d)
num_of_constraints  = int(p*(p+1)/2)

σ_1 = 2
σ_2 = 2
#Matrix in distribution

@numba.jit(nopython=True)
def vec_to_matrix(q):
    X = np.zeros((d,p))
    for i in range(d):
        for j in range(p):
            X[i,j] = q[j*d+i]
    return X

@numba.jit(nopython=True)
def matrix_to_vec(X):
    #initialising filler array
    x = np.zeros(d*p)
    
    for i in range(d*p):
        i_index = i%d
        j_index =  int((i - i_index)/d)
        x[i] = X[i_index,j_index]
    return x

@numba.jit(nopython=True)
def dot_product(v1,v2):
    dot = 0
    for i in range(len(v1)):
        
        dot += v1[i]*v2[i]
        
    return dot

@numba.jit(nopython=True)
def matmul(matrix1,matrix2):
    a = matrix1.shape[0]
    b = matrix2.shape[1]
    c = matrix2.shape[0]
    rmatrix = np.zeros((a,b))
    for i in range(a):
        for j in range(b):
            for k in range(c):
                rmatrix[i,j] += matrix1[i,k] * matrix2[k,j]
    return rmatrix

@numba.jit(nopython=True)
def matrix_vec_multiplication(A,x):
    v = np.zeros(len(A))
    
    for i in range(len(A)):
        for j in range(len(x)):
                v[i] += A[i][j] * x[j]
    return v


@numba.jit(nopython=True)
def g_ij(q,i,j):
    
    q_mat = q[:d*p]
    
    X = vec_to_matrix(q_mat)
    
    if i==j:
        y = np.linalg.norm(X[:,i])**2 - 1
    else:
        y = dot_product(X[:,i],X[:,j])
        
    return y


@numba.jit(nopython=True)
def G(q): #considering i<j.
    
    q_mat = q[:d*p]
    
    X = vec_to_matrix(q_mat)
    
    z = np.zeros((dimension,num_of_constraints))
    
    for i in range(p): #block diagonals
        z[d*i:d*(i+1),int(p*i-0.5*i*(i-1)):int(p*(i+1) - 0.5*i*(i+1))] = X[:,i:]
    
        #vector diagonals
        for j in range(p-i):
            z[(j+i)*d:(j+i+1)*d,int(p*i-0.5*i*(i-1) + j)] += X[:,i]  
    z = z.T #could implement this above
    return z


In [9]:
@numba.jit(nopython=True)
def potential_derv_fast(q):
    #Can check with numerical differentiation.
    
    X = vec_to_matrix(q[:d*p])
    d_1 = q[d*p:d*p+p]
    d_2 = q[d*p+p:]
    
    D_1 = np.diag(d_1)
    D_2 = np.diag(d_2)
    
    Σ = matmul(matmul(X,D_1),np.transpose(X)) + D_2
    
    Σ_inv_T = np.transpose(np.linalg.inv(Σ))

    #Constructing M
    M_kl = 0.5*n*Σ_inv_T 
    
    for k in range(d):
        for l in range(d):
            for r in range(n):
                M_kl[k,l] -= 0.5*dot_product(Data[:,r]-mean,Σ_inv_T[k,:])*dot_product(Σ_inv_T[:,l],Data[:,r]-mean)
    
    #dUdX
    dUdX = np.zeros((d,p))
    
    for i in range(d):
        
        for j in range(p):
            
            for k in range(d):
                
                for l in range(d):
                    
                    if k == i and l == i:
                        
                        dΣ_kl_dX_ij = 2*X[i,j]*d_1[j]
                        
                    elif k == i:
                        
                        dΣ_kl_dX_ij = d_1[j]*X[l,j]
                        
                    elif l == i:
                        
                        dΣ_kl_dX_ij = X[k,j]*d_1[j]
                    
                    else:
                        continue
                    
                    dUdX[i,j] += M_kl[k,l]*dΣ_kl_dX_ij
                    
                
            
    
    #dUd1
    dUd1 = np.zeros(p)
    
    for j in range(p):
    
        for k in range(d):
                
            for l in range(d): 
                    
                dΣ_kl_dD1_jj = X[k,j]*X[l,j]

                dUd1[j] += M_kl[k,l]*dΣ_kl_dD1_jj

        #adding extra term
        dUd1[j] += d_1[j]/(σ_1)**2
    
   
    #dUd2
    dUd2 = np.zeros(d)
    
    for j in range(d):

        dUd2[j] += M_kl[j,j]*1.
        
        #adding extra term
        dUd2[j] += d_2[j]/(σ_2)**2
    
    pot_derv = np.zeros(int(d*p+p+d))
    pot_derv[:d*p] = matrix_to_vec(dUdX)
    pot_derv[d*p:d*p+p] = dUd1
    pot_derv[d*p+p:] = dUd2
    
    return pot_derv

In [10]:
q_initial = list(matrix_to_vec(np.eye(d,p)))
d_1 = abs(np.random.normal(0,σ_1,p))
d_2 = abs(np.random.normal(0,σ_2,d))
q_initial += list(d_1)
q_initial += list(d_2)
q_initial = np.array(q_initial)
x_init = q_initial

In [11]:
#RATTLE Hamiltonian Flow
@numba.jit(nopython=True)
def Riemannian_GD(x0,t,dt,max_elim_iters):
    
    n = np.floor(t/dt)

    qn = x0

    G_q = G(qn)
    
    #Gram Matrix is GG^T
    gram = matmul(G_q,G_q.T)
    gram_inv = np.linalg.inv(gram)
    
    
    
    pderv = potential_derv_fast(qn)#potential_derv(qn)
    
    residual_list = np.zeros(num_of_constraints)
    for i in range(int(n)):  
      
        gram = matmul(G_q,G_q.T)
        gram_inv = np.linalg.inv(gram)
    
        proj_matrix = np.eye(dimension) - matmul(G_q.T,matmul(gram_inv,G_q))
    
        #sample 3d gaussian and then project onto tangent space.
        pderv = matrix_vec_multiplication(proj_matrix,pderv)
    
        #solver for Lagrange position multipliers
        Q = qn - dt*pderv
        
        #non-linear gaussian elimination
        for k in range(max_elim_iters): #i>j
            for i in range(p):
                for j in range(i,p):
                    g_Q = g_ij(Q,i,j)
                    index = int(i*p - 0.5*i*(i-1) + j-i)
                    
                    residual_list[index] = g_Q
                    if abs(g_Q) < 1e-20:
                        continue
                    G_Q = G(Q)
                    
                    #should be sum of i's and js in indexing below
                    dlambda = g_Q/dot_product(G_Q[index,:],G_q[index,:])
                    Q = Q - G_q[index,:]*dlambda
            #break condition
            if np.all(np.abs(residual_list)<1e-10):
                break

        
        #half step
        qn = Q
        
        pderv = potential_derv_fast(qn) #potential_derv(qn)
        G_q = G(qn)
        
        if i%100==0:
            print(i)

       
    return qn

3) Gaussian Sampling on Tangent Space

In [12]:
@numba.jit(nopython=True)
def U(q):
    
    X = vec_to_matrix(q[:d*p])
    
    d_1 = q[d*p:d*p+p]
    d_2 = q[d*p+p:]
    
    D_1 = np.diag(d_1)
    D_2 = np.diag(d_2)
    
    Σ = matmul(matmul(X,D_1),np.transpose(X)) + D_2
    Σ_inv = np.linalg.inv(Σ)
    
    #likelihood
    pot = 0.5*n*np.log(np.linalg.det(Σ)) + 0.5*n*p*np.log((2*np.pi))
    for i in range(n):
        Bx = matrix_vec_multiplication(Σ_inv,Data[:,i]-mean)
        pot += 0.5*dot_product(Data[:,i]-mean,Bx)
        
    #prior
    #don't need uniform prior because it's constant
    pot += 0.5*np.log(2*np.pi*(σ_1)**2)
    pot += 0.5*dot_product(d_1,d_1)/(σ_1)**2
    
    pot += 0.5*np.log(2*np.pi*(σ_2)**2)
    pot += 0.5*dot_product(d_2,d_2)/(σ_2)**2
    
    return pot

@numba.jit(nopython=True)
def hamiltonian(x,v):
    return U(x) + 0.5*dot_product(v,v)

In [13]:
@numba.jit(nopython=True)
def f(q):
    X = vec_to_matrix(q[:d*p])
    d_1 = q[d*p:d*p+p]
    d_2 = q[d*p+p:]
    
    D_1 = np.diag(d_1)
    D_2 = np.diag(d_2)
    
    Σ = matmul(matmul(X,D_1),np.transpose(X)) + D_2
    return Σ

In [14]:
#Testing 
q_initial = list(matrix_to_vec(A))

d_1 = np.diag(D_1) #abs(np.random.normal(0,σ_1,p))
d_2 = np.diag(D_2) #abs(np.random.normal(0,σ_2,d))

q_initial += list(d_1)
q_initial += list(d_2)
q_initial = np.array(q_initial)
x_init = q_initial
x_init = [0.0469471231212426, 0.062054729118645506, 0.10913131772105193, -0.027378951115851053, 0.03014722081747145, 0.10087743323235587, 0.09111447065097171, 0.10837575566660887, 0.1699019584406265, 0.1059886088874678, 0.2499207296524493, 0.22524161185599115, 0.20382349971087854, 0.17284349111282096, 0.17326450364785495, 0.10581136373250787, -0.027529851366716266, -0.04742961177935156, -0.11748667755346122, -0.04194168222678295, -0.016494638370806836, -0.020844065791741133, -0.02061626834729403, 0.024004971051419115, -0.03875888396674657, -0.004591560130064922, 0.070160000406579, 0.03818766226969645, -0.0015952381842810946, 0.025189705041020548, -0.030437123559845108, -0.10705318115249758, -0.013431018972954243, 0.10919189743303016, 0.10265997165056316, 0.07800049892581161, 0.08715900235609217, 0.20311487869163292, 0.11767482776472347, 0.12402460888669307, 0.14434452068567066, 0.10009184486444897, 0.1458666900612036, 0.044054481159667425, 0.1742903216626856, 0.11185886439387553, 0.13818178551729793, 0.20313636238774366, 0.21387223602131405, 0.18972953502778467, 0.2421331539865, 0.2448351455361014, 0.22499355987756023, 0.1960073296827896, 0.19229648250279952, 0.1528185693599717, -0.034442463578631045, -0.06031151041138012, -0.11138080348801215, -0.10944012518268502, -0.09509245997368744, -0.0604473476808728, 0.2620987013032502, 0.0861126155887287, 0.0750959949163359, 0.1481680713974297, 0.00720089543612883, 0.11589025033601447, 0.09029252031134799, 0.07714227657762941, 0.1223602549092192, 0.17246104969325504, 0.08821787118839788, -0.1578800814280233, -0.16860914798510127, -0.14815400645118376, -0.1415149849650896, -0.11026186161712741, 0.20744666219618305, 0.15881704976083894, -0.03291992870160949, -0.07878595098342023, 0.2007967081024905, 0.05982458551903851, 0.148132512149727, -0.041726001881530854, -0.0031432389875344145, -0.056760701458795436, -0.017393710118665227, -0.06556380167027966, 0.11093184983640006, -0.029164544060536514, 0.0673676560858933, -0.009291610751052802, -0.08491674870852656, 0.13375846910709424, 0.19015380343657007, 0.11495451390319067, -0.052049129535754295, -0.20031595174660283, -0.07268389281364553, -0.0877792681804858, 0.19523528276459234, 0.0503016084519598, -0.005624534445405909, 0.11580465241498458, 0.0036150164467321617, 0.1225245701908238, 0.019091469632692027, -0.09581687137135568, 0.07182661987678737, 0.09798005202301407, -0.06224325431420575, -0.1974011498700592, -0.19393738350925982, -0.14496844800719452, -0.21852704616142005, -0.23179806350389973, 0.19647191247183945, 0.24901737924538586, -0.035174098156411704, 0.34524593818823496, 0.21887742488030856, 0.3871592034128386, 0.2576652318826079, 0.03868339222722682, 0.07420748895227415, -0.019227036739347947, -0.10736912944353576, -0.19677520649413324, 0.06083826763098696, 0.03547856337958604, 0.01794171725123831, 0.004776590592955343, 0.05663503742359467, -0.045680971592645864, 0.16949747215749364, -0.1109408793313156, 0.023895386731305465, -0.12236056054592415, 0.031486855260139904, -0.027250432625084007, -0.13464171968872696, -0.15627112506795848, -0.0032370145838797773, 0.03201248827425901, 0.07548842268559094, -0.16000891758857197, -0.18177903285204433, 0.03613988156183945, -0.15394490984205972, 0.09791479853846893, 0.04309434057755858, -0.01393030947818297, 0.16567121082439779, -0.1046618259378864, -0.007395862687944379, -0.08793422378758844, -0.10307950800200863, -0.05313225398138815, 0.06255806322681574, 0.26685906817453087, 0.1337680840896862, 0.33603285092048163, -0.0026747125427053586, -0.01488951524766375, 0.15748159280372157, -0.13023996158548573, -0.14497739664923984, -0.014523459583070218, -0.04381501999807571, -0.025855761798349006, -0.05489253715194365, -0.020261096905887958, 0.04423525740989661, -0.018052721072824362, 0.030465246270668846, -0.017909813780319903, 0.07185300853931452, 0.027675889939823548, -0.12461366181681753, 0.19753813021465608, -0.203890589907934, -0.05799021772845055, 0.1484657834544044, 0.15235994621803853, -0.2611373653634501, 0.1499955259897282, -0.07243914295270998, -0.06137112648893198, 0.06884327663113692, 0.17611319217087515, 0.20764577499629833, 0.2827785206442503, 0.05710325815015804, -0.06994404409788517, 0.20569970267857457, -0.11447148032995415, 0.02519080052929643, -0.008462381480386024, -0.036211559581425294, 0.12120500647974837, 0.16994741153349274, -0.08194210827511345, -0.09419205828684624, -0.10337214531412929, -0.136754748166695, -0.05261598143769275, 0.06673894686982618, -0.106176529933809, -0.13992761934562703, -0.157262716478751, 0.009589283153823274, -0.16388557935966303, 0.027547099109504716, 0.09433807666001907, 0.1819383235305401, 0.10799955790332647, -0.01817061299780716, 0.1429109144306416, -0.038468300264238156, 0.07981970189629163, -0.277953279578314, -0.16584895253763446, 0.07047068960920465, 0.053598259400019305, -0.19113110749692225, -0.01854871871711896, -0.11873581470253118, -0.12706785493457512, -0.07899102392518985, -0.006475396221630598, 0.0180120784552541, 0.05973103770629178, -0.13993797452740625, -0.01908593470637134, 0.16151402403554369, 0.0011264052978587956, 0.13560792181135134, 0.1244135260977217, 0.03539217093841894, 0.060292775354526595, -0.12202832701032114, 0.18226550586047477, -0.2861828849709008, 0.08853813798751095, 0.13670644672560242, 0.08833589926421558, -0.2449250332679615, 0.0470237317281852, 0.11424986837397358, -0.1470704016141171, -0.012013888372895461, 0.07328402635352284, 0.08102719066013483, -0.005318055795376968, -0.15544378926092617, -0.09419543305896934, -0.3031112292976079, 0.18576830900810715, -0.009532908448873568, 0.005780348480476695, -0.020021469807552046, -0.05046485430392778, -0.047366913565957194, -0.0998469547575838, -0.1942361947287148, 0.13048292735291334, -0.05556974208099308, -0.18868867618447208, 0.18384742099940224, 0.021156078519275222, -0.0606619894879785, -0.0743467560128009, -0.08731241986126007, -0.06460357180058539, -0.052531219246618334, 0.10508752718734793, 0.16422802541101736, 0.13085536487493402, -0.07287182302869773, 0.08867659365162389, -0.05169006033115277, 0.23543340434905022, 0.11669437523259803, 0.12286638068711522, 0.007164714328351103, -0.10709022415441977, -0.16312427917277003, -0.17544985781159778, 0.09786687849816175, -0.1337358715524218, 0.02248446743550362, 0.05672409742206027, -0.01428083291286783, -0.05373533761963948, -0.26837546702103465, -0.1180355068818506, -0.1644923426250211, 0.15583446237372714, -0.06206417434073317, -0.185034746003223, 0.0490807554883775, 0.24429662091782467, 0.08166094377952332, 0.18997650592173393, -0.06070667773230823, 0.013280616246362186, 0.4521804896260935, 0.05912864379499513, 0.0466422801121356, 0.05523017669001712, -0.12506199264404483, 0.046014695980848015, 0.0036118925612533236, 0.16567453402516877, 0.049285229980040854, -0.16492441495612672, -0.07716193207699587, 0.1439637159765537, 0.024973685059972613, 0.1728685674368888, -0.07580273928895018, 0.02176058706184306, -0.11489492447300632, -0.018938163358547956, -0.0005027290495490253, -0.01645477750817589, -0.009895295960327696, -0.019227519880147974, -0.11041667812041696, -0.1032628128387768, 0.11765910708282394, 0.13628832730130772, 0.06336219307690417, -0.20249266869313956, -0.10301429234526956, -0.03804446301498932, -0.06793720918921865, 0.1680591737905561, -0.1817569004233648, -0.259793082054087, -0.03848884053216338, 0.16348059365470224, -0.024137547378596715, -0.04393904055854429, -0.13257930179986147, -0.12654586324138697, 0.2601110738186565, -0.17005420101541138, -0.06986882440635202, -0.11244393671689998, -0.055878229810800706, -0.024104713349626012, 0.05275883767870355, 0.02480854253825336, 0.10454370342193989, -0.005818207110458306, -0.13142140130969632, 0.1672395512062872, -0.13799311101203843, 0.11417771990622695, 0.0019145473412551163, 0.03901726151107209, 0.165449046133923, 0.05743456884132148, 0.12512289907548013, -0.14619399389368024, 0.18890164672985474, 0.06626441621229799, -0.25805221801612493, -0.0402558402542185, 0.07904287210679031, 0.26803323510696453, -0.1577436771859533, 0.16934771679557736, -0.054565658265405485, 0.008972511150012796, 0.005430462692886791, -0.18815617031092902, -0.14065215191967526, -0.040960809611617535, 0.1807586528879763, 0.1817628714693753, -0.06101599281805794, -0.08202864050828264, 0.14966462364801356, 0.16072062807927123, 0.07989775182731144, -0.26292338129430026, 0.045923929325936266, 0.1985080320800919, -0.2005262508348421, 0.009384341794824188, -0.22901983962504885, 0.08271458623666005, -0.059281556149027276, -0.12566464183894926, 0.019525187340730325, -0.05259942840017021, -0.12198240539263634, 0.11339508193711166, 0.06250465770934022, 0.07998298600363112, 0.19479073503901426, 0.02173263919513066, 0.11751094951189225, 0.027754836806568522, 0.07536669222808505, -0.08898193470009243, -0.2415350455268749, -0.07799868850910237, 0.024403751664361553, 0.13775524548647355, -0.13002143191657495, 0.060562363164743435, -0.045106818945715166, -0.10291258114409926, 0.01719264901558477, -0.1149633772164586, 0.044708745080435514, -0.23818926104091648, -0.010091500761512438, 0.24278348490175478, -0.1522798000046354, 0.148257132660078, 0.05457598516615979, -0.07596460566894564, -0.22462834449508912, 0.1585309087939724, 0.013098094087812899, 0.14484497151583983, 0.15815597007922763, 0.027385577102170214, 0.033459429329791096, 0.1451129885521328, 0.23500942346325177, 0.1565695930839693, -0.02352905333140257, -0.16307493679825807, -0.1915633180516557, -0.006955427283766557, 0.0962573091667635, 0.1574690867009003, 0.0015456520132516405, 0.17027687987483545, -0.11122054924911902, -0.15515060626377927, 0.05845907570379761, 0.028236504047863202, 0.047220688558432075, 0.0006617509738183362, 0.06965773904413154, -0.03235683003789166, -0.08040782097737012, 0.16943560622825932, 0.17522540237117767, -0.17399388602678742, 0.004047879527536031, -0.0380592523610379, -0.15190566371542982, -0.1988866224347917, -0.00222250116135341, 0.28682156948648924, -0.12987350174370732, 0.19023612743374904, -0.05387408276437565, -0.030195391857639954, -0.13954952567724457, 0.033104369202685446, 0.06421918791131023, 0.10859791610450238, -0.022779341277924078, -0.04701891488364476, -0.047094782632125284, 0.08111739481080707, 0.168214246409796, 0.11612108131362221, -0.08263982282662811, -0.17923655868887306, -0.12960702573739355, -0.19022888324393525, 0.37288205003578906, -0.15909864251628342, -0.007964968025143078, -0.10440360750248019, -0.04077039083653685, -0.019214837291509534, -0.18439708290980086, 0.020394515483308655, 0.023358404753433655, 0.09775040401408441, 0.08987348998801994, 0.25783099100971574, 0.08967329400578275, -0.035683591729343205, -9.974978854704584e-05, -0.04684202881887063, -0.08809051129598955, -0.06353456835903105, -0.2370516403056846, -0.16587639029214582, -0.16592905131873154, -0.16221548461446433, -0.03911759992159406, -0.2041103518910936, 0.055565622220813, -0.18944040789011507, -0.15935446769637937, -0.15819537015761007, 0.028918795131019952, -0.003535026864993597, -0.030925674617115324, -0.0363334682852739, 0.023251481232040815, -0.050012823936025144, -0.036870395700682554, -0.11194873899815679, -0.06603653756485761, 0.069201888884983, -0.01398272900148939, -0.25107629001760345, 0.26855118994384536, -0.24638698215247498, 0.02626202710326957, -0.0483641049875665, -0.13999190663051247, -0.14409781089503323, -0.10523343730483757, 0.023286072715166393, -0.14486471179411378, -0.02300022179165623, -0.03547518264439605, 0.11695742891384613, 0.06508224515050222, -0.0029052003626357558, -0.1189006358230278, 0.015072646323189148, -0.12781937341443902, 0.0739698555405985, 0.19929686220638831, 0.03305699640663018, 0.13595232899804255, -0.22286630602724186, -0.01396117978634729, 0.19112476780445578, 0.14471826217498804, 0.011416970303831415, -0.037328821040828165, 0.03410561591579016, 0.1364979339878212, -0.24451748895138412, -0.09373777099835229, 0.09808070636159502, -0.08815983197762296, 0.029508753205472443, 0.26047671544318257, 0.12217522651786893, -0.0428263743028032, 0.10795716364707314, -0.12078092612413605, 0.10598347140840193, -0.0956293720789686, 0.058072068140358556, -0.05130134632824287, -0.022234449022860155, -0.03620239043836225, 0.07470318040651928, 0.1252159144254915, 0.031927564052125344, 0.05872614175766498, 0.03454647227188273, 0.08158926235120163, 0.15466156869825864, 0.17125092481918805, -0.04748264257963765, 0.08937844976311937, -0.16141295710373746, -0.11036681874429563, 0.06928866060129736, 0.10870985573146506, 0.11935068738436735, 0.21708843991068036, -0.10923470807029845, 0.1785641739093059, 0.3218603928766934, -0.032214637531121276, 0.11108146479460469, 0.06835234375305244, -0.05690841416355265, 0.10239908939172516, -0.28529241078446654, -0.19923609916283808, -0.015778997791029682, -0.15239964684501794, -0.06218128805799238, 0.1265808833262374, 0.03353696246909184, 0.16528179620376798, 0.02830722694986739, 12.396128914656721, 5.469998776379874, 3.586704844293246, 2.5082282700092287, 1.7154817248716343, 2.2450114983105935, 1.21867442357063, 1.5411545893688685, 1.65746917130889, 1.470615064993922, 1.3572625856067404, 0.01615390633274867, 0.0006999178869180375, 0.00416209422679892, 0.7841884677206779, 0.7477373013168231, 0.7984139889851865, 1.054204045892748, 0.05433074092622328, 1.159916043842654, 0.19014056179856584, 0.045832892521257156, 0.3201793216638904, 0.03931474460609613, 0.3242987276267584, 0.5465840371191334, 0.19100382267567068, 0.8691700258630074, 0.40485787213778135, 0.33785442372057145, 0.7126959369635301, 0.9664103489728181, 0.6339769430512832, 1.0141588887914337, 1.0748739441907604, 0.6913128252419397, 0.9025632598578965, 0.7492453585357475, 1.2978925969472899, 0.4316326580346962, 0.9408428595755961, 0.9403276419110967, 1.0380108848767915, 0.790303352872922, 0.6531236845623875, 0.5223799599304365, 0.7318858163549629, 0.6023262340367901, 0.636190399426384, 0.5007810448951223, 0.8537894307956284, 4.6011698161112734e-05, 0.00423141482424736, 0.0027618404266846774, 0.5215106518907134, 0.648552304494129, 0.5842362999320564, 0.4415010116834816, 0.07070677408478944, 0.42769749268961643, 0.07859873331868927, 0.00019975018980990854, 0.15717116255724153, 0.013183870194999914, 0.016533809117277014, 0.34734427432721954, 0.11369233973750424, 0.4985108681900809, 0.31569236911807197, 0.4692549161092322]
x_init = np.array(x_init)

4) Simulation

In [15]:
#Initialise
num_of_iterations = 100
dt = 0.0001

In [16]:
@numba.jit(nopython=True)
def RGD(num_of_iterations,dt,x_init):
    
    
    
    t = num_of_iterations*dt
    
    x = Riemannian_GD(x_init,t,dt,1000)
        
        
    return x