In [1]:
import numpy as np
import matplotlib.pyplot as plt
import csv
import time
import scipy.stats
from jax.scipy.stats import norm
import scipy.optimize
import simsde
import symnum
import symnum.numpy as snp
import jax.numpy as jnp
from jax.scipy.linalg import cho_solve
from jax import jit, vmap, grad, value_and_grad
from jax.lax import scan
from jax.example_libraries.optimizers import adam
import matplotlib.pyplot as plt
from jax.config import config
config.update('jax_enable_x64', True)
config.update('jax_platform_name', 'cpu')
 

In [2]:
class hypo_gle_ho:
    def __init__(self, param, initial_value, step_size_data, step_size_sim, num_data, num_simulation):
        self.param = param # θ = (D, λ, α, σ)   
        self.initial_value = initial_value  
        self.step_size_data = step_size_data
        self.step_size_sim = step_size_sim
        self.num_data = num_data
        self.num_simulation = num_simulation

    def calc_A_q(self, θ): # θ = (D, λ, α, σ)    
        dt = self.step_size_data
        D, λ, α ,σ = θ
        return jnp.array(
            [dt - (D + (λ**2)) * (dt**3) /6, λ * (dt**2) /2 - λ * α * (dt**3) / 6]
            )

    def calc_A_h(self, θ): 
        dt = self.step_size_data
        D, λ, α ,σ = θ
        return jnp.array(
            [ 
            [1 - (D + (λ**2)) * (dt**2)/2, λ * dt - λ*α*(dt**2)/2], 
            [-λ*dt, 1 - α*dt]
            ])
    
    def calc_mu_q(self, q, θ):
        D = θ[0]
        dt = self.step_size_data
        return q - D*q*(dt**2)/2 

    def matrix_A(self, θ):
        dt = self.step_size_data
        D, λ, α ,σ = θ
        return jnp.array([
            [dt - (D + (λ**2)) * (dt**3) /6, λ * (dt**2) /2 - λ * α * (dt**3) / 6],
            [1 - (D + (λ**2)) * (dt**2) /2, λ * dt - λ * α * (dt**2) /2], 
            [-λ*dt, 1 - α*dt]
            ]) 
    
    # This is used to generate the true trajectories of sample paths 
    # x = (q, h) 
    def mean_one_step_sim(self, x, θ): 
        dt = self.step_size_sim
        D, λ, α ,σ = θ
        q = x[0]
        h = x[1:]
        matrix_A = jnp.array([
            [dt - (D + (λ**2)) * (dt**3) /6, λ * (dt**2) /2 - λ * α * (dt**3) / 6],
            [1 - (D + (λ**2)) * (dt**2) /2, λ * dt - λ * α * (dt**2) /2], 
            [-λ*dt, 1 - α*dt]
            ]) 
        return jnp.array([q - D*q*(dt**2)/2, - D*q*dt, 0]) + jnp.dot(matrix_A, h) 
    
    def mean_one_step(self, current_q, current_h, θ):
        dt = self.step_size_data
        D, λ, α ,σ = θ
        matrix_A = jnp.array([
            [dt - (D + (λ**2)) * (dt**3) /6, λ * (dt**2) /2 - λ * α * (dt**3) / 6],
            [1 - (D + (λ**2)) * (dt**2) /2, λ * dt - λ * α * (dt**2) /2], 
            [-λ*dt, 1 - α*dt]
            ]) 
        return jnp.array([current_q - D*current_q*(dt**2)/2,  - D*current_q*dt, 0]) + jnp.dot(matrix_A, current_h) 
    
    def covariance_one_step_sim(self, θ):
        dt = self.step_size_sim
        D, λ, α ,σ = θ
        return (σ**2)*jnp.array([
            [(λ**2)*(dt**5)/20, (λ**2)*(dt**4)/8, λ*(dt**3)/6], 
            [(λ**2)*(dt**4)/8, (λ**2)*(dt**3)/3, λ*(dt**2)/2], 
            [λ*(dt**3)/6, λ*(dt**2)/2, dt]
            ])
    
    def covariance_one_step(self, θ):
        dt = self.step_size_data
        D, λ, α ,σ = θ
        return (σ**2)*jnp.array([
            [(λ**2)*(dt**5)/20, (λ**2)*(dt**4)/8, λ*(dt**3)/6], 
            [(λ**2)*(dt**4)/8, (λ**2)*(dt**3)/3, λ*(dt**2)/2], 
            [λ*(dt**3)/6, λ*(dt**2)/2, dt]
            ])

    def generate_sample_paths(self, θ, seed=20230606):
        np.random.seed(seed)
        seq_rvs = np.random.multivariate_normal(np.zeros(3), self.covariance_one_step_sim(θ), size=self.num_simulation)
        x_0 = self.initial_value
        
        @jit
        def step_func(x, noise):
            x_next = self.mean_one_step_sim(x, θ) + noise
            return x_next, x_next 
        
        _, x_seq = scan(step_func, x_0, seq_rvs) 

        return jnp.concatenate((x_0[None], x_seq))
    
    def prediction_covariance(self, forward_filter_covariance, θ):
        Σ =self.covariance_one_step(θ)
        A = self.matrix_A(θ)
        pred_cov =  Σ + A @ forward_filter_covariance @ A.T 
        pred_cov_qq = pred_cov[0,0]
        pred_cov_hq = pred_cov[1:,0]
        pred_cov_hh = pred_cov[1:, 1:]
        return pred_cov_qq, pred_cov_hq, pred_cov_hh
    
    
    def prediction_mean(self, q, forward_filter_mean, θ):
        pred_mean = self.mean_one_step(q, forward_filter_mean, θ) 
        pred_mean_q = pred_mean[0]
        pred_mean_h = pred_mean[1:]
        return pred_mean_q, pred_mean_h
    

    def forward_filter_mean_cov_one_step(self, current_q, next_q, forward_filter_mean, forward_filter_covariance, θ):
        μ_q, μ_h = self.prediction_mean(current_q, forward_filter_mean, θ)
        Λ_qq, Λ_hq, Λ_hh = self.prediction_covariance(forward_filter_covariance, θ) 
        next_filter_mean = μ_h + ((next_q - μ_q)/Λ_qq)*Λ_hq
        mat = jnp.array([[Λ_hq[0]**2, Λ_hq[0]*Λ_hq[1]], [Λ_hq[0]*Λ_hq[1], Λ_hq[1]**2]])
        next_filter_cov = Λ_hh - mat / Λ_qq
        return next_filter_mean, next_filter_cov
    
    
    def forward_filter_mean_cov_paths_scan(self, q_paths, initial_mean, initial_cov, θ):
        @jit
        def step_func(filter_mean_cov, q_paths_current_next):
            filter_mean, filter_cov = filter_mean_cov
            q_current, q_next = q_paths_current_next
            filter_next = self.forward_filter_mean_cov_one_step(q_current, q_next, filter_mean, filter_cov, θ)
            return filter_next, filter_next 
        
        _, filter_mean_cov = scan(step_func, (initial_mean, initial_cov), (q_paths[:-1], q_paths[1:]))
        filter_mean, filter_cov = filter_mean_cov

        return jnp.concatenate((initial_mean[None], filter_mean)), jnp.concatenate((initial_cov[None], filter_cov)) 
    
    def get_contrast_function_scan(self, θ, q_paths, initial_mean, initial_cov):
        filter_mean_paths, filter_cov_paths = self.forward_filter_mean_cov_paths_scan(q_paths, initial_mean, initial_cov, θ)
        initial_log_likelihood = norm.logpdf(q_paths[0], loc = q_paths[0], scale = 1.0)
        A_q = self.calc_A_q(θ)
        Σ = self.covariance_one_step(θ)

        @jit
        def step_func(loglikelihood, qset_and_filtermeancov):
            q_current, q_next, filter_mean, filter_cov = qset_and_filtermeancov
            q_mean = self.calc_mu_q(q_current, θ) + jnp.dot(A_q, filter_mean)
            vec = A_q @ filter_cov
            scalar = jnp.dot(vec, A_q)
            q_scale = jnp.sqrt(scalar + Σ[0,0])
            loglikelihood_next = loglikelihood + norm.logpdf(q_next, q_mean, q_scale)
            return loglikelihood_next, loglikelihood_next
        
        _, log_likelihood_seq = scan(step_func, initial_log_likelihood, (q_paths[:-1], q_paths[1:], filter_mean_paths[:-1,:], filter_cov_paths[:-1,:,:]))

        return -2*log_likelihood_seq[-1]


In [3]:
def drift_position(x, θ):
    return snp.array([x[1]])

def drift_momentum(x, θ):
    D, λ, *_ = θ
    # the potential function q -> V(q) is assumed to be V(q) = q^2 /2 
    return snp.array([- D*x[2] + λ*x[0]])

def diff_coeff_rough(x, θ):
    *_, σ = θ
    return snp.array([[σ]])

def drift_rough(x, θ):
    D, λ, α, *_ = θ
    return snp.array([- λ*x[1] - α*x[0]])

def drift_smooth(x, θ):
    return snp.concatenate((drift_momentum(x, θ), drift_position(x, θ)))

def drift_func(x, θ):
    return snp.concatenate((drift_rough(x, θ), drift_smooth(x, θ)))

def diff_coeff(x, θ):
    *_, σ = θ
    return snp.array([[σ], [0], [0]])

dim_x = 3
dim_s1 = 1
dim_s2 = 1
dim_r = 1
dim_θ = 4 
dim_w = 1

In [4]:
symolic_log_transition_density_generators = {
    'local_gaussian (p = 2)': simsde.densities.local_gaussian_log_transition_density_ii,
}
jax_log_transition_densities = {
    key: symnum.numpify(dim_x, dim_x, dim_θ, None, numpy_module=jnp)(
        symbolic_transition_density_generator(
            drift_position, drift_momentum, drift_rough, diff_coeff_rough
        )
    )
    for key, symbolic_transition_density_generator in 
    symolic_log_transition_density_generators.items()
}

In [5]:
def get_log_likelihood_functions(log_transition_density):
    
    @jit
    def log_likelihood_θ(θ, x_seq, t_seq):
        log_transition_density_terms = vmap(log_transition_density, (0, 0, None, 0))(
            x_seq[1:], x_seq[:-1], θ, t_seq[1:] - t_seq[:-1]
        )
        return log_transition_density_terms.sum()
            
    return {'θ': log_likelihood_θ}


In [6]:
dim_n, step_func = {
    "euler_maruyama": (
        dim_r,
        simsde.integrators.euler_maruyama_step(drift_func, diff_coeff),
    ),
    "local_gaussian_ii": (
        3*dim_r,
        simsde.integrators.hypoelliptic_ii_local_gaussian_step(
        drift_func, drift_rough, drift_position, drift_momentum, diff_coeff_rough)
    )
}["local_gaussian_ii"]

jax_step_func = symnum.numpify(dim_x, dim_θ, dim_n, (), numpy_module=jnp)(step_func)

@jit
def simulate_diffusion(x_0, θ, t_seq, n_seq):
    
    def step_func(x, n_dt):
        n, dt = n_dt
        x_next = jax_step_func(x, θ, n, dt)
        return x_next, x_next
    
    _, x_seq = scan(step_func, x_0, (n_seq, t_seq[1:] - t_seq[:-1]))
    
    return jnp.concatenate((x_0[None], x_seq))

Computation of maximum likelihood estimates

In [7]:
# setting 
rng = np.random.default_rng(20230204)
dt_simulation = 1e-4 # step size for synthetic data 
dt_obs = 1e-3  # step size for the observation 
T = 200 # Time length of data step
n_simulation = int(T / dt_simulation)
sub_interval = int(dt_obs/dt_simulation)
n_data = int(T / dt_obs) # number of data 
θ_true = jnp.array([1.0, 2.0, 4.0, 4.0]) # param θ = (D, λ, α, σ) 
x_0 = jnp.array([0.0, 0.0, 0.0]) # initial value  
t_seq_sim = np.arange(int(T / dt_simulation) + 1) * dt_simulation
model = hypo_gle_ho(θ_true, x_0, dt_obs, dt_simulation, n_data, n_simulation)
initial_mean = jnp.array([0.0, 0.0])
initial_cov = jnp.array([[1.0, 0.0], [0.0, 1.0]])

In [8]:
def compute_contrast_estimator(q_obs, θ_0, initial_mean, initial_cov, optimizer=adam, n_steps=2000, step_size= 0.2):
    optimizer_init, optimizer_update, optimizer_get_params = optimizer(step_size) 

    @jit
    def optimizer_step(state, q_obs, initial_mean, initial_cov, step_index):
        value, grad = value_and_grad(model.get_contrast_function_scan)(
            optimizer_get_params(state), q_obs, initial_mean, initial_cov
        )
        state = optimizer_update(step_index, grad, state)
        return value, state
    
    state = optimizer_init(θ_0)

    for s in range(n_steps):
        _, state = optimizer_step(state, q_obs, initial_mean, initial_cov, s)
        # print(optimizer_get_params(state))
        
    return optimizer_get_params(state)


log_transition_density = jax_log_transition_densities["local_gaussian (p = 2)"]
log_likelihood = get_log_likelihood_functions(log_transition_density)

def compute_complete_maximum_likelihood_estimates(
    t_seq, x_seqs, θ_0, optimizer=adam, n_steps=2000, step_size=0.2
):
    optimizer_init, optimizer_update, optimizer_get_params = optimizer(step_size)
    
    @jit 
    def optimizer_step(step_index, state, x_seq, t_seq):
        value, grad = value_and_grad(log_likelihood["θ"])(
            optimizer_get_params(state), x_seq, t_seq
        )
        state = optimizer_update(step_index, -grad, state)
        return value, state

    state = optimizer_init(θ_0)

    for s in range(n_steps):
        _, state = optimizer_step(s, state, x_seqs, t_seq)
        # print(optimizer_get_params(state))
        
    return optimizer_get_params(state)


In [9]:
num_sampling = 50
D_sample_complete = np.empty((num_sampling))
λ_sample_complete = np.empty((num_sampling))
α_sample_complete = np.empty((num_sampling))
σ_sample_complete = np.empty((num_sampling))
D_sample_partial = np.empty((num_sampling))
λ_sample_partial = np.empty((num_sampling))
α_sample_partial = np.empty((num_sampling))
σ_sample_partial = np.empty((num_sampling))
seed = 20230624

for k in range(num_sampling):
    rng = np.random.default_rng(seed)
    n_seqs = rng.standard_normal((t_seq_sim.shape[0] - 1, dim_n))
    print("Compute the observations -- Start")
    x_seqs_sim = simulate_diffusion(x_0, θ_true, t_seq_sim, n_seqs)
    x_seq_obs = x_seqs_sim[::sub_interval]
    print("Compute the observations -- End")
    q_paths_obs = x_seq_obs[:, 2]
    t_seq_obs = t_seq_sim[::sub_interval]
    θ_0 = jnp.array([3.0, 3.0, 3.0, 3.0])
    print("Optimisation Complete Observation Adam -- Start")
    complete_adam = compute_complete_maximum_likelihood_estimates(t_seq_obs, x_seq_obs, θ_0)
    print("Optimisation Complete Observation Adam -- End")
    print(complete_adam)
    print(value_and_grad(log_likelihood["θ"])(complete_adam, x_seq_obs, t_seq_obs))
    print("Optimisation Partial Observation Adam -- Start")
    θ_0 = jnp.array([3.0, 3.0, 3.0, 3.0])
    partial_adam = compute_contrast_estimator(q_paths_obs, θ_0, initial_mean, initial_cov)
    print("Optimisation Partial Observation Adam -- End")
    print(partial_adam)
    print(value_and_grad(model.get_contrast_function_scan)(partial_adam, q_paths_obs, initial_mean, initial_cov))
    print(k)
    D_sample_complete[k] = complete_adam[0]
    λ_sample_complete[k] = complete_adam[1]
    α_sample_complete[k] = complete_adam[2]
    σ_sample_complete[k] = complete_adam[3]
    D_sample_partial[k] = partial_adam[0]
    λ_sample_partial[k] = partial_adam[1]
    α_sample_partial[k] = partial_adam[2]
    σ_sample_partial[k] = partial_adam[3]
    seed += 1

Compute the observations -- Start
Compute the observations -- End
Optimisation Complete Observation Adam -- Start
Optimisation Complete Observation Adam -- End
[1.00008442 1.99981564 4.28380191 4.00012563]
(Array(5163086.65748791, dtype=float64), Array([-5.03665552e-04,  2.96984544e-02, -1.01830988e-10, -3.58488217e+01],      dtype=float64))
0
Compute the observations -- Start
Compute the observations -- End
Optimisation Complete Observation Adam -- Start
Optimisation Complete Observation Adam -- End
[1.00011341 2.00000162 3.69715559 4.00710152]
(Array(5162571.98027649, dtype=float64), Array([-1.07322134e-02,  2.14428814e-01, -1.14961207e-10, -3.19263166e+02],      dtype=float64))
1
Compute the observations -- Start
Compute the observations -- End
Optimisation Complete Observation Adam -- Start
Optimisation Complete Observation Adam -- End
[1.00005713 2.00015797 4.10820577 3.99858506]
(Array(5163700.91045112, dtype=float64), Array([ 2.66730226e-02,  2.13219463e-01, -1.70609749e-10, -2.

In [10]:
f = open(f'MLE_GLE_HO_complete={T}_dt_obs_{dt_obs}_dt_sim_{dt_simulation}.csv', 'w')
writer = csv.writer(f, delimiter='\t')
writer.writerow(D_sample_complete)
writer.writerow(λ_sample_complete)
writer.writerow(α_sample_complete)
writer.writerow(σ_sample_complete)
f.close()

f = open(f'MLE_GLE_HO_partial={T}_dt_obs_{dt_obs}_dt_sim_{dt_simulation}.csv', 'w')
writer = csv.writer(f, delimiter='\t')
writer.writerow(D_sample_partial)
writer.writerow(λ_sample_partial)
writer.writerow(α_sample_partial)
writer.writerow(σ_sample_partial)
f.close()