# Example 3 
In this experiment we run the KS ensemble by transport noise, using the standard particle filter. We observe degeneracy in the filter, in the twin experiment. 

In [11]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import ParticleFilter
jax.config.update("jax_enable_x64", True)

Initialisation

In [12]:
signal_params = ConfigDict(KS_params_SALT)
ensemble_params = ConfigDict(KS_params_SALT)
print(ensemble_params)


Advection_basis_name: sin
E: 1
Forcing_basis_name: none
P: 1
S: 0
c_0: 0
c_1: 1
c_2: 1
c_3: 0.0
c_4: 1
dt: 0.25
equation_name: Kuramoto-Sivashinsky
initial_condition: Kassam_Trefethen_KS_IC
method: Dealiased_SETDRK4
noise_magnitude: 0.001
nx: 256
tmax: 150
xmax: 100.53096491487338
xmin: 0



We now specify the number of ensemble members and the number of basis functions required for the salt noise ensemble. 

In [13]:
ensemble_params.update(E=128,noise_magnitude=0.001,P=32,stochastic_advection_basis='constant')


Now we specify the models, by calling the class.

In [14]:
signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
ensemble_model = ETD_KT_CM_JAX_Vectorised(ensemble_params)

initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
initial_ensemble = initial_condition(ensemble_model.x, ensemble_params.E, ensemble_params.initial_condition)

available_resamplers = ", ".join(resamplers.keys())
print(available_resamplers)

multinomial, systematic, no_resampling, default


In [None]:
obs_frequency = 16
observation_noise = 1e-2
observation_locations_params = jnp.arange(0,len(signal_model.x),obs_frequency)
observation_temporal_frequency = 16

pf_systematic = ParticleFilter(
    n_particles = ensemble_params.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = observation_noise,
    seed = 11,
    resampling='systematic',
    observation_locations = np.arange(0,signal_model.x.shape[0],1),
)

In [16]:
final, all = pf_systematic.run(initial_ensemble, initial_signal, signal_model.nmax) #the final input is scan length? 

In [17]:
print(all[0].shape,all[1].shape,all[2].shape)
print(all[2])

(600, 128, 256) (600, 1, 256) (600, 1, 256)
[[[ 0.84980152  0.81338298  0.845248   ...  0.84262877  0.73158369
    0.8681692 ]]

 [[ 0.70070055  0.66048015  0.6886556  ...  0.70547701  0.59037022
    0.72296852]]

 [[ 0.59744685  0.55424295  0.57942618 ...  0.61116035  0.49307336
    0.62269352]]

 ...

 [[-0.29801827 -0.03121329  0.22231525 ... -1.34053928 -1.14277201
   -0.63829948]]

 [[-1.21865242 -1.00382408 -0.63244445 ... -0.79241172 -1.29348617
   -1.29123143]]

 [[-0.55468955 -0.9992129  -1.15351342 ...  1.29244059  0.56966188
    0.04694245]]]


Put in the initial condition.

In [18]:
particles =jnp.concatenate([initial_ensemble[None,...], all[0]], axis=0)
signal = jnp.concatenate([initial_signal[None,...], all[1]], axis=0)
observations = jnp.concatenate([0*initial_signal[None,...], all[2]], axis=0)
print(particles.shape)

(601, 128, 256)


In [None]:
def plot(da_step):
    plt.plot(signal_model.x, signal[da_step,0,:], color='k',label='signal')
    plt.plot(signal_model.x, particles[da_step,:,:].T, color='b',label='particles',linewidth=0.1)
    plt.scatter(signal_model.x, observations[da_step,0,:], color='r',label='observations')
    plt.show()

interact(plot, da_step=(0, signal_model.nmax))

interactive(children=(IntSlider(value=300, description='da_step', max=600), Output()), _dom_classes=('widget-i…

<function __main__.plot(da_step)>