In [1]:
import os
os.environ["JAX_ENABLE_X64"] = "true"
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import ParticleFilter_tempered_jittered, EnsembleKalmanFilter, ParticleFilter
from jax import config
jax.config.update("jax_enable_x64", True)
import numpy as np

float64


Initialisation of a twin experiment. 

In [2]:
signal_params = ConfigDict(KS_params_SALT)
print(signal_params)
ensemble_params = ConfigDict(KS_params_SALT)


signal_params.update(E=1,P=32,stochastic_advection_basis='constant')
ensemble_params.update(E=128,P=32,stochastic_advection_basis='constant')

Advection_basis_name: sin
E: 1
Forcing_basis_name: none
P: 1
S: 0
c_0: 0
c_1: 1
c_2: 1
c_3: 0.0
c_4: 1
dt: 0.25
equation_name: Kuramoto-Sivashinsky
initial_condition: Kassam_Trefethen_KS_IC
method: Dealiased_SETDRK4
noise_magnitude: 0.001
nt: 600
nx: 256
tmax: 150
xmax: 100.53096491487338
xmin: 0



Now we specify the models, by calling the class.

In [3]:
signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
ensemble_model = ETD_KT_CM_JAX_Vectorised(ensemble_params)

initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
initial_ensemble = initial_condition(ensemble_model.x, ensemble_params.E, ensemble_params.initial_condition)

available_resamplers = ", ".join(resamplers.keys())
print(available_resamplers)

multinomial, systematic, no_resampling, none, default


In [4]:
observation_spatial_frequency = 4
observation_locations = np.arange(0,signal_model.x.shape[0],observation_spatial_frequency)
observation_noise = 0.1
number_of_observations_time = int(ensemble_model.params.nt/1)
observation_temporal_frequency = int(ensemble_model.params.nt/number_of_observations_time)

print(observation_locations)
print(observation_temporal_frequency)
pf_systematic = ParticleFilter(
    n_particles = ensemble_params.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = observation_noise,
    resampling="default",#'default',
    observation_locations = observation_locations,
)

kal = EnsembleKalmanFilter(
    n_particles = ensemble_params.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = observation_noise,
    observation_locations = observation_locations,
)

[  0   4   8  12  16  20  24  28  32  36  40  44  48  52  56  60  64  68
  72  76  80  84  88  92  96 100 104 108 112 116 120 124 128 132 136 140
 144 148 152 156 160 164 168 172 176 180 184 188 192 196 200 204 208 212
 216 220 224 228 232 236 240 244 248 252]
1


In [5]:
key = jax.random.PRNGKey(0)

out = pf_systematic.run_step(initial_ensemble, initial_signal, key)
out_kal = kal.run_step(initial_ensemble, initial_signal, key)

In [6]:
print(len(out))
ps, ss, obser = out
print(ps.shape, ss.shape, obser.shape)
jnp.count_nonzero(obser)


3
(128, 256) (1, 256) (1, 256)


Array(64, dtype=int64)

In [7]:
da_steps = number_of_observations_time
key = jax.random.PRNGKey(0)
final, all = pf_systematic.run(initial_ensemble, initial_signal, da_steps,key) #the final input is scan length? 

In [8]:
final_kal, all_kal = kal.run(initial_ensemble, initial_signal, da_steps,key) #the final input is scan length?

Put in the initial condition.

In [9]:
particles =jnp.concatenate([initial_ensemble[None,...], all[0]], axis=0)
signal = jnp.concatenate([initial_signal[None,...], all[1]], axis=0)
observations = jnp.concatenate([initial_signal[None,...], all[2]], axis=0)
observations = all[2][:,:, observation_locations]
print(observations.shape)
print(particles.shape)

particles_kal = jnp.concatenate([initial_ensemble[None, ...], all_kal[0]], axis=0)
signal_kal = jnp.concatenate([initial_signal[None, ...], all_kal[1]], axis=0)
observations_kal = jnp.concatenate([initial_signal[None, ...], all_kal[2]], axis=0)
observations_kal = all_kal[2][:, :, observation_locations]
print(observations_kal.shape)
print(particles_kal.shape)

(600, 1, 64)
(601, 128, 256)
(600, 1, 64)
(601, 128, 256)


In [None]:
def plot_all(da_step):
    plt.figure(figsize=(12, 6))
    # Plot signal
    plt.plot(signal_model.x, signal[da_step, 0, :], color='k', label='Signal', linewidth=2)
    # Plot all particles
    plt.plot(signal_model.x, particles[da_step, 0, :].T, color='b', linewidth=0.5, alpha=0.05, label='Particles')

    plt.plot(signal_model.x, particles[da_step, :, :].T, color='b', linewidth=0.5, alpha=0.5)
    plt.plot(signal_model.x, particles_kal[da_step, 0, :].T, color='g', label='Kalman Signal', linewidth=2)

    plt.plot(signal_model.x, particles_kal[da_step, :, :].T, color='g', linewidth=2)

    # Plot observations
    if da_step > 0:
        plt.scatter(signal_model.x[observation_locations], observations[da_step - 1, 0, :], color='r', label='Observations', zorder=5)
    plt.xlabel('Spatial Domain', fontsize=14)
    plt.ylabel('Amplitude', fontsize=14)
    plt.title(f'Data Assimilation Step {da_step}', fontsize=16)
    plt.legend(fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.show()

interact(plot_all, da_step=(0, da_steps))

interactive(children=(IntSlider(value=300, description='da_step', max=600), Output()), _dom_classes=('widget-i…

<function __main__.plot_all(da_step)>