In [None]:
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.stats import norm 
from jax import jit, vmap, value_and_grad
from jax.example_libraries.optimizers import adam
from jax.lax import scan 
import matplotlib.pyplot as plt

In [None]:
### Class for computing Kalman filtering under interacting FHN model
### Consider the particle index is fixed. 
### Observable: X, rough, Hidden: Y, smooth

class ips_fhn_filtering:
    def __init__(self, dt_data):
        self.step_size_data = dt_data
    
    def calc_coeff_mat_obs(self):
        dt = self.step_size_data
        return jnp.array(
            [-dt]
        )
    
    def calc_mean_obs(self, obs, θ, empirical_mean):
        dt = self.step_size_data
        *_, κ = θ
        return obs + dt * (obs - obs**3 / 3 - κ*(obs-empirical_mean))  
    
    def calc_mean_hidden(self, obs, θ, empirical_mean):
        dt = self.step_size_data
        a, b, τ, σ, κ = θ
        return ((obs + a)/τ)*dt + 0.5*dt**2*((obs- obs**3 / 3 - κ * (obs-empirical_mean))/τ - b*(obs + a)/(τ**2))


    def matrix_A(self, θ):
        a, b, τ, *_ = θ
        dt = self.step_size_data
        return jnp.array(
            [
                - dt,
                1 - dt*b/τ + 0.5 * (dt**2)*(-1/τ + (b/τ)**2)
            ]
        ) 
    
    def mean_one_step(self, obs, hidden, θ, empirical_mean):
        A = self.matrix_A(θ)
        a, b, τ, σ, κ = θ
        dt = self.step_size_data

        h_times_A = A * hidden

        mean_vec = jnp.array([
            self.calc_mean_obs(obs, θ, empirical_mean),
            self.calc_mean_hidden(obs, θ, empirical_mean)
            ])
        
        ret = mean_vec + h_times_A
        return ret
        
    def covariance_one_step(self, θ):
        dt = self.step_size_data
        *_, τ, σ, κ = θ 
        return (σ**2) * jnp.array([
            [dt, dt**2 / (2 * τ)], 
            [dt**2 / (2 * τ), (dt**3) / (3 * τ**2)]
            ])
    
    def prediction_covariance(self, forward_filter_covariance, θ):
        Σ =self.covariance_one_step(θ)
        A = self.matrix_A(θ)
        pred_cov =  Σ + (A @ A.T)*forward_filter_covariance
        pred_cov_oo = pred_cov[0,0] # covariance of  (obs, obs) 
        pred_cov_ho = pred_cov[1,0] # covariance of  (hiden, obs) 
        pred_cov_hh = pred_cov[1,1] # covariance of  (hiden, hidden) 
        return pred_cov_oo, pred_cov_ho, pred_cov_hh
    
    
    def prediction_mean(self, obs, forward_filter_mean, θ, empirical_mean):
        pred_mean = self.mean_one_step(obs, forward_filter_mean, θ, empirical_mean) 
        pred_mean_obs = pred_mean[0]
        pred_mean_hidden = pred_mean[1]
        return pred_mean_obs, pred_mean_hidden
    

    def forward_filter_mean_cov_one_step(self, current_obs, next_obs, forward_filter_mean, forward_filter_covariance, θ, empirical_mean):
        μ_o, μ_h = self.prediction_mean(current_obs, forward_filter_mean, θ, empirical_mean)
        Λ_oo, Λ_ho, Λ_hh = self.prediction_covariance(forward_filter_covariance, θ) 
        next_filter_mean = μ_h + ((next_obs - μ_o)/Λ_oo)*Λ_ho
        next_filter_cov = Λ_hh - (Λ_ho**2) / Λ_oo
        return jnp.array([next_filter_mean]), jnp.array([next_filter_cov])
    
    
    def forward_filter_mean_cov_paths_scan(self, obs_path, initial_mean, initial_cov, θ, empirical_mean_path):
        @jit
        def step_func(filter_mean_cov, path_current_next):
            filter_mean, filter_cov = filter_mean_cov
            obs_current, obs_next, empirical_mean_current = path_current_next
            filter_next = self.forward_filter_mean_cov_one_step(obs_current, obs_next, filter_mean, filter_cov, θ, empirical_mean_current)     
            return filter_next, filter_next 
        
        _, filter_mean_cov = scan(step_func, (initial_mean, initial_cov), (obs_path[:-1], obs_path[1:], empirical_mean_path[:-1]))
        filter_mean, filter_cov = filter_mean_cov

        return jnp.concatenate((initial_mean[None], filter_mean)), jnp.concatenate((initial_cov[None], filter_cov)) 
    
    
    def get_contrast_function_fixed_index(self, θ, obs_path, initial_mean, initial_cov, empirical_mean_path):
        filter_mean_path, filter_cov_path = self.forward_filter_mean_cov_paths_scan(obs_path, initial_mean, initial_cov, θ, empirical_mean_path)
        initial_log_likelihood = norm.logpdf(obs_path[0], loc = obs_path[0], scale = 1.0)
        Σ = self.covariance_one_step(θ)

        @jit
        def step_func(loglikelihood, qset_filtermeancov_empiricalmean):
            q_current, q_next, filter_mean, filter_cov, empirical_mean = qset_filtermeancov_empiricalmean
            A_q = self.calc_coeff_mat_obs()
            q_mean = self.calc_mean_obs(q_current, θ, empirical_mean) + jnp.dot(A_q, filter_mean)
            vec = filter_cov * A_q 
            scalar = jnp.dot(vec, A_q)
            q_scale = jnp.sqrt(scalar + Σ[0,0])
            loglikelihood_next = loglikelihood + norm.logpdf(q_next, q_mean, q_scale)
            return loglikelihood_next, loglikelihood_next
        
        _, log_likelihood_seq = scan(step_func, initial_log_likelihood, (obs_path[:-1], obs_path[1:], filter_mean_path[:-1], filter_cov_path[:-1], empirical_mean_path[:-1])
                                     )

        return -2*log_likelihood_seq[-1]
    
    
    def get_contrast_function(self, θ, obs_path_particles, initial_mean, initial_cov, empirical_mean_path):
        contrast_function = jnp.sum(
            vmap(self.get_contrast_function_fixed_index, (None, 1, None, None, None))(
                θ, obs_path_particles, initial_mean, initial_cov, empirical_mean_path
            )
        )

        return contrast_function

In [None]:
# Define the interaction term
def interaction_term(v_i, v_j):
    """Compute the interaction term between two particles."""
    return v_j - v_i

# Function to run a single simulation using JAX and scan
def run_simulation(seed, num_particles, num_steps, dt, model_parameters, discard_steps=0):
    # Set the random seed for reproducibility
    key = jax.random.PRNGKey(seed)

    α = model_parameters['alpha']
    β = model_parameters['beta']
    τ = model_parameters['tau']
    σ = model_parameters['sigma']
    κ = model_parameters['kappa']

    # Initialize variables
    v = jnp.zeros(num_particles)  # Voltage-like variable
    w = jnp.zeros(num_particles)  # Recovery-like variable

    # Define the step function for scan
    def step(carry, _):
        v, w, key = carry

        # Compute interaction terms using jax.vmap
        def compute_interaction(v_all):
            return jax.vmap(lambda v_i: jnp.sum(jax.vmap(lambda v_j: interaction_term(v_i, v_j))(v_all)))(v_all)

        interaction = compute_interaction(v)

        # Generate random noise
        key, subkey = jax.random.split(key)
        noise = jax.random.normal(subkey, shape=(num_particles,))

        # Update FitzHugh-Nagumo equations with interaction and noise
        dv = (v - v**3 / 3 - w + κ * interaction / num_particles) * dt + σ * jnp.sqrt(dt) * noise
        dw = ((v + α - β * w) / τ) * dt

        v = v + dv
        w = w + dw

        return (v, w, key), (v, w)

    # Run the simulation using scan
    (v_final, w_final, _), (v_history, w_history) = jax.lax.scan(step, (v, w, key), jnp.arange(num_steps))

    # Combine v and w into pairs for each particle at each time step
    vw_pairs = jnp.stack((v_history, w_history), axis=-1)  # Shape: (num_steps, num_particles, 2)
    return vw_pairs[discard_steps:, :, :]


In [None]:
# Plot the results

def showplots(x_seq_sim, x_seq_obs, t_seq_sim, t_seq_obs):
    """Plot the voltage-like and recovery-like variables."""
    plt.figure(figsize=(12, 6))
    plt.plot(t_seq_sim, x_seq_sim, '.', alpha=0.5, markersize=0.5)
    plt.plot(t_seq_obs, x_seq_obs, '.', alpha=0.5, markersize=0.5)
    plt.title('Voltage-like Variable (v) Over Time')
    plt.xlabel('Time')
    plt.ylabel('v')
    plt.grid()
    plt.show()

In [None]:
# Simulation setting 
num_particles = 100  # Number of particles
dt_sim = 0.0005            # Time step for simulation 
dt_obs = 0.005          # Time step for observation (subsampling) 
num_steps = 60000     # Number of time steps for simulation
discard_steps = 0 # Steps to discard for initial transients  
initial_mean = jnp.array([0.0]) 
initial_cov = jnp.array([1.0]) 

# FitzHugh-Nagumo model parameters
α = 0.2  # Parameter controlling excitability
β = 0.8  # Parameter controlling recovery
τ = 1.5  # Timescale for recovery variable
σ = 0.5  # Noise strength
κ = 2.0  # Strength of interaction between particles

model_parameters = {'alpha': α, 'beta': β, 'tau': τ, 'sigma': σ, 'kappa': κ} 

num_iter = 100
θ_seq_lg = np.zeros((5, num_iter))
θ_seq_em = np.zeros((5, num_iter))

sub_interval = int(dt_obs/dt_sim) 
t_seq_sim = np.arange(discard_steps, num_steps) * dt_sim 
t_seq_obs = t_seq_sim[::sub_interval] 

ips_fhn = ips_fhn_filtering(dt_obs)

θ_true = jnp.array([model_parameters['alpha'], model_parameters['beta'], model_parameters['tau'], model_parameters['sigma'], model_parameters['kappa']]) 

xy_seq_particles_sim = run_simulation(seed=int(20250625), 
num_particles=num_particles, num_steps=num_steps, dt=dt_sim, model_parameters=model_parameters, discard_steps=discard_steps)

x_seq_particles_sim = xy_seq_particles_sim[:, :, 0]  # Voltage variable
# Run simulation for each iteration
x_seq_particles_obs = x_seq_particles_sim[::sub_interval, :]  # Downsampled for observation

empirical_mean_seq = x_seq_particles_obs[:, :].mean(axis=1)  # Mean of v over time for all particles 

In [None]:
def compute_contrast_estimator(objective_func, obs_path_particles, θ_0, initial_mean, initial_cov, empirical_mean_seq, optimizer=adam, n_steps=5000, step_size= 0.005):
    optimizer_init, optimizer_update, optimizer_get_params = optimizer(step_size) 

    @jit
    def optimizer_step(state, obs_path_particles, initial_mean, initial_cov, empirical_mean_seq, step_index):
        value, grad = value_and_grad(objective_func)(
            optimizer_get_params(state), obs_path_particles, initial_mean, initial_cov, empirical_mean_seq
        )
        state = optimizer_update(step_index, grad, state)
        return value, state
    
    state = optimizer_init(θ_0)

    for s in range(n_steps):
        _, state = optimizer_step(state, obs_path_particles, initial_mean, initial_cov, empirical_mean_seq, s)
        # print(optimizer_get_params(state))
        
    return optimizer_get_params(state)

In [None]:
for i in range(num_iter):
    print(f"Iteration {i+1}")
    print("Compute the sample paths")
    xy_seq_particles_sim = run_simulation(seed=int(20250625+i), 
    num_particles=num_particles, num_steps=num_steps, dt=dt_sim, model_parameters=model_parameters, discard_steps=discard_steps)

    x_seq_particles_sim = xy_seq_particles_sim[:, :, 0]  # Voltage variable
    # Run simulation for each iteration
    x_seq_particles_obs = x_seq_particles_sim[::sub_interval, :]  # Downsampled for observation

    empirical_mean_seq = x_seq_particles_obs[:, :].mean(axis=1)  # Mean of v over time for all particles 

    showplots(x_seq_particles_sim, x_seq_particles_obs, t_seq_sim, t_seq_obs)

    θ_0 = np.array([1.0, 1.0, 1.0, 1.0, 0.0])  # Initial guess for parameters  

    
    print("Optimising LG-based contrast estimator starts.") 
    θ_seq_lg[:, i] = compute_contrast_estimator(
        ips_fhn.get_contrast_function, x_seq_particles_obs, θ_0, initial_mean, initial_cov, empirical_mean_seq
    )
        
    print(θ_seq_lg[:, i])
    print("Optimising LG-based contrast estimator ends.") 




In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Number of parameters
num_parameters = θ_seq_lg.shape[0]

# Create a single figure with subplots for relative errors of θ_seq_lg
fig, axes = plt.subplots(1, num_parameters, figsize=(15, 5), sharey=True)  # Adjust width and height
# fig.suptitle("Relative Errors of θ_seq_lg Parameters", fontsize=16)

reference_values = [model_parameters['alpha'], model_parameters['beta'], model_parameters['tau'], model_parameters['sigma'], model_parameters['kappa']]
parameter_labels = ['α', 'β', 'τ', 'σ', 'κ']  # True labels from model_parameters

for i in range(num_parameters):
    # Data for LG relative errors for the current parameter
    data = (θ_seq_lg[i, :] - reference_values[i]) / reference_values[i]
    
    # Create boxplot for the current parameter
    axes[i].boxplot(data, labels=[parameter_labels[i]])
    axes[i].grid(True)

    # Overlay scattered points for LG
    x_positions = np.random.normal(1, 0.05, size=len(data))  # Add slight jitter for better visualization
    axes[i].scatter(x_positions, data, alpha=0.6, color='red', s=10, label='Data Points')

# Add a shared y-axis label
fig.text(0.04, 0.5, 'Relative Error', va='center', rotation='vertical')

plt.tight_layout(rect=[0.05, 0, 1, 0.95])  # Adjust layout to fit the title
plt.show()


In [None]:
import pandas as pd

# Combine θ_seq_lg and θ_seq_em into a single DataFrame
data = {
    f"LG_θ_{i+1}": θ_seq_lg[i, :] for i in range(θ_seq_lg.shape[0])
}
data.update({
    f"EM_θ_{i+1}": θ_seq_em[i, :] for i in range(θ_seq_em.shape[0])
})

df = pd.DataFrame(data)

# Write the DataFrame to a CSV file
output_file = "ips_fhn_partial.csv"
df.to_csv(output_file, index=False)

print(f"Data successfully written to {output_file}")