In [None]:
import numpy as np
import matplotlib.pyplot as plt
import csv
import time
import scipy.stats
from jax.scipy.stats import norm
import scipy.optimize
import scipy
import sympy
import symnum
import symnum.numpy as snp
import numpy as np
import jax.numpy as jnp
import jax.random as jrand
import jax.scipy.optimize as jopt
from jax.scipy.linalg import cho_solve
from jax import jit, vmap, grad, value_and_grad
from jax.lax import scan
from jax.example_libraries.optimizers import adam
import matplotlib.pyplot as plt
from jax.config import config
config.update('jax_enable_x64', True)
config.update('jax_platform_name', 'cpu')
import simsde 
from simsde.operators import v_hat_k, subscript_k

In [None]:
# x[0]: extended space (rough), x[1]: momentum, x[2]:position

def drift_position(x, θ):
    return snp.array([x[1]])

def drift_momentum(x, θ):
    D, λ, *_ = θ
    # the potential function q -> V(q) is assumed to be V(q) = (q^2 - D)^2 / 4 
    return snp.array([- (x[2]**3-D*x[2]) + λ*x[0]])

def diff_coeff_rough(x, θ):
    *_, σ = θ
    return snp.array([[σ]])

def drift_rough(x, θ):
    D, λ, α, *_ = θ
    return snp.array([- λ*x[1] - α*x[0]])

def drift_smooth(x, θ):
    return snp.concatenate((drift_momentum(x, θ), drift_position(x, θ)))

def drift_func(x, θ):
    return snp.concatenate((drift_rough(x, θ), drift_smooth(x, θ)))

def diff_coeff(x, θ):
    *_, σ = θ
    return snp.array([[σ], [0], [0]])

dim_x = 3
dim_s1 = 1
dim_s2 = 1
dim_r = 1
dim_θ = 4 
dim_w = 1

In [None]:
# contrast function for weaker step size condition

def m_and_Σ_p_3(y, x, θ, t):
    dim_r = drift_rough(x, θ).shape[0]
    dim_s2 = drift_momentum(x, θ).shape[0]
    x_r, x_s_2, x_s_1 = x[:dim_r], x[dim_r : dim_r + dim_s2], x[dim_r + dim_s2 :]
    y_r, y_s_2, y_s_1 = y[:dim_r], y[dim_r : dim_r + dim_s2], y[dim_r + dim_s2 :]

    # m: standardisation of three components
    m = snp.concatenate(
        [
        (
            y_r - x_r - drift_rough(x, θ) * t
        )
        / snp.sqrt(t), 
        #
        (
            y_s_2 - x_s_2 - drift_momentum(x, θ) * t 
            - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_momentum)(x, θ)* (t**2) / 2
        )
        / snp.sqrt(t)**3,
        # 
        (
            y_s_1 - x_s_1 - drift_position(x, θ) * t 
            - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_position)(x, θ)* (t**2) / 2
            - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_position))(x, θ) * (t**3)/6
        )
        / snp.sqrt(t)**5,
        ]
    )

    C_r = diff_coeff_rough(x, θ)
    C_s2 = snp.array(
        [
            v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_momentum)(
                x, θ
            )
        ],
    )
    C_s1 = snp.array(
        [
            v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(
                v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_position)
            )(x, θ)
        ],
    )
    Σ_RR = C_r @ C_r.T
    Σ_RS2 = C_r @ C_s2.T / 2
    Σ_RS1 = C_r @ C_s1.T / 6
    Σ_S2S2 = C_s2 @ C_s2.T / 3
    Σ_S2S1 = C_s2 @ C_s1.T / 8
    Σ_S1S1 = C_s1 @ C_s1.T / 20
    Σ = snp.concatenate(
        [
            snp.concatenate([Σ_RR, Σ_RS2, Σ_RS1], axis=1),
            snp.concatenate([Σ_RS2.T, Σ_S2S2, Σ_S2S1], axis=1),
            snp.concatenate([Σ_RS1.T, Σ_S2S1.T, Σ_S1S1], axis=1),
        ],
        axis=0,
    )

    return m, Σ

def m_and_Σ_p_4(y, x, θ, t):
    dim_r = drift_rough(x, θ).shape[0]
    dim_s2 = drift_momentum(x, θ).shape[0]
    x_r, x_s_2, x_s_1 = x[:dim_r], x[dim_r : dim_r + dim_s2], x[dim_r + dim_s2 :]
    y_r, y_s_2, y_s_1 = y[:dim_r], y[dim_r : dim_r + dim_s2], y[dim_r + dim_s2 :]

    # m: standardisation of three components
    m = snp.concatenate(
        [
        (
            y_r - x_r - drift_rough(x, θ) * t - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_rough)(x, θ)* t**2 / 2
        )
        / snp.sqrt(t), 
        #
        (
            y_s_2 - x_s_2 - drift_momentum(x, θ) * t 
            - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_momentum)(x, θ)* t**2 / 2
            - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(
            v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_momentum)
            )(x, θ) * (t**3) / 6
        )
        / snp.sqrt(t)**3,
        # 
        (
            y_s_1 - x_s_1 - drift_position(x, θ) * t 
            - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_position)(x, θ)* t**2 / 2
            - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_position))(x, θ) * t**3/6
            - v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_position)))(x, θ) * (t**4) / 24
        )
        / snp.sqrt(t)**5,
        ]
    )

    C_r = diff_coeff_rough(x, θ)
    C_s2 = snp.array(
        [
            v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_momentum)(
                x, θ
            )
        ],
    )
    C_s1 = snp.array(
        [
            v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(
                v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_position)
            )(x, θ)
        ],
    )
    Σ_RR = C_r @ C_r.T
    Σ_RS2 = C_r @ C_s2.T / 2
    Σ_RS1 = C_r @ C_s1.T / 6
    Σ_S2S2 = C_s2 @ C_s2.T / 3
    Σ_S2S1 = C_s2 @ C_s1.T / 8
    Σ_S1S1 = C_s1 @ C_s1.T / 20
    Σ = snp.concatenate(
        [
            snp.concatenate([Σ_RR, Σ_RS2, Σ_RS1], axis=1),
            snp.concatenate([Σ_RS2.T, Σ_S2S2, Σ_S2S1], axis=1),
            snp.concatenate([Σ_RS1.T, Σ_S2S1.T, Σ_S1S1], axis=1),
        ],
        axis=0,
    )

    return m, Σ

def sym_Σ_1(x, θ, t):
    σ = diff_coeff_rough(x, θ)
    L1_mu_R = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_rough)(x, θ)
    L1_mu_S2 = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(drift_momentum)(x, θ)
    L1L0_mu_S2 = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_momentum))(x, θ) 
    L1L0_mu_S1 = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_position))(x, θ)
    L1L0L0_mu_S1 = v_hat_k(drift_func, diff_coeff_rough, 1, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(v_hat_k(drift_func, diff_coeff_rough, 0, dim_r)(drift_position)))(x, θ)


    Σ_1_RR = snp.array(2*t*σ*L1_mu_R/2)
    Σ_1_RS2 = snp.array(t*(σ*L1L0_mu_S2/6 + L1_mu_R*L1_mu_S2/3))
    Σ_1_RS1 = snp.array(t*(σ*L1L0L0_mu_S1/24 + L1_mu_R*L1L0_mu_S1/8))
    Σ_1_S2S2 = snp.array([2*t*L1_mu_S2*L1L0_mu_S2/8])
    Σ_1_S2S1 = snp.array([t*(L1_mu_S2*L1L0L0_mu_S1/30 + L1L0_mu_S2*L1L0_mu_S1/20)])
    Σ_1_S1S1 = snp.array([2*t*L1L0_mu_S1*L1L0L0_mu_S1/72])
    Σ_1 = snp.concatenate(
        [
            snp.concatenate([Σ_1_RR, Σ_1_RS2, Σ_1_RS1], axis=1),
            snp.concatenate([Σ_1_RS2, Σ_1_S2S2, Σ_1_S2S1], axis=1),
            snp.concatenate([Σ_1_RS1, Σ_1_S2S1, Σ_1_S1S1], axis=1),
        ],
        axis=0,
    )
    return Σ_1

def contrast_function_p_2(
    drift_func_smooth_1, drift_func_smooth_2, drift_func_rough, diff_coeff_rough):

    def one_step_contrast_function(x_t, x_0, θ, t):
        dim_x = x_0.shape[0]
        m, Σ = m_and_Σ_p_3(x_t, x_0, θ, t) 
        m, Σ, = sympy.Matrix(m), sympy.Matrix(Σ)
        chol_Σ = Σ.cholesky(hermitian=False)
        invΣ = Σ.inverse_CH()
        
        return -(
            (m.T @ invΣ @ m)[0, 0]
            / 2
            + snp.log(chol_Σ.diagonal()).sum()
        )

    return one_step_contrast_function

def contrast_function_p_3(
    drift_func_smooth_1, drift_func_smooth_2, drift_func_rough, diff_coeff_rough):

    def one_step_contrast_function(x_t, x_0, θ, t):
        dim_x = x_0.shape[0]
        m, Σ = m_and_Σ_p_3(x_t, x_0, θ, t) 
        Σ_1 = sym_Σ_1(x_0, θ, t)
        m, Σ, Σ_1 = sympy.Matrix(m), sympy.Matrix(Σ), sympy.Matrix(Σ_1)
        chol_Σ = Σ.cholesky(hermitian=False)
        invΣ = Σ.inverse_CH()
        invΣ_Σ_1 = sympy.Matrix(invΣ @ Σ_1)
        
        return -(
            (m.T @ (invΣ - invΣ @ Σ_1 @ invΣ) @ m)[0, 0]
            / 2
            + snp.log(chol_Σ.diagonal()).sum()
            + invΣ_Σ_1.trace()/2
        )

    return one_step_contrast_function

In [None]:
symolic_log_transition_density_generators = {
    'local_gaussian (p = 2)': contrast_function_p_2,
    'local_gaussian (p = 3)': contrast_function_p_3,
}
jax_log_transition_densities = {
    key: symnum.numpify(dim_x, dim_x, dim_θ, None, numpy_module=jnp)(
        symbolic_transition_density_generator(
            drift_position, drift_momentum, drift_rough, diff_coeff_rough
        )
    )
    for key, symbolic_transition_density_generator in 
    symolic_log_transition_density_generators.items()
}

In [None]:
def get_log_likelihood_functions(log_transition_density):
    @jit
    
    def log_likelihood_θ(θ, x_seq, t_seq):
        log_transition_density_terms = vmap(log_transition_density, (0, 0, None, 0))(
            x_seq[1:], x_seq[:-1], θ, t_seq[1:] - t_seq[:-1]
        )
        return log_transition_density_terms.sum()
            
    return {'θ': log_likelihood_θ}
 

In [None]:
dim_n, step_func = {
    "euler_maruyama": (
        dim_r,
        simsde.integrators.euler_maruyama_step(drift_func, diff_coeff),
    ),
    "local_gaussian_ii": (
        3*dim_r,
        simsde.integrators.hypoelliptic_ii_local_gaussian_step(
        drift_func, drift_rough, drift_position, drift_momentum, diff_coeff_rough)
    )
}["local_gaussian_ii"]

jax_step_func = symnum.numpify(dim_x, dim_θ, dim_n, (), numpy_module=jnp)(step_func)

@jit
def simulate_diffusion(x_0, θ, t_seq, n_seq):
    
    def step_func(x, n_dt):
        n, dt = n_dt
        x_next = jax_step_func(x, θ, n, dt)
        return x_next, x_next
    
    _, x_seq = scan(step_func, x_0, (n_seq, t_seq[1:] - t_seq[:-1]))
    
    return jnp.concatenate((x_0[None], x_seq))

In [None]:
# setting 
rng = np.random.default_rng(20230204)
dt_simulation = 1e-4  # step size for synthetic data 
T = 1000  # Time length of data step 
θ_true = jnp.array([2.0, 1.0, 4.0, 4.0]) # param θ = (D, λ, α, σ) 
x_0 = jnp.array([0.0, 0.0, 0.0]) # initial value for the state  
t_seq_sim = np.arange(int(T / dt_simulation) + 1) * dt_simulation

In [None]:
def compute_complete_maximum_likelihood_estimates(
    log_likelihood, t_seq, x_seqs, θ_0, optimizer=adam, n_steps=8000, step_size=1e-1
):
    optimizer_init, optimizer_update, optimizer_get_params = optimizer(step_size)
    
    @jit 
    def optimizer_step(step_index, state, x_seq, t_seq):
        value, grad = value_and_grad(log_likelihood["θ"])(
            optimizer_get_params(state), x_seq, t_seq
        )
        state = optimizer_update(step_index, -grad, state)
        return value, state

    state = optimizer_init(θ_0)

    for s in range(n_steps):
        _, state = optimizer_step(s, state, x_seqs, t_seq)
        # print(optimizer_get_params(state))
        
    return optimizer_get_params(state)


In [None]:
num_sampling = 100
seed = 20231024

contrast_type_items = {
    0: 'local_gaussian (p = 2)',
    1: 'local_gaussian (p = 3)',
}

D_sample_complete = np.empty((2, num_sampling))
λ_sample_complete = np.empty((2, num_sampling))
α_sample_complete = np.empty((2, num_sampling))
σ_sample_complete = np.empty((2, num_sampling))

dt_obs_items = [0.005]
for item_dt in range(len(dt_obs_items)):
    dt_obs = dt_obs_items[item_dt]  # step size for the observation 
    sub_interval = int(dt_obs/dt_simulation)
    for k in range(num_sampling):
        print("Compute the observations -- Start")
        rng = np.random.default_rng(seed)
        n_seqs = rng.standard_normal((t_seq_sim.shape[0] - 1, dim_n))
        x_seqs_sim = simulate_diffusion(x_0, θ_true, t_seq_sim, n_seqs)
        x_seq_obs = x_seqs_sim[::sub_interval]
        print("Compute the observations -- End")

        for key, type in contrast_type_items.items():
            t_seq_obs = t_seq_sim[::sub_interval]
            θ_0 = jnp.array([1.0, 1.0, 1.0, 1.0])
            log_transition_density = jax_log_transition_densities[type]
            log_likelihood = get_log_likelihood_functions(log_transition_density)
            print("Optimisation Complete Observation Adam -- Start", type, dt_obs)
            complete_adam = compute_complete_maximum_likelihood_estimates(log_likelihood, t_seq_obs, x_seq_obs, θ_0)
            print("Optimisation Complete Observation Adam -- End", type, dt_obs)
            print(complete_adam)
            print(value_and_grad(log_likelihood["θ"])(complete_adam, x_seq_obs, t_seq_obs))
            print(k)
            D_sample_complete[key, k] = complete_adam[0]
            λ_sample_complete[key, k] = complete_adam[1]
            α_sample_complete[key, k] = complete_adam[2]
            σ_sample_complete[key, k] = complete_adam[3]
        
        seed += 1
    
    # record in a csv file 
    for key, type in contrast_type_items.items():
        f = open(f'MLE_non-linear_final_{type}=T_{T}_dt_obs_{dt_obs}_dt_sim_{dt_simulation}.csv', 'w')
        writer = csv.writer(f, delimiter='\t')
        writer.writerow(D_sample_complete[key,:])
        writer.writerow(λ_sample_complete[key,:])
        writer.writerow(α_sample_complete[key,:])
        writer.writerow(σ_sample_complete[key,:])
        f.close()