In [10]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import ParticleFilter
from jax.config import config
config.update("jax_enable_x64", True)


Initialisation

In [11]:
signal_params = ConfigDict(KS_params)
print(signal_params)

Advection_basis_name: none
E: 1
Forcing_basis_name: none
P: 0
S: 0
c_0: 0
c_1: 1
c_2: 1
c_3: 0.0
c_4: 1
dt: 0.25
equation_name: Kuramoto-Sivashinsky
initial_condition: Kassam_Trefethen_KS_IC
method: Dealiased_ETDRK4
nx: 256
sigma: 0.0
tmax: 150
xmax: 100.53096491487338
xmin: 0



In [12]:
ensemble_params = ConfigDict(KS_params_SALT)
print(ensemble_params)

Advection_basis_name: sin
E: 1
Forcing_basis_name: none
P: 1
S: 0
c_0: 0
c_1: 1
c_2: 1
c_3: 0.0
c_4: 1
dt: 0.25
equation_name: Kuramoto-Sivashinsky
initial_condition: Kassam_Trefethen_KS_IC
method: Dealiased_SETDRK4
nx: 256
sigma: 0.001
tmax: 150
xmax: 100.53096491487338
xmin: 0



Next, we specify a signal, by choosing a deterministic solver, 

In [13]:
signal_params.update(E=1,method='Dealiased_ETDRK4',nx = 256,P=0,S=0)
ensemble_params.update(E=128,sigma=0.0001,P=32,stochastic_advection_basis='constant')


Now we continue to define a stochastic ensemble

Now we specify the models, by calling the class.

In [14]:
signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
ensemble_model = ETD_KT_CM_JAX_Vectorised(ensemble_params)

import numpy as np
initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
initial_ensemble = initial_condition(ensemble_model.x, ensemble_params.E, ensemble_params.initial_condition)

available_resamplers = ", ".join(resamplers.keys())
print(available_resamplers)

multinomial, systematic, no_resampling, default


In [15]:
pf_systematic = ParticleFilter(
    n_particles = ensemble_params.E,
    n_steps = 1,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = 0.1,
    seed = 11,
    resampling='default',
    observation_locations = np.arange(0,signal_model.x.shape[0],128),
)

In [16]:
final, all = pf_systematic.run(initial_ensemble, initial_signal, signal_model.nmax) #the final input is scan length? 

Put in the initial condition.

In [17]:
particles =jnp.concatenate([initial_ensemble[None,...], all[0]], axis=0)
signal = jnp.concatenate([initial_signal[None,...], all[1]], axis=0)
print(particles.shape)

(601, 128, 256)


In [18]:
def plot(da_step):
    plt.plot(signal_model.x, signal[da_step,0,:], color='k',label='signal')
    plt.plot(signal_model.x, particles[da_step,:,:].T, color='b',label='particles',linewidth=0.1)
    plt.show()

interact(plot, da_step=(0, signal_model.nmax))

interactive(children=(IntSlider(value=300, description='da_step', max=600), Output()), _dom_classes=('widget-i…

<function __main__.plot(da_step)>