In [1]:
from typing import Optional

import matplotlib.pyplot as plt
import mlx
import mlx.core as mx
import mlx.optimizers as optim
from tqdm import tqdm 
from src.state_space import StateSpaceModel
from src.kalman import KalmanFilter
from src.utils import sample_mvn

In [2]:
mx.random.seed(0)

In [3]:
F = mx.eye(6)
for i in range(3):
    F[i, 3+i] = 1 

def generate_H(x: Optional[mx.array] = None) -> mx.array:
    if x is None:
        x = mx.ones((3,))
    H = mx.zeros((4, 6))
    for i in range(3):
        H[i,i] = 1
    r = mx.linalg.norm(x)
    H[3,3:6]= x/r if r > 0 else x
    return H

H = generate_H() 

# Q_true = 0.1 * mx.eye(6)
# R_true = mx.diag(mx.array([10000., 10000., 10000., 25.]))
Q_true = 0.1 * mx.eye(6)
R_true = mx.diag(mx.array([100., 150., 60., 70]))

init_state_mean = mx.zeros((6,))
init_state_cov = mx.eye(6)

truth_model = StateSpaceModel(F, H, Q_true, R_true)

In [4]:
def generate_trajectory(init_state: mx.array,
                        model: StateSpaceModel,
                        t_steps: int) -> tuple[mx.array, mx.array]:
        
    process_noise, observation_noise = model.generate_noise(t_steps)

    states = []
    observations = []

    state = init_state
    for t in range(t_steps):
        states.append(state)
        model.update_H(generate_H(state[:3]))
        state, observation = model.step(state, process_noise[t], observation_noise[t])
        observations.append(observation)
        
    return mx.stack(states), mx.stack(observations)  


def generate_dataset(init_states: mx.array,
                     model: StateSpaceModel,
                     t_steps: int) -> tuple[mx.array, mx.array]:
    x_train = []
    y_train = []

    for n in range(init_states.shape[0]):
        x, y = generate_trajectory(
            init_states[n],
            model,
            t_steps,
        )
        x_train.append(x)
        y_train.append(y)
    return mx.stack(x_train), mx.stack(y_train)

In [5]:
N = 10
T = 20

train_inits = sample_mvn(N, init_state_mean, init_state_cov)
test_inits = sample_mvn(N, init_state_mean, init_state_cov)

x_train, y_train = generate_dataset(train_inits, truth_model, T)
x_test, y_test = generate_dataset(test_inits, truth_model, T)

In [6]:
@mx.custom_function
def linear_solve(A: mx.array, B: mx.array):
    """
    Returns X s.t. AX=B, where A is a symmetric PSD matrix
    """
    return mx.linalg.solve(A, B, stream=mx.Device(mx.cpu))

@linear_solve.vjp
def linear_solve_vjp(primals, cotangent, output):
    A, B = primals
    grad_B = mx.linalg.solve(A.T, cotangent, stream=mx.Device(mx.cpu))
    grad_A = -grad_B @ output.T
    return grad_A, grad_B

# @linear_solve.jvp
# def linear_solve_jvp(primals, tangents, output):
#     A, B = primals
#     dA, dB = tangents
#     dX = mx.linalg.solve(A, dB - dA @ output, stream=mx.Device(mx.cpu))
#     return dX

In [7]:
@mx.compile
def _kalman_predict(F: mx.array,
                   state_mean: mx.array,
                   state_cov: mx.array,
                   Q: mx.array) -> tuple[mx.array, mx.array]:
        state_mean_pred = F @ state_mean
        state_cov_pred = F @ state_cov @ F.transpose() + Q
        return state_mean_pred, state_cov_pred

@mx.compile
def _kalman_update(H: mx.array,
                  state_mean: mx.array,
                  state_cov: mx.array,
                  R: mx.array,
                  observation: mx.array) -> tuple[mx.array, mx.array]:
    
    d_y = H.shape[0]
    d_x = H.shape[1]

    gain = linear_solve( 
        H @ state_cov @ H.transpose() + R,
        H @ state_cov,
    )
    gain = gain.transpose()
    new_state_mean = state_mean + gain @ (observation - H @ state_mean) 
    new_state_cov = (mx.eye(d_x) - gain @ H) @ state_cov
    return new_state_mean, new_state_cov

@mx.compile
def kalman_step(F: mx.array,
                 H: mx.array,
                 Q: mx.array,
                 R: mx.array,
                 state_mean: mx.array,
                 state_cov: mx.array,
                 observation: mx.array) -> tuple[mx.array, mx.array]:
    state_mean, state_cov = _kalman_predict(
        F, state_mean, state_cov, Q,
    )
    estim_state_mean, estim_state_cov = _kalman_update(
        H, state_mean, state_cov, R, observation,
    )
    return estim_state_mean, estim_state_cov

In [8]:
@mx.compile
def cholvec_to_cov(v: mx.array, d: int) -> mx.array:
    # cholvec is of dim (1/2 * d (d+1),) for some d
    L = mx.zeros((d,d))
    for i in range(d):
        for j in range(i+1):
            index = int(i*(i+1)/2 + j)
            L[i,j] = mx.exp(v[index]) if i == j else v[index]
    return L @ L.T

@mx.compile
def cov_to_cholvec(C: mx.array) -> mx.array:
    L = mx.linalg.cholesky(C, stream=mx.Device(mx.cpu))
    d = C.shape[0]
    v = mx.zeros((int(d*(d+1)/2),))
    for i in range(d):
        for j in range(i+1):
            index = int(i*(i+1)/2 + j)
            v[index] = mx.log(L[i,j]) if i == j else L[i,j]
    return v

In [9]:
Q_init = mx.eye(6)
R_init = mx.eye(4)

model_okf = StateSpaceModel(F, H, Q_init, R_init)
okf = KalmanFilter(model_okf, init_state_mean, init_state_cov)

def loss_fn(params: dict, x_train: mx.array, y_train: mx.array):
    loss = mx.array(0)
    Q = cholvec_to_cov(params["w_q"], okf.model.d_x)
    R = cholvec_to_cov(params["w_r"], okf.model.d_y)
    B, T, _ = x_train.shape
    for b in range(B):
        state_mean = okf.init_state_mean
        state_cov = okf.init_state_cov
        for t in range(T):
            current_measurement = y_train[b,t,:]  # (d_y,)
            okf.update_model_H(current_measurement[:3])
            state_mean, state_cov = kalman_step(
                okf.model.F, okf.model.H, Q, R,
                state_mean, state_cov, current_measurement,
            )
            loss += mx.square(state_mean - x_train[b, t]).sum()
    return loss / (B*T)


params = {
    "w_q": cov_to_cholvec(Q_init),
    "w_r": cov_to_cholvec(R_init),
}


In [10]:
from typing import Callable
import copy

# Finite differences are very sensitive to the epsilon value
# for small epsilon, the errors blow up. This suggests that 
# we are doing okay in so far as allowable by floating point.

def finite_difference_grads(loss_fn: Callable,
                            params: dict,
                            x_train: mx.array,
                            y_train: mx.array,
                            eps: float = 1e-2) -> dict:
    """Computes finite difference gradient approximation for each parameter."""
    fd_grads = {}

    for key in params:
        param = params[key]
        grad = mx.zeros_like(param)

        # Flatten and iterate over each parameter element
        for i in range(param.size):
            perturb = mx.zeros_like(param)
            perturb[i] = eps

            params_plus = copy.deepcopy(params)
            params_minus = copy.deepcopy(params)

            params_plus[key] = param + perturb
            params_minus[key] = param - perturb

            # Compute loss at perturbed points
            loss_plus = loss_fn(params_plus, x_train, y_train)
            loss_minus = loss_fn(params_minus, x_train, y_train)

            # Finite difference approximation
            grad[i] = (loss_plus - loss_minus) / (2 * eps)

        fd_grads[key] = grad

    return fd_grads

In [11]:
# Choose a small batch to reduce computation time
x_sample = x_train
y_sample = y_train

# Compute autodiff gradients
loss, grads = mx.value_and_grad(loss_fn)(params, x_sample, y_sample)

# Compute finite difference gradients
fd_grads = finite_difference_grads(loss_fn, params, x_sample, y_sample)

# Compare gradients
for key in grads:
    abs_diff = mx.linalg.norm(grads[key] - fd_grads[key])
    rel_error = abs_diff / (mx.linalg.norm(grads[key]) + 1e-8)
    print(f"{key} | abs diff: {abs_diff:.3e} | rel error: {rel_error:.3e}")

w_q | abs diff: 1.811e-02 | rel error: 2.438e-04
w_r | abs diff: 1.493e-02 | rel error: 1.827e-04
