# https://pyro.ai/examples/ekf.html

In [4]:
import os
import math

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, config_enumerate
from pyro.contrib.tracking.extended_kalman_filter import EKFState
from pyro.contrib.tracking.distributions import EKFDistribution
from pyro.contrib.tracking.dynamic_models import NcvContinuous
from pyro.contrib.tracking.measurements import PositionMeasurement

pyro.__version__

'1.8.6'

In [5]:
dt = 1e-2
num_frames = 10
dim = 4

# Continuous model
ncv = NcvContinuous(dim, 2.0)

# Truth trajectory
xs_truth = torch.zeros(num_frames, dim)
# initial direction
theta0_truth = 0.0
# initial state
with torch.no_grad():
    xs_truth[0, :] = torch.tensor([0.0, 0.0,  math.cos(theta0_truth), math.sin(theta0_truth)])
    for frame_num in range(1, num_frames):
        # sample independent process noise
        dx = pyro.sample('process_noise_{}'.format(frame_num), ncv.process_noise_dist(dt))
        xs_truth[frame_num, :] = ncv(xs_truth[frame_num-1, :], dt=dt) + dx
        
# Measurements
measurements = []
mean = torch.zeros(2)
# no correlations
cov = 1e-5 * torch.eye(2)
with torch.no_grad():
    # sample independent measurement noise
    dzs = pyro.sample('dzs', dist.MultivariateNormal(mean, cov).expand((num_frames,)))
    # compute measurement means
    zs = xs_truth[:, :2] + dzs

In [7]:
def model(data):
    # a HalfNormal can be used here as well
    R = pyro.sample('pv_cov', dist.HalfCauchy(2e-6)) * torch.eye(4)
    Q = pyro.sample('measurement_cov', dist.HalfCauchy(1e-6)) * torch.eye(2)
    # observe the measurements
    pyro.sample('track_{}'.format(i), EKFDistribution(xs_truth[0], R, ncv,
                                                      Q, time_steps=num_frames),
                obs=data)

guide = AutoDelta(model)  # MAP estimation

optim = pyro.optim.Adam({'lr': 2e-2})
svi = SVI(model, guide, optim, loss=Trace_ELBO(retain_graph=True))

pyro.set_rng_seed(0)
pyro.clear_param_store()

for i in range(250):
    loss = svi.step(zs)
    if not i % 10:
        print('loss: ', loss)

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


loss:  -15.20763874053955
loss:  -15.339868545532227
loss:  -15.413694381713867
loss:  -15.473196029663086
loss:  -15.507671356201172
loss:  -15.523503303527832
loss:  -15.5301513671875
loss:  -15.532791137695312
loss:  -15.533793449401855
loss:  -15.534193992614746
loss:  -15.534348487854004
loss:  -15.534411430358887
loss:  -15.534439086914062
loss:  -15.534448623657227
loss:  -15.534452438354492
loss:  -15.534453392028809
loss:  -15.534455299377441
loss:  -15.534454345703125
loss:  -15.534454345703125
loss:  -15.534454345703125
loss:  -15.534455299377441
loss:  -15.534455299377441
loss:  -15.534454345703125
loss:  -15.534453392028809
loss:  -15.534456253051758


In [9]:
R = guide()['pv_cov'] * torch.eye(4)
Q = guide()['measurement_cov'] * torch.eye(2)
ekf_dist = EKFDistribution(xs_truth[0], R, ncv, Q, time_steps=num_frames)
states= ekf_dist.filter_states(zs)
states

[<pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35346980>,
 <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35347c10>,
 <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35347d60>,
 <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35345d20>,
 <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c35344730>,
 <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354ffe80>,
 <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354ff0d0>,
 <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354fded0>,
 <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354ff9d0>,
 <pyro.contrib.tracking.extended_kalman_filter.EKFState at 0x7f7c354feb90>]