In [1]:
import numpy as np
import pandas as pd
import csv

# 读取 CSV 文件
file_path = "./all_fisher_matrices.csv"  # 更新为实际路径
df = pd.read_csv(file_path)

# 提取数据并构建 Fisher 矩阵
fisher_matrices = {}
for index, row in df.iterrows():
    group = row['Group']
    detector = row['Detector']
    fisher_values = row.iloc[2:].values.astype(float).reshape(9, 9)  # 将数据重塑为 9x9 矩阵
    key = f"{group}_{detector}"
    fisher_matrices[key] = fisher_values

# 定义计算协方差矩阵的函数
def compute_covariance(fisher_matrix):
    return np.linalg.pinv(fisher_matrix)  # 直接计算逆

# 定义计算 KL 散度的函数
def kl_divergence(cov1, mean1, cov2, mean2):
    try:
        inv_cov2 = np.linalg.inv(cov2)  # 计算 cov2 的逆
    except np.linalg.LinAlgError:
        return np.nan

    term1 = np.trace(inv_cov2 @ cov1)  # Trace 计算
    sign1, logdet1 = np.linalg.slogdet(cov1)  # 计算 log(det) 和符号
    sign2, logdet2 = np.linalg.slogdet(cov2)
    
    # 检查行列式是否为负数
    if sign1 <= 0 or sign2 <= 0:
        return np.nan

    term2 = logdet2 - logdet1  # 行列式差
    term3 = (mean2 - mean1).T @ inv_cov2 @ (mean2 - mean1)  # 均值差

    result = 0.5 * (term1 + term2 + term3 - cov1.shape[0])
    return max(result, 0)  # 返回非负值

# 计算协方差矩阵和均值
cov_matrices = {}
means = {}
for key, fisher_matrix in fisher_matrices.items():
    cov_matrix = compute_covariance(fisher_matrix)  # 使用完整的9x9 Fisher矩阵计算协方差
    mean_vector = np.mean(cov_matrix, axis=0)  # 计算均值向量
    cov_matrices[key] = cov_matrix
    means[key] = mean_vector

# 按组计算 KL 散度并保存结果
results = []
groups = df['Group'].unique()

for group in groups:
    group_detectors = [key for key in cov_matrices.keys() if key.startswith(group)]
    if len(group_detectors) >= 2:
        kl_group_results = {}
        for i in range(len(group_detectors)):
            for j in range(i + 1, len(group_detectors)):
                key1 = group_detectors[i]
                key2 = group_detectors[j]
                kl_value = kl_divergence(cov_matrices[key1], means[key1], cov_matrices[key2], means[key2])
                kl_group_results[f'{key1} vs {key2}'] = kl_value
        
        results.append({
            'group': group,
            'kl_12': kl_group_results.get(f'{group}_ET_1 vs {group}_ET_2', np.nan),
            'kl_13': kl_group_results.get(f'{group}_ET_1 vs {group}_ET_3', np.nan),
            'kl_23': kl_group_results.get(f'{group}_ET_2 vs {group}_ET_3', np.nan),
        })

# 保存结果到CSV文件
output_file = "kl_divergence_results_raw.csv"
with open(output_file, mode='w', newline='') as file:
    writer = csv.DictWriter(file, fieldnames=['group', 'kl_12', 'kl_13', 'kl_23'])
    writer.writeheader()
    writer.writerows(results)

print(f"KL divergence results saved to {output_file}")

KL divergence results saved to kl_divergence_results_raw.csv
