In [None]:
import pandas as pd
import numpy as np
import json
from pathlib import Path


class WorldBankDataProcessor:
   """世界银行数据处理器：填充 → 提取年份 → 排名 → 筛选国家"""
   
   def __init__(self, data_type, input_file, start_year=2005, end_year=2023, 
                rank_ascending=True):
       """
       初始化处理器
       
       Args:
           data_type: 数据类型（如 'GDP', 'Energy_import'）
           input_file: 输入文件名
           start_year: 起始年份
           end_year: 结束年份
           rank_ascending: 排名规则（True表示数值越小排名越高）
       """
       self.data_type = data_type
       self.start_year = start_year
       self.end_year = end_year
       self.year_cols = [str(y) for y in range(start_year, end_year + 1)]
       self.rank_ascending = rank_ascending
       
       self.current_dir = Path.cwd()
       self.data_dir = self.current_dir.parent / "data"
       self.data_origin_dir = self.current_dir.parent / "data_origin"
       self.output_dir = self.data_dir / "5-2-Countries_background"
       
       self.output_dir.mkdir(parents=True, exist_ok=True)
       
       self.config_mapping_path = self.data_dir / "config_mappings.json"
       self.input_file_path = self.data_origin_dir / input_file
       self.cluster_path = self.data_dir / "4-2-Consensus_Policy_Cluster_Mapping.csv"
       
       self.out_filled = self.output_dir / f"{data_type}_filled_all_countries.csv"
       self.out_year_extracted = self.output_dir / f"{data_type}_year_{start_year}_{end_year}.csv"
       self.out_ranked = self.output_dir / f"{data_type}_ranked_{start_year}_{end_year}.csv"
       self.out_final = self.output_dir / f"{data_type}_final_filtered.csv"
       
   def load_country_name_mapping(self):
       """加载国家名称映射"""
       with open(self.config_mapping_path, "r", encoding="utf-8") as f:
           config_data = json.load(f)
       return config_data.get("country_names", {})
   
   def load_cluster_countries(self):
       """从聚类文件中提取国家列表"""
       cluster_df = pd.read_csv(self.cluster_path, encoding='utf-8-sig')
       return set(cluster_df['国家'].unique())
   
   def step1_load_and_fill(self, country_name_map):
       """
       步骤1: 加载所有国家数据并填充缺失值
       
       Args:
           country_name_map: 国家名称映射
           
       Returns:
           填充后的完整DataFrame
       """
       df = pd.read_csv(self.input_file_path, skiprows=4)
       
       all_year_cols = [str(y) for y in range(1960, 2024)]
       year_cols_exist = [c for c in df.columns if c in all_year_cols]
       
       base_cols = ["Country Name", "Country Code", "Indicator Name"]
       df = df[base_cols + year_cols_exist].copy()
       
       df[year_cols_exist] = df[year_cols_exist].apply(pd.to_numeric, errors="coerce")
       df.insert(1, "Country Name_CN", df["Country Code"].map(country_name_map))
       
       for idx, row in df.iterrows():
           year_values = row[year_cols_exist].values
           year_values = pd.to_numeric(year_values, errors='coerce')
           valid_values = year_values[~np.isnan(year_values)]
           
           if len(valid_values) > 0:
               mean_value = valid_values.mean()
               for col in year_cols_exist:
                   if pd.isna(row[col]):
                       df.at[idx, col] = mean_value
       
       df.to_csv(self.out_filled, index=False, encoding="utf-8-sig")
       return df
   
   def step2_extract_years(self, df_filled):
       """
       步骤2: 提取指定年份段
       
       Args:
           df_filled: 填充后的DataFrame
           
       Returns:
           提取年份后的DataFrame
       """
       year_cols_exist = [c for c in df_filled.columns if c in self.year_cols]
       base_cols = ["Country Name", "Country Name_CN", "Country Code", "Indicator Name"]
       df_year = df_filled[base_cols + year_cols_exist].copy()
       
       df_year.to_csv(self.out_year_extracted, index=False, encoding="utf-8-sig")
       return df_year
   
   def step3_rank_by_year(self, df_year):
       """
       步骤3: 对每一年的数据进行排名
       
       Args:
           df_year: 提取年份后的DataFrame
           
       Returns:
           包含排名的DataFrame
       """
       df_ranked = df_year.copy()
       year_cols_exist = [c for c in df_ranked.columns if c in self.year_cols]
       
       for year in year_cols_exist:
           rank_col = f'{year}_rank'
           df_ranked[rank_col] = df_ranked[year].rank(
               ascending=self.rank_ascending, 
               method='min'
           )
       
       df_ranked.to_csv(self.out_ranked, index=False, encoding="utf-8-sig")
       return df_ranked
   
   def step4_calculate_average_and_filter(self, df_ranked, cluster_countries):
       """
       步骤4: 计算平均排名并筛选目标国家
       
       Args:
           df_ranked: 包含排名的DataFrame
           cluster_countries: 目标国家代码集合
           
       Returns:
           最终筛选后的DataFrame
       """
       df_result = df_ranked.copy()
       
       year_cols_exist = [c for c in df_result.columns if c in self.year_cols]
       rank_cols = [f'{year}_rank' for year in year_cols_exist]
       
       df_result['avg_rank'] = df_result[rank_cols].mean(axis=1)
       df_result = df_result.sort_values('avg_rank')
       
       df_filtered = df_result[df_result['Country Code'].isin(cluster_countries)].copy()
       
       df_filtered.to_csv(self.out_final, index=False, encoding="utf-8-sig")
       return df_filtered
   
   def process_all(self):
       """执行完整的处理流程"""
       country_name_map = self.load_country_name_mapping()
       cluster_countries = self.load_cluster_countries()
       
       df_filled = self.step1_load_and_fill(country_name_map)
       df_year = self.step2_extract_years(df_filled)
       df_ranked = self.step3_rank_by_year(df_year)
       df_final = self.step4_calculate_average_and_filter(df_ranked, cluster_countries)
       
       return df_final


def main():
   """处理GDP和能源进口两个数据集"""
   datasets = [
       {
           'data_type': 'GDP',
           'input_file': 'GDP_per-API_NY.GDP.PCAP.KD_DS2_en_csv_v2_130141.csv',
           'rank_ascending': False
       },
       {
           'data_type': 'Energy_import',
           'input_file': 'Energy_import-API_EG.IMP.CONS.ZS_DS2_en_csv_v2_216046.csv',
           'rank_ascending': True
       }
   ]
   
   results = {}
   
   for config in datasets:
       processor = WorldBankDataProcessor(
           data_type=config['data_type'],
           input_file=config['input_file'],
           start_year=2005,
           end_year=2023,
           rank_ascending=config['rank_ascending']
       )
       
       df_result = processor.process_all()
       results[config['data_type']] = df_result


if __name__ == "__main__":
   main()

In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from pathlib import Path
import numpy as np


class ScatterPlotGeneratorByClusters:
    """散点图生成器：按聚类簇分组显示GDP和能源进口的平均排名"""
    
    def __init__(self):
        self.current_dir = Path.cwd()
        self.data_dir = self.current_dir.parent / "data"
        self.input_dir = self.data_dir / "5-2-Countries_background"
        self.output_dir = self.input_dir / "scatter_fig"
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        self.gdp_file = self.input_dir / "GDP_final_filtered.csv"
        self.energy_file = self.input_dir / "Energy_import_final_filtered.csv"
        self.cluster_file = self.data_dir / "4-2-Consensus_Policy_Cluster_Mapping.csv"
        
        self.setup_chinese_fonts()
        
    def setup_chinese_fonts(self):
        """设置中文字体"""
        candidates = ['SimHei', 'Microsoft YaHei', 'STHeiti', 'Heiti TC', 'Arial Unicode MS']
        
        for font_name in candidates:
            font_path = fm.findfont(fm.FontProperties(family=font_name), fallback_to_default=False)
            if font_path and Path(font_path).exists():
                plt.rcParams['font.sans-serif'] = [font_name]
                plt.rcParams['axes.unicode_minus'] = False
                plt.rcParams['figure.dpi'] = 300
                self.font_cn = fm.FontProperties(fname=font_path)
                return
        
        plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
        plt.rcParams['axes.unicode_minus'] = False
        plt.rcParams['figure.dpi'] = 300
        self.font_cn = fm.FontProperties()
    
    def load_data(self):
        """加载GDP、能源和聚类数据"""
        gdp_df = pd.read_csv(self.gdp_file, encoding='utf-8-sig')
        energy_df = pd.read_csv(self.energy_file, encoding='utf-8-sig')
        cluster_df = pd.read_csv(self.cluster_file, encoding='utf-8-sig')
        return gdp_df, energy_df, cluster_df
    
    def merge_data(self, gdp_df, energy_df, cluster_df, k_value):
        """合并GDP、能源和聚类数据"""
        k_clusters = cluster_df[cluster_df['K值'] == k_value].copy()
        
        if len(k_clusters) == 0:
            return pd.DataFrame()
        
        country_to_cluster = dict(zip(k_clusters['国家'], k_clusters['共识聚类ID']))
        
        gdp_data = gdp_df[['Country Code', 'Country Name_CN', 'avg_rank']].copy()
        gdp_data.columns = ['Country Code', 'Country Name_CN', 'GDP_avg_rank']
        
        energy_data = energy_df[['Country Code', 'avg_rank']].copy()
        energy_data.columns = ['Country Code', 'Energy_avg_rank']
        
        merged_df = gdp_data.merge(energy_data, on='Country Code', how='inner')
        merged_df['Cluster'] = merged_df['Country Code'].map(country_to_cluster)
        merged_df = merged_df[merged_df['Cluster'].notna()].copy()
        merged_df = merged_df.dropna(subset=['GDP_avg_rank', 'Energy_avg_rank'])
        
        return merged_df
    
    def create_scatter_plot_by_cluster(self, merged_df, k_value):
        """创建按簇分组的散点图"""
        if len(merged_df) == 0:
            return None
        
        cluster_ids = sorted(merged_df['Cluster'].unique())
        n_clusters = len(cluster_ids)
        
        fig, ax = plt.subplots(figsize=(16, 11))
        colors = plt.cm.tab10(np.linspace(0, 1, n_clusters))
        
        for idx, cluster_id in enumerate(cluster_ids):
            cluster_data = merged_df[merged_df['Cluster'] == cluster_id]
            n_countries = len(cluster_data)
            
            ax.scatter(
                cluster_data['GDP_avg_rank'],
                cluster_data['Energy_avg_rank'],
                c=[colors[idx]],
                marker='o',
                s=150,
                alpha=0.75,
                label=f'簇 {int(cluster_id)} ({n_countries}国)',
                edgecolors='white',
                linewidths=1.5
            )
        
        ax.set_xlabel('GDP平均排名 (2005-2023)', 
                      fontproperties=self.font_cn, fontsize=14, fontweight='bold')
        ax.set_ylabel('能源进口平均排名 (2005-2023)', 
                      fontproperties=self.font_cn, fontsize=14, fontweight='bold')
        ax.set_title(f'K={k_value} 共识聚类国家背景', 
                     fontproperties=self.font_cn, fontsize=17, fontweight='bold', pad=20)
        
        ax.legend(prop=self.font_cn, loc='best', fontsize=11, framealpha=0.95, 
                 edgecolor='gray', fancybox=True, shadow=True)
        ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        ax.invert_xaxis()
        ax.invert_yaxis()
        
        plt.tight_layout()
        
        output_path = self.output_dir / f'K{k_value}_GDP_vs_Energy_Scatter_by_Cluster.png'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        return output_path
    
    def create_detailed_scatter_all_labels(self, merged_df, k_value):
        """创建带所有国家标签的详细散点图"""
        if len(merged_df) == 0:
            return None
        
        cluster_ids = sorted(merged_df['Cluster'].unique())
        n_clusters = len(cluster_ids)
        
        fig, ax = plt.subplots(figsize=(22, 18))
        colors = plt.cm.tab10(np.linspace(0, 1, n_clusters))
        
        for idx, cluster_id in enumerate(cluster_ids):
            cluster_data = merged_df[merged_df['Cluster'] == cluster_id]
            n_countries = len(cluster_data)
            
            ax.scatter(
                cluster_data['GDP_avg_rank'],
                cluster_data['Energy_avg_rank'],
                c=[colors[idx]],
                marker='o',
                s=180,
                alpha=0.75,
                label=f'簇 {int(cluster_id)} ({n_countries}国)',
                edgecolors='white',
                linewidths=2
            )
        
        for _, row in merged_df.iterrows():
            country_cn = row['Country Name_CN'] if pd.notna(row['Country Name_CN']) else ''
            label_text = f"{row['Country Code']}"
            if country_cn:
                label_text += f"\n{country_cn}"
            
            ax.annotate(
                label_text,
                xy=(row['GDP_avg_rank'], row['Energy_avg_rank']),
                xytext=(10, 10),
                textcoords='offset points',
                fontsize=7,
                alpha=0.90,
                bbox=dict(boxstyle='round,pad=0.4', facecolor='yellow', alpha=0.5, 
                         edgecolor='gray', linewidth=0.5)
            )
        
        ax.set_xlabel('GDP平均排名 (2005-2023)', 
                      fontproperties=self.font_cn,  fontweight='bold')
        ax.set_ylabel('能源进口平均排名 (2005-2023)', 
                      fontproperties=self.font_cn,  fontweight='bold')
        ax.set_title(f'K={k_value} 共识聚类：各国GDP与能源进口关系详细分析', 
                     fontproperties=self.font_cn, fontsize=19, fontweight='bold', pad=20)
        
        ax.legend(prop=self.font_cn, loc='best', fontsize=12, framealpha=0.95, 
                 edgecolor='gray', fancybox=True, shadow=True)
        ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        ax.invert_xaxis()
        ax.invert_yaxis()
        
        plt.tight_layout()
        
        output_path = self.output_dir / f'K{k_value}_GDP_vs_Energy_Scatter_All_Labels.png'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        return output_path
    
    def save_merged_data(self, merged_df, k_value):
        """保存合并后的数据"""
        output_path = self.output_dir / f'K{k_value}_GDP_Energy_Merged_with_Clusters.csv'
        merged_df.to_csv(output_path, index=False, encoding='utf-8-sig')
        return output_path
    
    def generate_for_all_k(self):
        """为所有K值生成散点图"""
        gdp_df, energy_df, cluster_df = self.load_data()
        k_values = sorted(cluster_df['K值'].unique())
        
        for k_value in k_values:
            merged_df = self.merge_data(gdp_df, energy_df, cluster_df, k_value)
            
            if len(merged_df) == 0:
                continue
            
            self.save_merged_data(merged_df, k_value)
            self.create_scatter_plot_by_cluster(merged_df, k_value)
            self.create_detailed_scatter_all_labels(merged_df, k_value)


def main():
    generator = ScatterPlotGeneratorByClusters()
    generator.generate_for_all_k()


if __name__ == "__main__":
    main()


In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from pathlib import Path
import numpy as np
from scipy import stats


class ScatterPlotGeneratorByClusters:
    """散点图生成器：按聚类簇分组显示GDP和能源进口的平均排名"""
    
    def __init__(self):
        self.current_dir = Path.cwd()
        self.data_dir = self.current_dir.parent / "data"
        self.input_dir = self.data_dir / "5-2-Countries_background"
        self.output_dir = self.input_dir / "scatter_fig"
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        self.gdp_file = self.input_dir / "GDP_final_filtered.csv"
        self.energy_file = self.input_dir / "Energy_import_final_filtered.csv"
        self.cluster_file = self.data_dir / "4-2-Consensus_Policy_Cluster_Mapping.csv"
        
        self.setup_chinese_fonts()
        
    def setup_chinese_fonts(self):
        """设置中文字体"""
        candidates = ['SimHei', 'Microsoft YaHei', 'STHeiti', 'Heiti TC', 'Arial Unicode MS']
        
        for font_name in candidates:
            font_path = fm.findfont(fm.FontProperties(family=font_name), fallback_to_default=False)
            if font_path and Path(font_path).exists():
                plt.rcParams['font.sans-serif'] = [font_name]
                plt.rcParams['axes.unicode_minus'] = False
                plt.rcParams['figure.dpi'] = 300
                self.font_cn = fm.FontProperties(fname=font_path)
                return
        
        plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
        plt.rcParams['axes.unicode_minus'] = False
        plt.rcParams['figure.dpi'] = 300
        self.font_cn = fm.FontProperties()
    
    def load_data(self):
        """加载GDP、能源和聚类数据"""
        gdp_df = pd.read_csv(self.gdp_file, encoding='utf-8-sig')
        energy_df = pd.read_csv(self.energy_file, encoding='utf-8-sig')
        cluster_df = pd.read_csv(self.cluster_file, encoding='utf-8-sig')
        return gdp_df, energy_df, cluster_df
    
    def merge_data(self, gdp_df, energy_df, cluster_df, k_value):
        """合并GDP、能源和聚类数据"""
        k_clusters = cluster_df[cluster_df['K值'] == k_value].copy()
        
        if len(k_clusters) == 0:
            return pd.DataFrame()
        
        country_to_cluster = dict(zip(k_clusters['国家'], k_clusters['共识聚类ID']))
        
        gdp_data = gdp_df[['Country Code', 'Country Name_CN', 'avg_rank']].copy()
        gdp_data.columns = ['Country Code', 'Country Name_CN', 'GDP_avg_rank']
        
        energy_data = energy_df[['Country Code', 'avg_rank']].copy()
        energy_data.columns = ['Country Code', 'Energy_avg_rank']
        
        merged_df = gdp_data.merge(energy_data, on='Country Code', how='inner')
        merged_df['Cluster'] = merged_df['Country Code'].map(country_to_cluster)
        merged_df = merged_df[merged_df['Cluster'].notna()].copy()
        merged_df = merged_df.dropna(subset=['GDP_avg_rank', 'Energy_avg_rank'])
        
        return merged_df
    
    def create_scatter_plot_by_cluster(self, merged_df, k_value):
        """创建按簇分组的散点图（带置信区间和簇中心）"""
        if len(merged_df) == 0:
            return None
        
        cluster_ids = sorted(merged_df['Cluster'].unique())
        n_clusters = len(cluster_ids)
        
        fig, ax = plt.subplots(figsize=(16, 11))
        colors = plt.cm.tab10(np.linspace(0, 1, n_clusters))
        
        for idx, cluster_id in enumerate(cluster_ids):
            cluster_data = merged_df[merged_df['Cluster'] == cluster_id]
            n_countries = len(cluster_data)
            color = colors[idx]
            
            # 绘制置信椭圆
            if len(cluster_data) >= 3:
                x = cluster_data['GDP_avg_rank'].values
                y = cluster_data['Energy_avg_rank'].values
                
                mean_x = np.mean(x)
                mean_y = np.mean(y)
                
                cov = np.cov(x, y)
                lambda_, v = np.linalg.eig(cov)
                lambda_ = np.sqrt(lambda_)
                
                confidence = 0.95
                chi2_val = np.sqrt(stats.chi2.ppf(confidence, df=2))
                
                theta = np.linspace(0, 2*np.pi, 100)
                ellipse_x = mean_x + chi2_val * (lambda_[0] * np.cos(theta) * v[0, 0] + lambda_[1] * np.sin(theta) * v[0, 1])
                ellipse_y = mean_y + chi2_val * (lambda_[0] * np.cos(theta) * v[1, 0] + lambda_[1] * np.sin(theta) * v[1, 1])
                
                # 使用非常淡的颜色填充置信区间
                ax.fill(ellipse_x, ellipse_y, color=color, alpha=0.15, zorder=1)
            
            # 绘制散点
            ax.scatter(
                cluster_data['GDP_avg_rank'],
                cluster_data['Energy_avg_rank'],
                c=[color],
                marker='o',
                s=150,
                alpha=0.75,
                label=f'簇 {int(cluster_id)} ({n_countries}国)',
                edgecolors='white',
                linewidths=1.5,
                zorder=2
            )
            
            # 计算并绘制簇中心点
            mean_gdp = cluster_data['GDP_avg_rank'].mean()
            mean_energy = cluster_data['Energy_avg_rank'].mean()
            
            ax.scatter(
                mean_gdp,
                mean_energy,
                c=[color],
                marker='o',
                s=500,
                alpha=0.9,
                edgecolors='black',
                linewidths=2.5,
                zorder=3
            )
            
            # 在中心点上标注簇编号
            ax.text(
                mean_gdp,
                mean_energy,
                str(int(cluster_id)),
                fontsize=16,
                fontweight='bold',
                ha='center',
                va='center',
                color='white',
                zorder=4
            )
        
        ax.set_xlabel('GDP平均排名 (2005-2023)\n← 排名越小GDP越高', 
                      fontproperties=self.font_cn, fontsize=14, fontweight='bold')
        ax.set_ylabel('能源进口平均排名 (2005-2023)\n← 排名越小进口越少', 
                      fontproperties=self.font_cn, fontsize=14, fontweight='bold')
        ax.set_title(f'K={k_value} 共识聚类：各国GDP与能源进口关系分析', 
                     fontproperties=self.font_cn, fontsize=17, fontweight='bold', pad=20)
        
        ax.legend(prop=self.font_cn, loc='best', fontsize=11, framealpha=0.95, 
                 edgecolor='gray', fancybox=True, shadow=True)
        ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        ax.invert_xaxis()
        ax.invert_yaxis()
        
        plt.tight_layout()
        
        output_path = self.output_dir / f'K{k_value}_GDP_vs_Energy_Scatter_by_Cluster.png'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        return output_path
    
    def create_detailed_scatter_all_labels(self, merged_df, k_value):
        """创建带所有国家标签的详细散点图（带置信区间和簇中心）"""
        if len(merged_df) == 0:
            return None
        
        cluster_ids = sorted(merged_df['Cluster'].unique())
        n_clusters = len(cluster_ids)
        
        fig, ax = plt.subplots(figsize=(22, 18))
        colors = plt.cm.tab10(np.linspace(0, 1, n_clusters))
        
        for idx, cluster_id in enumerate(cluster_ids):
            cluster_data = merged_df[merged_df['Cluster'] == cluster_id]
            n_countries = len(cluster_data)
            color = colors[idx]
            
            # 绘制置信椭圆
            if len(cluster_data) >= 3:
                x = cluster_data['GDP_avg_rank'].values
                y = cluster_data['Energy_avg_rank'].values
                
                mean_x = np.mean(x)
                mean_y = np.mean(y)
                
                cov = np.cov(x, y)
                lambda_, v = np.linalg.eig(cov)
                lambda_ = np.sqrt(lambda_)
                
                confidence = 0.95
                chi2_val = np.sqrt(stats.chi2.ppf(confidence, df=2))
                
                theta = np.linspace(0, 2*np.pi, 100)
                ellipse_x = mean_x + chi2_val * (lambda_[0] * np.cos(theta) * v[0, 0] + lambda_[1] * np.sin(theta) * v[0, 1])
                ellipse_y = mean_y + chi2_val * (lambda_[0] * np.cos(theta) * v[1, 0] + lambda_[1] * np.sin(theta) * v[1, 1])
                
                ax.fill(ellipse_x, ellipse_y, color=color, alpha=0.15, zorder=1)
            
            # 绘制散点
            ax.scatter(
                cluster_data['GDP_avg_rank'],
                cluster_data['Energy_avg_rank'],
                c=[color],
                marker='o',
                s=180,
                alpha=0.75,
                label=f'簇 {int(cluster_id)} ({n_countries}国)',
                edgecolors='white',
                linewidths=2,
                zorder=2
            )
            
            # 计算并绘制簇中心点
            mean_gdp = cluster_data['GDP_avg_rank'].mean()
            mean_energy = cluster_data['Energy_avg_rank'].mean()
            
            ax.scatter(
                mean_gdp,
                mean_energy,
                c=[color],
                marker='o',
                s=600,
                alpha=0.9,
                edgecolors='black',
                linewidths=3,
                zorder=3
            )
            
            ax.text(
                mean_gdp,
                mean_energy,
                str(int(cluster_id)),
                fontsize=18,
                fontweight='bold',
                ha='center',
                va='center',
                color='white',
                zorder=4
            )
        
        # 标注所有国家
        for _, row in merged_df.iterrows():
            country_cn = row['Country Name_CN'] if pd.notna(row['Country Name_CN']) else ''
            label_text = f"{row['Country Code']}"
            if country_cn:
                label_text += f"\n{country_cn}"
            
            ax.annotate(
                label_text,
                xy=(row['GDP_avg_rank'], row['Energy_avg_rank']),
                xytext=(10, 10),
                textcoords='offset points',
                fontsize=7,
                alpha=0.90,
                bbox=dict(boxstyle='round,pad=0.4', facecolor='yellow', alpha=0.5, 
                         edgecolor='gray', linewidth=0.5),
                zorder=5
            )
        
        ax.set_xlabel('GDP平均排名 (2005-2023)\n← 排名越小GDP越高', 
                      fontproperties=self.font_cn, fontsize=15, fontweight='bold')
        ax.set_ylabel('能源进口平均排名 (2005-2023)\n← 排名越小进口越少', 
                      fontproperties=self.font_cn, fontsize=15, fontweight='bold')
        ax.set_title(f'K={k_value} 共识聚类：各国GDP与能源进口关系详细分析', 
                     fontproperties=self.font_cn, fontsize=19, fontweight='bold', pad=20)
        
        ax.legend(prop=self.font_cn, loc='best', fontsize=12, framealpha=0.95, 
                 edgecolor='gray', fancybox=True, shadow=True)
        ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        ax.invert_xaxis()
        ax.invert_yaxis()
        
        plt.tight_layout()
        
        output_path = self.output_dir / f'K{k_value}_GDP_vs_Energy_Scatter_All_Labels.png'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        return output_path
    
    def save_merged_data(self, merged_df, k_value):
        """保存合并后的数据"""
        output_path = self.output_dir / f'K{k_value}_GDP_Energy_Merged_with_Clusters.csv'
        merged_df.to_csv(output_path, index=False, encoding='utf-8-sig')
        return output_path
    
    def generate_for_all_k(self):
        """为所有K值生成散点图"""
        gdp_df, energy_df, cluster_df = self.load_data()
        k_values = sorted(cluster_df['K值'].unique())
        
        for k_value in k_values:
            merged_df = self.merge_data(gdp_df, energy_df, cluster_df, k_value)
            
            if len(merged_df) == 0:
                continue
            
            self.save_merged_data(merged_df, k_value)
            self.create_scatter_plot_by_cluster(merged_df, k_value)
            self.create_detailed_scatter_all_labels(merged_df, k_value)


def main():
    generator = ScatterPlotGeneratorByClusters()
    generator.generate_for_all_k()


if __name__ == "__main__":
    main()


In [5]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from pathlib import Path
import numpy as np


class ScatterPlotGeneratorByClusters:
    """散点图生成器：为K=7的所有簇在一张图上生成子图（3行4列）"""
    
    def __init__(self):
        self.current_dir = Path.cwd()
        self.data_dir = self.current_dir.parent / "data"
        self.input_dir = self.data_dir / "5-2-Countries_background"
        self.output_dir = self.input_dir / "scatter_fig"
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        self.gdp_file = self.input_dir / "GDP_final_filtered.csv"
        self.energy_file = self.input_dir / "Energy_import_final_filtered.csv"
        self.cluster_file = self.data_dir / "4-2-Consensus_Policy_Cluster_Mapping.csv"
        
        self.setup_chinese_fonts()
        
    def setup_chinese_fonts(self):
        """设置中文字体"""
        candidates = ['SimHei', 'Microsoft YaHei', 'STHeiti', 'Heiti TC', 'Arial Unicode MS']
        
        for font_name in candidates:
            font_path = fm.findfont(fm.FontProperties(family=font_name), fallback_to_default=False)
            if font_path and Path(font_path).exists():
                plt.rcParams['font.sans-serif'] = [font_name]
                plt.rcParams['axes.unicode_minus'] = False
                plt.rcParams['figure.dpi'] = 300
                self.font_cn = fm.FontProperties(fname=font_path)
                return
        
        plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
        plt.rcParams['axes.unicode_minus'] = False
        plt.rcParams['figure.dpi'] = 300
        self.font_cn = fm.FontProperties()
    
    def load_data(self):
        """加载GDP、能源和聚类数据"""
        gdp_df = pd.read_csv(self.gdp_file, encoding='utf-8-sig')
        energy_df = pd.read_csv(self.energy_file, encoding='utf-8-sig')
        cluster_df = pd.read_csv(self.cluster_file, encoding='utf-8-sig')
        return gdp_df, energy_df, cluster_df
    
    def merge_data(self, gdp_df, energy_df, cluster_df, k_value):
        """合并GDP、能源和聚类数据"""
        k_clusters = cluster_df[cluster_df['K值'] == k_value].copy()
        
        if len(k_clusters) == 0:
            return pd.DataFrame()
        
        country_to_cluster = dict(zip(k_clusters['国家'], k_clusters['共识聚类ID']))
        
        gdp_data = gdp_df[['Country Code', 'Country Name_CN', 'avg_rank']].copy()
        gdp_data.columns = ['Country Code', 'Country Name_CN', 'GDP_avg_rank']
        
        energy_data = energy_df[['Country Code', 'avg_rank']].copy()
        energy_data.columns = ['Country Code', 'Energy_avg_rank']
        
        merged_df = gdp_data.merge(energy_data, on='Country Code', how='inner')
        merged_df['Cluster'] = merged_df['Country Code'].map(country_to_cluster)
        merged_df = merged_df[merged_df['Cluster'].notna()].copy()
        merged_df = merged_df.dropna(subset=['GDP_avg_rank', 'Energy_avg_rank'])
        
        return merged_df
    
    def plot_single_cluster_in_subplot(self, ax, cluster_data, cluster_id, all_data_stats, color):
        """在指定的子图上绘制单个簇"""
        n_countries = len(cluster_data)
        
        # 绘制散点
        ax.scatter(
            cluster_data['GDP_avg_rank'],
            cluster_data['Energy_avg_rank'],
            c=[color],
            marker='o',
            s=120,
            alpha=0.75,
            edgecolors='white',
            linewidths=1.5
        )
        
        # 添加国家标签
        for _, row in cluster_data.iterrows():
            country_cn = row['Country Name_CN'] if pd.notna(row['Country Name_CN']) else ''
            label_text = f"{row['Country Code']}"
            if country_cn:
                label_text += f"\n{country_cn}"
            
            ax.annotate(
                label_text,
                xy=(row['GDP_avg_rank'], row['Energy_avg_rank']),
                xytext=(5, 5),
                textcoords='offset points',
                fontsize=6,
                alpha=0.85,
                bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.5, 
                         edgecolor='gray', linewidth=0.5)
            )
        
        # 计算簇统计信息
        gdp_mean = cluster_data['GDP_avg_rank'].mean()
        energy_mean = cluster_data['Energy_avg_rank'].mean()
        
        # 添加均值十字线
        ax.axvline(x=gdp_mean, color='red', linestyle='--', linewidth=1, alpha=0.4)
        ax.axhline(y=energy_mean, color='blue', linestyle='--', linewidth=1, alpha=0.4)
        
        # 设置标题
        ax.set_title(f'簇{int(cluster_id)} ({n_countries}国)', 
                     fontproperties=self.font_cn, fontsize=11, fontweight='bold', pad=8)
        
        # 设置坐标轴范围为全局范围
        ax.set_xlim(all_data_stats['gdp_max'] + 5, all_data_stats['gdp_min'] - 5)
        ax.set_ylim(all_data_stats['energy_max'] + 5, all_data_stats['energy_min'] - 5)
        
        # 网格
        ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
        
        # 设置坐标轴标签字体
        ax.tick_params(labelsize=8)
    
    def create_subplots_for_all_clusters(self, merged_df, k_value):
        """创建3行4列的子图，展示所有簇"""
        if len(merged_df) == 0:
            return None
        
        # 计算全局统计信息
        all_data_stats = {
            'gdp_min': merged_df['GDP_avg_rank'].min(),
            'gdp_max': merged_df['GDP_avg_rank'].max(),
            'energy_min': merged_df['Energy_avg_rank'].min(),
            'energy_max': merged_df['Energy_avg_rank'].max()
        }
        
        # 获取所有簇ID并排序
        cluster_ids = sorted(merged_df['Cluster'].unique())
        n_clusters = len(cluster_ids)
        
        # 创建3行4列的子图
        fig, axes = plt.subplots(3, 4, figsize=(24, 18))
        axes = axes.flatten()  # 展平为一维数组便于索引
        
        # 生成颜色
        colors = plt.cm.tab10(np.linspace(0, 1, 10))
        
        # 为每个簇绘制子图
        for idx, cluster_id in enumerate(cluster_ids):
            cluster_data = merged_df[merged_df['Cluster'] == cluster_id]
            color = colors[int(cluster_id) % 10]
            self.plot_single_cluster_in_subplot(axes[idx], cluster_data, cluster_id, all_data_stats, color)
        
        # 隐藏多余的子图
        for idx in range(n_clusters, 12):
            axes[idx].axis('off')
        
        # 添加总标题
        fig.suptitle(f'K=7 共识聚类：各簇GDP与能源进口关系分析', 
                     fontproperties=self.font_cn, fontsize=20, fontweight='bold', y=0.995)
        
        # 添加公共坐标轴标签
        fig.text(0.5, 0.02, 'GDP平均排名 (2005-2023)', 
                 ha='center', fontproperties=self.font_cn, fontsize=14, fontweight='bold')
        fig.text(0.02, 0.5, '能源进口平均排名 (2005-2023)', 
                 va='center', rotation='vertical', fontproperties=self.font_cn, fontsize=14, fontweight='bold')
        
        plt.tight_layout(rect=[0.03, 0.03, 1, 0.99])
        
        output_path = self.output_dir / f'K7_All_Clusters_Subplots_3x4.png'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"✓ 已生成K=7所有簇的子图合集: {output_path.name}")
        return output_path
    
    def save_merged_data(self, merged_df, k_value):
        """保存合并后的数据"""
        output_path = self.output_dir / f'K{k_value}_GDP_Energy_Merged_with_Clusters.csv'
        merged_df.to_csv(output_path, index=False, encoding='utf-8-sig')
        print(f"✓ 已保存合并数据: {output_path.name}")
        return output_path
    
    def generate_for_k7(self):
        """为K=7生成3行4列的子图"""
        print("开始生成K=7的散点图（3行4列子图）...")
        
        gdp_df, energy_df, cluster_df = self.load_data()
        k_value = 7
        
        merged_df = self.merge_data(gdp_df, energy_df, cluster_df, k_value)
        
        if len(merged_df) == 0:
            print(f"警告: K={k_value}没有找到数据")
            return
        
        # 保存合并数据
        self.save_merged_data(merged_df, k_value)
        
        # 获取簇信息
        cluster_ids = sorted(merged_df['Cluster'].unique())
        print(f"\nK=7共有{len(cluster_ids)}个簇: {[int(x) for x in cluster_ids]}")
        
        # 生成3行4列的子图
        self.create_subplots_for_all_clusters(merged_df, k_value)
        
        print(f"\n✅ 完成！生成了包含所有簇的子图")
        print(f"输出目录: {self.output_dir}")


def main():
    generator = ScatterPlotGeneratorByClusters()
    generator.generate_for_k7()


if __name__ == "__main__":
    main()


开始生成K=7的散点图（3行4列子图）...
✓ 已保存合并数据: K7_GDP_Energy_Merged_with_Clusters.csv

K=7共有7个簇: [1, 2, 3, 4, 5, 6, 7]
✓ 已生成K=7所有簇的子图合集: K7_All_Clusters_Subplots_3x4.png

✅ 完成！生成了包含所有簇的子图
输出目录: f:\Desktop\CAMPF_Supplementary\data\5-2-Countries_background\scatter_fig
