In [None]:
import numpy as np
import matplotlib.pyplot as plt
import symnum
import symnum.numpy as snp
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, vmap, value_and_grad
from jax.example_libraries.optimizers import adam
import matplotlib.pyplot as plt
# from jax.config import config
# config.update('jax_enable_x64', True)
# config.update('jax_platform_name', 'cpu')
import density_ips_fhn

In [None]:
import jax
import jax.numpy as jnp

# Define the interaction term
def interaction_term(v_i, v_j):
    """Compute the interaction term between two particles."""
    return v_j - v_i

# Function to run a single simulation using JAX and scan
def run_simulation(seed, num_particles, num_steps, dt, model_parameters, discard_steps=0):
    # Set the random seed for reproducibility
    key = jax.random.PRNGKey(seed)

    α = model_parameters['alpha']
    β = model_parameters['beta']
    τ = model_parameters['tau']
    σ = model_parameters['sigma']
    κ = model_parameters['kappa']

    # Initialize variables
    v = jnp.zeros(num_particles)  # Voltage-like variable
    w = jnp.zeros(num_particles)  # Recovery-like variable

    # Define the step function for scan
    def step(carry, _):
        v, w, key = carry

        # Compute interaction terms using jax.vmap
        def compute_interaction(v_all):
            return jax.vmap(lambda v_i: jnp.sum(jax.vmap(lambda v_j: interaction_term(v_i, v_j))(v_all)))(v_all)

        interaction = compute_interaction(v)

        # Generate random noise
        key, subkey = jax.random.split(key)
        noise = jax.random.normal(subkey, shape=(num_particles,))

        # Update FitzHugh-Nagumo equations with interaction and noise
        dv = (v - v**3 / 3 - w + κ * interaction / num_particles) * dt + σ * jnp.sqrt(dt) * noise
        dw = ((v + α - β * w) / τ) * dt

        v = v + dv
        w = w + dw

        return (v, w, key), (v, w)

    # Run the simulation using scan
    (v_final, w_final, _), (v_history, w_history) = jax.lax.scan(step, (v, w, key), jnp.arange(num_steps))

    # Combine v and w into pairs for each particle at each time step
    vw_pairs = jnp.stack((v_history, w_history), axis=-1)  # Shape: (num_steps, num_particles, 2)
    return vw_pairs[discard_steps:, :, :]

In [None]:
# Plot the results

def showplots(x_seq_sim, x_seq_obs, t_seq_sim, t_seq_obs):
    """Plot the voltage-like and recovery-like variables."""
    plt.figure(figsize=(12, 6))
    plt.plot(t_seq_sim, x_seq_sim[:, :, 0], '.', alpha=0.5, markersize=0.5)
    plt.plot(t_seq_obs, x_seq_obs[:, :, 0], '.', alpha=0.5, markersize=0.5)
    plt.title('Voltage-like Variable (v) Over Time')
    plt.xlabel('Time')
    plt.ylabel('v')
    plt.grid()
    plt.show()

    plt.figure(figsize=(12, 6))
    plt.plot(t_seq_sim, x_seq_sim[:, :, 1], '.', alpha=0.5, markersize=0.5)
    plt.plot(t_seq_obs, x_seq_obs[:, :, 1], '.', alpha=0.5, markersize=0.5)
    plt.title('Recovery-like Variable (w) Over Time')
    plt.xlabel('Time')
    plt.ylabel('w')
    plt.grid()
    plt.show()


In [None]:
def drift_func_rough(x, θ, empirical_mean):
    *_, κ = θ

    return snp.array([x[0] - x[0]**3/3 - x[1] - κ * (x[0] -  empirical_mean)]) 
    
def drift_func_smooth(x, θ):
    a, b, τ, *_ = θ

    return snp.array([(x[0] + a - b * x[1]) / τ])

def diff_coeff_rough(x, θ):
    *_, σ, κ = θ
    return snp.array([[σ]]) 
    
def diff_coeff(x, θ):
    *_, σ, κ = θ
    return snp.array([[σ], [0]]) 


In [None]:
dim_x = 2
dim_θ = 5

symolic_log_transition_density_generators = {
    'local_gaussian': density_ips_fhn.local_gaussian_log_transition_density,
    'euler_maruyama': density_ips_fhn.euler_maruyama_log_transition_density_rough,
}

jax_log_transition_densities = {
    key: symnum.numpify(dim_x, dim_x, dim_θ, None, None,numpy_module=jnp)(
        symbolic_transition_density_generator(
            drift_func_rough, drift_func_smooth, diff_coeff_rough
        )
    )
    for key, symbolic_transition_density_generator in 
    symolic_log_transition_density_generators.items()
}


In [None]:
def get_log_likelihood_functions(log_transition_density):
    @jit
    # Function to compute the log-likelihood of θ given fixed indices, say for 'i'-th particle. Single particle over the time interval.

    def log_likelihood_θ_fixed_index(θ, x_seq, t_seq, empirical_mean_seq):
        log_transition_density_terms = vmap(log_transition_density, (0, 0, None, 0, 0))(
            x_seq[1:], x_seq[:-1], θ, t_seq[1:] - t_seq[:-1], empirical_mean_seq[:-1]
        )
        return log_transition_density_terms.sum()
    
    @jit
    # Full likelihood, which sums over all particles.
    def log_likelihood_θ(θ, x_seq_particles, t_seq, empirical_mean_seq):
        # Vectorized computation of log likelihood for all particles
        log_likelihood_terms = jnp.sum(
            vmap(log_likelihood_θ_fixed_index, (None, 1, None, None))(
                θ, x_seq_particles, t_seq, empirical_mean_seq
            )
        )
        return log_likelihood_terms
            
    return {'θ': log_likelihood_θ}
 

In [None]:
def compute_complete_maximum_likelihood_estimates(
    log_likelihood, t_seq, x_seq_particles, empirical_mean_seq, θ_0, optimizer=adam, n_steps=8000, step_size= 0.01):
    optimizer_init, optimizer_update, optimizer_get_params = optimizer(step_size)
    
    @jit 
    def optimizer_step(step_index, state, x_seq_particles, t_seq, empirical_mean_seq):
        value, grad = value_and_grad(log_likelihood["θ"])(
            optimizer_get_params(state), x_seq_particles, t_seq, empirical_mean_seq
        )
        state = optimizer_update(step_index, -grad, state)
        return value, state

    state = optimizer_init(θ_0)

    for s in range(n_steps):
        _, state = optimizer_step(s, state, x_seq_particles, t_seq, empirical_mean_seq)
        # print(optimizer_get_params(state))
        
    return optimizer_get_params(state)


In [None]:
# Simulation setting 
num_particles = 100  # Number of particles
dt_sim = 0.0005            # Time step for simulation 
dt_obs = 0.005           # Time step for observation (subsampling) 
num_steps = 60000     # Number of time steps for simulation
discard_steps = 0 # Steps to discard for initial transients  

# FitzHugh-Nagumo model parameters
α = 0.2  # Parameter controlling excitability
β = 0.8  # Parameter controlling recovery
τ = 1.5  # Timescale for recovery variable
σ = 0.5  # Noise strength
κ = 2.0  # Strength of interaction between particles

model_parameters = {'alpha': α, 'beta': β, 'tau': τ, 'sigma': σ, 'kappa': κ} 

num_iter = 100
θ_seq_lg = np.zeros((5, num_iter))
θ_seq_em = np.zeros((5, num_iter))

log_likelihood_lg = get_log_likelihood_functions(jax_log_transition_densities['local_gaussian'])

sub_interval = int(dt_obs/dt_sim) 
t_seq_sim = np.arange(discard_steps, num_steps) * dt_sim 
t_seq_obs = t_seq_sim[::sub_interval] 

for i in range(num_iter):
    print(f"Iteration {i+1}")
    print("Compute the sample paths")
    x_seq_particles_sim = run_simulation(seed=int(20250625+i), num_particles=num_particles, num_steps=num_steps, dt=dt_sim, model_parameters=model_parameters, discard_steps=discard_steps)
    # Run simulation for each iteration
    x_seq_particles_obs = x_seq_particles_sim[::sub_interval, :, :]  # Downsampled for observation

    showplots(x_seq_particles_sim, x_seq_particles_obs,t_seq_sim, t_seq_obs)

    empirical_mean_seq = x_seq_particles_obs[:, :, 0].mean(axis=1)  # Mean of v over time for all particles

    θ_0_lg = np.array([1.0, 1.0, 1.0, 1.0, 0.0])  # Initial guess for parameters 

    print("Optimising LG-based contrast estimator starts.") 
    θ_seq_lg[:, i] = compute_complete_maximum_likelihood_estimates(
        log_likelihood_lg, t_seq_obs, x_seq_particles_obs, empirical_mean_seq, θ_0_lg
    )
    print(θ_seq_lg[:, i])

    print("Likelihood value and gradient with estimated parameters:")
    print(value_and_grad(log_likelihood_lg["θ"])(
            θ_seq_lg[:, i], x_seq_particles_obs, t_seq_obs, empirical_mean_seq
        ))
    print("Optimising LG-based contrast estimator ends.") 




In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Number of parameters
num_parameters = θ_seq_lg.shape[0]

# Create a single figure with subplots for relative errors of θ_seq_lg
fig, axes = plt.subplots(1, num_parameters, figsize=(15, 5), sharey=True)  # Adjust width and height
fig.suptitle("Relative Errors of θ_seq_lg Parameters", fontsize=16)

reference_values = [model_parameters['alpha'], model_parameters['beta'], model_parameters['tau'], model_parameters['sigma'], model_parameters['kappa']]
parameter_labels = ['α', 'β', 'τ', 'σ', 'κ']  # True labels from model_parameters

for i in range(num_parameters):
    # Data for LG relative errors for the current parameter
    data = (θ_seq_lg[i, :] - reference_values[i]) / reference_values[i]
    
    # Create boxplot for the current parameter
    axes[i].boxplot(data, labels=[parameter_labels[i]])
    axes[i].grid(True)

    # Overlay scattered points for LG
    x_positions = np.random.normal(1, 0.05, size=len(data))  # Add slight jitter for better visualization
    axes[i].scatter(x_positions, data, alpha=0.6, color='red', s=10, label='Data Points')

# Add a shared y-axis label
fig.text(0.04, 0.5, 'Relative Error', va='center', rotation='vertical')

plt.tight_layout(rect=[0.05, 0, 1, 0.95])  # Adjust layout to fit the title
plt.show()


In [None]:
# import matplotlib.pyplot as plt
# import numpy as np

# # Create a figure for the boxplot of σ
# fig, ax = plt.subplots(figsize=(4, 4))
# fig.suptitle("Boxplots of Diffusion Parameter", fontsize=16)

# reference_value = model_parameters['sigma']  # True value of σ

# # Data for LG and EM relative errors for σ
# lg_data = (θ_seq_lg[3, :] - reference_value) / reference_value
# em_data = (θ_seq_em[3, :] - reference_value) / reference_value

# # Create boxplot for σ
# ax.boxplot([lg_data, em_data], labels=['LG', 'EM'])
# ax.grid(True)

# # Overlay scattered points for LG and EM
# lg_positions = np.random.normal(1, 0.05, size=len(lg_data))  # Add slight jitter for LG
# em_positions = np.random.normal(2, 0.05, size=len(em_data))  # Add slight jitter for EM
# ax.scatter(lg_positions, lg_data, alpha=0.6, color='red', s=10, label='LG Data Points')
# ax.scatter(em_positions, em_data, alpha=0.6, color='blue', s=10, label='EM Data Points')

# # Add y-axis label
# ax.set_ylabel('Relative Error')

# plt.tight_layout()
# plt.show()



In [None]:
import pandas as pd

# Combine θ_seq_lg and θ_seq_em into a single DataFrame
data = {
    f"LG_θ_{i+1}": θ_seq_lg[i, :] for i in range(θ_seq_lg.shape[0])
}

# data.update({
#     f"EM_θ_{i+1}": θ_seq_em[i, :] for i in range(θ_seq_em.shape[0])
# })

df = pd.DataFrame(data)

# Write the DataFrame to a CSV file
output_file = "theta_sequences_ips_fhn_fourth_trial.csv"
df.to_csv(output_file, index=False)

print(f"Data successfully written to {output_file}")