In [8]:
import pandas as pd
import numpy as np
import json
from pathlib import Path


class WorldBankDataProcessor:
   """世界银行数据处理器：提取年份 → 计算百分位数 → 求平均 → 筛选国家"""
   
   def __init__(self, data_type, input_file, start_year=2005, end_year=2023):
       """
       初始化处理器
       
       Args:
           data_type: 数据类型（如 'GDP', 'Energy_import'）
           input_file: 输入文件名
           start_year: 起始年份
           end_year: 结束年份
       """
       self.data_type = data_type
       self.start_year = start_year
       self.end_year = end_year
       self.year_cols = [str(y) for y in range(start_year, end_year + 1)]
       
       self.current_dir = Path.cwd()
       self.data_dir = self.current_dir.parent / "data"
       self.data_origin_dir = self.current_dir.parent / "data_origin"
       self.output_dir = self.data_dir / "5-2-Countries_background"
       
       self.output_dir.mkdir(parents=True, exist_ok=True)
       
       self.config_mapping_path = self.data_dir / "config_mappings.json"
       self.input_file_path = self.data_origin_dir / input_file
       self.cluster_path = self.data_dir / "4-2-Consensus_Policy_Cluster_Mapping.csv"
       
       self.out_year_extracted = self.output_dir / f"{data_type}_year_{start_year}_{end_year}.csv"
       self.out_percentile = self.output_dir / f"{data_type}_percentile_{start_year}_{end_year}.csv"
       self.out_final = self.output_dir / f"{data_type}_final_filtered.csv"
       
   def load_country_name_mapping(self):
       """加载国家名称映射"""
       with open(self.config_mapping_path, "r", encoding="utf-8") as f:
           config_data = json.load(f)
       return config_data.get("country_names", {})
   
   def load_cluster_countries(self):
       """从聚类文件中提取国家列表"""
       cluster_df = pd.read_csv(self.cluster_path, encoding='utf-8-sig')
       return set(cluster_df['国家'].unique())
   
   def step1_load_and_extract_years(self, country_name_map):
       """
       步骤1: 加载数据并直接提取2005-2023年的数据
       
       Args:
           country_name_map: 国家名称映射
           
       Returns:
           提取年份后的DataFrame
       """
       df = pd.read_csv(self.input_file_path, skiprows=4)
       
       # 只保留需要的年份列
       year_cols_exist = [c for c in df.columns if c in self.year_cols]
       base_cols = ["Country Name", "Country Code", "Indicator Name"]
       df = df[base_cols + year_cols_exist].copy()
       
       # 转换为数值类型
       df[year_cols_exist] = df[year_cols_exist].apply(pd.to_numeric, errors="coerce")
       
       # 添加中文国家名
       df.insert(1, "Country Name_CN", df["Country Code"].map(country_name_map))
       
       df.to_csv(self.out_year_extracted, index=False, encoding="utf-8-sig")
       return df
   
   def step2_calculate_percentile_by_year(self, df_year):
       """
       步骤2: 逐年计算每个国家在全球的百分位数（跳过空值）
       
       Args:
           df_year: 提取年份后的DataFrame
           
       Returns:
           包含百分位数的DataFrame
       """
       df_percentile = df_year.copy()
       year_cols_exist = [c for c in df_percentile.columns if c in self.year_cols]
       
       # 对每一年分别计算百分位数
       for year in year_cols_exist:
           percentile_col = f'{year}_percentile'
           
           # 获取该年的所有有效数据（非空）
           valid_data = df_percentile[year].dropna()
           
           if len(valid_data) > 0:
               # 对每个国家计算其在有效数据中的百分位数
               df_percentile[percentile_col] = df_percentile[year].apply(
                   lambda x: self._calculate_percentile(x, valid_data) if pd.notna(x) else np.nan
               )
           else:
               # 如果该年没有有效数据，设置为NaN
               df_percentile[percentile_col] = np.nan
       
       df_percentile.to_csv(self.out_percentile, index=False, encoding="utf-8-sig")
       return df_percentile
   
   def _calculate_percentile(self, value, valid_data):
       """
       计算单个值在数据集中的百分位数
       
       Args:
           value: 要计算百分位数的值
           valid_data: 有效数据的Series
           
       Returns:
           百分位数（0-100）
       """
       if pd.isna(value):
           return np.nan
       
       # 计算有多少个值小于等于当前值
       rank = (valid_data <= value).sum()
       total = len(valid_data)
       
       # 百分位数 = (排名 / 总数) * 100
       percentile = (rank / total) * 100
       return percentile
   
   def step3_average_percentile_and_filter(self, df_percentile, cluster_countries):
       """
       步骤3: 计算每个国家多年的平均百分位数并筛选目标国家
       
       Args:
           df_percentile: 包含百分位数的DataFrame
           cluster_countries: 目标国家代码集合
           
       Returns:
           最终筛选后的DataFrame
       """
       df_result = df_percentile.copy()
       
       # 获取所有百分位数列
       year_cols_exist = [c for c in df_result.columns if c in self.year_cols]
       percentile_cols = [f'{year}_percentile' for year in year_cols_exist]
       
       # 计算平均百分位数（自动跳过NaN值）
       df_result['avg_percentile'] = df_result[percentile_cols].mean(axis=1, skipna=True)
       
       # 按平均百分位数排序
       df_result = df_result.sort_values('avg_percentile')
       
       # 筛选目标国家
       df_filtered = df_result[df_result['Country Code'].isin(cluster_countries)].copy()
       
       df_filtered.to_csv(self.out_final, index=False, encoding="utf-8-sig")
       return df_filtered
   
   def process_all(self):
       """执行完整的处理流程"""
       country_name_map = self.load_country_name_mapping()
       cluster_countries = self.load_cluster_countries()
       
       # 步骤1: 提取年份数据
       df_year = self.step1_load_and_extract_years(country_name_map)
       
       # 步骤2: 逐年计算百分位数
       df_percentile = self.step2_calculate_percentile_by_year(df_year)
       
       # 步骤3: 计算平均百分位数并筛选国家
       df_final = self.step3_average_percentile_and_filter(df_percentile, cluster_countries)
       
       return df_final


def main():
   """处理GDP和能源进口两个数据集"""
   datasets = [
       {
           'data_type': 'GDP',
           'input_file': 'GDP_per-API_NY.GDP.PCAP.KD_DS2_en_csv_v2_130141.csv'
       },
       {
           'data_type': 'Energy_import',
           'input_file': 'Energy_import-API_EG.IMP.CONS.ZS_DS2_en_csv_v2_216046.csv'
       }
   ]
   
   results = {}
   
   for config in datasets:
       processor = WorldBankDataProcessor(
           data_type=config['data_type'],
           input_file=config['input_file'],
           start_year=2005,
           end_year=2023
       )
       
       df_result = processor.process_all()
       results[config['data_type']] = df_result


if __name__ == "__main__":
   main()