In [1]:
import pandas as pd   
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
from general_pic_setup import setup_mpl_single2
import matplotlib as mpl
from matplotlib.patches import Patch, Rectangle  # Patch + Rectangle

setup_mpl_single2()
mpl.rcParams['ytick.direction'] = 'out'
mpl.rcParams['xtick.direction'] = 'out'
mpl.rcParams['axes.titlesize'] = 'Medium'


class ScatterPlotGeneratorByClusters:
    """为K=7的所有簇在3行3列子图上生成散点图"""
   
    def __init__(self):
        self.current_dir = Path.cwd()
        self.data_dir = self.current_dir.parent / "data"
        self.input_dir = self.data_dir / "5-2-Countries_background"
        self.output_dir = self.input_dir / "scatter_fig"
        self.output_dir.mkdir(parents=True, exist_ok=True)
       
        self.gdp_file = self.input_dir / "GDP_final_filtered.csv"
        self.energy_file = self.input_dir / "Energy_import_final_filtered.csv"
        self.cluster_file = self.data_dir / "4-2-Consensus_Policy_Cluster_Mapping.csv"
       
        self.nature_colors = ['#E64B35', "#6917C2", '#00A087', '#3C5488', '#F39B7F', 
                             '#DC0000', '#7E6148', '#B09C85', '#E18727', '#20854E', '#0072B5']
        self.marker_size = 200
       
        self.bar_colors = {
            'gdp_bar': '#86A9B2',
            'gdp_bar_edge': "#000000",
            'gdp_scatter': '#B6CACF',
            'gdp_scatter_edge': '#5D7E83',
            'energy_bar': '#C7B587',
            'energy_bar_edge': "#000000",
            'energy_scatter': '#DCD5B2',
            'energy_scatter_edge': '#95885B'
        }
       
        # 引导线配置
        self.leader_line_config = {
            1: {
                'SVK': {
                    'start_position': 'top',
                    'start_offset_x': -1.1,
                    'start_offset_y': 2,
                    'segments': [{'direction': 'left', 'length': 31}]
                },
                'DEU': {
                    'start_position': 'top',
                    'start_offset_x': -1.8,
                    'start_offset_y': 1.5,
                    'segments': [
                        {'direction': 'left', 'length': 38},
                        {'direction': 'up', 'length': 10},
                        {'direction': 'left', 'length': 8.5}
                    ]
                },
                'GBR': {
                    'start_position': 'bottom',
                    'start_offset_x': -2.35,
                    'start_offset_y': 0,
                    'segments': [
                        {'direction': 'left', 'length': 7.5},
                        {'direction': 'down', 'length': 20},
                        {'direction': 'left', 'length': 12}
                    ]
                }
            }
        }
       
        # 手动标注调整规则
        self.manual_adjustments = {
            1: {
                'CAN': 'top', 'DEU': 'leader_line', 'SVK': 'leader_line',
                'PRT': 'bottom_left_cluster1-PRT', 'HUN': 'bottom_left_cluster1-HUN',
                'FRA': 'bottom_left_cluster1-FRA', 'GBR': 'leader_line',
                'POL': 'left', 'DNK': 'bottom', 'NOR': 'top', 'ESP': 'top',
                'BEL': 'top_right_close', 'CZE': 'bottom_left_cluster1',
                'NZL': 'bottom_right_cluster1-NZL'
            },
            2: {
                'TUR': 'left', 'KOR': 'top_left', 'NLD': 'bottom_left',
                'JPN': 'top_right', 'CHE': 'top', 'AUS': 'top', 'ITA': 'bottom_left'
            },
            3: {
                'BGR': 'bottom_right', 'ITU': 'left', 'GRC': 'bottom_right_cluster3',
                'LTU': 'left', 'LUX': 'left'
            },
            4: {'CRI': 'left'},
            5: {'HRV': 'left'},
            7: {'ISL': 'left'}
        }

    def load_data(self):
        """加载GDP、能源和聚类数据"""
        gdp_df = pd.read_csv(self.gdp_file, encoding='utf-8-sig')
        energy_df = pd.read_csv(self.energy_file, encoding='utf-8-sig')
        cluster_df = pd.read_csv(self.cluster_file, encoding='utf-8-sig')
        return gdp_df, energy_df, cluster_df

    def merge_data(self, gdp_df, energy_df, cluster_df, k_value):
        """合并GDP、能源和聚类数据"""
        k_clusters = cluster_df[cluster_df['K值'] == k_value].copy()
        if len(k_clusters) == 0:
            return pd.DataFrame()
       
        country_to_cluster = dict(zip(k_clusters['国家'], k_clusters['共识聚类ID']))
       
        gdp_data = gdp_df[['Country Code', 'Country Name_CN', 'avg_percentile']].copy()
        gdp_data.columns = ['Country Code', 'Country Name_CN', 'GDP_avg_percentile']
       
        energy_data = energy_df[['Country Code', 'avg_percentile']].copy()
        energy_data.columns = ['Country Code', 'Energy_avg_percentile']
       
        merged_df = gdp_data.merge(energy_data, on='Country Code', how='inner')
        merged_df['Cluster'] = merged_df['Country Code'].map(country_to_cluster)
        merged_df = merged_df[merged_df['Cluster'].notna()].copy()
        merged_df = merged_df.dropna(subset=['GDP_avg_percentile', 'Energy_avg_percentile'])
       
        return merged_df

    def get_start_point(self, x, y, position, offset_x, offset_y):
        """根据位置和XY偏移量获取引导线起始点"""
        return (x + offset_x, y + offset_y)

    def calculate_leader_line_path(self, x, y, start_position, start_offset_x, start_offset_y, segments):
        """根据线段配置计算引导线路径"""
        start_x, start_y = self.get_start_point(x, y, start_position, start_offset_x, start_offset_y)
        points = [(start_x, start_y)]
        current_x, current_y = start_x, start_y
       
        for segment in segments:
            direction = segment['direction']
            length = segment['length']
           
            if direction == 'left':
                current_x -= length
            elif direction == 'right':
                current_x += length
            elif direction == 'up':
                current_y += length
            elif direction == 'down':
                current_y -= length
           
            points.append((current_x, current_y))
       
        return points

    def draw_leader_line(self, ax, points):
        """绘制引导线"""
        xs = [p[0] for p in points]
        ys = [p[1] for p in points]
        ax.plot(xs, ys, color='black', linewidth=1, alpha=0.6, zorder=2, 
                solid_capstyle='round', solid_joinstyle='round')
        ax.plot(xs[-1], ys[-1], 'o', color='black', markersize=3, alpha=0.6, zorder=2)

    def apply_manual_adjustments(self, cluster_data, cluster_id):
        """应用手动标注调整"""
        if cluster_id not in self.manual_adjustments:
            return None
       
        adjustments = self.manual_adjustments[cluster_id]
        manual_positions = []
       
        for _, row in cluster_data.iterrows():
            country_code = row['Country Code']
            if country_code in adjustments:
                direction = adjustments[country_code]
                if direction is None:
                    continue
               
                x = row['GDP_avg_percentile']
                y = row['Energy_avg_percentile']
               
                if direction == 'leader_line':
                    manual_positions.append({
                        'x': x, 'y': y, 'code': country_code, 'use_leader_line': True
                    })
                else:
                    offset, ha, va = self._get_manual_offset_and_alignment(direction)
                    manual_positions.append({
                        'x': x, 'y': y, 'code': country_code,
                        'offset': offset, 'ha': ha, 'va': va, 'use_leader_line': False
                    })
       
        return manual_positions if manual_positions else None

    def _get_manual_offset_and_alignment(self, direction):
        """根据方向返回偏移量和对齐方式"""
        offset_distance = 8
        offset_distance_far = 15
        offset_distance_close = 5
        offset_distance_slight = 4
       
        offsets = {
            'right': ((offset_distance, 0), 'left', 'center'),
            'left': ((-offset_distance, 0), 'right', 'center'),
            'left_far': ((-offset_distance_far, 0), 'right', 'center'),
            'top': ((0, offset_distance), 'center', 'bottom'),
            'bottom': ((0, -offset_distance), 'center', 'top'),
            'bottom_slight': ((0, -offset_distance_slight), 'center', 'top'),
            'bottom_right': ((offset_distance-6, -offset_distance), 'left', 'top'),
            'bottom_left': ((-offset_distance+6, -offset_distance), 'right', 'top'),
            'top_right': ((offset_distance, offset_distance), 'left', 'bottom'),
            'top_right_close': ((offset_distance_close, offset_distance_close), 'left', 'bottom'),
            'top_left': ((-offset_distance+3, offset_distance-6), 'right', 'bottom'),
            'bottom_left_cluster1': ((-offset_distance+3, -offset_distance), 'right', 'top'),
            'bottom_left_cluster1-FRA': ((-offset_distance+3, -offset_distance+14), 'right', 'top'),
            'bottom_left_cluster1-HUN': ((-offset_distance, -offset_distance+12), 'right', 'top'),
            'bottom_left_cluster1-PRT': ((-offset_distance+2, -offset_distance+16), 'right', 'top'),
            'top_right_cluster1': ((offset_distance, offset_distance), 'left', 'bottom'),
            'bottom_right_cluster1-NZL': ((offset_distance+1, -offset_distance+11), 'left', 'top'),
            'bottom_right_cluster1': ((offset_distance+6, -offset_distance+14), 'left', 'top'),
            'bottom_right_cluster3': ((offset_distance, -offset_distance+14), 'left', 'top')
        }
       
        return offsets.get(direction, ((offset_distance, 0), 'left', 'center'))

    def get_smart_label_positions(self, cluster_data, cluster_id, max_labels=None):
        """获取智能标签位置"""
        hide_codes = set()
        if cluster_id in self.manual_adjustments:
            for code, direction in self.manual_adjustments[cluster_id].items():
                if direction is None:
                    hide_codes.add(code)
       
        manual_positions = self.apply_manual_adjustments(cluster_data, cluster_id)
       
        if manual_positions is not None:
            manual_codes = {pos['code'] for pos in manual_positions}
            exclude_codes = manual_codes | hide_codes
            remaining_data = cluster_data[~cluster_data['Country Code'].isin(exclude_codes)]
           
            if len(remaining_data) > 0:
                coords = remaining_data[['GDP_avg_percentile', 'Energy_avg_percentile', 'Country Code']].values
                auto_positions = self._iterative_label_placement(coords)
                return manual_positions + auto_positions
            else:
                return manual_positions
       
        if hide_codes:
            cluster_data = cluster_data[~cluster_data['Country Code'].isin(hide_codes)]
       
        coords = cluster_data[['GDP_avg_percentile', 'Energy_avg_percentile', 'Country Code']].values
       
        if max_labels and len(coords) > max_labels:
            indices = self._select_representative_points(coords, max_labels)
            coords = coords[indices]
       
        positions = self._iterative_label_placement(coords)
        return positions

    def _iterative_label_placement(self, coords, max_iterations=5):
        """迭代优化标签位置"""
        n_points = len(coords)
        x_threshold = 10
        y_threshold = 6
        label_directions = ['right'] * n_points
       
        for _ in range(max_iterations):
            conflicts_resolved = 0
           
            for i in range(n_points):
                x, y, code = coords[i]
                current_direction = label_directions[i]
                current_label_pos = self._get_label_position(x, y, current_direction)
               
                has_conflict = False
                conflict_directions = set()
               
                for j in range(n_points):
                    if i == j:
                        continue
                   
                    x2, y2, _ = coords[j]
                    other_direction = label_directions[j]
                    other_label_pos = self._get_label_position(x2, y2, other_direction)
                   
                    dx_label = abs(current_label_pos[0] - other_label_pos[0])
                    dy_label = abs(current_label_pos[1] - other_label_pos[1])
                   
                    if dx_label < x_threshold and dy_label < y_threshold:
                        has_conflict = True
                        if current_label_pos[0] > other_label_pos[0]:
                            conflict_directions.add('left')
                        else:
                            conflict_directions.add('right')
                        if current_label_pos[1] > other_label_pos[1]:
                            conflict_directions.add('bottom')
                        else:
                            conflict_directions.add('top')
               
                if has_conflict:
                    best_direction = self._find_best_alternative_direction(
                        x, y, coords, label_directions, i, conflict_directions, x_threshold, y_threshold
                    )
                    if best_direction != current_direction:
                        label_directions[i] = best_direction
                        conflicts_resolved += 1
           
            if conflicts_resolved == 0:
                break
       
        positions = []
        for i, (x, y, code) in enumerate(coords):
            direction = label_directions[i]
           
            if self._still_has_conflict(x, y, direction, coords, label_directions, i, x_threshold, y_threshold):
                continue
           
            offset, ha, va = self._get_offset_and_alignment(direction)
            positions.append({
                'x': x, 'y': y, 'code': code,
                'offset': offset, 'ha': ha, 'va': va, 'use_leader_line': False
            })
       
        return positions

    def _get_label_position(self, x, y, direction):
        """根据方向计算标签位置"""
        label_distance = 8
        positions = {
            'right': (x + label_distance, y),
            'left': (x - label_distance, y),
            'top': (x, y + label_distance),
            'bottom': (x, y - label_distance)
        }
        return positions.get(direction, (x + label_distance, y))

    def _find_best_alternative_direction(self, x, y, coords, label_directions, current_idx, 
                                         conflict_directions, x_threshold, y_threshold):
        """寻找最佳替代方向"""
        available_directions = ['right', 'left', 'top', 'bottom']
        available_directions = [d for d in available_directions if d not in conflict_directions]
       
        if not available_directions:
            available_directions = ['right', 'left', 'top', 'bottom']
       
        direction_scores = {}
        for direction in available_directions:
            label_pos = self._get_label_position(x, y, direction)
            conflict_count = 0
           
            for j, (x2, y2, _) in enumerate(coords):
                if j == current_idx:
                    continue
               
                other_direction = label_directions[j]
                other_label_pos = self._get_label_position(x2, y2, other_direction)
               
                dx = abs(label_pos[0] - other_label_pos[0])
                dy = abs(label_pos[1] - other_label_pos[1])
               
                if dx < x_threshold and dy < y_threshold:
                    conflict_count += 1
           
            direction_scores[direction] = conflict_count
       
        best_direction = min(direction_scores, key=direction_scores.get)
        return best_direction

    def _still_has_conflict(self, x, y, direction, coords, label_directions, current_idx, 
                            x_threshold, y_threshold):
        """检查是否仍有冲突"""
        label_pos = self._get_label_position(x, y, direction)
       
        for j in range(len(coords)):
            if j == current_idx:
                continue
           
            x2, y2, _ = coords[j]
            other_direction = label_directions[j]
            other_label_pos = self._get_label_position(x2, y2, other_direction)
           
            dx = abs(label_pos[0] - other_label_pos[0])
            dy = abs(label_pos[1] - other_label_pos[1])
           
            if dx < x_threshold and dy < y_threshold:
                return True
       
        return False

    def _get_offset_and_alignment(self, direction):
        """根据方向返回偏移量和对齐方式"""
        offset_distance = 8
        offsets = {
            'right': ((offset_distance, 0), 'left', 'center'),
            'left': ((-offset_distance, 0), 'right', 'center'),
            'top': ((0, offset_distance), 'center', 'bottom'),
            'bottom': ((0, -offset_distance), 'center', 'top')
        }
        return offsets.get(direction, ((offset_distance, offset_distance), 'left', 'bottom'))

    def _select_representative_points(self, coords, n_select):
        """选择有代表性的点"""
        if len(coords) <= n_select:
            return list(range(len(coords)))
       
        x_vals = coords[:, 0]
        y_vals = coords[:, 1]
       
        selected = []
        selected.append(np.argmin(x_vals + y_vals))
        selected.append(np.argmax(x_vals + y_vals))
        selected.append(np.argmin(x_vals - y_vals))
        selected.append(np.argmax(x_vals - y_vals))
       
        remaining = list(set(range(len(coords))) - set(selected))
        if remaining:
            n_more = min(n_select - len(selected), len(remaining))
            selected.extend(np.random.choice(remaining, n_more, replace=False))
       
        return list(set(selected))[:n_select]

    def plot_single_cluster_in_subplot(self, ax, all_data, cluster_id, all_data_stats, color, label):
        """在子图上绘制单个簇"""
        cluster_data = all_data[all_data['Cluster'] == cluster_id]
        other_data = all_data[all_data['Cluster'] != cluster_id]
       
        ax.axis('scaled')
        n_countries = len(cluster_data)
       
        ax.axvline(x=50, color='darkgray', linestyle='-', linewidth=1.5, alpha=0.7, zorder=0)
        ax.axhline(y=50, color='darkgray', linestyle='-', linewidth=1.5, alpha=0.7, zorder=0)
       
        if len(other_data) > 0:
            ax.scatter(
                other_data['GDP_avg_percentile'], other_data['Energy_avg_percentile'],
                c='lightgray', marker='o', s=120, alpha=0.3,
                edgecolors='gray', linewidths=0.5, zorder=1
            )
       
        ax.scatter(
            cluster_data['GDP_avg_percentile'], cluster_data['Energy_avg_percentile'],
            c=[color], marker='o', s=self.marker_size, alpha=0.8,
            edgecolors='white', linewidths=2, zorder=3
        )
       
        max_labels = 15 if n_countries > 15 else None
        label_positions = self.get_smart_label_positions(cluster_data, int(cluster_id), max_labels)
       
        for pos in label_positions:
            if pos.get('use_leader_line', False):
                country_code = pos['code']
                if int(cluster_id) in self.leader_line_config and country_code in self.leader_line_config[int(cluster_id)]:
                    config = self.leader_line_config[int(cluster_id)][country_code]
                    start_position = config['start_position']
                    start_offset_x = config['start_offset_x']
                    start_offset_y = config['start_offset_y']
                    segments = config['segments']
                   
                    points = self.calculate_leader_line_path(
                        pos['x'], pos['y'], start_position, start_offset_x, start_offset_y, segments
                    )
                    self.draw_leader_line(ax, points)
                   
                    end_point = points[-1]
                    label_x, label_y = end_point
                    gap = 2
                   
                    if segments:
                        last_dir = segments[-1]['direction']
                        if last_dir == 'left':
                            label_x -= gap
                        elif last_dir == 'right':
                            label_x += gap
                        elif last_dir == 'up':
                            label_y += gap
                        elif last_dir == 'down':
                            label_y -= gap
                   
                    ax.text(label_x, label_y, country_code, ha='right', va='center',
                            alpha=0.9, color='black', zorder=4)
            else:
                ax.annotate(
                    pos['code'], xy=(pos['x'], pos['y']), xytext=pos['offset'],
                    textcoords='offset points', alpha=0.9, color='black',
                    ha=pos['ha'], va=pos['va'], zorder=4
                )
       
        ax.set_xlim(-2, 102)
        ax.set_ylim(-2, 102)
       
        tick_positions = [0, 25, 50, 75, 100]
        ax.set_xticks(tick_positions)
        ax.set_yticks(tick_positions)
       
        for spine in ['top', 'right']:
            ax.spines[spine].set_visible(False)
       
        ax.set_title(f'Group {int(cluster_id)} ({n_countries} countries)', pad=20)
        
        # 在左上角添加标签
        ax.text(0.02, 0.98, label, transform=ax.transAxes, 
                verticalalignment='top', horizontalalignment='left')

    def create_cluster_bar_plot_in_ax(self, ax, merged_df, label, legend_xy=(1.05, 0.5)):
        """在指定的ax上创建簇的柱状图"""
        if len(merged_df) == 0:
            return
       
        cluster_ids = sorted(merged_df['Cluster'].unique())
        n_clusters = len(cluster_ids)
       
        cluster_stats = {}
        for cluster_id in cluster_ids:
            cluster_data = merged_df[merged_df['Cluster'] == cluster_id]
            cluster_stats[cluster_id] = {
                'gdp_mean': cluster_data['GDP_avg_percentile'].mean(),
                'gdp_sem': cluster_data['GDP_avg_percentile'].sem(),
                'energy_mean': cluster_data['Energy_avg_percentile'].mean(),
                'energy_sem': cluster_data['Energy_avg_percentile'].sem(),
                'gdp_values': cluster_data['GDP_avg_percentile'].values,
                'energy_values': cluster_data['Energy_avg_percentile'].values
            }
       
        bar_width = 0.35
        within_cluster_gap = 0.06
        between_cluster_gap = 1.2
       
        x_positions = []
        current_x = 0
        for idx in range(n_clusters):
            x_positions.append(current_x)
            current_x += between_cluster_gap
       
        for idx, cluster_id in enumerate(cluster_ids):
            stats = cluster_stats[cluster_id]
            x_base = x_positions[idx]
           
            gdp_x_center = x_base - bar_width/2 - within_cluster_gap/2
            ax.bar(gdp_x_center, stats['gdp_mean'], bar_width, 
                   color=self.bar_colors['gdp_bar'], alpha=1.0, 
                   edgecolor=self.bar_colors['gdp_bar_edge'], linewidth=1.5)
            ax.set_ylabel('Average Percentile (%)', labelpad=15)
            ax.errorbar(gdp_x_center, stats['gdp_mean'], 
                        yerr=[[0], [stats['gdp_sem']]], fmt='none', 
                        ecolor=self.bar_colors['gdp_bar_edge'], 
                        capsize=5, capthick=1.5, elinewidth=1.5, zorder=5)
           
            energy_x_center = x_base + bar_width/2 + within_cluster_gap/2
            ax.bar(energy_x_center, stats['energy_mean'], bar_width,
                   color=self.bar_colors['energy_bar'], alpha=1.0,
                   edgecolor=self.bar_colors['energy_bar_edge'], linewidth=1.5)
           
            ax.errorbar(energy_x_center, stats['energy_mean'], 
                        yerr=[[0], [stats['energy_sem']]], fmt='none',
                        ecolor=self.bar_colors['energy_bar_edge'], 
                        capsize=5, capthick=1.5, elinewidth=1.5, zorder=5)
           
            n_gdp = len(stats['gdp_values'])
            gdp_x = np.random.normal(gdp_x_center, bar_width/6, n_gdp)
            ax.scatter(gdp_x, stats['gdp_values'], 
                       color=self.bar_colors['gdp_scatter'], s=80, alpha=0.8,
                       edgecolors=self.bar_colors['gdp_scatter_edge'], linewidths=1, zorder=3)
           
            n_energy = len(stats['energy_values'])
            energy_x = np.random.normal(energy_x_center, bar_width/6, n_energy)
            ax.scatter(energy_x, stats['energy_values'], 
                       color=self.bar_colors['energy_scatter'], s=80, alpha=0.8,
                       edgecolors=self.bar_colors['energy_scatter_edge'], linewidths=1, zorder=3)
       
        # X 轴刻度
        ax.set_xticks(x_positions)
        ax.set_xticklabels([f'Cluster {int(cid)}' for cid in cluster_ids], fontsize=19.5)
        ax.set_ylim(0, 105)
        ax.grid(False)
       
        for spine in ['top', 'right']:
            ax.spines[spine].set_visible(False)
       
        legend_elements = [
            Patch(facecolor=self.bar_colors['gdp_bar'], edgecolor=self.bar_colors['gdp_bar_edge'], 
                  label='Per Capita GDP'),
            Patch(facecolor=self.bar_colors['energy_bar'], edgecolor=self.bar_colors['energy_bar_edge'], 
                  label='Net Energy Imports')
        ]
        ax.legend(
            handles=legend_elements,
            loc='center left',
            bbox_to_anchor=legend_xy,
            bbox_transform=ax.transAxes,
            ncol=1,
            frameon=False,
            fontsize=mpl.rcParams['xtick.labelsize']
        )
        
        # 在右上角添加标签，稍微往右移动
        ax.text(0.96, 0.98, label, transform=ax.transAxes, 
                verticalalignment='top', horizontalalignment='left')

    def create_subplots_for_all_clusters(self, merged_df, k_value,
                                         legend_xy=(0.4, 0.9),
                                         box_xy=(-0.03, -0.05),
                                         box_width=1.06,
                                         box_height=1.12,
                                         bar_pos_offset=(0.0, 0.0),
                                         bar_width_shrink=0.04):
        """创建3行3列的子图，最后一行为柱状图和Y标签

        Parameters
        ----------
        legend_xy : 图例在最后一张图中的位置 (ax.transAxes)
        box_xy : 矩形框左下角 (ax_bar.transAxes)
        box_width / box_height : 矩形框宽高 (ax_bar.transAxes)
        bar_pos_offset : 最后一个图整体平移 (dx, dy)，figure 坐标
        bar_width_shrink : 水平方向缩短多少（figure 坐标）
        """
        if len(merged_df) == 0:
            return None
       
        all_data_stats = {
            'gdp_min': merged_df['GDP_avg_percentile'].min(),
            'gdp_max': merged_df['GDP_avg_percentile'].max(),
            'energy_min': merged_df['Energy_avg_percentile'].min(),
            'energy_max': merged_df['Energy_avg_percentile'].max()
        }
       
        cluster_ids = sorted(merged_df['Cluster'].unique())
        n_clusters = len(cluster_ids)
       
        fig = plt.figure(figsize=(18, 18))
       
        gs = fig.add_gridspec(3, 3, hspace=0.29, wspace=0.29,
                              left=0.08, right=0.98, bottom=0.08, top=0.98)
       
        # 标签列表
        labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)']
        
        for idx, cluster_id in enumerate(cluster_ids):
            ax = fig.add_subplot(gs[idx // 3, idx % 3])
            color = self.nature_colors[int(cluster_id) % len(self.nature_colors)]
            self.plot_single_cluster_in_subplot(ax, merged_df, cluster_id, all_data_stats, color, labels[idx])
       
        ax_ylabel = fig.add_subplot(gs[2, 0])
        ax_ylabel.axis('off')
       
        ax_bar = fig.add_subplot(gs[2, 1:3])
        self.create_cluster_bar_plot_in_ax(ax_bar, merged_df, labels[7], legend_xy=legend_xy)

        # 调整最后一张图的位置和宽度
        pos = ax_bar.get_position()
        dx, dy = bar_pos_offset
        new_x0 = pos.x0 + dx
        new_y0 = pos.y0 + dy
        new_width = pos.width - bar_width_shrink
        ax_bar.set_position([new_x0, new_y0, new_width, pos.height])

        # 在最后一个图外面画黑色矩形框
        rect = Rectangle(
            box_xy,
            box_width,
            box_height,
            transform=ax_bar.transAxes,
            fill=False,
            linewidth=1.5,
            edgecolor='black',
            clip_on=False,
            zorder=10
        )
        ax_bar.add_patch(rect)
       
        fig.text(0.008, 0.5, 'Average Net Energy Imports Quantile (%)', 
                 ha='center', va='center', rotation='vertical', 
                 fontsize=30)
       
        fig.text(0.5, 0.015, 'Average Per Capita GDP Quantile (%)', 
                 ha='center', va='center', fontsize=30)
       
        output_path = self.output_dir / f'K7_All_Clusters_Subplots_3x3_with_Bar.png'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
       
        return output_path

    def save_merged_data(self, merged_df, k_value):
        """保存合并后的数据"""
        output_path = self.output_dir / f'K{k_value}_GDP_Energy_Merged_with_Clusters.csv'
        merged_df.to_csv(output_path, index=False, encoding='utf-8-sig')
        return output_path

    def generate_for_k7(self):
        """为K=7生成3行3列的子图"""
        gdp_df, energy_df, cluster_df = self.load_data()
        k_value = 7
        merged_df = self.merge_data(gdp_df, energy_df, cluster_df, k_value)
       
        if len(merged_df) == 0:
            return
       
        self.save_merged_data(merged_df, k_value)
        self.create_subplots_for_all_clusters(
            merged_df, k_value,
            legend_xy=(0.42, 0.93),   # 图例位置
            box_xy=(-0.15, -0.13),    # 矩形框左下角
            box_width=1.18,           # 矩形框宽
            box_height=1.22,          # 矩形框高
            bar_pos_offset=(0.04, 0.0),  # 最后一个图整体平移 (dx, dy)
            bar_width_shrink=0.04      # 减少宽度，约一个 Y 标签宽度
        )


def main():
    generator = ScatterPlotGeneratorByClusters()
    generator.generate_for_k7()


if __name__ == "__main__":
    main()



A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.6 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "d:\anaconda3\Lib\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "d:\anaconda3\Lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
    app.start()
  File "d:\anaconda3\Lib\site-packages\ipykernel\kernelapp.py", line 739, in start
    self.io_loop.start()
  File "d:\anaconda3\Lib\site-packages\tornado\platform\asyncio.py", line 211, in 

AttributeError: _ARRAY_API not found