In [None]:
import numpy as np
import matplotlib.pyplot as plt
import csv
import time
import scipy.stats
from jax.scipy.stats import norm
import scipy.optimize
import symnum.numpy as snp
import jax.numpy as jnp
from jax.scipy.linalg import cho_solve
from jax import jit, grad, value_and_grad
from jax.lax import scan
import matplotlib.pyplot as plt
from jax.config import config
config.update('jax_enable_x64', True)
config.update('jax_platform_name', 'cpu')

In [None]:
class gle_protein_folding:
    def __init__(self, param, initial_value, step_size_data, step_size_sim, num_data, num_simulation, mass, temp):
        self.param = param 
        self.initial_value = initial_value  
        self.step_size_data = step_size_data
        self.step_size_sim = step_size_sim
        self.num_data = num_data
        self.num_simulation = num_simulation
        self.mass = mass
        self.temp = temp
    
    def grad_potential(self, q, a = 1200, b=0.30, c=0.90, d = 0.001):
        return 2*a*((q-b)*((q-c)**2) + (q-c)*((q-b)**2)) + 3*d*(q**2)
    
    def hess_potential(self, q, a =1200, b=0.30, c=0.90, d= 0.001):
        return 2*a*((q-c)**2 + 4*(q-c)*(q-b) + (q-b)**2) + 6*d*q

    def calc_A_q(self, q, θ):
        dt = self.step_size_data
        c_1, c_2, τ_1, τ_2 = θ
        hess_U = self.hess_potential(q)
        sum_ratio = c_1/τ_1 + c_2/τ_2
        return jnp.array(
            [
                dt - (hess_U + sum_ratio)* (dt**3)/6,
                (dt**2)/2 - (dt**3)/(6*τ_1), 
                (dt**2)/2 - (dt**3)/(6*τ_2)
             ]
        )

    def calc_A_h(self, q, θ):  
        dt = self.step_size_data
        c_1, c_2, τ_1, τ_2 = θ
        hess_U = self.hess_potential(q)
        sum_ratio = c_1/τ_1 + c_2/τ_2
        return jnp.array(
            [ 
            [1 - (hess_U + sum_ratio) * (dt**2)/2, dt - (dt**2)/(2*τ_1), dt - (dt**2)/(2*τ_2)], 
            [-(c_1/τ_1)*dt, 1 - dt/τ_1, 0.0, 0.0],
            [-(c_2/τ_2)*dt, 0.0, 1 - dt/τ_2, 0.0],
            ])
    
    def calc_mu_q(self, q, θ):
        dt = self.step_size_data
        mass = self.mass 
        grad_U = self.grad_potential(q)
        return q - grad_U*(dt**2)/2

    def matrix_A(self, q, θ): 
        c_1, c_2, τ_1, τ_2  = θ
        dt = self.step_size_data
        hess_U = self.hess_potential(q)
        mass = self.mass
        sum_ratio = c_1/τ_1 + c_2/τ_2
        return jnp.array([
            [dt - (hess_U + sum_ratio)*(dt**3)/6, 
            (dt**2)/2 - (dt**3)/(6*τ_1), 
            (dt**2)/2 - (dt**3)/(6*τ_2)
            ],
            [1 - (hess_U + sum_ratio) * (dt**2)/2, 
            dt - (dt**2)/(2*τ_1), 
            dt - (dt**2)/(2*τ_2)
            ], 
            [-(c_1/τ_1)*dt, 1 - dt/τ_1, 0.0],
            [-(c_2/τ_2)*dt, 0.0, 1 - dt/τ_2]
            ]) 
    
    # This is used to generate the true trajectories of sample paths
    def mean_one_step_sim(self, x, θ): 
        q = x[0]
        h = x[1:] 
        dt = self.step_size_sim
        c_1, c_2, τ_1, τ_2 = θ
        grad_U = self.grad_potential(q)
        hess_U = self.hess_potential(q)
        sum_ratio = c_1/τ_1 + c_2/τ_2
        mass = self.mass
        matrix_A = jnp.array([
            [dt - (hess_U + sum_ratio)*(dt**3)/6, 
            (dt**2)/2 - (dt**3)/(6*τ_1), 
            (dt**2)/2 - (dt**3)/(6*τ_2)
            ],
            [1 - (hess_U + sum_ratio) * (dt**2)/2, 
            dt - (dt**2)/(2*τ_1), 
            dt - (dt**2)/(2*τ_2)
            ], 
            [-(c_1/τ_1)*dt, 1 - dt/τ_1, 0.0],
            [-(c_2/τ_2)*dt, 0.0, 1 - dt/τ_2]
            ]) 
        return jnp.array([q  - grad_U*(dt**2)/2, - grad_U*dt, 0.0, 0.0]) + jnp.dot(matrix_A, h) 

    
    def mean_one_step(self, current_q, current_h, θ):
        dt = self.step_size_data
        c_1, c_2, τ_1, τ_2 = θ
        grad_U = self.grad_potential(current_q)
        hess_U = self.hess_potential(current_q)
        sum_ratio = c_1/τ_1 + c_2/τ_2 
        mass = self.mass
        matrix_A = jnp.array(
            [
            [dt - (hess_U + sum_ratio)* (dt**3)/6, (dt**2)/2 - (dt**3)/(6*τ_1), (dt**2)/2 - (dt**3)/(6*τ_2)],
            [1 - (hess_U + sum_ratio) * (dt**2)/2, dt - (dt**2)/(2*τ_1), dt - (dt**2)/(2*τ_2)], 
            [-(c_1/τ_1)*dt, 1 - dt/τ_1, 0.0],
            [-(c_2/τ_2)*dt, 0.0, 1 - dt/τ_2]
            ])

        return jnp.array([current_q - grad_U*(dt**2)/(2*mass),  - grad_U*dt, 0.0, 0.0]) + jnp.dot(matrix_A, current_h) 
    
    def covariance_one_step_sim(self, θ):
        dt = self.step_size_sim
        c_1, c_2, τ_1, τ_2 = θ
        Σ_33 = 2*self.temp*c_1/(τ_1**2)
        Σ_44 = 2*self.temp*c_2/(τ_2**2)
        sum_Σ = Σ_33 + Σ_44
        return jnp.array([
            [sum_Σ*(dt**5)/20, sum_Σ*(dt**4)/8, Σ_33*(dt**3)/6, Σ_44*(dt**3)/6], 
            [sum_Σ*(dt**4)/8, sum_Σ*(dt**3)/3, Σ_33*(dt**2)/2, Σ_44*(dt**2)/2], 
            [Σ_33*(dt**3)/6, Σ_33*(dt**2)/2, Σ_33*dt, 0.0],
            [Σ_44*(dt**3)/6, Σ_44*(dt**2)/2, 0.0, Σ_44*dt]
            ])
    
    def covariance_one_step(self, θ):
        dt = self.step_size_data
        c_1, c_2, τ_1, τ_2 = θ
        Σ_33 = 2*self.temp*c_1/(τ_1**2)
        Σ_44 = 2*self.temp*c_2/(τ_2**2)
        sum_Σ = Σ_33 + Σ_44
        return jnp.array([
            [sum_Σ*(dt**5)/20, sum_Σ*(dt**4)/8, Σ_33*(dt**3)/6, Σ_44*(dt**3)/6], 
            [sum_Σ*(dt**4)/8, sum_Σ*(dt**3)/3, Σ_33*(dt**2)/2, Σ_44*(dt**2)/2], 
            [Σ_33*(dt**3)/6, Σ_33*(dt**2)/2, Σ_33*dt, 0.0],
            [Σ_44*(dt**3)/6, Σ_44*(dt**2)/2, 0.0, Σ_44*dt]
            ])

    def generate_sample_paths(self, θ, seed=20230704):
        np.random.seed(seed)
        seq_rvs = np.random.multivariate_normal(np.zeros(4), self.covariance_one_step_sim(θ), size=self.num_simulation)
        x_0 = self.initial_value 
        
        @jit
        def step_func(x, noise):
            x_next = self.mean_one_step_sim(x, θ) + noise
            return x_next, x_next 
        
        _, x_seq = scan(step_func, x_0, seq_rvs) 

        return jnp.concatenate((x_0[None], x_seq))
    
    def prediction_covariance(self, q, forward_filter_covariance, θ):
        Σ =self.covariance_one_step(θ)
        A = self.matrix_A(q, θ)
        pred_cov =  Σ + A @ forward_filter_covariance @ A.T 
        return pred_cov[0,0], pred_cov[1:,0], pred_cov[1:, 1:]
    
    
    def prediction_mean(self, q, forward_filter_mean, θ):
        pred_mean = self.mean_one_step(q, forward_filter_mean, θ) 
        pred_mean_q = pred_mean[0]
        pred_mean_h = pred_mean[1:]
        return pred_mean_q, pred_mean_h
    

    def forward_filter_mean_cov_one_step(self, current_q, next_q, forward_filter_mean, forward_filter_covariance, θ):
        μ_q, μ_h = self.prediction_mean(current_q, forward_filter_mean, θ)
        Λ_qq, Λ_hq, Λ_hh = self.prediction_covariance(current_q, forward_filter_covariance, θ) 
        next_filter_mean = μ_h + ((next_q - μ_q)/Λ_qq)*Λ_hq
        mat = jnp.outer(Λ_hq, Λ_hq)
        next_filter_cov = Λ_hh - mat / Λ_qq
        return next_filter_mean, next_filter_cov
    
    
    def forward_filter_mean_cov_paths_scan(self, q_paths, initial_mean, initial_cov, θ):
        @jit
        def step_func(filter_mean_cov, q_paths_current_next):
            filter_mean, filter_cov = filter_mean_cov
            q_current, q_next = q_paths_current_next
            filter_next = self.forward_filter_mean_cov_one_step(q_current, q_next, filter_mean, filter_cov, θ)
            return filter_next, filter_next 
        
        _, filter_mean_cov = scan(step_func, (initial_mean, initial_cov), (q_paths[:-1], q_paths[1:]))
        filter_mean, filter_cov = filter_mean_cov

        return jnp.concatenate((initial_mean[None], filter_mean)), jnp.concatenate((initial_cov[None], filter_cov)) 
    
    def get_contrast_function_scan(self, θ, q_paths, initial_mean, initial_cov):
        filter_mean_paths, filter_cov_paths = self.forward_filter_mean_cov_paths_scan(q_paths, initial_mean, initial_cov, θ)
        initial_log_likelihood = norm.logpdf(q_paths[0], loc = q_paths[0], scale = 1.0)
        Σ = self.covariance_one_step(θ)

        @jit
        def step_func(loglikelihood, qset_and_filtermeancov):
            q_current, q_next, filter_mean, filter_cov = qset_and_filtermeancov
            A_q = self.calc_A_q(q_current, θ)
            q_mean = self.calc_mu_q(q_current, θ) + jnp.dot(A_q, filter_mean)
            vec = A_q @ filter_cov
            scalar = jnp.dot(vec, A_q)
            q_scale = jnp.sqrt(scalar + Σ[0,0])
            loglikelihood_next = loglikelihood + norm.logpdf(q_next, q_mean, q_scale)
            return loglikelihood_next, loglikelihood_next
        
        _, log_likelihood_seq = scan(step_func, initial_log_likelihood, (q_paths[:-1], q_paths[1:], filter_mean_paths[:-1,:], filter_cov_paths[:-1,:,:]))

        return -2*log_likelihood_seq[-1]


Sample path of observation

In [None]:
dt_simulation = 1e-4 # step size for synthetic data 
dt_obs = 1e-3  # step size for the observation 
T = 1500 # Time length of data step
n_simulation = int(T / dt_simulation)
sub_interval = int(dt_obs/dt_simulation)
n_data = int(T / dt_obs) # number of data 
θ_true = [2.214*1e-1, 1.2, 0.007, 4.6] # param θ = (c_1, c_2, c_3, τ_1, τ_2, τ_3) 
x_0 = jnp.array([0.0, 0.0, 0.0, 0.0]) # initial value  
t_seq_sim = np.arange(int(T / dt_simulation) + 1) * dt_simulation
model = gle_protein_folding(θ_true, x_0, dt_obs, dt_simulation, n_data, n_simulation, mass=1.0, temp=2.949)
initial_mean = jnp.array([0.0, 0.0, 0.0])
initial_cov = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])

seed = 20230707
model = gle_protein_folding(θ_true, x_0, dt_obs, dt_simulation, n_data, n_simulation, mass=1.0, temp=2.949)
x_seq_obs = model.generate_sample_paths(θ_true, seed)
q_paths_sim = x_seq_obs[:, 0]
t_seq_obs = t_seq_sim[::sub_interval]
q_paths_obs = q_paths_sim[::sub_interval]
fig, axs = plt.subplots(figsize=(6.0, 5.0))
axs.plot(t_seq_obs, q_paths_obs, linewidth=1.0)
axs.set_ylabel(r"$q_t$", fontsize=16)
axs.set_xlabel(r'$\mathrm{time} \, (t)$', fontsize=16)
plt.show()
x_0 = x_seq_obs[-1]

Computation of maximum likelihood estimator

In [None]:
# setting for observations
num_sampling = 50
seed = 20230707
dt_simulation = 1e-4 # step size for synthetic data 
dt_obs = 1e-3  # step size for the observation 
T = 1500 # Time length of data step
n_simulation = int(T / dt_simulation)
sub_interval = int(dt_obs/dt_simulation)
n_data = int(T / dt_obs) # number of data 
t_seq_sim = np.arange(int(T / dt_simulation) + 1) * dt_simulation
x_0 = x_seq_obs[-1]
θ_true = [2.214*1e-1, 1.2, 0.007, 4.6] # param θ = (c_1, c_2, τ_1, τ_2) 
model = gle_protein_folding(θ_true, x_0, dt_obs, dt_simulation, n_data, n_simulation, mass=1.0, temp=2.949)
x_seq_obs = model.generate_sample_paths(θ_true, seed)
q_paths_sim = x_seq_obs[:, 0]
t_seq_obs = t_seq_sim[::sub_interval]
q_paths_obs = q_paths_sim[::sub_interval]
x_seq_0 = x_seq_obs[0]
initial_mean = jnp.array([0.0, 0.0, 0.0])
initial_cov = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])

In [None]:
θ_output = np.zeros(len(θ_true)*num_sampling).reshape(num_sampling,len(θ_true))
θ_0 = jnp.array([1e-1, 1.0, 1e-2, 10.0])

for k in range(num_sampling):
    print("Compute the observations -- Start")
    x_seq_sim = model.generate_sample_paths(θ_true, seed)
    print("Compute the observations -- End")
    q_paths_sim = x_seq_sim[:, 0]
    q_paths_obs = q_paths_sim[::sub_interval]
    inv_data = int(500/dt_obs)
    q_paths_obs = q_paths_obs[inv_data:]
    arg = (q_paths_obs, initial_mean, initial_cov)
    count = 0
    def cbf(X):
        global count
        count += 1
        f = model.get_contrast_function_scan(X, q_paths_obs, initial_mean, initial_cov)
        print('%d\t%f\t%f\t%f\t%f\t%f' % (count, X[0], X[1], X[2], X[3], f))
        
    print("Optimisation--Start")
    res = scipy.optimize.minimize(
        model.get_contrast_function_scan, θ_0, args=arg, 
        method='Nelder-Mead', 
        callback=cbf, 
        options={"maxiter":1000})
    print("Optimisation--End")
    print(res)
    print(k)
    θ_output[k] = res.x
    seed += 1

In [None]:
f = open(f'MLE_GLE_protein_folding_partial={T}_dt_obs_{dt_obs}_dt_sim_{dt_simulation}.csv', 'w')
writer = csv.writer(f, delimiter='\t')
for i in range (num_sampling):
    writer.writerow(θ_output[i])
f.close()