In [42]:
import pandas as pd         
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from pathlib import Path
import os
from general_pic_setup import setup_mpl_single2
import matplotlib as mpl
import matplotlib.cm as cm

mpl.rcdefaults()
setup_mpl_single2()
mpl.rcParams['axes.titlesize'] = 'medium'
mpl.rcParams['xtick.direction'] = 'out'

BORDER_WIDTH = 0.8

mpl.rcParams['xtick.major.width'] = BORDER_WIDTH
mpl.rcParams['ytick.major.width'] = BORDER_WIDTH
mpl.rcParams['xtick.minor.width'] = BORDER_WIDTH
mpl.rcParams['ytick.minor.width'] = BORDER_WIDTH
mpl.rcParams['axes.linewidth'] = BORDER_WIDTH
mpl.rcParams['lines.linewidth'] = 0.9

#nature_colors = ['#E64B35', "#6917C2", '#00A087']
nature_colors = ["#0C9C90","#C474EF", "#2652B1"]


# 设置Y值阈值
Y_VALUE_THRESHOLD = 0.03
LOW_VALUE_ALPHA = 0.03  # 低值曲线的透明度


def load_cluster_data(data_dir: Path):
    """加载聚类结果和详细时间序列数据"""
    cluster_assignment_path = data_dir / "3-1-L2_Policy_Clustering_pic" / "3-1_cluster_assignment_table.csv"
    cluster_df = pd.read_csv(cluster_assignment_path, index_col=0, encoding='utf-8-sig')
    
    cluster_detail_path = data_dir / "3-1-L2_Policy_Clustering_countries.csv"
    detail_df = pd.read_csv(cluster_detail_path, encoding='utf-8-sig')
    
    return cluster_df, detail_df


def is_low_value_series(values, threshold=Y_VALUE_THRESHOLD):
    """
    判断一条曲线是否长期处于低值状态
    
    参数:
        values: 曲线的Y值数组
        threshold: 判断阈值
    
    返回:
        True 如果大部分值都低于阈值，False 否则
    """
    # 计算有效值（非NaN）中有多少比例低于阈值
    valid_values = values[~np.isnan(values)]
    if len(valid_values) == 0:
        return False
    
    low_ratio = np.sum(valid_values <= threshold) / len(valid_values)
    
    # 如果至少80%的值都低于阈值，则认为是低值曲线
    return low_ratio >= 0.6


# MODIFIED: 添加 lighten_color 函数（仅用于 marker 填充）
def lighten_color(color, amount=0.7):
    """将颜色变浅, 用于圆圈填充"""
    try:
        c = mcolors.to_rgb(color)
        c = tuple([c[i] + (1 - c[i]) * amount for i in range(3)])
        return c
    except Exception:
        return color


def draw_cluster_in_box(ax, baseline, height, policy_name, y_min, y_max,
                       cluster_df, detail_df, layer_zorder=100):
    """在矩形框内绘制聚类小图：簇平均线 + 国家线 + 整体平均"""
    if cluster_df is None or detail_df is None or policy_name not in cluster_df.index:
        return
    
    cluster_assignments = cluster_df.loc[policy_name].dropna()
    if cluster_assignments.empty:
        return
    
    l2_data = detail_df[detail_df['L2政策中文名'] == policy_name].copy()
    if l2_data.empty:
        return
    
    country_matrices = {}
    for country in cluster_assignments.index:
        country_data = l2_data[l2_data['国家'] == country].sort_values('年份')
        if not country_data.empty:
            series = pd.Series(country_data['占比'].values, index=country_data['年份'].values)
            country_matrices[country] = series
    
    if not country_matrices:
        return
    
    years = sorted(l2_data['年份'].unique().tolist())
    
    cluster_dict = {}
    for country, cid in cluster_assignments.items():
        if country in country_matrices:
            cluster_dict.setdefault(int(cid), []).append(country)
    
    if not cluster_dict:
        return
    
    overall_sum = pd.Series(0.0, index=years, dtype=float)
    overall_cnt = pd.Series(0, index=years, dtype=float)
    
    for country, series in country_matrices.items():
        for y in years:
            if y in series.index and not pd.isna(series[y]):
                overall_sum[y] += series[y]
                overall_cnt[y] += 1
    
    overall_avg = overall_sum / overall_cnt.replace(0, np.nan)
    
    def map_to_box_y(values):
        values_clipped = np.clip(values, y_min, y_max)
        norm = (values_clipped - y_min) / (y_max - y_min)
        return baseline + norm * height
    
    # 绘制各簇的国家线条
    for idx, cid in enumerate(sorted(cluster_dict.keys())):
        countries = cluster_dict[cid]
        color = nature_colors[idx % len(nature_colors)]
        zorder_base = layer_zorder + idx * 10
        
        for country in countries:
            series = country_matrices.get(country)
            if series is None:
                continue
            
            vals = np.array([series.get(y, np.nan) for y in years], dtype=float)
            mask = ~np.isnan(vals)
            if mask.sum() < 2:
                continue
            
            x_valid = np.array(years)[mask]
            y_valid = vals[mask]
            
            # 判断是否为低值曲线
            is_low_value = is_low_value_series(y_valid, threshold=Y_VALUE_THRESHOLD)
            alpha = LOW_VALUE_ALPHA if is_low_value else 0.12
            
            y_mapped = map_to_box_y(y_valid)
            
            ax.plot(x_valid, y_mapped, color=color, alpha=alpha, 
                    linewidth=0.8, zorder=zorder_base)
    
    # 绘制整体平均虚线
    vals_overall = overall_avg.values.astype(float)
    mask_o = ~np.isnan(vals_overall)
    if mask_o.sum() >= 2:
        x_o = np.array(years)[mask_o]
        y_o = vals_overall[mask_o]
        y_o_mapped = map_to_box_y(y_o)
        
        ax.plot(x_o, y_o_mapped, color='#000000', linestyle='--',
                linewidth=1.2, alpha=0.9, dashes=(3, 2), zorder=layer_zorder + 50)
    
    # 绘制各簇的平均线
    for idx, cid in enumerate(sorted(cluster_dict.keys())):
        countries = cluster_dict[cid]
        color = nature_colors[idx % len(nature_colors)]
        zorder_base = layer_zorder + 100 + idx * 10
        
        cluster_sum = pd.Series(0.0, index=years, dtype=float)
        cluster_cnt = pd.Series(0, index=years, dtype=float)
        for country in countries:
            series = country_matrices.get(country)
            if series is None:
                continue
            for y in years:
                if y in series.index and not pd.isna(series[y]):
                    cluster_sum[y] += series[y]
                    cluster_cnt[y] += 1
        cluster_avg = cluster_sum / cluster_cnt.replace(0, np.nan)
        
        vals_avg = cluster_avg.values.astype(float)
        mask_avg = ~np.isnan(vals_avg)
        if mask_avg.sum() >= 2:
            x_avg = np.array(years)[mask_avg]
            y_avg = vals_avg[mask_avg]
            y_avg_mapped = map_to_box_y(y_avg)
            
            # MODIFIED: 计算 fill_color（浅化原色）
            fill_color = lighten_color(color, amount=0.7)
            
            # MODIFIED: 更新 marker 'o' 参数（仅此行改动，其他不变）
            ax.plot(x_avg, y_avg_mapped, color=color, linewidth=1.2, alpha=0.8,
                    zorder=zorder_base, marker='o', markersize=4, 
                    markerfacecolor=fill_color, markeredgecolor=color, markeredgewidth=0.8)


# 路径配置
script_dir = Path(os.getcwd())
data_dir = script_dir.parent / "data"
output_dir = data_dir / "2-2-Country_Policy_Trends"
output_dir.mkdir(parents=True, exist_ok=True)

# 数据加载
df = pd.read_parquet(data_dir / "2-1-country_proportions.parquet")
country_col = 'REF_AREA'
year_col = 'TIME_PERIOD'

df[year_col] = pd.to_numeric(df[year_col], errors='coerce')
df = df.dropna(subset=[year_col])
df = df[(df[year_col] >= 2005) & (df[year_col] <= 2023)].sort_values(year_col)

cluster_df, cluster_detail_df = load_cluster_data(data_dir)

selected_countries = ["JPN", "MEX", "ARG", "CHL"]

policies = {
    "Transport – market-based instruments": {
        "label": "Transport Policy",
        "y_range": (0, 0.15)
    },
    "Public Research, Development and Demonstration": {
        "label": "Research Development",
        "y_range": (0, 0.6)
    },
    "Electricity – non market-based instruments": {
        "label": "Electricity Non-MBI",
        "y_range": (0, 0.3)
    }
}

policy_short_mapping = {
    "Transport – market-based instruments": "Transport_MBI",
    "Public Research, Development and Demonstration": "Public_RD",
    "Electricity – non market-based instruments": "Electricity_NonMBI"
}

CURVE_WIDTH = 0.9
box_height = 0.4
gap_between = 0.01
FILL_ALPHA = 0.099

base_cmap = cm.get_cmap('YlGnBu')
n_countries = len(selected_countries)
color_positions = np.linspace(0.9, 0.4, n_countries)
discrete_colors = [base_cmap(pos) for pos in color_positions]
country_colors = {country: discrete_colors[i] for i, country in enumerate(selected_countries)}


def plot_ridge_chart(policy_name, policy_config, selected_countries, df,
                    cluster_df, cluster_detail_df):
    """绘制山脊图"""
    fig, ax = plt.subplots(figsize=(4, 9), dpi=100)
    y_min, y_max = policy_config["y_range"]
    
    valid_countries = [c for c in selected_countries
                       if c in df[country_col].values
                       and not df[df[country_col] == c][policy_name].isna().all()]
    
    layers = ["CLUSTER"] + valid_countries
    heights = [box_height + 0.02] + [box_height] * (len(layers) - 1)
    
    baselines = [0.0]
    for i in range(1, len(layers)):
        baselines.append(baselines[i-1] + heights[i-1] + gap_between)
    
    total_plot_height = sum(heights) + gap_between * (len(layers) - 1) if layers else 1.0
    
    for idx, code in enumerate(layers):
        baseline = baselines[idx]
        current_box_height = heights[idx]
        is_cluster = (code == "CLUSTER")
        
        ax.plot([2004.7, 2004.7], [baseline, baseline + current_box_height],
                color='black', linewidth=BORDER_WIDTH, zorder=idx - 1, solid_capstyle='butt')
        
        if is_cluster:
            ax.plot([2004.7, 2023.5], [baseline, baseline],
                    color='black', linewidth=BORDER_WIDTH, zorder=idx - 1, solid_capstyle='butt')
        
        if not is_cluster:
            df_country = df[df[country_col] == code]
            years = df_country[year_col].values
            values = df_country[policy_name].values
            
            mask = ~np.isnan(values)
            years, values = years[mask], values[mask]
            
            if len(values) > 1:
                values_clipped = np.clip(values, y_min, y_max)
                values_normalized = (values_clipped - y_min) / (y_max - y_min) * box_height
                country_color = country_colors.get(code, discrete_colors[-1])
                
                ax.fill_between(years, baseline, baseline + values_normalized,
                                color=country_color, alpha=FILL_ALPHA, zorder=idx + 50)
                ax.plot(years, baseline + values_normalized, color=country_color,
                        linewidth=CURVE_WIDTH, alpha=0.95, zorder=idx + 100)
        else:
            draw_cluster_in_box(ax, baseline + 0.02, current_box_height - 0.02, policy_name, 
                               y_min, y_max, cluster_df, cluster_detail_df, layer_zorder=idx + 200)
        
        label_y_offset = 0.2
        norm_y = (baseline + label_y_offset) / total_plot_height
        display_label = "Cluster" if is_cluster else code
        ax.text(-0.12, norm_y, display_label, transform=ax.transAxes,
                va='center', ha='right', fontsize=11, color='black')
        
        ax.text(2004.5, baseline + 0.02, f'{y_min}',
                va='center', ha='right', fontsize=8, color='black')
        ax.text(2004.5, baseline + current_box_height - 0.02, f'{y_max}',
                va='center', ha='right', fontsize=8, color='black')
    
    ax.set_xlim(2004.7, 2023.3)
    ax.set_xticks([2005, 2014, 2023])
    ax.set_xticklabels([2005, 2014, 2023], fontsize=8, rotation=60, ha='center')
    ax.set_xlabel('Year', fontsize=13, labelpad=3)
    
    ax.set_ylim(0, total_plot_height)
    ax.set_yticks([])
    ax.set_ylabel('Proportion', fontsize=13, labelpad=55)
    
    for spine_name in ['top', 'right', 'left', 'bottom']:
        ax.spines[spine_name].set_visible(False)
    
    ax.plot([2004.7, 2023.3], [0, 0], 
            color='black', linewidth=BORDER_WIDTH, 
            clip_on=False, zorder=1000, 
            solid_capstyle='butt')
    
    ax.tick_params(axis='x', 
                   colors='black', 
                   width=BORDER_WIDTH,
                   length=4, 
                   pad=2, 
                   top=False, 
                   bottom=True,
                   direction='out')
    
    ax.tick_params(axis='y', 
                   left=False, 
                   right=False)
    
    plt.tight_layout()
    
    policy_short = policy_short_mapping[policy_name]
    out_path = output_dir / f"{policy_short}_ridge_plot.png"
    fig.savefig(out_path, dpi=300, bbox_inches='tight', pad_inches=0.15)
    plt.close()


for policy_name, policy_config in policies.items():
    plot_ridge_chart(policy_name, policy_config, selected_countries, df,
                     cluster_df, cluster_detail_df)


  base_cmap = cm.get_cmap('YlGnBu')
