In [1]:
import numpy as np
import pickle

# 指定.pkl文件的路径
file_path = "./R_ET_1_0.pkl"

# 打开并读取.pkl文件中的数据
with open(file_path, 'rb') as f:
    [SNR, FM, wholeFM, SNRw] = pickle.load(f)

# 提取 9x9 的 Fisher 矩阵，并仅保留前 7 个参数的子矩阵
def extract_relevant_submatrix(matrix, size=7):
    """提取前 size 个参数的子矩阵"""
    return matrix[:size, :size]

fisher_matrix_1 = extract_relevant_submatrix(FM)
print(fisher_matrix_1)

[[ 6.25162357e+02  5.50588656e+03  4.95517563e+05  1.59372497e+03
  -6.30320337e+02 -6.58566336e+05 -3.91101952e+02]
 [ 5.50588656e+03  1.25797291e+05  1.14203147e+07  2.04696721e+02
   1.39120379e+03 -3.99264456e+07 -8.73721741e+03]
 [ 4.95517563e+05  1.14203147e+07  1.03698696e+09  6.39836649e+02
   1.34072920e+05 -3.58463400e+09 -7.92854858e+05]
 [ 1.59372497e+03  2.04696721e+02  6.39836649e+02  6.62828596e+03
  -2.84814829e+03 -7.40254013e+03 -2.88889625e+02]
 [-6.30320337e+02  1.39120379e+03  1.34072920e+05 -2.84814829e+03
   1.25901745e+03 -2.47882681e+06 -8.34020492e+01]
 [-6.58566336e+05 -3.99264456e+07 -3.58463400e+09 -7.40254013e+03
  -2.47882681e+06  3.76758877e+11  1.47345779e+07]
 [-3.91101952e+02 -8.73721741e+03 -7.92854858e+05 -2.88889625e+02
  -8.34020492e+01  1.47345779e+07  1.23944815e+03]]


In [2]:
# 处理第二个文件
file_path = "./R_ET_2_0.pkl"
with open(file_path, 'rb') as f:
    [SNR, FM, wholeFM, SNRw] = pickle.load(f)
fisher_matrix_2 = extract_relevant_submatrix(FM)
print(fisher_matrix_2)

[[ 7.04889104e+02  3.83210915e+03  4.16671716e+05 -1.70010064e+03
  -9.10510532e+02 -3.89619537e+05  4.39438430e+02]
 [ 3.83210915e+03  1.09653058e+05  9.84620778e+06  2.56776730e+03
   2.91769957e+03 -3.47029700e+07  9.82933138e+03]
 [ 4.16671716e+05  9.84620778e+06  8.93588681e+08  5.83301204e+02
   1.33330583e+05 -3.10804759e+09  9.04640633e+05]
 [-1.70010064e+03  2.56776730e+03  5.83301204e+02  5.73130357e+03
   3.24137321e+03 -6.70079438e+03 -3.10432162e+02]
 [-9.10510532e+02  2.91769957e+03  1.33330583e+05  3.24137321e+03
   1.87305443e+03 -2.46106977e+06  9.48538336e+01]
 [-3.89619537e+05 -3.47029700e+07 -3.10804759e+09 -6.70079438e+03
  -2.46106977e+06  3.29279258e+11 -1.66555844e+07]
 [ 4.39438430e+02  9.82933138e+03  9.04640633e+05 -3.10432162e+02
   9.48538336e+01 -1.66555844e+07  1.85450516e+03]]


In [3]:
# 处理第三个文件
file_path = "./R_ET_3_0.pkl"
with open(file_path, 'rb') as f:
    [SNR, FM, wholeFM, SNRw] = pickle.load(f)
fisher_matrix_3 = extract_relevant_submatrix(FM)
print(fisher_matrix_3)

[[ 1.71097482e+03 -2.01574273e+03 -6.53431125e+04  2.41872229e+02
  -3.23282821e+03  1.29717641e+06 -6.89603632e+01]
 [-2.01574273e+03  3.53993636e+03  1.77150424e+05 -1.71567840e+02
   3.91288660e+03 -1.44958229e+06 -1.49652191e+03]
 [-6.53431125e+04  1.77150424e+05  1.11239413e+07  3.65512731e+00
   1.32416748e+05 -3.70691826e+07 -1.26772612e+05]
 [ 2.41872229e+02 -1.71567840e+02  3.65512731e+00  6.96489314e+01
  -4.46673580e+02 -5.18076640e+01 -4.74031233e+02]
 [-3.23282821e+03  3.91288660e+03  1.32416748e+05 -4.46673580e+02
   6.11764085e+03 -2.44623778e+06 -1.31962575e+01]
 [ 1.29717641e+06 -1.44958229e+06 -3.70691826e+07 -5.18076640e+01
  -2.44623778e+06  3.71338138e+09  2.19779602e+06]
 [-6.89603632e+01 -1.49652191e+03 -1.26772612e+05 -4.74031233e+02
  -1.31962575e+01  2.19779602e+06  6.10459399e+03]]


In [4]:
# 计算协方差矩阵
def compute_covariance(fisher_matrix):
    return np.linalg.pinv(fisher_matrix)

# 计算协方差矩阵
cov_matrix_1 = compute_covariance(fisher_matrix_1)
cov_matrix_2 = compute_covariance(fisher_matrix_2)
cov_matrix_3 = compute_covariance(fisher_matrix_3)

# 打印验证
print("Shape of fisher_matrix_1:", fisher_matrix_1.shape)
print("Shape of cov_matrix_1:", cov_matrix_1.shape)
print("Shape of fisher_matrix_2:", fisher_matrix_2.shape)
print("Shape of cov_matrix_2:", cov_matrix_2.shape)
print("Shape of fisher_matrix_3:", fisher_matrix_3.shape)
print("Shape of cov_matrix_3:", cov_matrix_3.shape)

Shape of fisher_matrix_1: (7, 7)
Shape of cov_matrix_1: (7, 7)
Shape of fisher_matrix_2: (7, 7)
Shape of cov_matrix_2: (7, 7)
Shape of fisher_matrix_3: (7, 7)
Shape of cov_matrix_3: (7, 7)


In [5]:
# 优化均值计算
mean1 = np.mean(cov_matrix_1, axis=0)
mean2 = np.mean(cov_matrix_2, axis=0)
mean3 = np.mean(cov_matrix_3, axis=0)

In [6]:
# 计算 KL 散度
def kl_divergence(cov1, mean1, cov2, mean2, matrix_name=""):
    try:
        inv_cov2 = np.linalg.inv(cov2)
    except np.linalg.LinAlgError:
        print(f"{matrix_name} inversion failed, skipping KL computation.")
        return np.nan

    term1 = np.trace(inv_cov2 @ cov1)
    print(f"{matrix_name} KL term1 (Trace): {term1}")

    sign1, logdet1 = np.linalg.slogdet(cov1)
    sign2, logdet2 = np.linalg.slogdet(cov2)
    if sign1 <= 0 or sign2 <= 0:
        print(f"{matrix_name} determinant is non-positive.")
        return np.nan

    term2 = logdet2 - logdet1
    print(f"{matrix_name} KL term2 (Log determinant difference): {term2}")

    term3 = (mean2 - mean1).T @ inv_cov2 @ (mean2 - mean1)
    print(f"{matrix_name} KL term3 (Mean difference): {term3}")

    result = 0.5 * (term1 + term2 + term3 - cov1.shape[0])
    print(f"{matrix_name} KL result: {result}")

    return max(result, 0)

In [7]:
# 打印 KL 散度
kl_12 = kl_divergence(cov_matrix_1, mean1, cov_matrix_2, mean2, "matrix_12")
kl_13 = kl_divergence(cov_matrix_1, mean1, cov_matrix_3, mean3, "matrix_13")
kl_23 = kl_divergence(cov_matrix_2, mean2, cov_matrix_3, mean3, "matrix_23")

print("\n==== KL Divergence ====")
print(f"KL Divergence between matrix 1 and matrix 2: {kl_12}")
print(f"KL Divergence between matrix 1 and matrix 3: {kl_13}")
print(f"KL Divergence between matrix 2 and matrix 3: {kl_23}")

matrix_12 KL term1 (Trace): 96242.48920798408
matrix_12 KL term2 (Log determinant difference): 0.13381746615884538
matrix_12 KL term3 (Mean difference): 1196032.7000441225
matrix_12 KL result: 646134.1615347863
matrix_13 KL term1 (Trace): 96242.78836101791
matrix_13 KL term2 (Log determinant difference): 16.60423275815829
matrix_13 KL term3 (Mean difference): 1198834.3842477845
matrix_13 KL result: 647543.3884207803
matrix_23 KL term1 (Trace): 133644.30918240803
matrix_23 KL term2 (Log determinant difference): 16.470415291999444
matrix_23 KL term3 (Mean difference): 2860636.1256417725
matrix_23 KL result: 1497144.9526197363

==== KL Divergence ====
KL Divergence between matrix 1 and matrix 2: 646134.1615347863
KL Divergence between matrix 1 and matrix 3: 647543.3884207803
KL Divergence between matrix 2 and matrix 3: 1497144.9526197363
