### 国家政策趋势_签到热力图

In [6]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from joblib import Parallel, delayed
from matplotlib.colors import ListedColormap
from pathlib import Path
from collections import defaultdict
from typing import List, Tuple

# ==========================================
# 1. 基础配置
# ==========================================
CURRENT_DIR = Path.cwd()
DATA_DIR = CURRENT_DIR.parent / "data"
INPUT_FILE = DATA_DIR / "3-3-Visualizing_Models_by_Country.csv"
OUTPUT_DIR = DATA_DIR / "3-3-Visualizing_Models_by_Country_Output"

# 绘图网格与比例
NROWS, NCOLS = 7, 7
DPI = 300
# 再次加宽 figsize 到 45，确保左侧长标签有绝对充足的空间
FIGSIZE = (45, 28) 
TITLE_PAD = 15

# 颜色定义: 0=空, 1=Starting(蓝), 2=Trend(橙), 3=Ending(紫), 4=间隔线(灰)
FILL_COLORS = {0: "#4C78A8", 1: "#F58518", 2: "#9A60B4"}
CMAP_COLORS = ["#FFFFFF", FILL_COLORS[0], FILL_COLORS[1], FILL_COLORS[2], "#E9E9E9"]

# L1 分类颜色映射
L1_STYLE_COLORS = {
    'Incentive-based': '#2E86AB', 'Regulatory': '#E27D60',
    'Research and Development (R&D)': '#85C7F2', 'Commitment-based': '#61C0BF'
}

# 标准字段顺序（动词形式，不含 Share）
FIELD_ORDER = {
    "Starting": ["Low", "Medium", "High"],
    "Trend":    ["Decline", "Stable", "Fluctuate", "Rise"],
    "Ending":   ["Low", "Medium", "High"]
}
ALLOWED = {k: set(v) for k, v in FIELD_ORDER.items()}

# ==========================================
# 2. 逻辑函数
# ==========================================
def setup_mpl_style():
    """配置 Times New Roman 加粗样式"""
    matplotlib.rcParams['font.family'] = 'serif'
    matplotlib.rcParams['font.serif'] = ['Times New Roman']
    matplotlib.rcParams['font.weight'] = 'bold'
    matplotlib.rcParams['axes.labelweight'] = 'bold'
    matplotlib.rcParams['axes.titleweight'] = 'bold'
    matplotlib.rcParams['axes.unicode_minus'] = False

def clean_data_in_memory(df: pd.DataFrame) -> pd.DataFrame:
    """内存数据清洗：统一动词，删除 Share，处理 High 后缀"""
    df = df.copy()
    for col in ['Starting', 'Trend', 'Ending', 'L2政策']:
        df[col] = df[col].astype(str).str.strip()
    
    # 剥离 Ending 中的 Share
    df['Ending'] = df['Ending'].str.replace(' Share', '', case=False, regex=False)
    
    # 处理 (Later)/(Early) 为 High
    high_map = {"High (Later)": "High", "High (Early)": "High"}
    df['Ending'] = df['Ending'].replace(high_map)
    df['Starting'] = df['Starting'].replace(high_map)
    
    # 统一 Trend 为动词形式以匹配 FIELD_ORDER
    trend_map = {"Rising": "Rise", "Declining": "Decline", "Fluctuating": "Fluctuate"}
    df['Trend'] = df['Trend'].replace(trend_map)
    return df

def build_matrix(country: str, df: pd.DataFrame, policies: List[str], fields: List[str], total_w: int) -> Tuple[str, np.ndarray]:
    """构建单国矩阵"""
    c_df = df[df['国家'] == country].set_index('L2政策')
    p_map = c_df.to_dict('index')
    matrix = np.zeros((len(policies), total_w), dtype=int)
    
    for i, pol in enumerate(policies):
        if pol in p_map:
            row_data, c_ptr = p_map[pol], 0
            for idx_f, fld in enumerate(fields):
                val = row_data.get(fld, "")
                if val in ALLOWED[fld]:
                    matrix[i, c_ptr + FIELD_ORDER[fld].index(val)] = idx_f + 1
                c_ptr += len(FIELD_ORDER[fld])
                if idx_f < 2: # 填充间隔列
                    matrix[i, c_ptr] = 4
                    c_ptr += 1
    return country, matrix

def plot_heatmap(df: pd.DataFrame, type_label: str):
    """绘制 7x7 热力图网格"""
    df_unique = df.drop_duplicates(subset=['国家', 'L2政策'])
    
    # 获取政策元数据用于排序和着色
    l2_meta = df_unique.groupby('L2政策')[['L2政策中文名', 'L1分类']].first().to_dict('index')
    
    # 分组垂直排序：按 L2 代码前缀逻辑
    grouped = defaultdict(list)
    for p in df_unique['L2政策'].unique():
        if 'CROSS_SEC' in p: grp = 'CROSS'
        elif '_INT_' in p: grp = 'INT'
        else: grp = 'SEC'
        grouped[grp].append(p)
        
    policies_ordered = []
    for g in ['CROSS', 'INT', 'SEC']:
        policies_ordered.extend(sorted(grouped[g], key=lambda x: l2_meta[x]['L2政策中文名']))

    # X轴配置
    target_fields = ["Starting", "Trend", "Ending"]
    total_width = sum(len(FIELD_ORDER[f]) for f in target_fields) + 2
    
    xtl, xtc = [], []
    for idx, f in enumerate(target_fields):
        xtl.extend(FIELD_ORDER[f])
        xtc.extend([FILL_COLORS[idx]] * len(FIELD_ORDER[f]))
        if idx < 2: xtl.append(""); xtc.append("#000000")

    # 并行构建
    countries = sorted(df_unique['国家'].unique())[:49]
    results = Parallel(n_jobs=-1)(delayed(build_matrix)(c, df_unique, policies_ordered, target_fields, total_width) for c in countries)
    country_data = dict(results)

    # 绘图
    cmap = ListedColormap(CMAP_COLORS)
    fig, axes = plt.subplots(NROWS, NCOLS, figsize=FIGSIZE)
    # 显著增加 wspace (水平间隔) 和 left (左侧预留)
    plt.subplots_adjust(wspace=1.2, hspace=0.6, left=0.1) 
    axes_flat = axes.flatten()

    for idx, ax in enumerate(axes_flat):
        if idx >= len(countries): ax.axis("off"); continue
        cty = countries[idx]
        V = country_data.get(cty)
        if V is None: ax.axis("off"); continue
        
        r, c = V.shape
        # vmin=0, vmax=4 锁定色带
        ax.pcolormesh(np.arange(c+1), np.arange(r+1), V.astype(float), cmap=cmap, vmin=0, vmax=4, 
                      shading="flat", edgecolors="#B0B0B0", linewidth=0.6)
        
        ax.set_aspect('equal')
        ax.set_xlim(0, c); ax.set_ylim(r, 0)
        
        # Y轴标签：移除截断，全名显示，新罗马体加粗
        ax.set_yticks(np.arange(r) + 0.5)
        ax.set_yticklabels([l2_meta[p]['L2政策中文名'] for p in policies_ordered], 
                           fontsize=10, fontweight='bold')
        for tick, p in zip(ax.get_yticklabels(), policies_ordered):
            tick.set_color(L1_STYLE_COLORS.get(l2_meta[p]['L1分类'], "#333333"))
        
        # X轴标签
        ax.set_xticks(np.arange(c) + 0.5)
        ax.set_xticklabels(xtl, fontsize=9, rotation=75, ha='right', fontweight='bold')
        for t, color in zip(ax.get_xticklabels(), xtc):
            if t.get_text(): t.set_color(color)
            
        ax.tick_params(length=0)
        for s in ax.spines.values(): s.set_visible(False)
        ax.set_title(cty, fontsize=16, pad=TITLE_PAD, fontweight='bold')

    # 保存
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    save_path = OUTPUT_DIR / f"3-3-Visualizing_Models_by_Country_{type_label.strip()}.png"
    # bbox_inches='tight' 自动扩展画布边界以包含所有长文字
    fig.savefig(save_path, dpi=DPI, bbox_inches="tight")
    plt.close(fig)
    print(f"✅ 生成完毕: {save_path.name}")

# ==========================================
# 3. 执行
# ==========================================
if __name__ == "__main__":
    setup_mpl_style()
    df_raw = pd.read_csv(INPUT_FILE, encoding="utf-8-sig")
    df_clean = clean_data_in_memory(df_raw)
    
    for label, sub_df in df_clean.groupby('Type'):
        plot_heatmap(sub_df, str(label))

✅ 生成完毕: 3-3-Visualizing_Models_by_Country_Breadth.png
✅ 生成完毕: 3-3-Visualizing_Models_by_Country_Intensity.png
