In [4]:
import json
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def load_probs(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    return data['probabilities']

def build_matrix(probs, max_state=10):
    matrix = np.zeros((max_state+1, max_state+1))
    for k, v in probs.items():
        old, new = map(int, k.split('-'))
        if 0 <= old <= max_state and 0 <= new <= max_state:
            matrix[old, new] = v
    return matrix

def plot_heatmap(matrix, title, figsize=(10,8), save_path=None):
    plt.figure(figsize=figsize)
    sns.heatmap(matrix, annot=True, fmt=".3f", cmap="YlGnBu", cbar_kws={'label': 'Transition Probability'})
    plt.title(title)
    plt.xlabel("New State")
    plt.ylabel("Old State")
    plt.xticks(rotation=0)
    plt.yticks(rotation=0)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
        print(f"Saved heatmap to {save_path}")
    plt.close()

# 当前目录下查找 JSON 文件
folder_path = "."
json_files = sorted(f for f in os.listdir(folder_path) if f.endswith('.json'))

for idx, jf in enumerate(json_files, 1):  # 自动编号标题
    full_path = os.path.join(folder_path, jf)
    probs = load_probs(full_path)
    matrix = build_matrix(probs)
    
    title = f"TCP State Transition Probabilities [{idx}] Attack"  # ← 只改这里的标题
    save_file = jf.replace('.json', '.png')  # 保持原有时间戳命名
    plot_heatmap(matrix, title, save_path=save_file)

Saved heatmap to baseline__probs_sliding20250704_125843.png
Saved heatmap to baseline__probs_sliding20250704_131445.png
Saved heatmap to baseline__probs_sliding20250704_132535.png


In [None]:
baseline_probs = load_baseline_probs()
attack_samples = [load_attack_probs(file) for file in attack_files]

baseline_kl = compute_kl_divergence(baseline_probs, baseline_probs)  # 理论应为0，实际会有小浮动
attack_kls = [compute_kl_divergence(baseline_probs, attack) for attack in attack_samples]

max_normal_kl = max(baseline_kl_samples)  # 来自正常流量的实际采样KL分布最大值
max_attack_kl = max(attack_kls)

threshold = max_normal_kl + margin  # margin可根据经验设定，例0.05或0.1

print(f"设计阈值为 {threshold}, 正常最大KL={max_normal_kl}, 攻击最大KL={max_attack_kl}")
