In [10]:
"""α、β多样性分析"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import mannwhitneyu, ttest_ind
from skbio.diversity import alpha_diversity, beta_diversity
from skbio.stats.distance import permanova
from skbio.stats.ordination import pcoa
from matplotlib.patches import Ellipse
import os
import warnings
warnings.filterwarnings('ignore')

class MicrobiomeDiversityAnalyzer:
    """
    微生物组多样性分析器
    """
    
    def __init__(self, output_dir="diversity_results"):
        """
        初始化分析器
        
        Parameters:
        output_dir: 结果输出目录
        """
        self.pcoa_results = None
        self.beta_stats = None
        self.beta_results = None
        self.alpha_stats = None
        self.alpha_results = None
        self.domains = None
        self.sample_metadata = None
        self.data_filtered = None
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # 创建子目录
        self.alpha_dir = os.path.join(output_dir, "alpha_diversity")
        self.beta_dir = os.path.join(output_dir, "beta_diversity")
        os.makedirs(self.alpha_dir, exist_ok=True)
        os.makedirs(self.beta_dir, exist_ok=True)
        
    def load_data(self, data_filtered, sample_metadata):
        """
        加载数据
        """
        self.data_filtered = data_filtered
        self.sample_metadata = sample_metadata
        self.domains = list(data_filtered.keys())
        
        print("✅ 数据加载成功")
        print(f"   微生物类别: {self.domains}")
        print(f"   样本数量: {len(sample_metadata)}")
        print(f"   分组分布: {dict(sample_metadata['Group'].value_counts())}")
        
    def validate_data(self):
        """
        验证数据质量
        """
        print("\n=== 数据质量验证 ===")
        
        for domain, df in self.data_filtered.items():
            print(f"\n{domain}:")
            print(f"  特征数量: {df.shape[0]}")
            print(f"  样本数量: {df.shape[1]}")
            print(f"  数据范围: {df.min().min():.4f} - {df.max().max():.4f}")
            
            # 检查零值比例
            zero_ratio = (df == 0).sum().sum() / df.size
            print(f"  零值比例: {zero_ratio:.2%}")
            
            # 检查是否为相对丰度
            sample_sums = df.sum(axis=0)
            if all(0.9 <= s <= 1.1 for s in sample_sums):
                print("  ✅ 相对丰度数据")
            else:
                print("  ⚠️ 非相对丰度数据，将进行转换")
                self.data_filtered[domain] = self._convert_to_relative_abundance(df)
    
    @staticmethod
    def _convert_to_relative_abundance(df):
        """转换为相对丰度"""
        return df / df.sum(axis=0)
    
    def calculate_alpha_diversity(self, metrics=None):
        """
        计算α多样性指数
        """
        if metrics is None:
            metrics = ['shannon', 'simpson', 'observed_otus', 'pielou_e']
        
        print("\n=== α多样性分析 ===")
        
        self.alpha_results = {}
        self.alpha_stats = {}
        
        for domain, df in self.data_filtered.items():
            print(f"\n分析 {domain}...")
            
            domain_results = {}
            domain_stats = {}
            
            for metric in metrics:
                try:
                    # 计算α多样性
                    alpha_div = alpha_diversity(metric, df.T.values, ids=df.T.index)
                    domain_results[metric] = alpha_div
                    
                    # 分组统计检验
                    control_mask = self.sample_metadata['Group'] == 'Control'
                    control_samples = self.sample_metadata[control_mask].index
                    mi_samples = self.sample_metadata[~control_mask].index
                    
                    control_alpha = alpha_div[control_samples].dropna()
                    mi_alpha = alpha_div[mi_samples].dropna()
                    
                    # 描述性统计
                    control_mean = control_alpha.mean()
                    control_std = control_alpha.std()
                    mi_mean = mi_alpha.mean()
                    mi_std = mi_alpha.std()
                    
                    # 统计检验
                    if len(control_alpha) >= 3 and len(mi_alpha) >= 3:
                        control_norm = abs(control_alpha.skew()) < 2
                        mi_norm = abs(mi_alpha.skew()) < 2
                        
                        if control_norm and mi_norm:
                            t_stat, t_p = ttest_ind(control_alpha, mi_alpha)
                            test_used = "t-test"
                            p_value = t_p
                        else:
                            u_stat, u_p = mannwhitneyu(control_alpha, mi_alpha)
                            test_used = "Mann-Whitney U"
                            p_value = u_p
                    else:
                        test_used = "样本量不足"
                        p_value = 1.0
                    
                    domain_stats[metric] = {
                        'control_mean': control_mean,
                        'control_std': control_std,
                        'mi_mean': mi_mean,
                        'mi_std': mi_std,
                        'p_value': p_value,
                        'test_used': test_used,
                        'significant': p_value < 0.05
                    }
                    
                    print(f"  {metric}: p={p_value:.4f} ({test_used})")
                    
                except Exception as e:
                    print(f"  {metric}计算失败: {e}")
                    domain_stats[metric] = None
            
            self.alpha_results[domain] = domain_results
            self.alpha_stats[domain] = domain_stats
        
        return self.alpha_results, self.alpha_stats
    
    def plot_alpha_diversity(self, figsize=(15, 10)):
        """
        绘制α多样性图表 (英文显示)
        """
        print("\n=== 生成α多样性图表 ===")
        
        if not hasattr(self, 'alpha_results'):
            print("请先计算α多样性指数")
            return
        
        for domain in self.domains:
            if domain not in self.alpha_results or not self.alpha_results[domain]:
                continue
                
            metrics = list(self.alpha_results[domain].keys())
            n_metrics = len(metrics)
            
            n_cols = min(3, n_metrics)
            n_rows = (n_metrics + n_cols - 1) // n_cols
            fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
            
            if n_metrics == 1:
                axes = [axes]
            else:
                axes = axes.flatten()
            
            for i, metric in enumerate(metrics):
                if i >= len(axes):
                    break
                    
                alpha_values = self.alpha_results[domain][metric]
                stats = self.alpha_stats[domain][metric]
                
                plot_data = []
                for sample in alpha_values.index:
                    group = self.sample_metadata.loc[sample, 'Group']
                    plot_data.append({
                        'Group': group,
                        'Value': alpha_values[sample],
                        'Metric': metric.upper()
                    })
                
                df_plot = pd.DataFrame(plot_data)
                
                ax = axes[i]
                sns.violinplot(data=df_plot, x='Group', y='Value', ax=ax, 
                              palette={'Control': 'lightblue', 'MI': 'lightcoral'},
                              inner='box', alpha=0.7)
                
                if stats and stats['p_value'] < 0.05:
                    y_max = df_plot['Value'].max()
                    if stats['p_value'] < 0.001:
                        sig_text = '***'
                    elif stats['p_value'] < 0.01:
                        sig_text = '**'
                    else:
                        sig_text = '*'
                    
                    ax.text(0.5, y_max * 1.05, sig_text, ha='center', va='bottom', 
                           fontsize=14, fontweight='bold')
                
                ax.set_title(f'{domain}\n{metric.upper()} Index', fontsize=12, fontweight='bold')
                ax.set_xlabel('')
                ax.set_ylabel('Alpha Diversity Value', fontsize=10)
            
            for j in range(i+1, len(axes)):
                axes[j].set_visible(False)
            
            plt.tight_layout()
            plt.savefig(f'{self.alpha_dir}/alpha_diversity_{domain}.png', 
                       dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f"✅ 已保存 {domain} 的α多样性图表")
    
    def calculate_beta_diversity(self, metrics=None):
        """
        计算β多样性
        """
        if metrics is None:
            metrics = ['braycurtis', 'jaccard', 'euclidean']
        
        print("\n=== β多样性分析 ===")
        
        self.beta_results = {}
        self.beta_stats = {}
        self.pcoa_results = {}
        
        for domain, df in self.data_filtered.items():
            print(f"\n分析 {domain}...")
            
            domain_results = {}
            domain_stats = {}
            domain_pcoa = {}
            
            for metric in metrics:
                try:
                    dm = beta_diversity(metric, df.T.values, ids=df.T.index)
                    domain_results[metric] = dm
                    
                    permanova_result = permanova(dm, self.sample_metadata['Group'])
                    pcoa_result = pcoa(dm)
                    variance_explained = pcoa_result.proportion_explained
                    
                    domain_stats[metric] = {
                        'r_squared': permanova_result['test statistic'],
                        'p_value': permanova_result['p-value'],
                        'significant': permanova_result['p-value'] < 0.05,
                        'n_samples': len(dm.ids)
                    }
                    
                    domain_pcoa[metric] = {
                        'pcoa_result': pcoa_result,
                        'variance_explained': variance_explained,
                        'pc1_var': variance_explained[0] * 100,
                        'pc2_var': variance_explained[1] * 100
                    }
                    
                    print(f"  {metric}: R²={permanova_result['test statistic']:.4f}, "
                          f"p={permanova_result['p-value']:.4f}")
                    
                except Exception as e:
                    print(f"  {metric}计算失败: {e}")
                    domain_stats[metric] = None
                    domain_pcoa[metric] = None
            
            self.beta_results[domain] = domain_results
            self.beta_stats[domain] = domain_stats
            self.pcoa_results[domain] = domain_pcoa
        
        return self.beta_results, self.beta_stats, self.pcoa_results
    
    def plot_beta_diversity(self):
        """
        绘制β多样性图表 (英文显示)
        """
        print("\n=== 生成β多样性图表 ===")
        
        if not hasattr(self, 'beta_results'):
            print("请先计算β多样性")
            return
        
        for domain in self.domains:
            if domain not in self.beta_results or not self.beta_results[domain]:
                continue
                
            metrics = list(self.beta_results[domain].keys())
            
            for metric in metrics:
                if (domain not in self.pcoa_results or 
                    metric not in self.pcoa_results[domain] or
                    self.pcoa_results[domain][metric] is None):
                    continue
                
                pcoa_info = self.pcoa_results[domain][metric]
                pcoa_result = pcoa_info['pcoa_result']
                dm = self.beta_results[domain][metric]
                
                fig, ax = plt.subplots(figsize=(8, 6))
                
                plot_data = []
                for sample_id in dm.ids:
                    group = self.sample_metadata.loc[sample_id, 'Group']
                    pc1_val = pcoa_result.samples.loc[sample_id, 'PC1']
                    pc2_val = pcoa_result.samples.loc[sample_id, 'PC2']
                    plot_data.append({
                        'PC1': pc1_val,
                        'PC2': pc2_val,
                        'Group': group,
                        'Sample': sample_id
                    })
                
                df_plot = pd.DataFrame(plot_data)
                
                colors = {'Control': '#3498db', 'MI': '#e74c3c'}
                markers = {'Control': 'o', 'MI': 's'}
                
                for group in ['Control', 'MI']:
                    group_data = df_plot[df_plot['Group'] == group]
                    ax.scatter(group_data['PC1'], group_data['PC2'],
                             c=colors[group], marker=markers[group],
                             label=group, s=60, alpha=0.7, edgecolors='white', linewidth=0.5)
                
                for group in ['Control', 'MI']:
                    group_data = df_plot[df_plot['Group'] == group]
                    if len(group_data) > 2:
                        self._add_confidence_ellipse(ax, group_data, colors[group])
                
                ax.legend()
                ax.set_xlabel(f'PC1 ({pcoa_info["pc1_var"]:.1f}%)', fontsize=12)
                ax.set_ylabel(f'PC2 ({pcoa_info["pc2_var"]:.1f}%)', fontsize=12)
                
                stats = self.beta_stats[domain][metric]
                significance = "***" if stats['p_value'] < 0.001 else "**" if stats['p_value'] < 0.01 else "*" if stats['p_value'] < 0.05 else "ns"
                ax.set_title(f'{domain} - {metric.upper()} PCoA\n'
                           f'PERMANOVA: R²={stats["r_squared"]:.3f}, p={stats["p_value"]:.3f} ({significance})',
                           fontweight='bold')
                
                ax.grid(True, alpha=0.3)
                plt.tight_layout()
                
                plt.savefig(f'{self.beta_dir}/pcoa_{domain}_{metric}.png', 
                           dpi=300, bbox_inches='tight')
                plt.close()
                
                print(f"✅ 已保存 {domain}_{metric} 的PCoA图")
    
    @staticmethod
    def _add_confidence_ellipse(ax, data, color, alpha=0.2):
        """添加置信椭圆"""
        try:
            cov = np.cov(data['PC1'], data['PC2'])
            if np.isnan(cov).any() or np.isinf(cov).any():
                return
                
            lambda_, v = np.linalg.eig(cov)
            lambda_ = np.sqrt(lambda_)
            
            ell = Ellipse(xy=(np.mean(data['PC1']), np.mean(data['PC2'])),
                         width=lambda_[0]*2*2, height=lambda_[1]*2*2,
                         angle=np.degrees(np.arctan2(v[1,0], v[0,0])),
                         color=color, alpha=alpha)
            ax.add_patch(ell)
        except:
            pass
    
    def generate_summary_report(self):
        """
        生成综合分析报告
        """
        print("\n=== 生成分析报告 ===")
        
        alpha_summary = []
        beta_summary = []
        
        for domain in self.domains:
            if domain in self.alpha_stats:
                for metric, stats in self.alpha_stats[domain].items():
                    if stats:
                        alpha_summary.append({
                            'Domain': domain,
                            'Metric': metric,
                            'Control_Mean': stats['control_mean'],
                            'Control_STD': stats['control_std'],
                            'MI_Mean': stats['mi_mean'],
                            'MI_STD': stats['mi_std'],
                            'P_Value': stats['p_value'],
                            'Test_Used': stats['test_used'],
                            'Significant': stats['significant']
                        })
        
        for domain in self.domains:
            if domain in self.beta_stats:
                for metric, stats in self.beta_stats[domain].items():
                    if stats:
                        beta_summary.append({
                            'Domain': domain,
                            'Metric': metric,
                            'R_Squared': stats['r_squared'],
                            'P_Value': stats['p_value'],
                            'Significant': stats['significant'],
                            'N_Samples': stats['n_samples']
                        })
        
        with pd.ExcelWriter(f'{self.output_dir}/diversity_analysis_summary.xlsx') as writer:
            if alpha_summary:
                pd.DataFrame(alpha_summary).to_excel(writer, sheet_name='Alpha_Diversity', index=False)
            if beta_summary:
                pd.DataFrame(beta_summary).to_excel(writer, sheet_name='Beta_Diversity', index=False)
            
            detailed_data = []
            for domain in self.domains:
                if domain in self.alpha_results:
                    for metric, alpha_values in self.alpha_results[domain].items():
                        for sample, value in alpha_values.items():
                            detailed_data.append({
                                'Domain': domain,
                                'Sample': sample,
                                'Metric': metric,
                                'Value': value,
                                'Group': self.sample_metadata.loc[sample, 'Group']
                            })
            
            if detailed_data:
                pd.DataFrame(detailed_data).to_excel(writer, sheet_name='Detailed_Alpha_Data', index=False)
        
        print(f"✅ 分析报告已保存至: {self.output_dir}/diversity_analysis_summary.xlsx")
    
    def run_complete_analysis(self):
        """
        运行完整的多样性分析流程
        """
        print("开始完整的微生物组多样性分析...")
        
        self.validate_data()
        self.calculate_alpha_diversity()
        self.plot_alpha_diversity()
        self.calculate_beta_diversity()
        self.plot_beta_diversity()
        self.generate_summary_report()
        
        print(f"\n🎉 多样性分析完成！")
        print(f"   结果保存在: {self.output_dir}")
        print(f"   α多样性图表: {self.alpha_dir}")
        print(f"   β多样性图表: {self.beta_dir}")

def load_your_data():
    filtered_data_path = 'E:/Python/MI_Analysis/data_figures/filtered_data/filtered_data.xlsx'
    
    print("正在读取已过滤的数据...")
    data_filtered = {}
    with pd.ExcelFile(filtered_data_path) as xls:
        for sheet_name in xls.sheet_names:
            data_filtered[sheet_name] = pd.read_excel(xls, sheet_name=sheet_name, index_col=0)
    
    common_samples = list(data_filtered['Bacteria'].columns)
    for domain in data_filtered:
        data_filtered[domain] = data_filtered[domain][common_samples]
    
    metadata = pd.DataFrame({'SampleID': common_samples})
    metadata['Group'] = ['MI' if 'MI' in s else 'Control' for s in common_samples]
    metadata.set_index('SampleID', inplace=True)
    
    print(f"✅ 数据读取完成: {len(common_samples)} 个样本")
    
    return data_filtered, metadata

In [12]:
data_filtered, metadata = load_your_data()
analyzer = MicrobiomeDiversityAnalyzer(output_dir="E:/Python/MI_Analysis/data_figures/filtered_data/diversity_results")
analyzer.load_data(data_filtered, metadata)
analyzer.run_complete_analysis()

正在读取已过滤的数据...
✅ 数据读取完成: 85 个样本
✅ 数据加载成功
   微生物类别: ['Bacteria', 'Fungi', 'Virus', 'Archaea']
   样本数量: 85
   分组分布: {'Control': np.int64(47), 'MI': np.int64(38)}
开始完整的微生物组多样性分析...

=== 数据质量验证 ===

Bacteria:
  特征数量: 335
  样本数量: 85
  数据范围: 0.0000 - 100.0000
  零值比例: 44.57%
  ⚠️ 非相对丰度数据，将进行转换

Fungi:
  特征数量: 194
  样本数量: 85
  数据范围: 0.0000 - 1.2471
  零值比例: 20.33%
  ⚠️ 非相对丰度数据，将进行转换

Virus:
  特征数量: 152
  样本数量: 85
  数据范围: 0.0000 - 13.0107
  零值比例: 51.26%
  ⚠️ 非相对丰度数据，将进行转换

Archaea:
  特征数量: 365
  样本数量: 85
  数据范围: 0.0000 - 0.1433
  零值比例: 54.80%
  ⚠️ 非相对丰度数据，将进行转换

=== α多样性分析 ===

分析 Bacteria...
  shannon: p=0.5567 (Mann-Whitney U)
  simpson: p=0.5807 (Mann-Whitney U)
  observed_otus: p=0.6908 (t-test)
  pielou_e: p=0.5076 (t-test)

分析 Fungi...
  shannon: p=0.2047 (Mann-Whitney U)
  simpson: p=0.6617 (Mann-Whitney U)
  observed_otus: p=0.0166 (Mann-Whitney U)
  pielou_e: p=0.1011 (Mann-Whitney U)

分析 Virus...
  shannon: p=0.8873 (t-test)
  simpson: p=0.9347 (t-test)
  observed_otus: p=0.9336 (t-test