In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import copy
import os

# ==============================================================================
# 1. 环境设置 (解决中文显示问题 和 SVG字体问题)
# ==============================================================================
try:
    plt.rcParams['font.sans-serif'] = ['Noto Sans CJK SC', 'Heiti TC', 'Heiti SC', 'PingFang SC', 'SimHei', 'Arial Unicode MS']
    plt.rcParams['axes.unicode_minus'] = False
    plt.rcParams['svg.fonttype'] = 'none' 
except Exception as e:
    print(f"注意：未能设置中文字体。图表中的中文可能无法正常显示。错误信息: {e}")

# ==============================================================================
# 2. ***核心修改***：使用您的新数据表
# ==============================================================================
data = {
    'Model': [
        '1. Base Baseline (12L, FP32)',
        '2. Base Baseline (12L, INT4)',
        '3. Base Baseline (12L, PTQ INT8)',
        '4. Base Baseline (12L, QAT)',
        '5. Base Pruned (8L, FP32)',
        '6. Base Pruned (8L, FP16)',
        '7. Base Pruned (8L, INT4)',
        '8. Base Pruned (8L, PTQ INT8)',
    ],
    'Size (MB)': [1253.16, 91.64, 173.09, 418.63, 310.42, 155.66, 74.78, 145.9],
    'Accuracy (GPU)': [0.9300, 0.9300, np.nan, 0.9255, 0.9278, 0.9266, 0.9278, np.nan],
    'Latency (GPU, ms)': [3.20, 9.31, np.nan, 3.22, 2.25, 2.27, 8.96, np.nan],
    'Peak GPU Mem (MB)': [428.26, 106.13, np.nan, 474.50, 366.35, 213.82, 113.18, np.nan],
    'Accuracy (CPU)': [0.9300, np.nan, 0.9186, 0.9255, 0.9278, np.nan, np.nan, 0.9232],
    'Latency (CPU, ms)': [167.87, np.nan, 87.91, 145.98, 107.95, np.nan, np.nan, 67.02]
}
df = pd.DataFrame(data)

# ==============================================================================
# 3. 数据预处理：归一化 (逻辑保持不变)
# ==============================================================================
benefit_metrics = ['Accuracy (GPU)', 'Accuracy (CPU)']
cost_metrics = ['Size (MB)', 'Latency (GPU, ms)', 'Peak GPU Mem (MB)', 'Latency (CPU, ms)']
df_normalized = df.copy()
alpha = 0.1

semantic_min_acc = 0.90
semantic_max_acc = 0.94

for col in benefit_metrics + cost_metrics:
    if col in benefit_metrics:
        min_val = semantic_min_acc
        max_val = semantic_max_acc
        range_val = max_val - min_val
        clipped_values = df_normalized[col].clip(lower=min_val)
        normalized_values = alpha + (1 - alpha) * (clipped_values - min_val) / range_val
        df_normalized[col] = normalized_values.clip(upper=1.0)
    
    elif col in cost_metrics:
        min_val = df_normalized[col].min()
        max_val = df_normalized[col].max()
        if pd.isna(min_val) or pd.isna(max_val) or min_val == max_val:
            if not pd.isna(min_val):
                df_normalized[col] = 0.5
            continue
        range_val = max_val - min_val
        df_normalized[col] = alpha + (1 - alpha) * (max_val - df_normalized[col]) / range_val

# ==============================================================================
# 4. 分离GPU和CPU数据 (逻辑保持不变)
# ==============================================================================
gpu_models_df = df_normalized.dropna(subset=['Accuracy (GPU)', 'Latency (GPU, ms)']).reset_index(drop=True)
cpu_models_df = df_normalized.dropna(subset=['Accuracy (CPU)', 'Latency (CPU, ms)']).reset_index(drop=True)

gpu_metrics = ['Size (MB)', 'Accuracy (GPU)', 'Latency (GPU, ms)', 'Peak GPU Mem (MB)']
cpu_metrics = ['Size (MB)', 'Accuracy (CPU)', 'Latency (CPU, ms)']

# ==============================================================================
# 5. ***核心修改***：定义新的“冷暖”对比配色方案和高亮模型
# ==============================================================================
model_colors = {
    # --- 基准模型 (12L) 使用灰色和冷色调 ---
    '1. Base Baseline (12L, FP32)': '#888888',      # 中性灰色 (Neutral Grey)
    '2. Base Baseline (12L, INT4)': '#56B4E9',      # 天蓝色 (Sky Blue)
    '3. Base Baseline (12L, PTQ INT8)': '#029E73',  # 青绿色 (Teal Green)
    '4. Base Baseline (12L, QAT)': '#0173B2',       # 稳重蓝 (Calm Blue)
    
    # --- 剪枝优化模型 (8L) 使用暖色调进行高亮 ---
    '5. Base Pruned (8L, FP32)': '#D55E00',         # 朱红色 (Vermilion)
    '6. Base Pruned (8L, FP16)': '#CC78BC',         # 品红色 (Magenta)
    '7. Base Pruned (8L, INT4)': '#E69F00',         # 橙色 (Orange)
    '8. Base Pruned (8L, PTQ INT8)': '#9467bd'      # 紫色 (Purple)
}

# 将所有剪枝模型都定义为高亮显示
highlight_models = [
    '5. Base Pruned (8L, FP32)',
    '6. Base Pruned (8L, FP16)',
    '7. Base Pruned (8L, INT4)',
    '8. Base Pruned (8L, PTQ INT8)'
]

# ==============================================================================
# 6. 绘图函数 (逻辑保持不变)
# ==============================================================================
def plot_radar_chart(df, metrics, title):
    current_colors = copy.deepcopy(model_colors)
    labels = [label.replace('(', '\n(').replace('Peak ', 'Peak\n') for label in metrics]
    num_vars = len(labels)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    angles += angles[:1]
    
    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(polar=True))

    for _, row in df.iterrows():
        model_name = row['Model']
        color = current_colors.get(model_name, 'black') 
        data = row[metrics].values.flatten().tolist()
        data += data[:1]
        
        if model_name in highlight_models:
            linewidth = 2.5
            line_alpha = 1.0
            fill_alpha = 0.25
            zorder = 10
        else:
            linewidth = 1.5
            line_alpha = 0.7
            fill_alpha = 0.1
            zorder = 5

        ax.plot(angles, data, color=color, linewidth=linewidth, alpha=line_alpha, label=model_name, zorder=zorder)
        ax.fill(angles, data, color=color, alpha=fill_alpha)

    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)
    ax.set_rlabel_position(0)
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels, fontsize=12)
    ax.tick_params(axis='x', which='major', pad=25)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels(["0.2", "0.4", "0.6", "0.8", "1.0"], color="grey", size=10)
    ax.set_ylim(0, 1.05)
    
    plt.title(title, size=20, color='black', y=1.15)
    
    # 调整图例位置和大小，以适应更多的模型
    plt.legend(loc='upper right', bbox_to_anchor=(1.6, 1.1), fontsize=11)
    
    fig.tight_layout()
    
    output_dir = '../figure'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    output_filename = os.path.join(output_dir, f'{title.replace(" ", "_").replace("(", "").replace(")", "")}.svg')
    plt.savefig(output_filename, bbox_inches='tight', format='svg', transparent=True)
    plt.close()
    print(f"图表已保存为: {output_filename}")


# ==============================================================================
# 7. 生成并保存图表
# ==============================================================================
print("正在生成GPU性能雷达图...")
plot_radar_chart(gpu_models_df, gpu_metrics, 'GPU 综合性能对比 (Base模型)')

print("正在生成CPU性能雷达图...")
plot_radar_chart(cpu_models_df, cpu_metrics, 'CPU 综合性能对比 (Base模型)')

print("\n所有图表已生成完毕！")

正在生成GPU性能雷达图...
图表已保存为: ../figure/GPU_综合性能对比_Base模型.svg
正在生成CPU性能雷达图...
图表已保存为: ../figure/CPU_综合性能对比_Base模型.svg

所有图表已生成完毕！
