In [None]:
import numpy as np
import matplotlib.pyplot as plt
import csv
import time
import scipy.stats
from jax.scipy.stats import norm
import scipy.optimize
import scipy
import sympy
import symnum
import symnum.numpy as snp
import numpy as np
import jax.numpy as jnp
import jax.random as jrand
import jax.scipy.optimize as jopt
from jax.scipy.linalg import cho_solve
from jax import jit, vmap, grad, value_and_grad
from jax.lax import scan
from jax.example_libraries.optimizers import adam
import matplotlib.pyplot as plt
from jax.config import config
config.update('jax_enable_x64', True)
config.update('jax_platform_name', 'cpu')
import simsde 
from simsde.operators import v_hat_k, subscript_k

In [None]:
def drift_func_rough(x, θ):
    γ, α, *_= θ
    return snp.array([
        γ * x[1] - x[0] + α
    ]) 

def drift_func_smooth(x, θ):
    *_, ε, σ = θ
    s = 0.01
    return snp.array([1/ε * (x[1] - x[1] ** 3 - x[0] - s)])

def diff_coeff_rough(x, θ):
    *_, σ = θ
    return snp.array([[σ]])

def drift_func(x, θ):
    return snp.concatenate((drift_func_rough(x, θ), drift_func_smooth(x, θ)))

def diff_coeff(x, θ):
    return snp.concatenate((diff_coeff_rough(x, θ), snp.zeros((dim_x - dim_r, dim_w))), 0)

dim_x = 2
dim_r = 1
dim_w = 1
dim_θ = 4 

In [None]:
# contrast function for weaker step size condition

def m_Σ_p_2(y, x, θ, t, drift_func_smooth, drift_func_rough, diff_coeff_rough):

    def drift_func(x, θ):
        return snp.concatenate((drift_func_rough(x, θ), drift_func_smooth(x, θ)))

    dim_r = drift_func_rough(x, θ).shape[0]
    x_r, x_s = x[:dim_r], x[dim_r:]
    y_r, y_s = y[:dim_r], y[dim_r:]

    # m: standardisation of three components
    m = snp.concatenate(
        [
        (
            y_r - x_r - drift_func_rough(x, θ) * t
        )
        / snp.sqrt(t), 
        #
        (
            y_s - x_s - drift_func_smooth(x, θ) * t 
            - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_func_smooth)(x, θ)* (t**2) / 2
        )
        / snp.sqrt(t)**3
        ]
    )

    C_r = diff_coeff_rough(x, θ)
    C_s = snp.array([v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_func_smooth)(x, θ)])

    Σ_0_RR = C_r @ C_r.T
    Σ_0_RS = C_r @ C_s.T / 2
    Σ_0_SS = C_s @ C_s.T / 3
    Σ = snp.concatenate(
        [
            snp.concatenate([Σ_0_RR, Σ_0_RS], axis=1),
            snp.concatenate([Σ_0_RS.T, Σ_0_SS], axis=1),
        ],
        axis=0,
    )

    return m, Σ

def sym_Σ1(x, θ, t, drift_func_smooth, drift_func_rough, diff_coeff_rough):
    dim_r = drift_func_rough(x, θ).shape[0]
    x_r, x_s = x[:dim_r], x[dim_r:]

    def drift_func(x, θ):
        return snp.concatenate((drift_func_rough(x, θ), drift_func_smooth(x, θ)))
        
    L1_drift_smooth = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_func_smooth)(x, θ)
    L1_drift_rough = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_func_rough)(x, θ)
    L1L0_drift_smooth = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_func_smooth))(x, θ)

    C_r = diff_coeff_rough(x, θ)
    C_s = snp.array([L1_drift_smooth])

    Σ_1_RR = snp.array(C_r * L1_drift_rough)
    Σ_1_RS = snp.array(C_r * L1L0_drift_smooth / 6 + C_s * L1_drift_rough /3) 
    Σ_1_SS = snp.array(C_s * L1L0_drift_smooth  / 4)

    Σ = snp.concatenate(
        [
        snp.concatenate([Σ_1_RR, Σ_1_RS], axis=1),
        snp.concatenate([Σ_1_RS.T, Σ_1_SS], axis=1),
        ],
        axis=0,
        )

    return Σ

def sym_Σ1_Σ2(x, θ, t, drift_func_smooth, drift_func_rough, diff_coeff_rough):
    dim_r = drift_func_rough(x, θ).shape[0]
    x_r, x_s = x[:dim_r], x[dim_r:]

    def drift_func(x, θ):
        return snp.concatenate((drift_func_rough(x, θ), drift_func_smooth(x, θ)))
        
    L1_drift_smooth = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_func_smooth)(x, θ)
    L1_drift_rough = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_func_rough)(x, θ)
    L1L0_drift_smooth = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_func_smooth))(x, θ)

    C_r = diff_coeff_rough(x, θ)
    C_s = snp.array([L1_drift_smooth])
    
    L1_drift_smooth = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_func_smooth)(x, θ)
    L1_drift_rough = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_func_rough)(x, θ)
    L1L0_drift_smooth = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_func_smooth))(x, θ)
    L1L0_drift_rough = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_func_rough))(x, θ)
    L1L0L0_drift_smooth = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_func_smooth)))(x, θ)
    
    C_r = diff_coeff_rough(x, θ)
    C_s = snp.array([L1_drift_smooth])

    Σ_0_RR = C_r @ C_r.T 
    Σ_0_RS = C_r @ C_s.T 
    Σ_0_SS = C_s @ C_s.T 

    Σ_1_RR = snp.array(C_r * L1_drift_rough)
    Σ_1_RS = snp.array(C_r * L1L0_drift_smooth / 6 + C_s * L1_drift_rough /3) 
    Σ_1_SS = snp.array(C_s * L1L0_drift_smooth  / 4)

    Σ_2_RR = snp.array(C_r * L1L0_drift_rough / 3 + L1_drift_rough **2 / 3)
    Σ_2_RS = snp.array(C_s * L1L0_drift_rough /6 + L1L0_drift_smooth*L1_drift_rough /8 + C_r * L1L0L0_drift_smooth /24) 
    Σ_2_SS = snp.array((C_s * L1L0L0_drift_smooth / 15 + L1L0_drift_smooth **2 / 20))

    Σ_1 = snp.concatenate(
        [
        snp.concatenate([Σ_1_RR, Σ_1_RS], axis=1),
        snp.concatenate([Σ_1_RS.T, Σ_1_SS], axis=1),
        ],
        axis=0,
        )

    Σ_2 = snp.concatenate(
        [
        snp.concatenate([Σ_2_RR, Σ_2_RS], axis=1),
        snp.concatenate([Σ_2_RS.T, Σ_2_SS], axis=1),
        ],
        axis=0,
        )

    return Σ_1, Σ_2

def m_p4_Σ_0(y, x, θ, t, drift_func_smooth, drift_func_rough, diff_coeff_rough):
    dim_r = drift_func_rough(x, θ).shape[0]
    x_r, x_s = x[:dim_r], x[dim_r:]
    y_r, y_s = y[:dim_r], y[dim_r:]

    def drift_func(x, θ):
        return snp.concatenate((drift_func_rough(x, θ), drift_func_smooth(x, θ)))

    m = snp.concatenate(
        [
        (
            y_r - x_r - drift_func_rough(x, θ) * t - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_func_rough)(x, θ)* t**2 / 2
        )
        / snp.sqrt(t), 
        #
        (
            y_s - x_s - drift_func_smooth(x, θ) * t 
            - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_func_smooth)(x, θ)* t**2 / 2
            - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_func_smooth))(x, θ)* t**3 / 6
        )
        / snp.sqrt(t)**3
        ]
    )

    C_r = diff_coeff_rough(x, θ)
    C_s = snp.array([v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_func_smooth)(x, θ)])

    Σ_0_RR = C_r @ C_r.T
    Σ_0_RS = C_r @ C_s.T / 2
    Σ_0_SS = C_s @ C_s.T / 3
    Σ = snp.concatenate(
        [
            snp.concatenate([Σ_0_RR, Σ_0_RS], axis=1),
            snp.concatenate([Σ_0_RS.T, Σ_0_SS], axis=1),
        ],
        axis=0,
    )
    return m, Σ

def contrast_function_p_2(
    drift_func_smooth, drift_func_rough, diff_coeff_rough):

    def one_step_contrast_function(x_t, x_0, θ, t):
        dim_x = x_0.shape[0]
        m, Σ = m_Σ_p_2(x_t, x_0, θ, t, drift_func_smooth, drift_func_rough, diff_coeff_rough) 
        m, Σ, = sympy.Matrix(m), sympy.Matrix(Σ)
        chol_Σ = Σ.cholesky(hermitian=False)
        invΣ = Σ.inverse_CH()
        
        return -(
            (m.T @ invΣ @ m)[0, 0]
            / 2
            + snp.log(chol_Σ.diagonal()).sum()
        )

    return one_step_contrast_function

def contrast_function_p_4(drift_func_smooth, drift_func_rough, diff_coeff_rough):

    def one_step_contrast_function(x_t, x_0, θ, t):
        dim_x = x_0.shape[0]
        m, Σ = m_p4_Σ_0(x_t, x_0, θ, t, drift_func_smooth, drift_func_rough, diff_coeff_rough) 
        Σ_1, Σ_2 = sym_Σ1_Σ2(x_0, θ, t, drift_func_smooth, drift_func_rough, diff_coeff_rough)
        m, Σ, Σ_1, Σ_2 = sympy.Matrix(m), sympy.Matrix(Σ), sympy.Matrix(Σ_1), sympy.Matrix(Σ_2) 
        chol_Σ = Σ.cholesky(hermitian=False)
        invΣ = Σ.inverse_CH()
        G_1 = - invΣ @ Σ_1 @ invΣ
        G_2 = -(G_1 @ Σ_1 + invΣ @ Σ_2) @ invΣ
        H_1 = sympy.Matrix(invΣ @ Σ_1).trace()
        H_2 = sympy.Matrix(G_1 @ Σ_1 / 2 + invΣ @ Σ_2).trace()
        
        return -(
            (m.T @ (invΣ + t * G_1 + t**2 * G_2) @ m)[0, 0]
            / 2
            + snp.log(chol_Σ.diagonal()).sum()
            + (t * H_1 + t**2 * H_2)/2
        )

    return one_step_contrast_function

    

In [None]:
symolic_log_transition_density_generators = {
    'local_gaussian (p = 2)': contrast_function_p_2, 
    'local_gaussian (p = 4)': contrast_function_p_4,
}
jax_log_transition_densities = {
    key: symnum.numpify(dim_x, dim_x, dim_θ, None, numpy_module=jnp)(
        symbolic_transition_density_generator(
            drift_func_smooth, drift_func_rough, diff_coeff_rough
        )
    )
    for key, symbolic_transition_density_generator in 
    symolic_log_transition_density_generators.items()
}

In [None]:
def get_log_likelihood_functions(log_transition_density):
    @jit
    
    def log_likelihood_θ(θ, x_seq, t_seq):
        log_transition_density_terms = vmap(log_transition_density, (0, 0, None, 0))(
            x_seq[1:], x_seq[:-1], θ, t_seq[1:] - t_seq[:-1]
        )
        return log_transition_density_terms.sum()
            
    return {'θ': log_likelihood_θ}
 

In [None]:
dim_n, step_func = {
    "euler_maruyama": (
        dim_r,
        simsde.integrators.euler_maruyama_step(drift_func, diff_coeff),
    ),
    "local_gaussian": (
        2*dim_r,
        simsde.integrators.hypoelliptic_local_gaussian_step(
        drift_func_rough, drift_func_smooth, diff_coeff_rough)
    )
}["local_gaussian"]

jax_step_func = symnum.numpify(dim_x, dim_θ, dim_n, (), numpy_module=jnp)(step_func)

@jit
def simulate_diffusion(x_0, θ, t_seq, n_seq):
    
    def step_func(x, n_dt):
        n, dt = n_dt
        x_next = jax_step_func(x, θ, n, dt)
        return x_next, x_next
    
    _, x_seq = scan(step_func, x_0, (n_seq, t_seq[1:] - t_seq[:-1]))
    
    return jnp.concatenate((x_0[None], x_seq))

In [None]:
# setting 
rng = np.random.default_rng(20230204)
dt_simulation = 1e-4 # step size for synthetic data 
T =5000 # Time length of data step 
θ_true = jnp.array([1.5, 0.3, 0.1, 0.6]) # param θ = (D, λ, α, σ) 
x_0 = jnp.array([0.0, 0.0]) # initial value  
t_seq_sim = np.arange(int(T / dt_simulation) + 1) * dt_simulation

In [None]:
def compute_complete_maximum_likelihood_estimates(
    log_likelihood, t_seq, x_seqs, θ_0, optimizer=adam, n_steps=8000, step_size=1e-2
):
    optimizer_init, optimizer_update, optimizer_get_params = optimizer(step_size)
    
    @jit 
    def optimizer_step(step_index, state, x_seq, t_seq):
        value, grad = value_and_grad(log_likelihood["θ"])(
            optimizer_get_params(state), x_seq, t_seq
        )
        state = optimizer_update(step_index, -grad, state)
        return value, state

    state = optimizer_init(θ_0)

    for s in range(n_steps):
        _, state = optimizer_step(s, state, x_seqs, t_seq)
        
    return optimizer_get_params(state)


In [None]:
num_sampling = 50
seed = 20231138

contrast_type_items = {
    0: 'local_gaussian (p = 2)',
    1: 'local_gaussian (p = 4)',
}

γ_sample_complete = np.empty((2, num_sampling))
α_sample_complete = np.empty((2, num_sampling))
ε_sample_complete = np.empty((2, num_sampling))
σ_sample_complete = np.empty((2, num_sampling))

dt_obs_items = [0.02]
for item_dt in range(len(dt_obs_items)):
    dt_obs = dt_obs_items[item_dt]  # step size for the observation 
    sub_interval = int(dt_obs/dt_simulation)
    for k in range(num_sampling):
        print("Compute the observations -- Start")
        rng = np.random.default_rng(seed)
        n_seqs = rng.standard_normal((t_seq_sim.shape[0] - 1, dim_n))
        x_seqs_sim = simulate_diffusion(x_0, θ_true, t_seq_sim, n_seqs)
        x_seq_obs = x_seqs_sim[::sub_interval]
        print("Compute the observations -- End")

        for key, type in contrast_type_items.items():
            t_seq_obs = t_seq_sim[::sub_interval]
            θ_0 = jnp.array([1.0, 1.0, 1.0, 1.0])
            log_transition_density = jax_log_transition_densities[type]
            log_likelihood = get_log_likelihood_functions(log_transition_density)
            print("Optimisation Complete Observation Adam -- Start", type, dt_obs)
            complete_adam = compute_complete_maximum_likelihood_estimates(log_likelihood, t_seq_obs, x_seq_obs, θ_0)
            print("Optimisation Complete Observation Adam -- End", type, dt_obs)
            print(complete_adam)
            print(value_and_grad(log_likelihood["θ"])(complete_adam, x_seq_obs, t_seq_obs))
            print(k)
            γ_sample_complete[key, k] = complete_adam[0]
            α_sample_complete[key, k] = complete_adam[1]
            ε_sample_complete[key, k] = complete_adam[2]
            σ_sample_complete[key, k] = complete_adam[3]
            # seed += 1
        
        seed += 1

In [None]:
for key, type in contrast_type_items.items():
        f = open(f'MLE_FHN_complete_{type}=T_{T}_dt_obs_{dt_obs}_dt_sim_{dt_simulation}.csv', 'w')
        writer = csv.writer(f, delimiter='\t')
        writer.writerow(γ_sample_complete[key,:])
        writer.writerow(α_sample_complete[key,:])
        writer.writerow(ε_sample_complete[key,:])
        writer.writerow(σ_sample_complete[key,:])
        f.close()