In [12]:
import os
import json
import math
import glob

EPSILON = 1e-6

# baseline 概率分布，示例
baseline_probs = {
        "1-4": 0.042063,
        "1-7": 0.140043,
        "1-8": 0.835920,
        "10-7": 0.809524,
        "2-1": 0.058135,
        "2-7": 0.958475,
        "3-1": 1.000000,
        "4-5": 0.222222,
        "4-7": 0.166667,
        "5-7": 0.222222,
        "7-2": 1.000000,
        "8-9": 1.000000,
        "9-7": 0.978543,
        "9-9": 0.030040,
}

def load_probs(path):
    with open(path) as f:
        data = json.load(f)
    return data["probabilities"]

def kl_divergence(p, q, epsilon=1e-6):
    divergence = 0.0
    for key in p:
        p_val = p.get(key, 0.0)
        q_val = q.get(key, epsilon)  # 防止 log(0)
        if p_val > 0:
            divergence += p_val * math.log(p_val / q_val)
    return max(divergence, 0.0)  # 防止浮点误差导致负值

def js_divergence(p, q, epsilon=1e-6):
    keys = set(p.keys()).union(q.keys())
    m = {}
    for k in keys:
        p_val = p.get(k, 0.0)
        q_val = q.get(k, 0.0)
        m[k] = 0.5 * (p_val + q_val)

    kl_pm = 0.0
    kl_qm = 0.0
    for k in keys:
        p_val = p.get(k, 0.0)
        q_val = q.get(k, 0.0)
        m_val = max(m[k], epsilon)

        if p_val > 0:
            kl_pm += p_val * math.log(p_val / m_val)
        if q_val > 0:
            kl_qm += q_val * math.log(q_val / m_val)

    return 0.5 * (kl_pm + kl_qm)

# 读取所有 baseline_short*.json 文件
short_files = sorted(glob.glob("baseline_short*.json"))
print(f"Found {len(short_files)} sample files.")

kl_values = []
js_values = []
for path in short_files:
    sample_probs = load_probs(path)
    kl = kl_divergence(sample_probs, baseline_probs)
    js = js_divergence(sample_probs, baseline_probs)
    kl_values.append(kl)
    js_values.append(js)
    print(f"{os.path.basename(path)}:")
    print(f"  KL divergence = {kl:.6f}")
    print(f"  JS divergence = {js:.6f}")
    print()

if kl_values:
    avg_kl = sum(kl_values) / len(kl_values)
    avg_js = sum(js_values) / len(js_values)
    print(f"📊 Average KL divergence across {len(kl_values)} samples: {avg_kl:.6f}")
    print(f"📊 Average JS divergence across {len(js_values)} samples: {avg_js:.6f}")
else:
    print("No baseline_short*.json files found.")


Found 7 sample files.
baseline_short20250706_010315.json:
  KL divergence = 0.166307
  JS divergence = 0.098913

baseline_short20250706_010457.json:
  KL divergence = 0.000000
  JS divergence = 0.251218

baseline_short20250706_010733.json:
  KL divergence = 0.000000
  JS divergence = 0.160734

baseline_short20250706_010948.json:
  KL divergence = 0.497799
  JS divergence = 0.091561

baseline_short20250706_011131.json:
  KL divergence = 0.023872
  JS divergence = 0.226675

baseline_short20250706_201505.json:
  KL divergence = 0.000000
  JS divergence = 0.069858

baseline_short20250706_202218.json:
  KL divergence = 0.027253
  JS divergence = 0.227540

📊 Average KL divergence across 7 samples: 0.102176
📊 Average JS divergence across 7 samples: 0.160929
