In [9]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import ParticleFilter

Specify initial setup of the signal, by first loading parameters from the stochastic traveling wave.

In [10]:
signal_params = ConfigDict(KDV_params_traveling)
print(signal_params)

Advection_basis_name: constant
E: 30
Forcing_basis_name: none
P: 1
S: 0
c_0: 0
c_1: 1
c_2: 0.0
c_3: 1
c_4: 0.0
dt: 0.0001
equation_name: KdV
initial_condition: traveling_wave
method: Dealiased_SETDRK4
noise_magnitude: 1.0
nx: 64
tmax: 1.0
xmax: 3.141592653589793
xmin: -3.141592653589793



Next, we specify a signal, by choosing a deterministic solver, 

In [11]:
signal_params.update(E=1,method='ETDRK4',nx = 128,P=0,S=0)


Now we continue to define a stochastic ensemble

In [12]:

ensemble_params = ConfigDict(KDV_params_traveling)
ensemble_params.update(method='StrangSplit_ETDRK4_SSP33',nx = 128, S=0, P=1)
ensemble_params.update(E=10)
ensemble_params.update(sigma=5)
print(ensemble_params)


Advection_basis_name: constant
E: 10
Forcing_basis_name: none
P: 1
S: 0
c_0: 0
c_1: 1
c_2: 0.0
c_3: 1
c_4: 0.0
dt: 0.0001
equation_name: KdV
initial_condition: traveling_wave
method: StrangSplit_ETDRK4_SSP33
noise_magnitude: 1.0
nx: 128
sigma: 5
tmax: 1.0
xmax: 3.141592653589793
xmin: -3.141592653589793



Now we specify the models, by calling the class.

In [13]:
signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
ensemble_model = ETD_KT_CM_JAX_Vectorised(ensemble_params)

initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
initial_ensemble = initial_condition(ensemble_model.x, ensemble_params.E, ensemble_params.initial_condition) 

available_resamplers = ", ".join(resamplers.keys())
print(available_resamplers)


multinomial, systematic, no_resampling, default


In [14]:
pf_systematic = ParticleFilter(
    n_particles = ensemble_params.E,
    n_steps = 1,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = 0.5,
    seed = 11,
    resampling='systematic',
)
# pf_no_resampling = ParticleFilter(
#     n_particles = ensemble_params.E,
#     n_steps = 1,
#     n_dim = initial_signal.shape[-1],
#     forward_model = ensemble_model,
#     signal_model = signal_model,
#     sigma = 0.01,
#     seed = 11,
#     resampling='no_resampling',
# )

In [15]:
final, all = pf_systematic.run(initial_ensemble, initial_signal, 10000) #the final input is scan length? 

ValueError: Method ETDRK4 not recognised

In [None]:
print(type(all))
print(len(all))
print(all[0].shape) # particles
print(all[1].shape) # signal
#out shape is (n_total, n_particles, n_dim)

<class 'tuple'>
3
(10000, 10, 128)
(10000, 1, 128)


In [None]:
print(initial_signal.shape)
signal = jnp.concatenate([initial_signal[None,...], all[1]], axis=0)
print(signal.shape)
particles = jnp.concatenate([initial_ensemble[None,...], all[0]], axis=0)
print(particles.shape)

(1, 128)
(10001, 1, 128)
(10001, 10, 128)


In [None]:
def plot(da_step):
    plt.plot(signal_model.x, signal[da_step,0,:], color='red')
    plt.plot(signal_model.x, particles[da_step,:,:].T, color='k', linewidth=0.1)
    plt.legend(['signal', 'particles'])
    plt.show()

interact(plot, da_step=(0, 10001))

interactive(children=(IntSlider(value=5000, description='da_step', max=10001), Output()), _dom_classes=('widge…

<function __main__.plot(da_step)>