In [None]:
import numpy as np
import matplotlib.pyplot as plt
import csv
import time
import scipy.stats
from jax.scipy.stats import norm
import scipy.optimize
import scipy
import sympy
import symnum
import symnum.numpy as snp
import numpy as np
import jax.numpy as jnp
import jax.random as jrand
import jax.scipy.optimize as jopt
from jax.scipy.linalg import cho_solve
from jax.scipy.special import ndtr, ndtri
from jax import jit, vmap, grad, value_and_grad
from jax.lax import scan
from jax.example_libraries.optimizers import adagrad, adam
import matplotlib.pyplot as plt
from jax.config import config
config.update('jax_enable_x64', True)
config.update('jax_platform_name', 'cpu')
import simsde 
from simsde.operators import v_hat_k 

In [None]:
class fhn_lg:
    def __init__(self, param, initial_value, step_size_data, step_size_sim, num_data, num_simulation):
        self.param = param # θ = (ε, γ, β, σ)   
        self.initial_value = initial_value  
        self.step_size_data = step_size_data
        self.step_size_sim = step_size_sim
        self.num_data = num_data
        self.num_simulation = num_simulation

    def calc_A_q(self, q, θ): # θ = (ε, γ, β, σ) 
        dt = self.step_size_data
        ε, γ, β, σ = θ
        return jnp.array(
            [-dt/ε + (- (1 - 3 * q**2) / ε + 1) * (dt**2) / (2 * ε)]
            )

    def calc_mu_q(self, q, θ, s=0.01):
        dt = self.step_size_data
        ε, γ, β, σ = θ
        return (q + dt * (q - q**3 - s) / ε + 0.5 * (dt ** 2) * ((1 - 3 * q**2) * (q - q**3 - s) / ε**2 - (γ * q + β) / ε))  
    
    def calc_mu_h(self, q, θ):
        dt = self.step_size_data
        ε, γ, β, σ = θ
        return (γ * q + β)*dt


    def matrix_A(self, q, θ):
        ε, γ, β, σ = θ
        dt = self.step_size_data
        return jnp.array([
            [-dt/ε + (- (1 - 3 * q**2) / ε + 1) * (dt**2) / (2 * ε)],
            [1 - dt]
            ]) 
    
    def mean_one_step(self, current_q, current_h, θ):
        dt = self.step_size_data
        ε, γ, β, σ = θ
        A = jnp.array(
            [
                - dt/ε + (-(1 - 3 * current_q**2) / ε + 1) * (dt**2) / (2 * ε),
                1 - dt
             ]
        )

        return jnp.array([self.calc_mu_q(current_q, θ), self.calc_mu_h(current_q, θ)]) + current_h * A 
        
    def covariance_one_step(self, θ):
        dt = self.step_size_data
        ε, γ, β, σ = θ
        return (σ**2) * jnp.array([
            [ (dt**3) / (3 * ε**2), - (dt**2) / (2 * ε)], 
            [- (dt**2) / (2 * ε), dt]
            ])
    
    def prediction_covariance(self, q, forward_filter_covariance, θ):
        Σ =self.covariance_one_step(θ)
        A = self.matrix_A(q, θ)
        pred_cov =  Σ + (A @ A.T)*forward_filter_covariance
        pred_cov_qq = pred_cov[0,0]
        pred_cov_hq = pred_cov[1,0]
        pred_cov_hh = pred_cov[1,1]
        return pred_cov_qq, pred_cov_hq, pred_cov_hh
    
    
    def prediction_mean(self, q, forward_filter_mean, θ):
        pred_mean = self.mean_one_step(q, forward_filter_mean, θ) 
        pred_mean_q = pred_mean[0]
        pred_mean_h = pred_mean[1]
        return pred_mean_q, pred_mean_h
    

    def forward_filter_mean_cov_one_step(self, current_q, next_q, forward_filter_mean, forward_filter_covariance, θ):
        μ_q, μ_h = self.prediction_mean(current_q, forward_filter_mean, θ)
        Λ_qq, Λ_hq, Λ_hh = self.prediction_covariance(current_q, forward_filter_covariance, θ) 
        next_filter_mean = μ_h + ((next_q - μ_q)/Λ_qq)*Λ_hq
        next_filter_cov = Λ_hh - (Λ_hq**2) / Λ_qq
        return jnp.array([next_filter_mean]), jnp.array([next_filter_cov])
    
    
    def forward_filter_mean_cov_paths_scan(self, q_paths, initial_mean, initial_cov, θ):
        @jit
        def step_func(filter_mean_cov, q_paths_current_next):
            filter_mean, filter_cov = filter_mean_cov
            q_current, q_next = q_paths_current_next
            filter_next = self.forward_filter_mean_cov_one_step(q_current, q_next, filter_mean, filter_cov, θ)
            return filter_next, filter_next 
        
        _, filter_mean_cov = scan(step_func, (initial_mean, initial_cov), (q_paths[:-1], q_paths[1:]))
        filter_mean, filter_cov = filter_mean_cov

        return jnp.concatenate((initial_mean[None], filter_mean)), jnp.concatenate((initial_cov[None], filter_cov)) 
    
    def get_contrast_function_scan(self, θ, q_paths, initial_mean, initial_cov):
        filter_mean_paths, filter_cov_paths = self.forward_filter_mean_cov_paths_scan(q_paths, initial_mean, initial_cov, θ)
        initial_log_likelihood = norm.logpdf(q_paths[0], loc = q_paths[0], scale = 1.0)
        Σ = self.covariance_one_step(θ)

        @jit
        def step_func(loglikelihood, qset_and_filtermeancov):
            q_current, q_next, filter_mean, filter_cov = qset_and_filtermeancov
            A_q = self.calc_A_q(q_current, θ)
            q_mean = self.calc_mu_q(q_current, θ) + jnp.dot(A_q, filter_mean)
            vec = filter_cov * A_q 
            scalar = jnp.dot(vec, A_q)
            q_scale = jnp.sqrt(scalar + Σ[0,0])
            loglikelihood_next = loglikelihood + norm.logpdf(q_next, q_mean, q_scale)
            return loglikelihood_next, loglikelihood_next
        
        _, log_likelihood_seq = scan(step_func, initial_log_likelihood, (q_paths[:-1], q_paths[1:], filter_mean_paths[:-1], filter_cov_paths[:-1]))

        return -2*log_likelihood_seq[-1]


In [None]:
class fhn_p3:
    def __init__(self, param, initial_value, step_size_data, step_size_sim, num_data, num_simulation):
        self.param = param # θ = (ε, γ, β, σ)   
        self.initial_value = initial_value  
        self.step_size_data = step_size_data
        self.step_size_sim = step_size_sim
        self.num_data = num_data
        self.num_simulation = num_simulation

    def calc_A_q(self, q, θ): # θ = (ε, γ, β, σ) 
        dt = self.step_size_data
        ε, γ, β, σ = θ
        return jnp.array(
            [-dt/ε + (- (1 - 3 * q**2) / ε + 1) * (dt**2) / (2 * ε)]
            )

    def calc_mu_q(self, q, θ, s=0.01):
        dt = self.step_size_data
        ε, γ, β, σ = θ
        return (q + dt * (q - q**3 - s) / ε + 0.5 * (dt ** 2) * ((1 - 3 * q**2) * (q - q**3 - s) / ε**2 - (γ * q + β) / ε))  
    
    def calc_mu_h(self, q, θ):
        dt = self.step_size_data
        ε, γ, β, σ = θ
        return (γ * q + β)*dt


    def matrix_A(self, q, θ):
        ε, γ, β, σ = θ
        dt = self.step_size_data
        return jnp.array([
            [-dt/ε + (- (1 - 3 * q**2) / ε + 1) * (dt**2) / (2 * ε)],
            [1 - dt]
            ]) 
    
    def mean_one_step(self, current_q, current_h, θ):
        dt = self.step_size_data
        ε, γ, β, σ = θ
        A = jnp.array(
            [
                - dt/ε + (-(1 - 3 * current_q**2) / ε + 1) * (dt**2) / (2 * ε),
                1 - dt
             ]
        )

        return jnp.array([self.calc_mu_q(current_q, θ), self.calc_mu_h(current_q, θ)]) + current_h * A 
        
    def covariance_one_step(self, current_q, θ):
        dt = self.step_size_data
        ε, γ, β, σ = θ
        L_1V_S0 =  - σ / ε
        L_1V_R0 = - σ
        L_1L_0V_S0 = σ * (- (1 - 3 * current_q**2) / ε + 1) / ε

        return jnp.array(
            [
             [
                (dt**3) * L_1V_S0**2 / 3 + dt**4 * L_1V_S0 * L_1L_0V_S0 / 4, 
                (dt**2) * L_1V_S0 * σ / 2 + dt**3 * (σ * L_1L_0V_S0 / 6 + L_1V_S0 * L_1V_R0 / 3)
             ], 
             [
                (dt**2) * L_1V_S0 * σ / 2 + dt**3 * (σ * L_1L_0V_S0 / 6 + L_1V_S0 * L_1V_R0 / 3), 
                dt * σ ** 2 + dt**2 * σ * L_1V_R0
             ]
            ]
            )
    
    def prediction_covariance(self, q, forward_filter_covariance, θ):
        Σ = self.covariance_one_step(q, θ)
        A = self.matrix_A(q, θ)
        pred_cov =  Σ + (A @ A.T)*forward_filter_covariance
        pred_cov_qq = pred_cov[0,0]
        pred_cov_hq = pred_cov[1,0]
        pred_cov_hh = pred_cov[1,1]
        return pred_cov_qq, pred_cov_hq, pred_cov_hh
    
    
    def prediction_mean(self, q, forward_filter_mean, θ):
        pred_mean = self.mean_one_step(q, forward_filter_mean, θ) 
        pred_mean_q = pred_mean[0]
        pred_mean_h = pred_mean[1]
        return pred_mean_q, pred_mean_h
    

    def forward_filter_mean_cov_one_step(self, current_q, next_q, forward_filter_mean, forward_filter_covariance, θ):
        μ_q, μ_h = self.prediction_mean(current_q, forward_filter_mean, θ)
        Λ_qq, Λ_hq, Λ_hh = self.prediction_covariance(current_q, forward_filter_covariance, θ) 
        next_filter_mean = μ_h + ((next_q - μ_q)/Λ_qq)*Λ_hq
        next_filter_cov = Λ_hh - (Λ_hq**2) / Λ_qq
        return jnp.array([next_filter_mean]), jnp.array([next_filter_cov])
    
    
    def forward_filter_mean_cov_paths_scan(self, q_paths, initial_mean, initial_cov, θ):
        @jit
        def step_func(filter_mean_cov, q_paths_current_next):
            filter_mean, filter_cov = filter_mean_cov
            q_current, q_next = q_paths_current_next
            filter_next = self.forward_filter_mean_cov_one_step(q_current, q_next, filter_mean, filter_cov, θ)
            return filter_next, filter_next 
        
        _, filter_mean_cov = scan(step_func, (initial_mean, initial_cov), (q_paths[:-1], q_paths[1:]))
        filter_mean, filter_cov = filter_mean_cov

        return jnp.concatenate((initial_mean[None], filter_mean)), jnp.concatenate((initial_cov[None], filter_cov)) 
    
    def get_contrast_function_scan(self, θ, q_paths, initial_mean, initial_cov):
        filter_mean_paths, filter_cov_paths = self.forward_filter_mean_cov_paths_scan(q_paths, initial_mean, initial_cov, θ)
        initial_log_likelihood = norm.logpdf(q_paths[0], loc = q_paths[0], scale = 1.0)

        @jit
        def step_func(loglikelihood, qset_and_filtermeancov):
            q_current, q_next, filter_mean, filter_cov = qset_and_filtermeancov
            Σ = self.covariance_one_step(q_current, θ)
            A_q = self.calc_A_q(q_current, θ)
            q_mean = self.calc_mu_q(q_current, θ) + jnp.dot(A_q, filter_mean)
            vec = filter_cov * A_q 
            scalar = jnp.dot(vec, A_q)
            q_scale = jnp.sqrt(scalar + Σ[0,0])
            loglikelihood_next = loglikelihood + norm.logpdf(q_next, q_mean, q_scale)
            return loglikelihood_next, loglikelihood_next
        
        _, log_likelihood_seq = scan(step_func, initial_log_likelihood, (q_paths[:-1], q_paths[1:], filter_mean_paths[:-1], filter_cov_paths[:-1]))

        return -2*log_likelihood_seq[-1]


In [None]:
# θ = (ε, γ, β, σ)    

def drift_func_rough(x, θ):
    ε, γ, β, *_= θ
    return snp.array([
        γ * x[1] - x[0] + β
    ]) 

def drift_func_smooth(x, θ):
    ε, γ, β, σ = θ
    s = 0.01
    return snp.array([(x[1] - x[1] ** 3 - x[0] - s) / ε])

def diff_coeff_rough(x, θ):
    *_, σ = θ
    return snp.array([[σ]])

def drift_func(x, θ):
    return snp.concatenate((drift_func_rough(x, θ), drift_func_smooth(x, θ)))

def diff_coeff(x, θ):
    return snp.concatenate((diff_coeff_rough(x, θ), snp.zeros((dim_x - dim_r, dim_w))), 0)

dim_x = 2
dim_r = 1
dim_w = 1
dim_θ = 4 

In [None]:
def get_log_likelihood_functions(log_transition_density):
    
    @jit
    def log_likelihood_θ(θ, x_seq, t_seq):
        log_transition_density_terms = vmap(log_transition_density, (0, 0, None, 0))(
            x_seq[1:], x_seq[:-1], θ, t_seq[1:] - t_seq[:-1]
        )
        return log_transition_density_terms.sum()
            
    return {'θ': log_likelihood_θ}


In [None]:
dim_n, step_func = {
    "euler_maruyama": (
        dim_r,
        simsde.integrators.euler_maruyama_step(drift_func, diff_coeff),
    ),
    "local_gaussian": (
        2*dim_r,
        simsde.integrators.hypoelliptic_local_gaussian_step(
        drift_func_rough, drift_func_smooth, diff_coeff_rough)
    )
}["local_gaussian"]

jax_step_func = symnum.numpify(dim_x, dim_θ, dim_n, (), numpy_module=jnp)(step_func)

@jit
def simulate_diffusion(x_0, θ, t_seq, n_seq):
    
    def step_func(x, n_dt):
        n, dt = n_dt
        x_next = jax_step_func(x, θ, n, dt)
        return x_next, x_next
    
    _, x_seq = scan(step_func, x_0, (n_seq, t_seq[1:] - t_seq[:-1]))
    
    return jnp.concatenate((x_0[None], x_seq))

In [None]:
# setting 
rng = np.random.default_rng(20230204)
dt_simulation = 1e-4 # step size for synthetic data 
dt_obs = 0.005  # step size for the observation 
T = 1000 # Time length of data step
n_simulation = int(T / dt_simulation)
sub_interval = int(dt_obs/dt_simulation)
n_data = int(T / dt_obs) # number of data 
θ_true = jnp.array([0.1, 1.5, 0.3, 0.6]) # param θ = (ε, γ, β, σ) 
x_0 = jnp.array([0.0, 0.0]) # initial value  
t_seq_sim = np.arange(int(T / dt_simulation) + 1) * dt_simulation
model_p3 = fhn_p3(θ_true, x_0, dt_obs, dt_simulation, n_data, n_simulation)
model_LG = fhn_lg(θ_true, x_0, dt_obs, dt_simulation, n_data, n_simulation)
initial_mean = jnp.array([0.0])
initial_cov = jnp.array([1.0])

In [None]:
num_sampling = 50
ε_sample_partial_LG = np.empty((num_sampling))
γ_sample_partial_LG= np.empty((num_sampling)) 
β_sample_partial_LG = np.empty((num_sampling)) 
σ_sample_partial_LG = np.empty((num_sampling))
ε_sample_partial_p3 = np.empty((num_sampling))
γ_sample_partial_p3= np.empty((num_sampling))
β_sample_partial_p3 = np.empty((num_sampling))
σ_sample_partial_p3 = np.empty((num_sampling))

seed = 20231124

for k in range(num_sampling):
    rng = np.random.default_rng(seed)
    n_seqs = rng.standard_normal((t_seq_sim.shape[0] - 1, dim_n))
    print("Compute the observations -- Start")
    x_seqs_sim = simulate_diffusion(x_0, θ_true, t_seq_sim, n_seqs)
    x_seq_obs = x_seqs_sim[::sub_interval]
    print("Compute the observations -- End")
    inv_start = 0
    q_paths_obs = x_seq_obs[inv_start:, 1]
    t_seq_obs = t_seq_sim[::sub_interval]
    t_seq_obs = t_seq_obs[inv_start:]
    plt.plot(t_seq_obs, q_paths_obs, "-", markersize=0.2)
    plt.xlabel("t")
    plt.ylabel("q_t")
    plt.savefig("Paths for q_t")
    plt.show()

    θ_0 = [0.5, 0.5, 0.5, 0.5]

    arg = (q_paths_obs, initial_mean, initial_cov)
    count = 0
    def cbf_p3(X):
        global count
        count += 1
        f = model_p3.get_contrast_function_scan(X, q_paths_obs, initial_mean, initial_cov)
        print('%d\t%f\t%f\t%f\t%f\t%f' % (count, X[0], X[1], X[2], X[3], f))
        
    print("Optimisation-p3--Start")

    res_p3 = scipy.optimize.minimize(
        model_p3.get_contrast_function_scan, 
        θ_0, 
        args=arg,
        method='Nelder-Mead', 
        callback=cbf_p3,
        options={"maxiter":5000}
    )
    print("Optimisation-p3--End")
    print(res_p3)
    print(k)

    def cbf_LG(X):
        global count
        count += 1
        f = model_LG.get_contrast_function_scan(X, q_paths_obs, initial_mean, initial_cov)
        print('%d\t%f\t%f\t%f\t%f\t%f' % (count, X[0], X[1], X[2], X[3], f))
        
    print("Optimisation-LG--Start")

    res_LG = scipy.optimize.minimize(
        model_LG.get_contrast_function_scan, 
        θ_0, 
        args=arg,
        method='Nelder-Mead', 
        callback=cbf_LG,
        options={"maxiter":5000}
    )
    print("Optimisation-LG--End")
    print(res_LG)
    print(k)

    ε_sample_partial_LG[k] = res_LG.x[0] 
    γ_sample_partial_LG[k] = res_LG.x[1] 
    β_sample_partial_LG[k] = res_LG.x[2]
    σ_sample_partial_LG[k] = res_LG.x[3]

    ε_sample_partial_p3[k] = res_p3.x[0]
    γ_sample_partial_p3[k] = res_p3.x[1]
    β_sample_partial_p3[k] = res_p3.x[2]
    σ_sample_partial_p3[k] = res_p3.x[3]
    seed += 1

In [None]:
f = open(f'MLE_FHN_partial_LG={T}_dt_obs_{dt_obs}_dt_sim_{dt_simulation}.csv', 'w')
writer = csv.writer(f, delimiter='\t')
writer.writerow(ε_sample_partial_LG)
writer.writerow(γ_sample_partial_LG)
writer.writerow(β_sample_partial_LG)
writer.writerow(σ_sample_partial_LG)
f.close()

f = open(f'MLE_FHN_partial_p3={T}_dt_obs_{dt_obs}_dt_sim_{dt_simulation}.csv', 'w')
writer = csv.writer(f, delimiter='\t')
writer.writerow(ε_sample_partial_p3)
writer.writerow(γ_sample_partial_p3)
writer.writerow(β_sample_partial_p3)
writer.writerow(σ_sample_partial_p3)
f.close()