In [88]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [89]:
import numpy as np
import matplotlib.pyplot as plt

In [96]:
def get_sigma_points(mu, sigma, lambd, alpha, beta):
    n = mu.shape[0]
    sigma_points = np.zeros((n, 2*n+1))
    
    sigma_points[:, 0:1] = mu
    
    mroot = np.linalg.cholesky((n+lambd)*sigma)
    
    for i in range(1, n+1):
        sigma_points[:,i:i+1] = mu + mroot[:, i-n+1:i-n+2]
    
    for i in range(n+1, 2*n+1):
        sigma_points[:,i:i+1] = mu - mroot[:, i-2*n+1:i-2*n+2]
        
    w_m = np.zeros(2*n+1)
    w_c = np.zeros(2*n+1)
    
    w_m[0] = lambd / (n + lambd)
    w_c[0] = w_m[0] + (1 - alpha**2 + beta)
    weight = 1 / ( 2 * (n + lambd))
    for t in range(1,2*n+1):
        w_m[t] = weight
        w_c[t] = weight
        
    return sigma_points, w_m, w_c

In [98]:
def recover_gaussian(sigma_points, w_m, w_c):
    n = sigma_points.shape[0]
    n_col = sigma_points.shape[1]
    
    mu = np.zeros((n, 1))
    
    for i in range(0, n_col):
        mu = mu + w_m[i] * sigma_points[:, i:i+1]
        
    sigma = np.zeros((n, n))
    for i in range(0, n_col):
        sigma = sigma + w_c[i]*(sigma_points[:, i:i+1] - mu) * (sigma_points[:, i:i+1] - mu).T

    return mu, sigma

In [99]:
def transform_linear(points):
    points[0:1, :] = points[0:1, :] + 1;
    points[1:2, :] = points[1:2, :] + 2;

In [100]:
sigma = np.eye(2) * 0.1
mu = np.matrix([2, 1]).T

n = mu.shape[0]
alpha = 0.9
beta = 2
kappa = 1
lambd = alpha*alpha*(n+kappa)-n


sigma_points, w_m, w_c = get_sigma_points(mu, sigma, lambd, alpha, beta)
print(sigma_points)
print(w_m)
print(w_c)

transform_linear(sigma_points)

mu, sigma = recover_gaussian(sigma_points, w_m, w_c)
print(mu)
print(sigma)

[[ 2.         2.4929503  2.         1.5070497  2.       ]
 [ 1.         1.         1.4929503  1.         0.5070497]]
[ 0.17695473  0.20576132  0.20576132  0.20576132  0.20576132]
[ 2.36695473  0.20576132  0.20576132  0.20576132  0.20576132]
[[ 3.]
 [ 3.]]
[[ 0.1  0. ]
 [ 0.   0.1]]
