In [9]:
import torch 
from utils import *
from Kalman3 import *
import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 150

In [10]:
class ArrayModel:
    """
    A class representing an antenna array model.

    Args:
        M (int): Number of sensors in the array.
        perturbation (float, optional): Percent perturbation added to sensor positions.

    Methods:
        get_steering_vector(theta): Computes the steering vector for given angles.
        get_steering_vector_derivative(theta): Computes the derivative of the steering vector.
    """

    def __init__(self, m: int, perturbation: float = 0):
        self.m = m
        self.array = torch.arange(0, m, 1)
        self.array = torch.arange(0, m, 1) + (torch.rand(m) * 2 - 1) * perturbation
    
    def get_steering_vector(self, theta: torch.Tensor) -> torch.Tensor:
        """
        Compute the steering vector for the given incident angles.

        Args:
            theta (torch.Tensor): Incident angles of size (D) or (B, D).

        Returns:
            torch.Tensor: Steering vector of size (M, D) or (B, M, D).
        """
        if len(theta.shape) == 1:
            return torch.exp(-1j * torch.pi * self.array.reshape(-1, 1) * torch.sin(theta).reshape(1, -1))
        elif len(theta.shape) == 2:
            return torch.exp(-1j * torch.pi * self.array.reshape(1, -1, 1) * torch.sin(theta).unsqueeze(1))

In [11]:
d = 2 # 2 sources
m = 8 # 8 sensors
SNR = 5 
array = ArrayModel(m)

T = 40.0 # Observation duration
Ts = 1.0 # Sampling time 
time = torch.arange(0, T, Ts)
t = time.shape[0]

v1 = 1500 # speed of airplane 1
x1 = -27000 + time * v1 # x coordinate of airplane 1
y1 = torch.Tensor([27000]) # y coordinate of airplane 1
theta1 = torch.atan2(x1, y1) # DoA of plane 1

v2 = -1500 # speed of airplane 2
x2 = 32000 + time * v2 # x coordinate of airplane 2
y2 = torch.Tensor([30000]) # y coordinate of airplane 2
theta2 = torch.atan2(x2, y2) # DoA of plane 2

thetas = torch.stack((theta1, theta2), dim=1) # Matrix of DoA of shape (nbSamples x nbSources)

In [12]:
alpha_abs = 0.99
alpha_angle = torch.rand(d) * 2 * torch.pi
alpha = alpha_abs * torch.exp(1j * alpha_angle)

sigma = 0.01
sigma_x = sigma / sqrt(1 - alpha_abs ** 2)
sigma_noise = sigma_x / 10 ** (SNR / 20)

x = torch.zeros(t, d, dtype=torch.complex64) # AR(1) model for sources, matrix of shape (nbSamples x nbSources)
x[0] = (torch.randn(d) + 1j * torch.randn(d)) / sqrt(2) * sigma_x
for i in range(1, t):
    x[i] = alpha * x[i - 1] + (torch.randn(d) + 1j * torch.randn(d)) / sqrt(2) * sigma

n = (torch.randn(t, m) + 1j * torch.randn(t, m)) / sqrt(2) * sigma_noise # noise in antenna array
y = torch.einsum('tmd,td->tm', array.get_steering_vector(thetas), x) + n # observation used for tracking, matrix of shape (nbSamples x nbSensors)

State vector

$$
\mathbf{x}(k) = \begin{bmatrix} 
\theta_1(k) \\
\theta_2(k) \\
\dot{\theta}_1(k) \\
\dot{\theta}_2(k) \\
\Re\{x_1(k)\} \\
\Re\{x_2(k)\} \\
\Im\{x_1(k)\} \\
\Im\{x_2(k)\} \\
\end{bmatrix}
$$

Observation vector 

$$
\mathbf{y}(k) = \begin{bmatrix} 
\Re\{y(k)\} \\
\Im\{y(k)\} \\
\end{bmatrix}
$$

In [13]:
x_init = torch.cat((thetas[1], 
                   (thetas[1] - thetas[0]) / Ts, 
                    torch.real(x[1]), 
                    torch.imag(x[1])), dim=0) # defined exactly like the above equation

P_init = torch.eye(4 * d) * 1

Q = torch.cat((torch.cat((torch.kron(torch.Tensor([[Ts ** 3 / 3, Ts ** 2 / 2],
                                                   [Ts ** 2 / 2, Ts]]), torch.eye(d)), 
                                                   torch.zeros(2 * d, 2 * d)), dim=1), 
               torch.cat((torch.zeros(2 * d, 2 * d), sigma * torch.eye(2*d)), dim=1)), dim=0) # noise matrix of dynamic equation

R = torch.eye(2 * m) * sigma_noise # noise matrix of observation equation

Filter = ExtendedKalmanFilter(array, d, Ts, alpha, x_init, P_init, Q, R)

In [14]:
# results = [x_init[:d]]

# for i in range(2, t):
#     yi = torch.cat((torch.real(y[i]), torch.imag(y[i])), dim=0)
#     results.append(Filter.step(yi))

# results = torch.stack(results, dim=0)

In [15]:
# plt.plot(time[1:], thetas[1:, 0], color='red', linewidth=0.5)
# plt.plot(time[1:], thetas[1:, 1], color='green', linewidth=0.5)
# plt.plot(time[1:], results[:, 0], color='red', marker='.', linewidth=0.5, markersize=1.0)
# plt.plot(time[1:], results[:, 1], color='green', marker='.', linewidth=0.5, markersize=1.0)
# plt.xlabel('time')
# plt.ylabel('theta')
# plt.ylim(-torch.pi/2, torch.pi/2)
# plt.show()