In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import math
from scipy import stats
from scipy.stats import gaussian_kde

# ==========================================
# 1. 路径与基础设置
# ==========================================
input_folder_path = '../../data/PDF_data_Visual/Long_dataframe'
output_folder_path = '../../data/Scatter'

if not os.path.exists(output_folder_path):
    os.makedirs(output_folder_path)

# 定义 X 轴参考文件
ref_files_map = {
    'Breadth': '4-1-overlapping_cluster_heatmap_Breadth_long.csv',
    'Intensity': '4-1-overlapping_cluster_heatmap_Intensity_long.csv'
}

# 自动获取 Y 轴目标文件
all_files = [f for f in os.listdir(input_folder_path) if f.endswith('.csv')]
target_files = [f for f in all_files if f not in ref_files_map.values()]
target_files.sort()

# ==========================================
# 2. 辅助函数：清洗文件名
# ==========================================
def clean_name(filename):
    """
    1. 去掉文件后缀
    2. 截取第一个下划线前最后一个横杠后的部分 (去掉数字前缀)
    3. 额外删掉 'Paper_collab-'
    """
    name_no_ext = filename.replace('_long.csv', '').replace('.csv', '')
    
    # 逻辑：找到第一个下划线，取其前面最后一个横杠之后的内容
    first_underscore_index = name_no_ext.find('_')
    if first_underscore_index != -1:
        prefix_part = name_no_ext[:first_underscore_index]
        last_hyphen_index = prefix_part.rfind('-')
        if last_hyphen_index != -1:
            name_no_ext = name_no_ext[last_hyphen_index + 1:]
            
    # 新增需求：去掉 Paper_collab-
    name_no_ext = name_no_ext.replace('Paper_collab-', '')
    
    return name_no_ext

# ==========================================
# 3. 核心绘图逻辑
# ==========================================
def run_full_analysis(x_name, x_filename):
    print(f"\n{'='*60}\n正在处理 X 轴变量: {x_name}\n{'='*60}")
    
    # 读取 X 轴数据
    x_path = os.path.join(input_folder_path, x_filename)
    df_x = pd.read_csv(x_path)
    df_x.columns = [c.strip() for c in df_x.columns]
    
    correlation_results = []
    
    # 计算布局
    num_plots = len(target_files)
    cols = 5
    rows = math.ceil(num_plots / cols)
    
    # === 样式参数设置 (超级大尺寸) ===
    FIG_WIDTH = 40           # 画布总宽
    ROW_HEIGHT = 8           # 每行高度
    
    FONT_TITLE = 24          # 子图标题字号
    FONT_LABEL = 20          # 坐标轴标签字号
    FONT_TICK = 16           # 刻度字号
    FONT_TEXT = 18           # 图内标注字号
    
    SCATTER_SIZE = 25        # 散点大小
    LINE_WIDTH = 4           # 折线粗细
    MARKER_SIZE = 12         # 折线点大小
    
    PAD_TITLE = 25           # 标题距离图表的间距
    
    # 创建两个大画布
    print(f"1. 正在绘制热力散点图 (Density Scatter)...")
    fig_scatter, axes_scatter = plt.subplots(rows, cols, figsize=(FIG_WIDTH, ROW_HEIGHT * rows), constrained_layout=True)
    axes_scatter = axes_scatter.flatten()
    
    print(f"2. 正在绘制均值趋势图 (Trend Line)...")
    fig_trend, axes_trend = plt.subplots(rows, cols, figsize=(FIG_WIDTH, ROW_HEIGHT * rows), constrained_layout=True)
    axes_trend = axes_trend.flatten()
    
    for i, filename in enumerate(target_files):
        clean_title = clean_name(filename)
        print(f"   - [{i+1}/{num_plots}] {clean_title}")
        
        y_path = os.path.join(input_folder_path, filename)
        
        try:
            # 读取并合并数据
            df_y = pd.read_csv(y_path)
            df_y.columns = [c.strip() for c in df_y.columns]
            merged = pd.merge(df_x, df_y, on=['Source', 'Target'], suffixes=('_X', '_Y'))
            
            x_data = merged['Weight_X']
            y_data = merged['Weight_Y']
            
            # 设置坐标轴范围
            max_x = int(x_data.max()) if not x_data.empty else 15
            
            # --- 计算相关性 (用于标注) ---
            if len(x_data) > 1:
                pearson_val, _ = stats.pearsonr(x_data, y_data)
                spearman_val, _ = stats.spearmanr(x_data, y_data)
                corr_text = f"Pearson: {pearson_val:.2f}\nSpearman: {spearman_val:.2f}"
            else:
                corr_text = "N/A"
            
            # =========================================
            # A. 绘制热力散点图 (Scatter)
            # =========================================
            ax_s = axes_scatter[i]
            
            # 数据清洗
            mask = ~np.isnan(x_data) & ~np.isnan(y_data)
            x_clean = x_data[mask]
            y_clean = y_data[mask]
            
            if len(x_clean) > 0:
                try:
                    # 如果数据量特别大，采样计算密度以提高速度
                    if len(x_clean) > 5000:
                        idx_sample = np.random.choice(len(x_clean), 5000, replace=False)
                        xy_sample = np.vstack([x_clean.iloc[idx_sample], y_clean.iloc[idx_sample]])
                        z_model = gaussian_kde(xy_sample)
                        xy_all = np.vstack([x_clean, y_clean])
                        z = z_model(xy_all)
                    else:
                        xy = np.vstack([x_clean, y_clean])
                        z = gaussian_kde(xy)(xy)
                    
                    # 排序让高密度点在上
                    idx = z.argsort()
                    x_sorted, y_sorted, z_sorted = x_clean.iloc[idx], y_clean.iloc[idx], z[idx]
                    
                    ax_s.scatter(x_sorted, y_sorted, c=z_sorted, s=SCATTER_SIZE, cmap='Spectral_r', edgecolor='none', alpha=0.8)
                except:
                    ax_s.scatter(x_clean, y_clean, s=SCATTER_SIZE, c='#3182bd', alpha=0.5)
            
            ax_s.set_title(clean_title, fontsize=FONT_TITLE, fontweight='bold', pad=PAD_TITLE)
            ax_s.set_xlabel(x_name, fontsize=FONT_LABEL)
            ax_s.set_ylabel('Value', fontsize=FONT_LABEL)
            ax_s.tick_params(axis='both', which='major', labelsize=FONT_TICK)
            ax_s.set_xticks(range(0, max_x + 2, max(1, max_x // 10)))
            ax_s.set_ylim(-0.05, 1.05)
            
            # 标注相关性
            ax_s.text(0.03, 0.97, corr_text, transform=ax_s.transAxes, fontsize=FONT_TEXT,
                      verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))

            # =========================================
            # B. 绘制均值趋势图 (Trend)
            # =========================================
            ax_t = axes_trend[i]
            grouped = merged.groupby('Weight_X')['Weight_Y'].agg(['mean', 'sem', 'count'])
            grouped = grouped[grouped['count'] > 2] 
            
            ax_t.plot(grouped.index, grouped['mean'], marker='o', 
                      linewidth=LINE_WIDTH, markersize=MARKER_SIZE, 
                      color='#d73027', label='Mean')
            
            ax_t.fill_between(grouped.index, 
                              grouped['mean'] - 1.96 * grouped['sem'], 
                              grouped['mean'] + 1.96 * grouped['sem'], 
                              alpha=0.2, color='#d73027')
            
            ax_t.set_title(clean_title, fontsize=FONT_TITLE, fontweight='bold', pad=PAD_TITLE)
            ax_t.set_xlabel(x_name, fontsize=FONT_LABEL)
            ax_t.set_ylabel('Mean Value', fontsize=FONT_LABEL)
            ax_t.grid(True, linestyle='--', alpha=0.4)
            ax_t.tick_params(axis='both', which='major', labelsize=FONT_TICK)
            ax_t.set_xticks(range(0, max_x + 2, max(1, max_x // 10)))
            
            ax_t.text(0.03, 0.97, corr_text, transform=ax_t.transAxes, fontsize=FONT_TEXT,
                      verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))

            # =========================================
            # C. 收集详细相关性数据
            # =========================================
            res = {'Variable': clean_title}
            res['All (Pearson)'] = pearson_val if len(x_data) > 1 else np.nan
            res['All (Spearman)'] = spearman_val if len(x_data) > 1 else np.nan
            
            # 修改处：限制循环范围到 13 (即 k 取 0 到 12，对应的列为 >0 ... >12)
            limit_range = min(13, int(max_x))
            for k in range(limit_range):
                subset = merged[merged['Weight_X'] > k]
                col_name = f'>{k}'
                if len(subset) > 20 and subset['Weight_X'].std() != 0 and subset['Weight_Y'].std() != 0:
                    corr, _ = stats.pearsonr(subset['Weight_X'], subset['Weight_Y'])
                    res[col_name] = corr
                else:
                    res[col_name] = np.nan
            correlation_results.append(res)
            
        except Exception as e:
            print(f"Error on {filename}: {e}")

    # 清理多余子图
    for j in range(i + 1, len(axes_scatter)):
        fig_scatter.delaxes(axes_scatter[j])
        fig_trend.delaxes(axes_trend[j])
    
    # 保存 Scatter 大图
    fig_scatter.suptitle(f'Density Scatter Plots: {x_name}', fontsize=30, fontweight='bold', y=1.03)
    save_path_s = os.path.join(output_folder_path, f'Density_Scatter_{x_name}.png')
    plt.figure(fig_scatter.number)
    plt.savefig(save_path_s, dpi=300, bbox_inches='tight')
    print(f"--> Saved Scatter Plot: {save_path_s}")
    
    # 保存 Trend 大图
    fig_trend.suptitle(f'Trend Analysis (Mean & 95% CI): {x_name}', fontsize=30, fontweight='bold', y=1.03)
    save_path_t = os.path.join(output_folder_path, f'Trend_Lines_{x_name}.png')
    plt.figure(fig_trend.number)
    plt.savefig(save_path_t, dpi=300, bbox_inches='tight')
    print(f"--> Saved Trend Plot: {save_path_t}")

    # =========================================
    # D. 绘制详细相关性热力图
    # =========================================
    if correlation_results:
        print(f"3. 正在绘制详细相关性热力图...")
        df_corr = pd.DataFrame(correlation_results)
        df_corr.set_index('Variable', inplace=True)
        
        cols = list(df_corr.columns)
        subset_cols = [c for c in cols if c.startswith('>')]
        subset_cols.sort(key=lambda x: int(x.replace('>', '')))
        final_cols = ['All (Pearson)', 'All (Spearman)'] + subset_cols
        final_cols = [c for c in final_cols if c in df_corr.columns]
        df_corr = df_corr[final_cols]
        
        plt.figure(figsize=(24, len(df_corr) * 1.2 + 3))
        
        ax = sns.heatmap(df_corr, annot=True, fmt=".2f", cmap='RdBu_r', center=0, 
                    linewidths=1, cbar_kws={'label': 'Correlation Coefficient'},
                    annot_kws={"size": 14})
        
        ax.set_xticklabels(ax.get_xticklabels(), fontsize=14, rotation=45, ha='right')
        ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
        
        cbar = ax.collections[0].colorbar
        cbar.ax.tick_params(labelsize=14)
        cbar.set_label('Correlation Coefficient', fontsize=16)

        # 修改处：标题从 >15 改为 >12
        plt.title(f'Comprehensive Correlation Analysis ({x_name})\nSubset Analysis from >0 to >12', fontsize=24, pad=40)
        plt.xlabel('Correlation Threshold (Subset: Overlap > k)', fontsize=18)
        plt.ylabel('')
        
        save_path_c = os.path.join(output_folder_path, f'Correlation_Matrix_Detailed_{x_name}.png')
        plt.savefig(save_path_c, dpi=300, bbox_inches='tight')
        print(f"--> Saved Heatmap: {save_path_c}")
        plt.close('all')

# ==========================================
# 4. 执行所有分析
# ==========================================
for name, fname in ref_files_map.items():
    run_full_analysis(name, fname)


正在处理 X 轴变量: Breadth
1. 正在绘制热力散点图 (Density Scatter)...
2. 正在绘制均值趋势图 (Trend Line)...
   - [1/15] norm_ideal_sim_2005_2023
   - [2/15] BACI_Trade_Intensity_Avg
   - [3/15] norm_geographical_proximity_avg
   - [4/15] Common_Official_Language
   - [5/15] Common_Ethno_Language
   - [6/15] AeroSCOPE_Flight_Intensity
   - [7/15] GDP_Diff
   - [8/15] GHG_Diff
   - [9/15] Rent_Diff
   - [10/15] Fuel_Ex_Diff
   - [11/15] Total
   - [12/15] Climate_Change_Policy_and_Economics
   - [13/15] Climate_Communication_and_Perception
   - [14/15] Sustainable_Development_and_Env_Policy
   - [15/15] Climate_Adaptation_and_Migration
--> Saved Scatter Plot: ../../data/Scatter\Density_Scatter_Breadth.png
--> Saved Trend Plot: ../../data/Scatter\Trend_Lines_Breadth.png
3. 正在绘制详细相关性热力图...
--> Saved Heatmap: ../../data/Scatter\Correlation_Matrix_Detailed_Breadth.png

正在处理 X 轴变量: Intensity
1. 正在绘制热力散点图 (Density Scatter)...
2. 正在绘制均值趋势图 (Trend Line)...
   - [1/15] norm_ideal_sim_2005_2023
   - [2/15] BACI_Trade_Int