In [2]:
import pandas as pd
import numpy as np
from scipy import stats
from scipy.stats import ttest_ind, mannwhitneyu, chi2_contingency, fisher_exact, shapiro, levene, skew, kurtosis
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

def check_normality_and_variance(group0, group1, alpha=0.05):
    """
    改进的正态性和方差齐性检查函数
    结合统计检验和描述性统计进行综合判断
    """
    n0, n1 = len(group0), len(group1)
    
    # 如果样本量太小，直接使用非参数检验
    if n0 < 10 or n1 < 10:
        return False
    
    # 1. 方差齐性检查
    if n0 > 1 and n1 > 1:
        _, p_var = levene(group0, group1)
        var_ok = p_var > alpha
    else:
        var_ok = True
    
    # 2. 正态性检查 - 采用多标准综合判断
    normal_scores = 0
    total_criteria = 0
    
    # 标准1: Shapiro-Wilk检验 (对小样本敏感)
    if 3 <= n0 <= 5000:
        _, p_normal0 = shapiro(group0)
        normal_scores += 1 if p_normal0 > 0.10 else 0  # 放宽到0.10
        total_criteria += 1
    
    if 3 <= n1 <= 5000:
        _, p_normal1 = shapiro(group1)
        normal_scores += 1 if p_normal1 > 0.10 else 0
        total_criteria += 1
    
    # 标准2: 偏度和峰度判断 (对大样本更实用)
    if n0 >= 20:
        skew0, kurtosis0 = skew(group0), kurtosis(group0)
        # 偏度绝对值<1.5且峰度绝对值<3可认为近似正态
        skew_ok0 = abs(skew0) < 1.5
        kurtosis_ok0 = abs(kurtosis0) < 3
        normal_scores += 1 if skew_ok0 and kurtosis_ok0 else 0
        total_criteria += 1
    
    if n1 >= 20:
        skew1, kurtosis1 = skew(group1), kurtosis(group1)
        skew_ok1 = abs(skew1) < 1.5
        kurtosis_ok1 = abs(kurtosis1) < 3
        normal_scores += 1 if skew_ok1 and kurtosis_ok1 else 0
        total_criteria += 1
    
    # 标准3: 异常值检查 (少于5%的异常值)
    if n0 >= 20:
        Q1_0, Q3_0 = np.percentile(group0, [25, 75])
        IQR_0 = Q3_0 - Q1_0
        lower_bound_0 = Q1_0 - 1.5 * IQR_0
        upper_bound_0 = Q3_0 + 1.5 * IQR_0
        outliers_0 = np.sum((group0 < lower_bound_0) | (group0 > upper_bound_0))
        outlier_ok0 = outliers_0 / n0 < 0.05
        normal_scores += 1 if outlier_ok0 else 0
        total_criteria += 1
    
    if n1 >= 20:
        Q1_1, Q3_1 = np.percentile(group1, [25, 75])
        IQR_1 = Q3_1 - Q1_1
        lower_bound_1 = Q1_1 - 1.5 * IQR_1
        upper_bound_1 = Q3_1 + 1.5 * IQR_1
        outliers_1 = np.sum((group1 < lower_bound_1) | (group1 > upper_bound_1))
        outlier_ok1 = outliers_1 / n1 < 0.05
        normal_scores += 1 if outlier_ok1 else 0
        total_criteria += 1
    
    # 如果没有任何标准可用，默认使用非参数检验
    if total_criteria == 0:
        return False
    
    # 综合评分：超过60%的标准通过则认为正态
    normal_ok = (normal_scores / total_criteria) >= 0.6
    
    return normal_ok and var_ok

def plot_variable_distribution(df, var, outcome_col='outcome', save_path=None):
    """
    绘制变量的分布图帮助判断正态性
    """
    try:
        data = df[[outcome_col, var]].dropna()
        if len(data) == 0:
            return
        
        group0 = data[data[outcome_col] == 0][var]
        group1 = data[data[outcome_col] == 1][var]
        
        if len(group0) == 0 or len(group1) == 0:
            return
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle(f'变量分布检查: {var}', fontsize=16, fontweight='bold')
        
        # 1. 箱线图
        plot_data = pd.melt(data, id_vars=[outcome_col], value_vars=[var])
        sns.boxplot(data=plot_data, x='variable', y='value', hue=outcome_col, ax=ax1)
        ax1.set_title(f'{var} - 箱线图')
        ax1.set_xlabel('')
        
        # 2. 分布直方图
        for outcome_val, color, label in zip([0, 1], ['skyblue', 'lightcoral'], ['Outcome=0', 'Outcome=1']):
            subset = data[data[outcome_col] == outcome_val][var]
            if len(subset) > 0:
                sns.histplot(subset, kde=True, label=label, ax=ax2, 
                           alpha=0.6, color=color, stat='density')
        ax2.set_title(f'{var} - 分布直方图')
        ax2.legend()
        
        # 3. Q-Q图 - Outcome=0
        if len(group0) >= 3:
            stats.probplot(group0, dist="norm", plot=ax3)
            ax3.set_title(f'Outcome=0 - Q-Q图 (n={len(group0)})')
        
        # 4. Q-Q图 - Outcome=1
        if len(group1) >= 3:
            stats.probplot(group1, dist="norm", plot=ax4)
            ax4.set_title(f'Outcome=1 - Q-Q图 (n={len(group1)})')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(f"{save_path}/{var}_distribution.png", dpi=300, bbox_inches='tight')
            plt.close()
        else:
            plt.show()
        
        # 打印统计信息
        print(f"\n=== {var} 统计信息 ===")
        for outcome_val, group, label in zip([0, 1], [group0, group1], ['Outcome=0', 'Outcome=1']):
            if len(group) > 0:
                print(f"{label}: n={len(group)}, 均值={group.mean():.3f}, 标准差={group.std():.3f}")
                if len(group) >= 3:
                    print(f"  偏度={skew(group):.3f}, 峰度={kurtosis(group):.3f}")
                    if len(group) <= 5000:
                        _, p_shapiro = shapiro(group)
                        print(f"  Shapiro-Wilk检验 p值={p_shapiro:.4f}")
        
        # 检查正态性
        normality_ok = check_normality_and_variance(group0, group1)
        print(f"综合正态性判断: {'适合t检验' if normality_ok else '适合Mann-Whitney U检验'}")
        
    except Exception as e:
        print(f"绘制变量 {var} 分布图时出错: {e}")

def descriptive_statistics_analysis(df, outcome_col='outcome', plot_distributions=False, plot_save_path=None):
    """
    改进的描述性统计分析函数
    
    参数:
    df: 包含数据的DataFrame
    outcome_col: 结果变量列名
    plot_distributions: 是否绘制分布图
    plot_save_path: 分布图保存路径
    """
    
    # 定义变量类型
    continuous_vars = ['age','red_blood_cells', 'hemoglobin', 
                'rdw', 'hematocrit', 'neutrophils',
                  'lymphocytes', 'platelets', 'alt', 'ast', 
                  'total_bilirubin', 'albumin', 'bun', 'creatinine', 
                  'glucose', 'sodium', 'chloride', 'free_calcium', 
                  'total_calcium', 'pao2', 'pco2', 'ph', 'lactate', 
                  'anion_gap', 'inr', 'pt', 'ptt', 
                  'heart_rate', 'resp_rate', 'temperature', 
                  'sbp', 'dbp', 'mbp', 'spo2', 'weight_admit',
                  'ferritin', 'tibc', 'iron'
                  ]
    
    categorical_vars = ['gender',  
                    'hypertension', 'diabetes', 'heart_failure', 'mi', 
                    'stroke', 'sepsis', 'aki', 'ckd', 
                    'hyperlipidemia', 'peripheral_vascular_disease', 
                    'arterial_embolization', 'crrt_used', 
                    'vasoactive_used','tibc_quantile_group','ferritin_quantile_group','iron_quantile_group',
                    ]
    
    # 确保变量存在于数据框中
    continuous_vars = [var for var in continuous_vars if var in df.columns]
    categorical_vars = [var for var in categorical_vars if var in df.columns]
    
    print(f"找到连续变量: {len(continuous_vars)}个")
    print(f"找到分类变量: {len(categorical_vars)}个")
    
    results = []
    test_counts = {'t检验': 0, 'Mann-Whitney U检验': 0}
    
    # 分析连续变量
    for i, var in enumerate(continuous_vars):
        try:
            # 移除缺失值
            data = df[[outcome_col, var]].dropna()
            if len(data) == 0:
                print(f"跳过变量 {var}: 无有效数据")
                continue
                
            group0 = data[data[outcome_col] == 0][var]
            group1 = data[data[outcome_col] == 1][var]
            
            # 如果任何一组数据为空，跳过
            if len(group0) == 0 or len(group1) == 0:
                print(f"跳过变量 {var}: 某一组数据为空")
                continue
            
            print(f"\n处理连续变量 ({i+1}/{len(continuous_vars)}): {var}")
            print(f"样本量: Outcome=0: {len(group0)}, Outcome=1: {len(group1)}")
            
            # 绘制分布图（可选）
            if plot_distributions:
                plot_variable_distribution(df, var, outcome_col, plot_save_path)
            
            # 描述性统计
            desc0 = f"{group0.median():.2f} [{group0.quantile(0.25):.2f}, {group0.quantile(0.75):.2f}]"
            desc1 = f"{group1.median():.2f} [{group1.quantile(0.25):.2f}, {group1.quantile(0.75):.2f}]"
            
            mean_std0 = f"{group0.mean():.2f} ± {group0.std():.2f}"
            mean_std1 = f"{group1.mean():.2f} ± {group1.std():.2f}"
            
            # 选择适当的检验方法
            use_ttest = check_normality_and_variance(group0, group1)
            
            if use_ttest:
                # 使用t检验
                stat, p_value = ttest_ind(group0, group1, equal_var=True)
                test_used = "t检验"
                test_counts['t检验'] += 1
            else:
                # 使用Mann-Whitney U检验
                stat, p_value = mannwhitneyu(group0, group1, alternative='two-sided')
                test_used = "Mann-Whitney U检验"
                test_counts['Mann-Whitney U检验'] += 1
            
            results.append({
                '变量': var,
                '类型': '连续变量',
                'Outcome=0 (n={})'.format(len(group0)): mean_std0,
                'Outcome=1 (n={})'.format(len(group1)): mean_std1,
                '中位数[IQR] Outcome=0': desc0,
                '中位数[IQR] Outcome=1': desc1,
                '统计量': f"{stat:.4f}",
                'P值': f"{p_value:.4f}",
                '检验方法': test_used
            })
            
        except Exception as e:
            print(f"处理连续变量 {var} 时出错: {e}")
            continue
    
    # 分析分类变量（保持原逻辑）
    for i, var in enumerate(categorical_vars):
        try:
            print(f"\n处理分类变量 ({i+1}/{len(categorical_vars)}): {var}")
            
            # 创建列联表
            contingency_table = pd.crosstab(df[var], df[outcome_col])
            
            # 确保表格至少有2x2的维度
            if contingency_table.shape[0] < 2 or contingency_table.shape[1] < 2:
                # 简化处理单类别情况
                total0 = contingency_table.iloc[:, 0].sum() if 0 in contingency_table.columns else 0
                total1 = contingency_table.iloc[:, 1].sum() if 1 in contingency_table.columns else 0
                
                if total0 == 0 or total1 == 0:
                    continue
                    
                category_name = str(contingency_table.index[0])
                count0 = contingency_table.iloc[0, 0] if 0 in contingency_table.columns else 0
                count1 = contingency_table.iloc[0, 1] if 1 in contingency_table.columns else 0
                
                percentage0 = (count0 / total0 * 100) if total0 > 0 else 0
                percentage1 = (count1 / total1 * 100) if total1 > 0 else 0
                
                desc0 = f"{category_name}: {count0}({percentage0:.1f}%)"
                desc1 = f"{category_name}: {count1}({percentage1:.1f}%)"
                
                results.append({
                    '变量': var,
                    '类型': '分类变量',
                    'Outcome=0 (n={})'.format(total0): desc0,
                    'Outcome=1 (n={})'.format(total1): desc1,
                    '中位数[IQR] Outcome=0': "N/A",
                    '中位数[IQR] Outcome=1': "N/A",
                    '统计量': "N/A",
                    'P值': "N/A",
                    '检验方法': "无法检验(单类别)"
                })
                continue
            
            # 计算百分比
            total0 = contingency_table[0].sum() if 0 in contingency_table.columns else 0
            total1 = contingency_table[1].sum() if 1 in contingency_table.columns else 0
            
            if total0 == 0 or total1 == 0:
                continue
            
            # 为每个类别创建描述
            desc_list0 = []
            desc_list1 = []
            
            for i, category in enumerate(contingency_table.index):
                count0 = contingency_table.iloc[i, contingency_table.columns.get_loc(0)] if 0 in contingency_table.columns else 0
                count1 = contingency_table.iloc[i, contingency_table.columns.get_loc(1)] if 1 in contingency_table.columns else 0
                
                percentage0 = (count0 / total0 * 100) if total0 > 0 else 0
                percentage1 = (count1 / total1 * 100) if total1 > 0 else 0
                
                desc_list0.append(f"{category}: {count0}({percentage0:.1f}%)")
                desc_list1.append(f"{category}: {count1}({percentage1:.1f}%)")
            
            desc0 = "; ".join(desc_list0)
            desc1 = "; ".join(desc_list1)
            
            # 选择适当的检验方法
            try:
                chi2, p_value, dof, expected = chi2_contingency(contingency_table)
                expected_flat = expected.flatten()
                
                # 检查期望频数条件
                cells_lt_5 = np.sum(expected_flat < 5)
                cells_lt_1 = np.sum(expected_flat < 1)
                total_cells = len(expected_flat)
                
                if cells_lt_1 > 0 or cells_lt_5 > 0.2 * total_cells:
                    if contingency_table.shape == (2, 2):
                        if total0 + total1 < 1000:
                            _, p_value = fisher_exact(contingency_table)
                            stat = "N/A"
                            test_used = "Fisher精确检验"
                        else:
                            stat = chi2
                            test_used = "卡方检验"
                    else:
                        stat = chi2
                        test_used = "卡方检验(期望频数较低)"
                else:
                    stat = chi2
                    test_used = "卡方检验"
                    
            except Exception as e:
                print(f"处理分类变量 {var} 的统计检验时出错: {e}")
                p_value = np.nan
                stat = "N/A"
                test_used = "无法计算"
            
            results.append({
                '变量': var,
                '类型': '分类变量',
                'Outcome=0 (n={})'.format(total0): desc0,
                'Outcome=1 (n={})'.format(total1): desc1,
                '中位数[IQR] Outcome=0': "N/A",
                '中位数[IQR] Outcome=1': "N/A",
                '统计量': f"{stat:.4f}" if isinstance(stat, (int, float)) else stat,
                'P值': f"{p_value:.4f}" if not np.isnan(p_value) else "N/A",
                '检验方法': test_used
            })
            
        except Exception as e:
            print(f"处理分类变量 {var} 时出错: {e}")
            continue
    
    # 创建结果DataFrame
    result_df = pd.DataFrame(results)
    
    # 打印检验方法使用统计
    print(f"\n=== 检验方法使用统计 ===")
    print(f"t检验使用次数: {test_counts['t检验']}")
    print(f"Mann-Whitney U检验使用次数: {test_counts['Mann-Whitney U检验']}")
    total_continuous = test_counts['t检验'] + test_counts['Mann-Whitney U检验']
    if total_continuous > 0:
        print(f"t检验比例: {test_counts['t检验']/total_continuous*100:.1f}%")
        print(f"Mann-Whitney U检验比例: {test_counts['Mann-Whitney U检验']/total_continuous*100:.1f}%")
    
    return result_df

# 使用示例
if __name__ == "__main__":
    # 读取数据
    df = pd.read_csv("final_imputed_with_sofa_firsticu.csv")
    
    # 执行分析（不绘制分布图）
    result = descriptive_statistics_analysis(df, 'survival_30', plot_distributions=False)
    
    # 显示结果
    pd.set_option('display.max_rows', None)
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', 1000)
    print(result)
    
    # 保存结果到Excel
    result.to_excel('描述性统计分析结果.xlsx', index=False)
    print("已保存'")
    
    # 可选：对特定变量绘制分布图
    # 例如，检查age变量的分布
    # plot_variable_distribution(df, 'age', 'outcome')
    
    # 可选：批量绘制所有连续变量的分布图
    # import os
    # os.makedirs('distribution_plots', exist_ok=True)
    # for var in ['age', 'creatinine', 'hemoglobin']:  # 选择几个关键变量
    #     plot_variable_distribution(df, var, 'outcome', 'distribution_plots')

找到连续变量: 38个
找到分类变量: 17个

处理连续变量 (1/38): age
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (2/38): red_blood_cells
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (3/38): hemoglobin
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (4/38): rdw
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (5/38): hematocrit
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (6/38): neutrophils
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (7/38): lymphocytes
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (8/38): platelets
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (9/38): alt
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (10/38): ast
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (11/38): total_bilirubin
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (12/38): albumin
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (13/38): bun
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (14/38): creatinine
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (15/38): glucose
样本量: Outcome=0: 293, Outcome=1: 835

处理连续变量 (16/38): sodium
样本量: Outcome=0: 293, Outcome=1: