In [1]:
from filter_forecast.particle_filter.global_settings import GlobalSettings
from filter_forecast.particle_filter.particle_cloud import ParticleCloud
from filter_forecast.particle_filter.transition import GaussianNoiseModel
from filter_forecast.particle_filter.parameters import ModelParameters

In [2]:
days = 150

settings = GlobalSettings(
    num_particles=10,
    population=10000,
    location_code="04",
    final_date="2024-07-22",
)

In [3]:
particles = ParticleCloud(
    settings, transition=GaussianNoiseModel(model_params=ModelParameters())
)

In [4]:
particles.states[:, :, 0]

Array([[9.9684014e+03, 3.1598352e+01, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [9.9747852e+03, 2.5214497e+01, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [9.9578643e+03, 4.2135590e+01, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [9.9683633e+03, 3.1636997e+01, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [9.9581709e+03, 4.1829376e+01, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [9.9888916e+03, 1.1108426e+01, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [9.9701348e+03, 2.9865690e+01, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [9.9722812e+03, 2.7718929e+01, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [9.9985850e+03, 1.4150971e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00],
       [9.9985723e+03, 1.4274806e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00]], dtype=float32)

In [5]:
particles.states[:, :, 1]

Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)

In [6]:
particles.update_all_particles(t=1)

In [7]:
particles.states[:, :, 1]

Array([[9.9638086e+03, 3.3032963e+01, 2.9718053e+00, 1.8983790e-01,
        1.9012001e-01],
       [9.9693604e+03, 2.8119396e+01, 2.3717227e+00, 1.5153477e-01,
        1.5181687e-01],
       [9.9533594e+03, 4.2428673e+01, 3.9623055e+00, 2.5306135e-01,
        2.5334346e-01],
       [9.9630703e+03, 3.3768150e+01, 2.9754379e+00, 1.9006978e-01,
        1.9035189e-01],
       [9.9498691e+03, 4.5950523e+01, 3.9335215e+00, 2.5122407e-01,
        2.5150618e-01],
       [9.9874697e+03, 1.1421137e+01, 1.0457522e+00, 6.6898346e-02,
        6.7180462e-02],
       [9.9667588e+03, 3.0257086e+01, 2.8089349e+00, 1.7944193e-01,
        1.7972404e-01],
       [9.9686719e+03, 2.8558285e+01, 2.6071396e+00, 1.6656137e-01,
        1.6684347e-01],
       [9.9982402e+03, 1.6206012e+00, 1.3457924e-01, 8.7383753e-03,
        9.0204878e-03],
       [9.9983867e+03, 1.4717107e+00, 1.3574329e-01, 8.8126762e-03,
        9.0947887e-03]], dtype=float32)

In [8]:
particles.compute_all_weights(reported_data=2, t=1)

In [9]:
print(particles.weights[:, 1])

[-4.831617  -4.831645  -4.831572  -4.831617  -4.831573  -4.831709
 -4.8316245 -4.8316336 -4.8317547 -4.831754 ]


In [10]:
particles.normalize_weights(t=1)

In [11]:
print(particles.weights[:, 1])

[-2.302552  -2.30258   -2.3025072 -2.302552  -2.302508  -2.302644
 -2.3025596 -2.3025687 -2.3026898 -2.3026893]


Testing if a highly-weighted particle is resampled.

In [12]:
particles.weights = particles.weights.at[0, 1].set(1)
print(particles.weights[:, 1])

[ 1.        -2.30258   -2.3025072 -2.302552  -2.302508  -2.302644
 -2.3025596 -2.3025687 -2.3026898 -2.3026893]


In [13]:
particles.resample(t=1)