In [None]:
import os
os.environ["JAX_CHECK_TRACER_LEAKS"] = "1"
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from ipywidgets import interact
from ml_collections import ConfigDict
from models.ETD_KT_CM_JAX_Vectorised import *
from filters import resamplers
from filters.filter import ParticleFilter
from jax import config
jax.config.update("jax_enable_x64", True)
import numpy as np

Initialisation of a twin experiment. 

In [20]:
signal_params = ConfigDict(KDV_params_2_SALT)
ensemble_params = ConfigDict(KDV_params_2_SALT)
signal_params.update(E=1,P=3,noise_magnitude = 0.01,stochastic_advection_basis='sin')
ensemble_params.update(E=128,P=3,noise_magnitude = 0.01,stochastic_advection_basis='sin')
key = jax.random.PRNGKey(0)

Now we specify the models, by calling the class.

In [21]:
signal_model = ETD_KT_CM_JAX_Vectorised(signal_params)
ensemble_model = ETD_KT_CM_JAX_Vectorised(ensemble_params)

initial_signal = initial_condition(signal_model.x, signal_params.E, signal_params.initial_condition)
initial_ensemble = initial_condition(ensemble_model.x, ensemble_params.E, ensemble_params.initial_condition)

available_resamplers = ", ".join(resamplers.keys())
print(available_resamplers)

multinomial, systematic, no_resampling, none, default


In [22]:
observation_spatial_frequency = 16
observation_locations = np.arange(0,signal_model.x.shape[0],observation_spatial_frequency)
observation_noise = 0.1
number_of_observations_time = 32
observation_temporal_frequency = int(ensemble_model.params.nt/number_of_observations_time)


pf_systematic = ParticleFilter(
    n_particles = ensemble_params.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model,
    signal_model = signal_model,
    sigma = observation_noise,
    resampling="systematic",#'default',
    observation_locations = observation_locations,
)

In [23]:
da_steps = number_of_observations_time
final, all = pf_systematic.run(initial_ensemble, initial_signal, da_steps, key) #the final input is scan length? 


Put in the initial condition.

In [24]:
particles =jnp.concatenate([initial_ensemble[None,...], all[0]], axis=0)
signal = jnp.concatenate([initial_signal[None,...], all[1]], axis=0)
observations = jnp.concatenate([initial_signal[None,...], all[2]], axis=0)
observations = all[2][:,:, observation_locations]
print(observations.shape)
print(particles.shape)

(32, 1, 16)
(33, 128, 256)


In [25]:
def plot(da_step):
    plt.plot(signal_model.x, signal[da_step,0,:], color='k',label='signal')
    plt.plot(signal_model.x, particles[da_step,:,:].T, color='b',label='particles',linewidth=0.1)
    if da_step > 0:
        plt.plot(signal_model.x[observation_locations], observations[da_step-1,0,:], 'ro',label='observations')
    plt.show()

interact(plot, da_step=(0, da_steps))


interactive(children=(IntSlider(value=16, description='da_step', max=32), Output()), _dom_classes=('widget-int…

<function __main__.plot(da_step)>

In [None]:
ensemble_params_1 = ConfigDict(KDV_params_2_SALT)
ensemble_params_1.update(E=128,P=3,noise_magnitude = 0.01,stochastic_advection_basis='sin')
ensemble_model_1 = ETD_KT_CM_JAX_Vectorised(ensemble_params_1)
initial_ensemble_1 = initial_condition(ensemble_model_1.x, ensemble_params_1.E, ensemble_params_1.initial_condition)
key2 = jax.random.PRNGKey(1)
pf_systematic_1 = ParticleFilter(
    n_particles = ensemble_params_1.E,
    n_steps = observation_temporal_frequency,
    n_dim = initial_signal.shape[-1],
    forward_model = ensemble_model_1,
    signal_model = signal_model,
    sigma = observation_noise,
    resampling="systematic",#'default',
    observation_locations = observation_locations,
)

final, all = pf_systematic_1.run(initial_ensemble_1, initial_signal, da_steps, key2) #the final input is scan length? 


UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was scan_fn at /data/asl16/Particle_Filter/filters/filter.py:58 traced for scan.
------------------------------
The leaked intermediate value was created on line /data/asl16/Particle_Filter/models/ETD_KT_CM_JAX_Vectorised.py:137:12 (ETD_KT_CM_JAX_Vectorised.run). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/data/asl16/Particle_Filter/filters/filter.py:64:21 (ParticleFilter.run)
/data/asl16/Particle_Filter/filters/filter.py:61:45 (ParticleFilter.run.<locals>.scan_fn)
/data/asl16/Particle_Filter/filters/filter.py:43:17 (ParticleFilter.run_step)
/data/asl16/Particle_Filter/filters/filter.py:19:20 (ParticleFilter.advance_signal)
/data/asl16/Particle_Filter/models/ETD_KT_CM_JAX_Vectorised.py:137:12 (ETD_KT_CM_JAX_Vectorised.run)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError