In [1]:
import jax.numpy as jnp
from utils_v2 import ExtendedKalmanFilter
import math
import matplotlib.pyplot as plt
import jax.random as jrandom
from time import time

In [2]:
sigma_x = 0.01
sigma_t = 0.01

x_values = jnp.arange(0, 1.5, sigma_x)
t_values = jnp.arange(0, 2, sigma_t)
N = len(x_values)
M = len(t_values)
v = 0.001

In [3]:
u = jnp.sin(x_values * 2 * math.pi)
f = jnp.cos(x_values * 4 * math.pi)
nu = jnp.zeros(x_values.shape)

In [4]:
x_0 = jnp.expand_dims(jnp.r_[u, f, u, f, nu], axis=-1)

In [5]:
A = -2 * jnp.ones(u.shape[0])
A = jnp.diag(A)
A += jnp.diag(jnp.ones(u.shape[0] - 1), k=-1) + jnp.diag(jnp.ones(u.shape[0] - 1), k=1)
A /= sigma_x**2


In [6]:
B = jnp.diag(jnp.ones(u.shape[0] - 1), k=-1) + jnp.diag(jnp.ones(u.shape[0] - 1), k=1)
B /= 2 * sigma_x


In [7]:
INVS.shape

NameError: name 'INVS' is not defined

In [None]:
INVS = jnp.linalg.inv(jnp.eye(A.shape[0]) - 0.5 * sigma_t * v * A)

In [None]:
def new_u(
    u: jnp.ndarray, f: jnp.ndarray, prev_u: jnp.ndarray, prev_f: jnp.ndarray, v: float
) -> jnp.ndarray:
    new_u = INVS
    temp = (jnp.eye(A.shape[0]) + 0.5 * sigma_t * v * A) @ u

    temp -= sigma_t * B @ (1.5 * u**2 - 0.5 * prev_u**2)
    temp += sigma_t * (1.5 * f - 0.5 * prev_f)
    new_u @= temp
    return new_u

In [None]:
def new_state(
    x_k: jnp.ndarray,
    cnrl: jnp.ndarray = None,
) -> jnp.ndarray:
    start = 0
    end = u.shape[0]
    u_k = x_k[start:end]
    start += f.shape[0]
    end += f.shape[0]
    f_k = x_k[start:end]
    start += u.shape[0]
    end += u.shape[0]
    prev_u_k = x_k[start:end]
    start += f.shape[0]
    end += f.shape[0]
    prev_f_k = x_k[start:end]
    nu = x_k[end:]

    new_x = jnp.r_[new_u(u_k, f_k, prev_u_k, prev_f_k, v), cnrl, u_k, f_k, nu]
    return new_x

In [None]:
eye = jnp.eye(f.shape[0])
zero = jnp.zeros(A.shape)

"""
def jaccobian_f(x_k: jnp.ndarray) -> jnp.ndarray:
    Fu = INVS @ (jnp.eye(A.shape[0]) + (0.5 * sigma_t * v * A)) @ x_k[N:2*N] - 3 * INVS @ B @ (x_k[N:2*N] * x_k[N:2*N]) * sigma_t
    
    Ff = (sigma_t * 1.5 * INVS) @ x_k[3*N:4*N]

    Fpu=INVS @ B @ (x_k[2*N:3*N] * x_k[2*N:3*N]) * sigma_t
    

    Fpf=(-sigma_t * 1.5 * INVS) @ x_k[3*N:4*N]
    zero=jnp.zeros((150,1))
    eye=jnp.ones((150,1))

    F=jnp.block([[Fu,Ff,Fpu,Fpf,zero],
                [zero,eye,zero,zero,zero],
                [zero,zero,zero,zero,zero],
                [zero,zero,zero,eye,zero],
                [zero,zero,zero,zero,eye]])
    

    return F
"""


def jaccobian_f(x_k: jnp.ndarray) -> jnp.ndarray:
    Fuu = (
        jnp.eye(A.shape[0])
        + 0.5 * sigma_t * v * A
        - 3 * B @ jnp.diag(x_k[: A.shape[0]].squeeze()) * sigma_t
    )
    Fuu += 0.5 * B @ jnp.diag(x_k[A.shape[0] : 2 * A.shape[0]].squeeze()) * sigma_t
    Fuu @= INVS

    eye = jnp.eye(f.shape[0])
    zero = jnp.zeros(A.shape)

    Fuf = jnp.eye(f.shape[0]) * sigma_t
    Fpu = INVS @ B @ (x_k[2 * N : 3 * N] * x_k[2 * N : 3 * N]) * sigma_t

    Fpf = (-sigma_t * 1.5 * INVS) @ x_k[3 * N : 4 * N]
    print(Fuu.shape, Fuf.shape, Fpu.shape, Fpf.shape)

    F = jnp.block(
        [
            [Fuu, Fuf, Fpu, Fpf, zero],
            [zero, eye, zero, zero, zero],
            [zero, zero, eye, zero, zero],
            [zero, zero, zero, eye, zero],
            [zero, zero, zero, zero, eye],
        ]
    )
    return F

In [None]:
print((INVS @ B @ (x_0[2 * N : 3 * N] * x_0[2 * N : 3 * N]) * sigma_t).shape)

(150, 1)


In [None]:
def h(x_k: jnp.ndarray) -> jnp.ndarray:
    block = jnp.block(
        [
            jnp.eye(A.shape[0]),
            jnp.zeros(A.shape),
            jnp.zeros(A.shape),
            jnp.zeros(A.shape),
            jnp.zeros(A.shape),
        ]
    )
    return block @ x_k


In [None]:
def jaccobian_h(x_k: jnp.ndarray) -> jnp.ndarray:
    block = jnp.block(
        [
            jnp.eye(A.shape[0]),
            jnp.zeros(A.shape),
            jnp.zeros(A.shape),
            jnp.zeros(A.shape),
            jnp.zeros(A.shape),
        ]
    )
    return block


In [None]:
R = jnp.eye(u.shape[0]) * 0.05
Q = jnp.eye(5 * u.shape[0]) * 0.05
P_0 = jnp.eye(5 * u.shape[0]) * 0.05

In [None]:
def plot(ekf: ExtendedKalmanFilter, iterations: int, plot_each: int = None) -> None:
    plot_each = iterations if plot_each is None else plot_each
    plot_each = min(plot_each, iterations)
    fig, ax = plt.subplots()
    ax.plot(x_values, x_0[:N], label="True state")
    for i in range(1, iterations + 1):
        Bb = jnp.block([ekf.A[0:N], eye, eye, eye, eye]).T
        Bb = ekf.A[0:N].T
        K = (jnp.linalg.inv(ekf.R + Bb.T @ ekf.P @ Bb)) @ Bb.T @ ekf.P @ ekf.A
        ekf.predict(-1 * K @ ekf.x_k)
        ekf.update(x_0[:N])
        if i % plot_each == 0:
            ax.plot(x_values, ekf.x_k[:N], label=f"EKF estimation after {i} iterations")

    plt.xlabel("x_n values", fontsize=15)
    plt.ylabel("u_n values", fontsize=15)
    plt.title("Training evolution", fontsize=15)
    ax.legend(fontsize=15)
    fig.set_size_inches(17, 10, forward=True)
    plt.show()


In [None]:
key = jrandom.PRNGKey(int(time()))
mean = 0
std_dev = 0.05
w_k = jrandom.normal(key, shape=(5 * A.shape[0], 1)) * std_dev + mean
Z = jrandom.normal(key, shape=(A.shape[0], 1)) * std_dev + mean
ekf = ExtendedKalmanFilter(
    x_0, new_state, h, R, Q, Z, w_k, P_0, jaccobian_f, jaccobian_h
)
plot(ekf, iterations=20, plot_each=5)

(150, 150) (150, 150) (150, 1) (150, 1)


TypeError: Cannot concatenate arrays with shapes that differ in dimensions other than the one being concatenated: concatenating along dimension 0 for shapes (150, 452), (150, 750), (150, 750), (150, 750), (150, 750).