# 细胞级特征注意力整合

In [None]:
import os
import numpy as np
from PIL import Image
import glob
import pandas as pd
from tqdm import tqdm

import openslide
from openslide.deepzoom import DeepZoomGenerator

class AttentionQuerier:
    """
    一个工具类，用于根据patch路径查询其在不确定后缀的attention map中的数值。
    """
    def __init__(self, svs_root_dir, attn_root_dir, tile_size=224, overlap=0):
        """
        初始化 AttentionQuerier。
        Args:
            svs_root_dir (str): 存放 .svs 文件的根目录。
            attn_root_dir (str): 存放 attention map 子文件夹的根目录。
            tile_size (int): 切片时使用的tile大小。
            overlap (int): 切片时使用的重叠大小。
        """
        self.svs_root_dir = svs_root_dir
        self.attn_root_dir = attn_root_dir
        self.tile_size = tile_size
        self.overlap = overlap
        
        # 预检查路径是否存在
        if not os.path.isdir(svs_root_dir):
            raise FileNotFoundError(f"SVS root directory not found: {svs_root_dir}")
        if not os.path.isdir(attn_root_dir):
            raise FileNotFoundError(f"Attention map root directory not found: {attn_root_dir}")

    def _find_attention_maps(self, svs_basename):
        """根据SVS基础文件名，搜索所有可能的attention map路径。"""
        search_pattern = os.path.join(self.attn_root_dir, f"{svs_basename}_*/attn.png")
        found_paths = glob.glob(search_pattern)
        
        results = {}
        for path in found_paths:
            # 从路径中提取后缀, e.g., 'BCL2'
            # 路径: .../21-11167_BCL2/attn.png -> 文件夹: 21-11167_BCL2
            folder_name = os.path.basename(os.path.dirname(path))
            # 从文件夹名中提取后缀
            suffix = folder_name.replace(f"{svs_basename}_", "")
            results[suffix] = path
            
        return results

    def query_patch_attention(self, patch_path, method='mean'):
        """
        主函数：为单个patch查询其在所有可用attention map中的数值。
        """
        try:
            # 1. 解析Patch路径，获取SVS基础文件名
            svs_basename = os.path.basename(os.path.dirname(patch_path))
            
            # 2. 搜索所有相关的Attention Map
            found_maps = self._find_attention_maps(svs_basename)
            
            if not found_maps:
                print(f"警告: 未找到与 '{svs_basename}' 相关的attention map。")
                return {}

            # 3. 获取WSI信息 (只需要获取一次)
            svs_path = os.path.join(self.svs_root_dir, f"{svs_basename}.svs")
            if not os.path.exists(svs_path):
                raise FileNotFoundError(f"SVS file not found at: {svs_path}")
            
            slide = openslide.open_slide(svs_path)
            dz = DeepZoomGenerator(slide, self.tile_size, self.overlap)
            wsi_dims = slide.dimensions
            highest_zoom_level = dz.level_count - 1

            # 4. 解码Patch文件名，获取WSI坐标
            patch_filename_only = os.path.basename(patch_path)
            parts = os.path.splitext(patch_filename_only)[0].split('_')
            level, col, row = int(parts[0]), int(parts[1]), int(parts[2])
            
            downsample_factor = 2 ** (highest_zoom_level - level)
            x_start_wsi = col * self.tile_size * downsample_factor
            y_start_wsi = row * self.tile_size * downsample_factor
            x_end_wsi = x_start_wsi + (self.tile_size * downsample_factor)
            y_end_wsi = y_start_wsi + (self.tile_size * downsample_factor)

            # 5. 循环处理每个找到的attention map
            final_scores = {}
            for suffix, attn_path in found_maps.items():
                print(f"-> 正在从 '{suffix}' 注意力图获取数值...")
                
                # 加载attention map为灰度图
                attn_map_img = Image.open(attn_path)
                attn_map_array = np.array(attn_map_img, dtype=np.float32)
                attn_map_array = attn_map_array / 255.0

                # 计算缩放比例
                attn_height, attn_width = attn_map_array.shape
                scale_factor_width = wsi_dims[0] / attn_width
                scale_factor_height = wsi_dims[1] / attn_height

                # 映射坐标
                x_start_attn = int(x_start_wsi / scale_factor_width)
                y_start_attn = int(y_start_wsi / scale_factor_height)
                x_end_attn = int(x_end_wsi / scale_factor_width)
                y_end_attn = int(y_end_wsi / scale_factor_height)

                # 提取区域并计算分数
                attention_region = attn_map_array[y_start_attn:y_end_attn, x_start_attn:x_end_attn]
                
                if attention_region.size == 0:
                    score = 0.0
                elif method == 'mean':
                    score = np.mean(attention_region)
                elif method == 'max':
                    score = np.max(attention_region)
                else:
                    raise ValueError("Method must be 'mean' or 'max'")
                
                final_scores[suffix] = score
                
            return final_scores

        except Exception as e:
            print(f"处理 {patch_path} 时发生错误: {e}")
            return None

# 防止因图像过大而报错
Image.MAX_IMAGE_PIXELS = None

class AttentionQuerier:
    """
    一个工具类，用于根据patch路径查询其在不确定后缀的attention map中的数值。
    """
    def __init__(self, svs_root_dir, attn_root_dir, tile_size=224, overlap=0):
        self.svs_root_dir = svs_root_dir
        self.attn_root_dir = attn_root_dir
        self.tile_size = tile_size
        self.overlap = overlap
        
        if not os.path.isdir(svs_root_dir):
            raise FileNotFoundError(f"SVS root directory not found: {svs_root_dir}")
        if not os.path.isdir(attn_root_dir):
            raise FileNotFoundError(f"Attention map root directory not found: {attn_root_dir}")

    def _find_attention_maps(self, svs_basename):
        search_pattern = os.path.join(self.attn_root_dir, f"{svs_basename}_*/attn.png")
        found_paths = glob.glob(search_pattern)
        results = {}
        for path in found_paths:
            folder_name = os.path.basename(os.path.dirname(path))
            suffix = folder_name.replace(f"{svs_basename}_", "")
            results[suffix] = path
        return results

    def query_patch_attention(self, patch_path, method='mean'):
        try:
            svs_basename = os.path.basename(os.path.dirname(patch_path))
            found_maps = self._find_attention_maps(svs_basename)
            
            if not found_maps:
                return {}

            svs_path = os.path.join(self.svs_root_dir, f"{svs_basename}.svs")
            if not os.path.exists(svs_path):
                # 如果找不到SVS，也无法继续
                print(f"警告: 找不到SVS文件 {svs_path}，跳过patch {patch_path}")
                return {}
            
            slide = openslide.open_slide(svs_path)
            dz = DeepZoomGenerator(slide, self.tile_size, self.overlap)
            wsi_dims = slide.dimensions
            highest_zoom_level = dz.level_count - 1

            patch_filename_only = os.path.basename(patch_path)
            parts = os.path.splitext(patch_filename_only)[0].split('_')
            level, col, row = int(parts[0]), int(parts[1]), int(parts[2])
            
            downsample_factor = 2 ** (highest_zoom_level - level)
            x_start_wsi = col * self.tile_size * downsample_factor
            y_start_wsi = row * self.tile_size * downsample_factor
            x_end_wsi = x_start_wsi + (self.tile_size * downsample_factor)
            y_end_wsi = y_start_wsi + (self.tile_size * downsample_factor)

            final_scores = {}
            for suffix, attn_path in found_maps.items():
                attn_map_img = Image.open(attn_path)
                attn_map_array = np.array(attn_map_img, dtype=np.float32) / 255.0

                attn_height, attn_width = attn_map_array.shape
                scale_factor_width = wsi_dims[0] / attn_width
                scale_factor_height = wsi_dims[1] / attn_height

                x_start_attn = int(x_start_wsi / scale_factor_width)
                y_start_attn = int(y_start_wsi / scale_factor_height)
                x_end_attn = int(x_end_wsi / scale_factor_width)
                y_end_attn = int(y_end_wsi / scale_factor_height)

                attention_region = attn_map_array[y_start_attn:y_end_attn, x_start_attn:x_end_attn]
                
                if attention_region.size == 0:
                    score = np.nan # 使用NaN表示无效值
                elif method == 'mean':
                    score = np.mean(attention_region)
                elif method == 'max':
                    score = np.max(attention_region)
                else:
                    raise ValueError("Method must be 'mean' or 'max'")
                
                final_scores[suffix] = score
                
            return final_scores
        except Exception as e:
            print(f"处理 {patch_path} 时发生严重错误，将返回空结果: {e}")
            return {}


def process_csv_with_attention(input_csv, output_csv, patch_column_name, **kwargs):
    """
    读取CSV文件，为每一行计算attention分数，并将分数作为新列添加到文件末尾。

    Args:
        input_csv (str): 输入的CSV文件路径。
        output_csv (str): 输出的CSV文件路径。
        patch_column_name (str): 包含patch路径的列的名称。
        **kwargs: 传递给 AttentionQuerier 的参数 (svs_root_dir, attn_root_dir, ...)。
    """
    # 1. 读取输入的CSV文件
    print(f"正在读取输入文件: {input_csv}")
    try:
        df = pd.read_csv(input_csv)
    except FileNotFoundError:
        print(f"错误: 输入文件未找到 -> {input_csv}")
        return

    if patch_column_name not in df.columns:
        print(f"错误: 在CSV文件中找不到指定的列 '{patch_column_name}'。")
        print(f"可用的列有: {list(df.columns)}")
        return

    # 2. 初始化 AttentionQuerier
    print("正在初始化Attention Querier...")
    querier = AttentionQuerier(**kwargs)

    # 3. 循环处理每一行，获取attention分数
    results_list = []
    patch_paths = df[patch_column_name]
    
    print(f"开始处理 {len(df)} 个patch...")
    for patch_path in tqdm(patch_paths, desc="查询Attention分数"):
        scores = querier.query_patch_attention(patch_path, method='mean')
        results_list.append(scores)

    # 4. 将结果列表转换为Pandas DataFrame
    # Pandas 会自动处理列的创建和缺失值(NaN)的填充
    print("正在整合注意力分数...")
    df_attention = pd.DataFrame(results_list)

    # 5. 为新列添加前缀，使其更具可读性
    df_attention = df_attention.add_prefix('attention_')

    # 6. 将原始DataFrame与新的attention分数DataFrame合并
    print("正在合并原始数据和新分数...")
    # 使用 .join() 可以避免因索引问题导致的错位
    df_final = df.join(df_attention)

    # 7. 保存到新的CSV文件
    print(f"正在将最终结果保存到: {output_csv}")
    df_final.to_csv(output_csv, index=False)
    
    print("\n处理完成！")
    print(f"结果已保存在: {output_csv}")
    print("\n新文件的前5行预览:")
    print(df_final.head())


if __name__ == '__main__':

    SVS_ROOT_DIR = '/data/ceiling/data/DLBCL/WSI/TCH'
    ATTN_ROOT_DIR = '/data/ceiling/workspace/DLBLC/visualization/AMIL_MOE_Sampling/TCH'
    
    INPUT_CSV_PATH = '/data/cjy/nuclei/datasets/DLBCL/TCH/features/cell_athena_sna.csv' 
    OUTPUT_CSV_PATH = '/data/cjy/nuclei/datasets/DLBCL/TCH/features/cell_athena_sna_attn.csv'

    PATCH_COLUMN_NAME = 'Unnamed: 0' 

    process_csv_with_attention(
        input_csv=INPUT_CSV_PATH,
        output_csv=OUTPUT_CSV_PATH,
        patch_column_name=PATCH_COLUMN_NAME,
        svs_root_dir=SVS_ROOT_DIR,
        attn_root_dir=ATTN_ROOT_DIR,
        tile_size=224
    )

# 特征与Target 相关性分析

## 相关性文件生成

In [None]:
import pandas as pd
import numpy as np
import os

def run_correlation_analysis(csv_path, attention_columns, top_n=20, output_dir='correlation_reports'):
    print(f"步骤 1: 正在加载数据从 '{csv_path}'...")
    df = pd.read_csv(csv_path)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"已创建输出目录: '{output_dir}'")
        
    feature_columns = df.select_dtypes(include=np.number).columns.tolist()
    for attn_col in attention_columns:
        if attn_col in feature_columns:
            feature_columns.remove(attn_col)
            
    print(f"自动检测到 {len(feature_columns)} 个数值型特征列进行分析。")
    
    for attn_col in attention_columns:
        print("\n" + "="*80)
        print(f"相关性分析报告: {attn_col.upper()}")
        print("="*80)
    
        print(f"正在计算 '{attn_col}' 与所有特征的相关性...")
        
        correlations = df[feature_columns].corrwith(df[attn_col])
        correlations = correlations.dropna().sort_values(ascending=False)
        
        # print(f"\n[+] 与 '{attn_col}' 【正相关性最强】的Top {top_n}个特征:")
        # print("    (这些特征值越高，Attention值也倾向于越高)")
        # print(correlations.head(top_n).to_string())
        
        # print(f"\n[-] 与 '{attn_col}' 【负相关性最强】的Top {top_n}个特征:")
        # print("    (这些特征值越高，Attention值反而倾向于越低)")

        # 对于负相关，我们需要对升序排列的结果取头部
        print(correlations.sort_values(ascending=True).head(top_n).to_string())
        
        output_filename = os.path.join(output_dir, f'correlation_report_{attn_col}.csv')
        correlations.to_frame(name='pearson_correlation').to_csv(output_filename)
        print(f"\n>>> 完整的相关性分析报告已保存到: '{output_filename}'")


    print("\n" + "="*80)
    print("所有分析任务完成！")

if __name__ == '__main__':
    INPUT_CSV_WITH_ATTN = '/data/cjy/nuclei/datasets/DLBCL/TCH/features/cell_athena_sna_attn.csv'
    ATTENTION_COLUMNS_TO_ANALYZE = ['attention_BCL2', 'attention_BCL6', 'attention_MYC']
    OUTPUT_DIRECTORY = 'attention_report'

    run_correlation_analysis(
        csv_path=INPUT_CSV_WITH_ATTN,
        attention_columns=ATTENTION_COLUMNS_TO_ANALYZE,
        output_dir=OUTPUT_DIRECTORY,
        top_n=20 # 显示前20个结果
    )

## TOP20 皮尔森相关系数绘图

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

# 设置全局字体和字号
plt.rcParams["font.family"] = "Arial"
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams.update({
    'font.size': 14,          # 基础字号
    'axes.labelsize': 16,     # 坐标轴标题
    'xtick.labelsize': 14,    # 横轴刻度
    'ytick.labelsize': 14,    # 纵轴刻度
    'legend.fontsize': 14,    # 图例
    'axes.titlesize': 18      # 标题字号
})

def visualize_focused_correlation_heatmap(report_dir, attention_sources, top_n=20):
    print("步骤 1: 正在加载并整合所有分析报告...")
    all_corrs = []
    
    for source in attention_sources:
        file_path = os.path.join(report_dir, f'correlation_report_{source}.csv')
        try:
            df = pd.read_csv(file_path, index_col=0)
            df.columns = [source]
            all_corrs.append(df)
        except FileNotFoundError:
            print(f"警告: 未找到报告文件 '{file_path}'，已跳过。")
            continue

    if not all_corrs:
        print("错误: 未加载任何有效的报告文件，无法继续。")
        return

    corr_df = pd.concat(all_corrs, axis=1)
    print(f"成功加载并整合了 {len(corr_df)} 个特征的相关性数据。")

    print("\n步骤 2: 正在根据关键词筛选'细胞级'特征...")
    CELLULAR_FEATURES_KEYWORDS = [
        'number of', 'mean of their', 'std of their',
        'skew of their', 'kurtosis of their', 'Infiltration of'
    ]
    focused_feature_index = [
        feature_name for feature_name in corr_df.index
        if any(keyword in feature_name for keyword in CELLULAR_FEATURES_KEYWORDS)
    ]
    corr_df_focused = corr_df.loc[focused_feature_index]
    
    print(f"已筛选出 {len(corr_df_focused)} 个细胞级特征进行下一步分析。")
    if len(corr_df_focused) == 0:
        print("错误: 未找到任何匹配的细胞级特征，请检查关键词或报告中的列名。")
        return

    print(f"\n步骤 3: 正在从已筛选的特征中，找出Top {top_n}个最重要的...")
    corr_df_focused['importance_score'] = corr_df_focused.abs().mean(axis=1)
    top_features_df = corr_df_focused.sort_values(by='importance_score', ascending=False).head(top_n)
    plot_data = top_features_df.drop(columns=['importance_score'])

    simplified_columns = {col: col.replace('attention_', '') for col in plot_data.columns}
    plot_data = plot_data.rename(columns=simplified_columns)

    print("\n步骤 4: 正在生成聚焦于细胞级指标的热力图...")
    plt.figure(figsize=(10, 12))
    sns.heatmap(
        plot_data, 
        annot=True, 
        cmap='coolwarm', 
        vmin=-1, vmax=1, center=0,
        fmt='.2f',
        linewidths=.5,
        cbar_kws={'label': 'Pearson Correlation Coefficient'}
    )
    
    plt.title(f'Top {top_n} Cellular Features Correlated with Attention', pad=10, weight='bold')
    plt.ylabel('Direct Cellular & Interaction Features')
    plt.xlabel('Attention Source')
    
    plt.tight_layout()
    plt.savefig('/data114_4/chenjy/DLBCL/model/SI-MIL/Top_Cellular_Features_Correlation.pdf', dpi=300)
    print("\n步骤 5: 图表已保存并显示。")
    plt.show()

if __name__ == '__main__':
    REPORT_DIRECTORY = 'attention_report'
    ATTENTION_SOURCES_TO_COMPARE = ['attention_BCL2', 'attention_BCL6', 'attention_MYC']
    
    visualize_focused_correlation_heatmap(
        report_dir=REPORT_DIRECTORY,
        attention_sources=ATTENTION_SOURCES_TO_COMPARE,
        top_n=20
    )

# 不同Target下 细胞特征重要性分析
## 相关文件生成

In [None]:
import pandas as pd
import numpy as np
import os

def run_and_save_individual_analysis_with_cohens_d(csv_path, attention_columns, quantile_threshold=0.9, top_n=20, output_dir='analysis_reports_cohens_d'):
    """
    对一个CSV文件中的多个Attention列进行批量分组对比分析。
    """
    # --- 1. 加载数据并设置输出目录 ---
    print(f"步骤 1: 正在加载数据从 '{csv_path}'...")
    try:
        df = pd.read_csv(csv_path)
    except FileNotFoundError:
        print(f"错误: 文件未找到 -> {csv_path}")
        return

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"已创建输出目录: '{output_dir}'")
        
    all_feature_columns = [col for col in df.columns if col not in attention_columns]
    
    # --- 2. 循环处理每一个Attention列 ---
    for attn_col in attention_columns:
        if attn_col not in df.columns:
            print(f"\n警告: 在CSV中找不到列 '{attn_col}'，已跳过。")
            continue

        print("\n" + "="*80)
        print(f"               分析对象: {attn_col.upper()}")
        print("="*80)
        
        df_cleaned = df[[attn_col] + all_feature_columns].dropna()
        
        # --- 分组 ---
        threshold_value = df_cleaned[attn_col].quantile(quantile_threshold)
        high_attention_mask = df_cleaned[attn_col] >= threshold_value
        
        high_group_df = df_cleaned[high_attention_mask]
        baseline_group_df = df_cleaned[~high_attention_mask]
        
        n_high, n_base = len(high_group_df), len(baseline_group_df)
        print(f"定义 '高分组' ({n_high}个样本) vs. '基线组' ({n_base}个样本)。")

        if n_high < 2 or n_base < 2:
            print("警告: 某个分组的样本太少，无法进行有效分析，已跳过。")
            continue

        # --- 计算统计量 ---
        mean_high = high_group_df[all_feature_columns].mean(numeric_only=True)
        mean_base = baseline_group_df[all_feature_columns].mean(numeric_only=True)
        std_high = high_group_df[all_feature_columns].std(numeric_only=True)
        std_base = baseline_group_df[all_feature_columns].std(numeric_only=True)

        # 计算合并标准差 (Pooled Standard Deviation)
        pooled_std = np.sqrt( ((n_high - 1) * std_high**2 + (n_base - 1) * std_base**2) / (n_high + n_base - 2) )
        
        # 计算科恩d值
        cohens_d = (mean_high - mean_base) / (pooled_std + 1e-8) # 加一个极小值避免除以零

        # --- 创建结果DataFrame并排序 ---
        result_df = pd.DataFrame({
            'cohens_d': cohens_d,
            'mean_in_high_group': mean_high,
            'mean_in_baseline_group': mean_base,
        })
        
        result_df['abs_cohens_d'] = abs(result_df['cohens_d'])
        result_df = result_df.sort_values(by='abs_cohens_d', ascending=False)
        
        # --- 打印和保存 ---
        print(f"\n[!] 与 '{attn_col}' 相关的【效应量最大】的Top {top_n}个特征 (基于科恩d值):")
        print(result_df.drop(columns=['abs_cohens_d']).head(top_n).to_string())
        
        output_filename = os.path.join(output_dir, f'analysis_report_{attn_col}.csv')
        result_df.to_csv(output_filename)
        print(f"\n>>> 详细分析报告已成功保存到: '{output_filename}'")

    print("\n" + "="*80)
    print("所有分析任务完成！")


if __name__ == '__main__':
    INPUT_CSV_WITH_ATTN = '/data/cjy/nuclei/datasets/DLBCL/TCH/features/cell_athena_sna_attn.csv'
    ATTENTION_COLUMNS_TO_ANALYZE = ['attention_BCL2', 'attention_BCL6', 'attention_MYC']
    OUTPUT_DIRECTORY = '/data/cjy/nuclei/models/SI-MIL/attention_report_cohens_d/'

    run_and_save_individual_analysis_with_cohens_d(
        csv_path=INPUT_CSV_WITH_ATTN,
        attention_columns=ATTENTION_COLUMNS_TO_ANALYZE,
        output_dir=OUTPUT_DIRECTORY,
        quantile_threshold=0.9,
        top_n=20
    )

## TOP15 重要特征绘图

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

# --- 全局设置字体和字号 ---
try:
    plt.rcParams['font.family'] = 'Arial'
    plt.rcParams['axes.unicode_minus'] = False 
    print("✅ 全局字体已成功设置为 Arial。")
except Exception as e:
    print(f"❌ 设置全局字体为 Arial 失败。错误: {e}")
    plt.rcParams['font.family'] = 'sans-serif'

# ---- 全局字号 ----
plt.rcParams.update({
    'font.size': 14,          # 基础字号
    'axes.labelsize': 16,     # 坐标轴标题
    'xtick.labelsize': 14,    # 横轴刻度
    'ytick.labelsize': 14,    # 纵轴刻度
    'legend.fontsize': 14     # 图例
})


def show_focused_cohens_d_plot(report_dir, attention_sources, top_n=20, color_map=None):
    """
    读取基于科恩d值的分析报告，筛选细胞级特征，并创建分组柱状图。
    此版本使用全局字体与字号设置，代码更简洁。
    """
    # --- 1. 加载并整合所有报告数据 ---
    print("步骤 1: 正在加载基于科恩d值的分析报告...")
    all_metrics = []
    
    for source in attention_sources:
        file_path = os.path.join(report_dir, f'analysis_report_{source}.csv')
        try:
            df = pd.read_csv(file_path, index_col=0)
            all_metrics.append(df[['cohens_d']].rename(columns={'cohens_d': source}))
        except FileNotFoundError:
            print(f"警告: 未找到报告文件 '{file_path}'，已跳过。")
            continue
            
    if not all_metrics:
        print("错误: 未加载任何有效的报告文件，无法继续。")
        return

    metric_df = pd.concat(all_metrics, axis=1)

    # --- 2. 筛选细胞级特征 ---
    print("\n步骤 2: 正在筛选'细胞级'特征...")
    CELLULAR_FEATURES_KEYWORDS = [
        'number of', 'mean of their', 'std of their', 
        'skew of their', 'kurtosis of their', 'Infiltration of'
    ]
    focused_feature_index = [
        name for name in metric_df.index if any(kw in name for kw in CELLULAR_FEATURES_KEYWORDS)
    ]
    metric_df_focused = metric_df.loc[focused_feature_index]
    
    # --- 3. 找出最重要的Top-N ---
    print(f"\n步骤 3: 正在从已筛选的特征中，找出Top {top_n}个最重要的...")
    metric_df_focused['importance_score'] = metric_df_focused.abs().mean(axis=1)
    top_features_df = metric_df_focused.sort_values(by='importance_score', ascending=False).head(top_n)
    plot_data = top_features_df.drop(columns=['importance_score']).iloc[::-1]
    
    simplified_columns = {col: col.replace('attention_', '') for col in plot_data.columns}
    plot_data = plot_data.rename(columns=simplified_columns)

    # --- 4. 创建分组柱状图 ---
    print("\n步骤 4: 正在生成可视化图表...")
    n_sources = len(plot_data.columns)
    bar_width = 0.8 / n_sources
    index = np.arange(len(plot_data))
    
    fig, ax = plt.subplots(figsize=(10, 9))

    for i, source_name in enumerate(plot_data.columns):
        bar_positions = index + (i - (n_sources - 1) / 2) * bar_width
        values = plot_data[source_name]
        bar_color = color_map.get(source_name) if color_map else None
        ax.barh(bar_positions, values, height=bar_width, label=source_name, alpha=0.85, color=bar_color)

    # --- 5. 美化图表 ---
    ax.set_xlabel("Standardized Mean Difference (Cohen's d)")
    ax.set_ylabel('Direct Cellular & Interaction Features')
    ax.set_title(f'Top {top_n} Cellular Features with Highest Effect Size', pad=10, weight='bold')

    ax.set_yticks(index)
    ax.set_yticklabels(plot_data.index)
    
    ax.axvline(0, color='black', linewidth=1.0, linestyle='-')
    ax.grid(axis='x', linestyle=':', alpha=0.7)
    ax.legend(title="Attention Source")
    
    plt.tight_layout()
    
    # --- 6. 保存图表 ---
    output_path = '/data114_4/chenjy/DLBCL/model/SI-MIL/Top_Cellular_Features.pdf'
    plt.savefig(output_path, dpi=300)
    print(f"\n步骤 5: 图表已保存至 '{output_path}'")


# ==============================================================================
#                                使用示例
# ==============================================================================
if __name__ == '__main__':
    REPORT_DIRECTORY = '/data114_4/chenjy/DLBCL/model/SI-MIL/attention_report_cohens_d' 
    ATTENTION_SOURCES_TO_COMPARE = ['attention_BCL2', 'attention_BCL6', 'attention_MYC']
    
    COLOR_MAPPING = {
        'BCL2': '#78C8BB',
        'MYC':  '#F4E07D',
        'BCL6': '#EB9797'
    }
    
    show_focused_cohens_d_plot(
        report_dir=REPORT_DIRECTORY,
        attention_sources=ATTENTION_SOURCES_TO_COMPARE,
        top_n=15,
        color_map=COLOR_MAPPING
    )
