In [None]:
from math import ceil, exp
from typing import Callable
from scipy.io import loadmat
import numpy as np
import matplotlib.pyplot as plt

In [None]:
data = loadmat("carte_centreMetres.mat")
h_mat, x_val, y_val = data["h_MNT"], data["x_MNT"][0], data["y_MNT"][0]

In [None]:
def h_mnt(
    x_val: np.ndarray,
    y_val: np.ndarray,
    h_mat: np.ndarray,
    x_float: float,
    y_float: float,
) -> float:
    x_floor, y_floor = (
        np.searchsorted(x_val, x_float) - 1,
        np.searchsorted(y_val, y_float) - 1,
    )
    if x_floor >= 798 or y_floor >= 798:
        print(x_float, y_float)
    alpha_x = (x_float - x_val[x_floor]) / (x_val[x_floor + 1] - x_val[x_floor])
    interp_y_floor = (
        alpha_x * h_mat[x_floor + 1, y_floor] + (1 - alpha_x) * h_mat[x_floor, y_floor]
    )
    interp_y_ceil = (
        alpha_x * h_mat[x_floor + 1, y_floor + 1]
        + (1 - alpha_x) * h_mat[x_floor, y_floor + 1]
    )
    alpha_y = (y_float - y_val[y_floor]) / (y_val[y_floor + 1] - y_val[y_floor])
    return alpha_y * interp_y_ceil + (1 - alpha_y) * interp_y_floor


def currified_h_mnt(
    x_val: np.ndarray, y_val: np.ndarray, h_mat: np.ndarray
) -> Callable[[float, float], float]:
    def currified(x_float: float, y_float: float) -> float:
        return h_mnt(
            x_val=x_val, y_val=y_val, h_mat=h_mat, x_float=x_float, y_float=y_float
        )

    return currified

In [None]:
DT = 1
PHI = np.array(
    [
        [1, 0, 0, DT, 0, 0],
        [0, 1, 0, 0, DT, 0],
        [0, 0, 1, 0, 0, DT],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1],
    ]
)
CAP_AVION = 45 * np.pi / 180
V_AVION = 300 * 1000 / 3600
LONG_TRAJET = 60 * 1000
NB_IT = ceil(LONG_TRAJET / (DT * V_AVION))
VX = V_AVION * np.cos(CAP_AVION)
VY = V_AVION * np.sin(CAP_AVION)
ETAT_INIT = np.array([[150000], [150000], [8000], [VX], [VY], [0]])
SIGMA_BRUIT = 30
SIGMA_X = 3000
SIGMA_Y = 3000
SIGMA_Z = 500
SIGMA_VX = SIGMA_VY = SIGMA_VZ = 5
P_0 = np.diag(
    [
        SIGMA_X**2,
        SIGMA_Y**2,
        SIGMA_Z**2,
        SIGMA_VX**2,
        SIGMA_VY**2,
        SIGMA_VZ**2,
    ]
)
N_PART = 5000
x_avion = ETAT_INIT[0] + VX * DT * np.array(range(NB_IT))
y_avion = ETAT_INIT[1] + VY * DT * np.array(range(NB_IT))
z_avion = ETAT_INIT[2] * np.ones(shape=(NB_IT,))
etat_avion = np.array([x_avion, y_avion, z_avion])

In [None]:
rd = np.random.default_rng(20)
epsilon_bruit = rd.normal(loc=0, scale=SIGMA_BRUIT, size=NB_IT)
h_mes = (
    z_avion
    - np.array(
        [
            currified_h_mnt(x_val=x_val, y_val=y_val, h_mat=h_mat)(
                x_float=x, y_float=y
            )  # type: ignore
            for (x, y) in zip(x_avion, y_avion)
        ]
    )
    + epsilon_bruit
)
epsilon: np.ndarray = rd.multivariate_normal(
    mean=6 * [0], cov=P_0, size=N_PART
).reshape((N_PART, 6, 1))
init_particles = ETAT_INIT + epsilon

In [None]:
fig = plt.figure()
axe = fig.add_axes((0, 0, 1, 1))
axe.set_ylim(y_val[-1], y_val[0])
axe.plot(etat_avion[0], etat_avion[1], color="red")
axe.pcolormesh(x_val, y_val, h_mat)
axe.scatter(
    init_particles[:, 0, 0], init_particles[:, 1, 0], color="pink", marker=".", s=0.5
)

In [None]:
DIM_STATE = 6


def particle_filter(
    h_mnt: Callable[[float, float], float],
    h_mes: np.ndarray,
    init_particles: np.ndarray,
    phi: np.ndarray,
    sigma_bruit: float,
):
    nb_it = h_mes.shape[0]
    n_part = init_particles.shape[0]
    particles = np.empty(shape=(nb_it, n_part, DIM_STATE, 1))
    particles[0, :, :, :] = init_particles
    h_opt = (4 / (DIM_STATE + 2)) ** (1 / (DIM_STATE + 4)) * n_part ** (
        -1 / (DIM_STATE + 4)
    )
    h_fin = 0.4 * h_opt
    n_redis = 0.4 * n_part
    part_pred = np.empty(shape=(n_part, 6, 1))
    weights = np.empty(shape=(nb_it, n_part))
    weights[0, :] = 1 / n_part
    for k in range(nb_it - 1):
        for i in range(n_part):
            part_pred[i, :, :] = phi @ particles[k, i, :, :]
            ecart = h_mes[k + 1] - (
                part_pred[i, 2, 0] - h_mnt(part_pred[i, 0, 0], part_pred[i, 1, 0])
            )
            weights[k + 1, i] = weights[k, i] * exp(
                -(1 / (2 * sigma_bruit**2)) * ecart**2
            )
        weights[k + 1, :] /= sum(weights[k + 1, :])
        # Redistribution:
        if 1 / np.sum(weights[k + 1, :] ** 2) < n_redis:
            x_mean = np.sum(
                weights[k + 1, :].reshape(n_part, 1, 1) * part_pred,
                axis=0,
            )
            cov_mat = np.sum(
                weights[k + 1, :].reshape((n_part, 1, 1))
                * np.matmul(
                    part_pred - x_mean,
                    np.transpose(part_pred - x_mean, axes=[0, 2, 1]),
                    axes=[(-2, -1), (-2, -1), (-2, -1)],
                ),
                axis=0,
            )
            big_i = rd.multinomial(n=n_part, pvals=weights[k + 1, :])
            index_new_part = 0  # pylint: disable=invalid-name
            for i in range(n_part):
                if big_i[i] != 0:
                    kern_dev = rd.multivariate_normal(
                        mean=6 * [0], cov=h_fin**2 * cov_mat, size=big_i[i]
                    ).reshape((big_i[i], 6, 1))
                    particles[
                        k + 1, index_new_part : index_new_part + big_i[i], :, :
                    ] = (
                        part_pred[i] + kern_dev
                    )  # pylint: disable=superfluous-parens
                    index_new_part += big_i[i]
            weights[k + 1, :] = 1 / n_part
        else:
            particles[k + 1, :, :, :] = part_pred[:, :, :]
    return particles

In [None]:
def step_viewer(
    x_val: np.ndarray,
    y_val: np.ndarray,
    h_mat: np.ndarray,
    particles: np.ndarray,
    etat_avion: np.ndarray,
) -> Callable[[int], None]:
    def view(k: int) -> None:
        fig = plt.figure()
        axe = fig.add_axes((0, 0, 1, 1))
        axe.set_ylim(200000, 100000)
        axe.set_xlim(100000, 200000)
        axe.pcolormesh(x_val, y_val, h_mat)
        axe.scatter(
            particles[k, :, 0, 0],
            particles[k, :, 1, 0],
            color="pink",
            marker=".",
            s=0.5,
        )
        axe.scatter(etat_avion[0, k], etat_avion[1, k], color="red", s=1)

    return view

In [None]:
particles = particle_filter(
    h_mes=h_mes,
    h_mnt=currified_h_mnt(x_val=x_val, y_val=y_val, h_mat=h_mat),
    init_particles=init_particles,
    phi=PHI,
    sigma_bruit=SIGMA_BRUIT,
)
see_step = step_viewer(
    x_val=x_val, y_val=y_val, h_mat=h_mat, particles=particles, etat_avion=etat_avion
)

In [None]:
see_step(k=700)  # type: ignore