In [47]:
import torch
from system_config import Nt, delta, desired_angle_rad
n_indices = torch.arange(Nt, dtype=torch.float32)
desired_angle_rad_torch = torch.tensor(desired_angle_rad, dtype=torch.float32)
phase = 1j* 2 * torch.pi * delta * torch.sin(desired_angle_rad_torch) * n_indices
a_phi_0 = torch.exp(phase)  # shape: (Nt, 1)
a_dot_phi_0 = ((1j * 2 * torch.pi * delta * torch.cos(desired_angle_rad_torch) * n_indices) * a_phi_0)

a_phi_0 = a_phi_0.unsqueeze(1)  # shape: (Nt, 1)
a_dot_phi_0 = a_dot_phi_0.unsqueeze(1)  # shape: (Nt, 1)

In [53]:
A_dot = a_dot_phi_0 @ a_phi_0.transpose(0, 1) + a_phi_0 @ a_dot_phi_0.transpose(0, 1)
print(A_dot.shape)

torch.Size([64, 64])


In [72]:
R_N = torch.eye(Nt)  # noise covariance matrix
R_N_inv = torch.linalg.inv(R_N)  # inverse of noise covariance matrix
print(R_N_inv.shape)

torch.Size([64, 64])


In [82]:

from utility import initialize, get_data_tensor
from system_config import initial_normalization, data_source, snr_dB_list

import numpy as np

H_train, H_test0 = get_data_tensor(data_source)
H_shuffeld = torch.transpose(H_train, 0, 1)[np.random.permutation(len(H_train[0]))]
H = torch.transpose(H_shuffeld[0:4], 0, 1)
snr_dB_train = np.random.choice(snr_dB_list)
snr_train = 10 ** (snr_dB_train / 10)
rate_init, F, W = initialize(H, snr_train, initial_normalization)

In [83]:
print(f'Shape of the metrixes: F.shape = {F.shape}, W.shape = {W.shape}')

Shape of the metrixes: F.shape = torch.Size([1, 4, 64, 4]), W.shape = torch.Size([1, 4, 4, 4])


In [117]:
def get_grad_F_CRB(F, W, xi_0, A_dot, R_N_inv):

    # match the data type of A_dot and R_N_inv with F
    A_dot = A_dot.to(F.dtype)
    R_N_inv = R_N_inv.to(F.dtype)

    # reshape A_dot and R_N_inv for batch processing
    A_dot = A_dot.unsqueeze(0).unsqueeze(0) # [1, 1, Nt, Nt]
    R_N_inv = R_N_inv.unsqueeze(0).unsqueeze(0) # [1, 1, Nr, Nr]

    A_dot_H = A_dot.conj().transpose(-2, -1)
    W_H = W.conj().transpose(-2, -1)
    F_H = F.conj().transpose(-2, -1)
    
    M = A_dot_H @ R_N_inv @ A_dot

    inner_mat = W_H @ F_H @ M @ F @ W
    batch_trace = (torch.diagonal(inner_mat, dim1=-2, dim2=-1).sum(-1)) ** 2
    
    numerator = M @ F @ W @ W_H
    denominator = (2 * (torch.abs(torch.tensor(xi_0))**2) * batch_trace).view(1, -1, 1, 1)
    
    grad_F_CRB = numerator / denominator
    
    return grad_F_CRB

In [118]:
from system_config import xi_0
grad_F_CRB = get_grad_F_CRB(F, W, xi_0, A_dot, R_N_inv)

In [None]:
print(grad_F_CRB.shape)

tensor([[[[ 8.3058e+03+1.0464e+04j,  3.8136e+03-7.1531e+02j,
            5.7495e+03+3.6689e+03j,  1.8971e+03-1.1728e+04j],
          [ 3.3853e+03-1.3310e+04j, -2.9606e+03-2.6708e+03j,
           -5.7972e+02-6.9783e+03j, -1.0771e+04+5.7501e+03j],
          [-1.2985e+04+5.5168e+03j, -3.3794e+02+4.0805e+03j,
           -5.3357e+03+4.8110e+03j,  1.1399e+04+5.2243e+03j],
          ...,
          [-3.2172e+04-1.6613e+04j, -8.4656e+03+6.0747e+03j,
           -1.7739e+04-2.5593e+03j,  8.2360e+03+3.0884e+04j],
          [ 6.3334e+03+3.6030e+04j,  1.0064e+04+3.0870e+03j,
            8.7971e+03+1.5824e+04j,  1.9787e+04-2.5520e+04j],
          [ 2.5086e+04-2.7139e+04j, -3.6777e+03-9.9778e+03j,
            7.3350e+03-1.6751e+04j, -3.2620e+04-2.8946e+02j]],

         [[-2.0059e+03-5.0146e+02j,  4.6715e+03+4.6464e+03j,
            6.4630e+02+1.2073e+03j, -1.1100e+03+3.7702e+03j],
          [ 8.3579e+02+1.9505e+03j,  8.9362e+02-6.7124e+03j,
            5.8508e+02-1.2794e+03j,  3.7728e+03-1.4396e+03j],

In [None]:
def get_grad_W_CRB(F, W, xi_0, A_dot, R_N_inv):

    A_dot = A_dot.to(F.dtype)
    R_N_inv = R_N_inv.to(F.dtype)

    A_dot = A_dot.unsqueeze(0).unsqueeze(0) # [1, 1, Nt, Nt]
    R_N_inv = R_N_inv.unsqueeze(0).unsqueeze(0) # [1, 1, Nr, Nr]

    A_dot_H = A_dot.conj().transpose(-2, -1)
    W_H = W.conj().transpose(-2, -1)
    F_H = F.conj().transpose(-2, -1)


    M = A_dot_H @ R_N_inv @ A_dot
    inner_mat = W_H @ F_H @ M @ F @ W
    batch_trace = (torch.diagonal(inner_mat, dim1=-2, dim2=-1).sum(-1)) ** 2
    
    numerator = F_H @ M @ F @ W
    denominator = 2 * (torch.abs(torch.tensor(xi_0))**2) * batch_trace.view(1, -1, 1, 1)
    grad_W_CRB = numerator / denominator
    return grad_W_CRB

In [121]:
from system_config import xi_0
get_grad_W_CRB = get_grad_W_CRB(F, W, xi_0, A_dot, R_N_inv)

In [130]:
print(get_grad_W_CRB)

tensor([[[[ 4.5998e+05-6.8725e+04j,  1.2370e+04-1.6614e+05j,
            3.2411e+05-1.1673e+05j, -2.9971e+05-1.2474e+05j],
          [ 1.8984e+04+2.3601e+05j,  8.4064e+04+1.2040e+04j,
            4.8660e+04+1.6877e+05j,  7.3723e+04-1.4801e+05j],
          [ 5.0835e+05+2.8842e+04j,  5.0868e+04-1.7547e+05j,
            3.7676e+05-5.3326e+04j, -2.9367e+05-2.0099e+05j],
          [-1.9420e+05+1.4727e+03j, -1.4721e+04+6.7976e+04j,
           -1.4016e+05+2.9430e+04j,  1.1660e+05+6.9017e+04j]],

         [[ 7.8749e+04+2.3223e+04j, -2.4906e+04-2.1747e+05j,
           -9.6019e+03-8.4896e+04j,  1.1225e+05-8.3736e+04j],
          [-3.8953e+04+1.4118e+05j,  3.9750e+05-5.4251e+04j,
            1.5436e+05-2.1119e+04j,  1.5845e+05+2.0061e+05j],
          [-3.4812e+04+8.1223e+04j,  2.4122e+05+1.7909e+02j,
            9.3736e+04+5.7081e+01j,  7.7927e+04+1.3258e+05j],
          [ 6.8791e+04+6.3731e+04j,  8.7483e+04-2.4133e+05j,
            3.3881e+04-9.3815e+04j,  1.6098e+05-3.0214e+04j]],

         [[ 

In [126]:
from utility import normalize
def get_crb(H, F, W, xi_0, A_dot, R_N_inv, Pt):
    F, W = normalize(F, W, H, Pt)
    
    A_dot = A_dot.to(F.dtype)
    R_N_inv = R_N_inv.to(F.dtype)

    A_dot = A_dot.unsqueeze(0).unsqueeze(0) # [1, 1, Nt, Nt]
    R_N_inv = R_N_inv.unsqueeze(0).unsqueeze(0) # [1, 1, Nr, Nr]

    A_dot_H = A_dot.conj().transpose(-2, -1)
    W_H = W.conj().transpose(-2, -1)
    F_H = F.conj().transpose(-2, -1)
    
    M = A_dot_H @ R_N_inv @ A_dot
    inner_mat = W_H @ F_H @ M @ F @ W
    batch_trace = (torch.diagonal(inner_mat, dim1=-2, dim2=-1).sum(-1))

    denominator = 2 * (torch.abs(torch.tensor(xi_0))**2) * batch_trace.view(1, -1, 1, 1)

    crb = 1 / denominator
    
    return crb
    

In [128]:
crb = get_crb(H, F, W, xi_0, A_dot, R_N_inv, snr_train)
print(crb)

tensor([[[[208735.7969+0.0460j]],

         [[130972.7031-0.0163j]],

         [[348946.6562+0.0271j]],

         [[ 36266.1445-0.0021j]]]])
