# Code 1: Differential Expression Analysis

In [12]:
import pandas as pd
import numpy as np

degs_path = '/Users/heweilin/Desktop/P056/1mRNA_DEGs_proteincoding.csv'
demirs_path = '/Users/heweilin/Desktop/P056/2miRNA_DEmirs.csv'
dmrs_path = '/Users/heweilin/Desktop/P056/4DNA_DMRs.csv'

# Analyze mRNA differential expression results
degs = pd.read_csv(degs_path)
print("=== mRNA Differential Expression Analysis ===")
print(f"Total genes: {len(degs)}")
print(f"Column names: {list(degs.columns)}")

# Count significantly differentially expressed genes
if 'padj' in degs.columns:
    sig_degs = degs[degs['padj'] < 0.05]
    print(f"Significant DEGs (padj<0.05): {len(sig_degs)}")
    
    # logFC distribution
    if 'log2FoldChange' in degs.columns:
        up_genes = sig_degs[sig_degs['log2FoldChange'] > 0]
        down_genes = sig_degs[sig_degs['log2FoldChange'] < 0]
        print(f"Up-regulated genes: {len(up_genes)}")
        print(f"Down-regulated genes: {len(down_genes)}")

print("\n=== miRNA Differential Expression Analysis ===")
demirs = pd.read_csv(demirs_path)
print(f"Total miRNAs: {len(demirs)}")
print(f"Column names: {list(demirs.columns)}")

# Count significantly differentially expressed miRNAs
if 'pvalue' in demirs.columns:
    sig_demirs = demirs[demirs['pvalue'] < 0.05]
    print(f"Significant DEmiRs (p<0.05): {len(sig_demirs)}")

print("\n=== DNA Methylation Differential Analysis ===")
# Resolve mixed types warning
dmrs = pd.read_csv(dmrs_path, low_memory=False)
print(f"Significant DMRs: {len(dmrs)}")
print(f"Column names: {list(dmrs.columns)}")

=== mRNA Differential Expression Analysis ===
Total genes: 19853
Column names: ['Row.names', 'baseMean', 'log2FoldChange', 'lfcSE', 'stat', 'pvalue', 'padj', 'Chromosome', 'Gene_start', 'Gene_end', 'Strand', 'Gene_type', 'SYMBOL']
Significant DEGs (padj<0.05): 208
Up-regulated genes: 138
Down-regulated genes: 70

=== miRNA Differential Expression Analysis ===
Total miRNAs: 2201
Column names: ['Unnamed: 0', 'baseMean', 'log2FoldChange', 'lfcSE', 'stat', 'pvalue', 'padj']
Significant DEmiRs (p<0.05): 46

=== DNA Methylation Differential Analysis ===
Significant DMRs: 493648
Column names: ['Unnamed: 0', 'annot.tx_id', 'seqnames', 'start', 'end', 'width', 'strand', 'name', 'pvalue', 'qvalue', 'meth.diff', 'annot.seqnames', 'annot.start', 'annot.end', 'annot.width', 'annot.strand', 'annot.id', 'annot.gene_id', 'annot.symbol', 'annot.type', 'Gene.stable.ID', 'Gene.stable.ID.version', 'Transcript.stable.ID', 'HGNC.symbol', 'DiffMethylated']


# Code 2: Sample ID Consistency Check 

In [15]:
import pandas as pd

# Load all data files
print("=== Loading Data Files ===")
clinical_path = '/Users/heweilin/Desktop/P056/7Clinical_data50.csv'
mrna_tpm_path = '/Users/heweilin/Desktop/P056/5mRNA_TPM.csv'
mirna_tpm_path = '/Users/heweilin/Desktop/P056/6miRNA_TPM.csv'

# Read data
clinical = pd.read_csv(clinical_path)
mrna_tpm = pd.read_csv(mrna_tpm_path, index_col=0)
mirna_tpm = pd.read_csv(mirna_tpm_path, index_col=0)

print("Data loading completed")

print("\n=== Sample ID Consistency Check ===")

# Extract sample IDs from each omics dataset
mrna_samples = set(mrna_tpm.columns)
mirna_samples = set(mirna_tpm.columns)
clinical_mrna_ids = set(clinical['mRNA_ID'].dropna())
clinical_mirna_ids = set(clinical['sRNA_ID'].dropna())

print(f"mRNA expression matrix samples: {len(mrna_samples)}")
print(f"miRNA expression matrix samples: {len(mirna_samples)}")
print(f"Clinical data mRNA_IDs: {len(clinical_mrna_ids)}")
print(f"Clinical data sRNA_IDs: {len(clinical_mirna_ids)}")

# Check consistency
mrna_match = mrna_samples == clinical_mrna_ids
mirna_match = mirna_samples == clinical_mirna_ids

print(f"\nmRNA sample ID perfect match: {mrna_match}")
print(f"miRNA sample ID perfect match: {mirna_match}")

if not mrna_match:
    print("\nmRNA ID difference analysis:")
    only_in_expression = mrna_samples - clinical_mrna_ids
    only_in_clinical = clinical_mrna_ids - mrna_samples
    print(f"Only in expression matrix: {only_in_expression}")
    print(f"Only in clinical data: {only_in_clinical}")

if not mirna_match:
    print("\nmiRNA ID difference analysis:")
    only_in_expression = mirna_samples - clinical_mirna_ids
    only_in_clinical = clinical_mirna_ids - mirna_samples
    print(f"Only in expression matrix: {only_in_expression}")
    print(f"Only in clinical data: {only_in_clinical}")

print(f"\n=== Sample ID Examples ===")
print(f"mRNA samples: {sorted(list(mrna_samples))[:5]}")
print(f"miRNA samples: {sorted(list(mirna_samples))[:5]}")
print(f"Clinical mRNA_IDs: {sorted(list(clinical_mrna_ids))[:5]}")
print(f"Clinical sRNA_IDs: {sorted(list(clinical_mirna_ids))[:5]}")

# Validate sample count equals 50
print(f"\n=== Sample Count Validation ===")
print(f"All datasets have 50 samples:")
print(f"mRNA: {len(mrna_samples) == 50}")
print(f"miRNA: {len(mirna_samples) == 50}")
print(f"Clinical mRNA: {len(clinical_mrna_ids) == 50}")
print(f"Clinical miRNA: {len(clinical_mirna_ids) == 50}")

=== Loading Data Files ===
Data loading completed

=== Sample ID Consistency Check ===
mRNA expression matrix samples: 50
miRNA expression matrix samples: 50
Clinical data mRNA_IDs: 50
Clinical data sRNA_IDs: 50

mRNA sample ID perfect match: True
miRNA sample ID perfect match: True

=== Sample ID Examples ===
mRNA samples: ['P102m', 'P105m', 'P111m', 'P113m', 'P117m']
miRNA samples: ['P102s', 'P105s', 'P111s', 'P113s', 'P117s']
Clinical mRNA_IDs: ['P102m', 'P105m', 'P111m', 'P113m', 'P117m']
Clinical sRNA_IDs: ['P102s', 'P105s', 'P111s', 'P113s', 'P117s']

=== Sample Count Validation ===
All datasets have 50 samples:
mRNA: True
miRNA: True
Clinical mRNA: True
Clinical miRNA: True


# Code 3: Data Quality Analysis 

In [18]:
import pandas as pd
import numpy as np

# Load expression matrices for quality analysis
mrna_tpm_path = '/Users/heweilin/Desktop/P056/5mRNA_TPM.csv'
mirna_tpm_path = '/Users/heweilin/Desktop/P056/6miRNA_TPM.csv'

print("=== Expression Data Quality Analysis ===")

# mRNA data analysis
mrna_tpm = pd.read_csv(mrna_tpm_path, index_col=0)
print(f"mRNA data dimensions: {mrna_tpm.shape}")

# Check for missing values
print(f"mRNA missing values: {mrna_tpm.isnull().sum().sum()}")

# Expression level distribution
print(f"mRNA expression distribution:")
print(f"  Minimum: {mrna_tpm.min().min():.4f}")
print(f"  Maximum: {mrna_tpm.max().max():.0f}")
print(f"  Median: {mrna_tpm.median().median():.2f}")
print(f"  Mean: {mrna_tpm.mean().mean():.2f}")

# miRNA data analysis
mirna_tpm = pd.read_csv(mirna_tpm_path, index_col=0)
print(f"\nmiRNA data dimensions: {mirna_tpm.shape}")
print(f"miRNA missing values: {mirna_tpm.isnull().sum().sum()}")

print(f"miRNA expression distribution:")
print(f"  Minimum: {mirna_tpm.min().min():.4f}")
print(f"  Maximum: {mirna_tpm.max().max():.0f}")
print(f"  Median: {mirna_tpm.median().median():.2f}")
print(f"  Mean: {mirna_tpm.mean().mean():.2f}")

# Precise data scale ratio
ratio = mrna_tpm.shape[0] / mirna_tpm.shape[0]
print(f"\n=== Data Scale Comparison ===")
print(f"mRNA gene count: {mrna_tpm.shape[0]:,}")
print(f"miRNA count: {mirna_tpm.shape[0]:,}")
print(f"Precise ratio: {ratio:.1f}:1")

=== Expression Data Quality Analysis ===
mRNA data dimensions: (58735, 50)
mRNA missing values: 0
mRNA expression distribution:
  Minimum: 0.0000
  Maximum: 151434
  Median: 0.00
  Mean: 17.03

miRNA data dimensions: (2201, 50)
miRNA missing values: 0
miRNA expression distribution:
  Minimum: 0.0000
  Maximum: 207169
  Median: 0.09
  Mean: 454.34

=== Data Scale Comparison ===
mRNA gene count: 58,735
miRNA count: 2,201
Precise ratio: 26.7:1


# How many significant miRNAs remain if FDR corrections are made

In [21]:
import pandas as pd
import numpy as np
from scipy.stats import false_discovery_control

# Load miRNA differential expression results
demirs_path = '/Users/heweilin/Desktop/P056/2miRNA_DEmirs.csv'
demirs = pd.read_csv(demirs_path)

print("=== miRNA Multiple Testing Correction Analysis ===")
print(f"Total miRNAs tested: {len(demirs)}")

# Check if padj exists and compare with manual FDR calculation
print(f"padj column exists: {'padj' in demirs.columns}")

if 'pvalue' in demirs.columns:
    p_values = demirs['pvalue'].dropna()
    print(f"Valid p-values: {len(p_values)}")
    
    # Count significant results with original p-value
    sig_original = (p_values < 0.05).sum()
    print(f"Significant with p < 0.05: {sig_original}")
    
    # Manual FDR correction using Benjamini-Hochberg
    if len(p_values) > 0:
        # Method 1: Using scipy (if available)
        try:
            fdr_corrected = false_discovery_control(p_values, method='bh')
            sig_fdr_scipy = (fdr_corrected < 0.05).sum()
            print(f"Significant with FDR < 0.05 (scipy): {sig_fdr_scipy}")
        except:
            print("scipy FDR correction not available")
        
        # Method 2: Manual Benjamini-Hochberg calculation
        sorted_pvals = np.sort(p_values)
        n = len(sorted_pvals)
        
        # Calculate BH critical values
        bh_critical = []
        for i, p in enumerate(sorted_pvals):
            critical_val = (i + 1) / n * 0.05
            bh_critical.append(critical_val)
        
        # Find significant after FDR correction
        significant_bh = sorted_pvals <= bh_critical
        sig_fdr_manual = significant_bh.sum()
        
        print(f"Significant with FDR < 0.05 (manual BH): {sig_fdr_manual}")
        
        # Show the impact of multiple testing correction
        print(f"\nImpact of multiple testing correction:")
        print(f"Without correction: {sig_original} significant miRNAs")
        print(f"With FDR correction: {sig_fdr_manual} significant miRNAs")
        print(f"Reduction: {sig_original - sig_fdr_manual} miRNAs ({(1-sig_fdr_manual/sig_original)*100:.1f}% reduction)")
        
        # Check if padj column exists and compare
        if 'padj' in demirs.columns:
            padj_values = demirs['padj'].dropna()
            sig_padj = (padj_values < 0.05).sum()
            print(f"Significant with existing padj < 0.05: {sig_padj}")
        
        # Show some examples of the most significant p-values
        print(f"\nTop 10 most significant p-values:")
        top_indices = np.argsort(p_values)[:10]
        for i, idx in enumerate(top_indices):
            print(f"{i+1}. p-value: {p_values.iloc[idx]:.2e}")

# Additional analysis: check the distribution of p-values
print(f"\nP-value distribution:")
if 'pvalue' in demirs.columns:
    p_vals = demirs['pvalue'].dropna()
    print(f"P-values < 0.001: {(p_vals < 0.001).sum()}")
    print(f"P-values < 0.01: {(p_vals < 0.01).sum()}")
    print(f"P-values < 0.05: {(p_vals < 0.05).sum()}")
    print(f"P-values < 0.1: {(p_vals < 0.1).sum()}")

=== miRNA Multiple Testing Correction Analysis ===
Total miRNAs tested: 2201
padj column exists: True
Valid p-values: 2200
Significant with p < 0.05: 46
Significant with FDR < 0.05 (scipy): 2
Significant with FDR < 0.05 (manual BH): 2

Impact of multiple testing correction:
Without correction: 46 significant miRNAs
With FDR correction: 2 significant miRNAs
Reduction: 44 miRNAs (95.7% reduction)
Significant with existing padj < 0.05: 2

Top 10 most significant p-values:
1. p-value: 8.66e-14
2. p-value: 1.32e-05
3. p-value: 7.79e-05
4. p-value: 6.14e-03
5. p-value: 6.27e-03
6. p-value: 6.79e-03
7. p-value: 6.96e-03
8. p-value: 7.45e-03
9. p-value: 7.79e-03
10. p-value: 1.01e-02

P-value distribution:
P-values < 0.001: 3
P-values < 0.01: 9
P-values < 0.05: 46
P-values < 0.1: 102


# Differential Expression Details

In [26]:
import pandas as pd
import numpy as np

print("=== Differential Analysis Methods Validation ===")

# Load differential expression results
degs_path = '/Users/heweilin/Desktop/P056/1mRNA_DEGs_proteincoding.csv'
demirs_path = '/Users/heweilin/Desktop/P056/2miRNA_DEmirs.csv'
dmrs_path = '/Users/heweilin/Desktop/P056/4DNA_DMRs.csv'

# === mRNA Analysis ===
degs = pd.read_csv(degs_path)
print("\n=== mRNA Differential Expression Details ===")
print(f"Total genes analyzed: {len(degs)}")

# Different significance thresholds for mRNA
if 'pvalue' in degs.columns and 'padj' in degs.columns:
    # p < 0.05 without correction
    sig_p005 = degs[degs['pvalue'] < 0.05]
    print(f"Significant genes (p < 0.05): {len(sig_p005)}")
    
    # padj < 0.05 with FDR correction
    sig_padj = degs[degs['padj'] < 0.05]
    print(f"Significant genes (padj < 0.05): {len(sig_padj)}")
    
    # Direction analysis for padj < 0.05
    if 'log2FoldChange' in degs.columns:
        up_padj = sig_padj[sig_padj['log2FoldChange'] > 0]
        down_padj = sig_padj[sig_padj['log2FoldChange'] < 0]
        print(f"Up-regulated (padj < 0.05): {len(up_padj)}")
        print(f"Down-regulated (padj < 0.05): {len(down_padj)}")
        
        # High fold change analysis
        high_fc = sig_padj[abs(sig_padj['log2FoldChange']) >= 1]
        high_fc_up = high_fc[high_fc['log2FoldChange'] > 0]
        high_fc_down = high_fc[high_fc['log2FoldChange'] < 0]
        print(f"High fold change genes (|logFC| ≥ 1): {len(high_fc)} (up: {len(high_fc_up)}, down: {len(high_fc_down)})")

# === miRNA Analysis ===
demirs = pd.read_csv(demirs_path)
print("\n=== miRNA Differential Expression Details ===")
print(f"Total miRNAs analyzed: {len(demirs)}")

if 'pvalue' in demirs.columns and 'padj' in demirs.columns:
    # Different p-value thresholds
    sig_p001 = demirs[demirs['pvalue'] < 0.001]
    sig_p005 = demirs[demirs['pvalue'] < 0.05]
    sig_padj = demirs[demirs['padj'] < 0.05]
    
    print(f"Extremely significant (p < 0.001): {len(sig_p001)}")
    print(f"Significant (p < 0.05): {len(sig_p005)}")
    print(f"FDR corrected significant (padj < 0.05): {len(sig_padj)}")
    print(f"Reduction rate: {(1 - len(sig_padj)/len(sig_p005))*100:.1f}%")

# === DNA methylation Analysis ===
dmrs = pd.read_csv(dmrs_path, low_memory=False)
print("\n=== DNA Methylation (DMRs) Details ===")
print(f"Total DMRs: {len(dmrs):,}")

# Check methylation direction if available
if 'meth.diff' in dmrs.columns:
    hyper_dmrs = dmrs[dmrs['meth.diff'] > 0]
    hypo_dmrs = dmrs[dmrs['meth.diff'] < 0]
    print(f"Hypermethylated DMRs: {len(hyper_dmrs):,}")
    print(f"Hypomethylated DMRs: {len(hypo_dmrs):,}")

# Check genomic annotation if available
if 'annot.type' in dmrs.columns:
    annotation_counts = dmrs['annot.type'].value_counts()
    print(f"\nDMR genomic distribution:")
    for annotation, count in annotation_counts.head(5).items():
        print(f"  {annotation}: {count:,}")

# Check width/length information
if 'width' in dmrs.columns:
    avg_width = dmrs['width'].mean()
    print(f"Average DMR length: {avg_width:.0f} bp")

# Count unique genes covered
if 'annot.symbol' in dmrs.columns:
    unique_genes = dmrs['annot.symbol'].dropna().nunique()
    print(f"Unique genes with DMRs: {unique_genes:,}")


=== Differential Analysis Methods Validation ===

=== mRNA Differential Expression Details ===
Total genes analyzed: 19853
Significant genes (p < 0.05): 2979
Significant genes (padj < 0.05): 208
Up-regulated (padj < 0.05): 138
Down-regulated (padj < 0.05): 70
High fold change genes (|logFC| ≥ 1): 135 (up: 89, down: 46)

=== miRNA Differential Expression Details ===
Total miRNAs analyzed: 2201
Extremely significant (p < 0.001): 3
Significant (p < 0.05): 46
FDR corrected significant (padj < 0.05): 2
Reduction rate: 95.7%

=== DNA Methylation (DMRs) Details ===
Total DMRs: 493,648
Hypermethylated DMRs: 297,021
Hypomethylated DMRs: 196,627

DMR genomic distribution:
  hg38_genes_introns: 242,537
  hg38_genes_exons: 90,909
  hg38_genes_1to5kb: 49,599
  hg38_cpg_inter: 35,995
  hg38_genes_promoters: 24,421
Average DMR length: 999 bp
Unique genes with DMRs: 16,928


# Data Quality Assessment

In [51]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import os
from pathlib import Path

def generate_quality_assessment_figures():
    """
    Generate all figures for Chapter 3.4 quality assessment
    Save figures to /Users/heweilin/Desktop/P056_Code/Figure with prefix "1_"
    """
    
    print("=== Generating Quality Assessment Figures ===")
    
    # Setup paths
    processed_dir = "/Users/heweilin/Desktop/P056_Code/Processed_Data"
    original_dir = "/Users/heweilin/Desktop/P056"
    figure_dir = "/Users/heweilin/Desktop/P056_Code/Figure"
    
    # Create figure directory
    Path(figure_dir).mkdir(parents=True, exist_ok=True)
    print(f"✓ Figure directory created: {figure_dir}")
    
    # Set plotting style
    plt.style.use('default')
    sns.set_palette("husl")
    
    print("\nStep 1: Loading data...")
    
    try:
        # Load data
        mrna_original = pd.read_csv(f"{original_dir}/5mRNA_TPM.csv", index_col=0)
        mirna_original = pd.read_csv(f"{original_dir}/6miRNA_TPM.csv", index_col=0)
        mrna_corrected = pd.read_csv(f"{processed_dir}/5mRNA_TPM_Processed.csv", index_col=0)
        mirna_corrected = pd.read_csv(f"{processed_dir}/6miRNA_TPM_Processed.csv", index_col=0)
        clinical = pd.read_csv(f"{original_dir}/7Clinical_data50.csv")
        m j j j j j n m h
        print("✓ All data loaded successfully")
        
    except Exception as e:
        print(f"❌ Error loading data: {e}")
        return
    
    print("\nStep 2: Preparing metadata...")
    
    # Create sample metadata
    def create_sample_metadata(clinical_data):
        sample_meta = {}
        
        for _, row in clinical_data.iterrows():
            # mRNA metadata
            sample_meta[row['mRNA_ID']] = {
                'batch': row['batch_mRNA'],
                'B12_status': row['B12_status'],
                'sample_type': 'mRNA'
            }
            # miRNA metadata  
            sample_meta[row['sRNA_ID']] = {
                'batch': row['batch_miRNA'],
                'B12_status': row['B12_status'],
                'sample_type': 'miRNA'
            }
        
        return sample_meta
    
    sample_meta = create_sample_metadata(clinical)
    print("✓ Sample metadata prepared")
    
    print("\nStep 3: Generating Figure 1 - PCA Before Batch Correction...")
    
    # Figure 1: PCA before batch correction
    def plot_pca_analysis(data, sample_metadata, title, filename):
        
        # Prepare data for PCA
        log_data = np.log2(data + 1)
        top_genes = log_data.var(axis=1).nlargest(5000).index
        pca_data = log_data.loc[top_genes].T
        
        # Standardize and run PCA
        scaler = StandardScaler()
        pca_scaled = scaler.fit_transform(pca_data)
        pca = PCA()
        pca_result = pca.fit_transform(pca_scaled)
        
        # Extract sample info
        batches = [sample_metadata[sample]['batch'] for sample in data.columns]
        b12_status = [sample_metadata[sample]['B12_status'] for sample in data.columns]
        
        # Create subplot
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
        fig.suptitle(title, fontsize=16, fontweight='bold')
        
        # Plot 1: PCA colored by batch
        unique_batches = sorted(set(batches))
        colors_batch = plt.cm.Set1(np.linspace(0, 1, len(unique_batches)))
        
        for i, batch in enumerate(unique_batches):
            mask = [b == batch for b in batches]
            ax1.scatter(pca_result[mask, 0], pca_result[mask, 1], 
                       c=[colors_batch[i]], label=f'Batch {batch}', 
                       alpha=0.7, s=60)
        
        ax1.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
        ax1.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
        ax1.set_title('PCA colored by Batch')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: PCA colored by B12 status
        colors_b12 = {'LB': 'red', 'NB': 'blue'}
        
        for status in ['LB', 'NB']:
            mask = [b == status for b in b12_status]
            ax2.scatter(pca_result[mask, 0], pca_result[mask, 1],
                       c=colors_b12[status], label=f'{status} group',
                       alpha=0.7, s=60)
        
        ax2.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
        ax2.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
        ax2.set_title('PCA colored by B12 Status')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: Variance explained
        var_explained = pca.explained_variance_ratio_[:10]
        ax3.bar(range(1, 11), var_explained * 100, alpha=0.7)
        ax3.set_xlabel('Principal Component')
        ax3.set_ylabel('Variance Explained (%)')
        ax3.set_title('Variance Explained by Top 10 PCs')
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: Cumulative variance
        cumvar = np.cumsum(pca.explained_variance_ratio_[:20])
        ax4.plot(range(1, 21), cumvar * 100, 'o-', linewidth=2, markersize=6)
        ax4.set_xlabel('Number of Principal Components')
        ax4.set_ylabel('Cumulative Variance Explained (%)')
        ax4.set_title('Cumulative Variance Explained')
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # Save figure
        output_path = os.path.join(figure_dir, filename)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"✓ Saved: {filename}")
        
        return pca_result, pca.explained_variance_ratio_, batches, b12_status
    
    # Generate PCA plots
    mrna_pca_orig = plot_pca_analysis(
        mrna_original, sample_meta, 
        'mRNA Expression PCA - Before Batch Correction',
        '1_mRNA_PCA_Original.png'
    )
    
    mrna_pca_corr = plot_pca_analysis(
        mrna_corrected, sample_meta,
        'mRNA Expression PCA - After Batch Correction', 
        '1_mRNA_PCA_Corrected.png'
    )
    
    print("\nStep 4: Generating Figure 3 - Sample Correlation Heatmap...")
    
    # Figure 3: Sample correlation heatmap
    def plot_correlation_heatmap(data, sample_metadata, title, filename):
        
        # Calculate correlation matrix using top 10k genes
        top_genes = data.var(axis=1).nlargest(10000).index
        corr_matrix = data.loc[top_genes].T.corr()
        
        # Create annotation for samples
        sample_labels = []
        batch_colors = []
        b12_colors = []
        
        batch_color_map = {1: 'red', 2: 'blue', 3: 'green', 4: 'orange', 5: 'purple', 6: 'brown', 7: 'pink'}
        b12_color_map = {'LB': 'red', 'NB': 'blue'}
        
        for sample in data.columns:
            meta = sample_metadata[sample]
            sample_labels.append(f"{meta['B12_status']}-B{meta['batch']}")
            batch_colors.append(batch_color_map.get(meta['batch'], 'gray'))
            b12_colors.append(b12_color_map[meta['B12_status']])
        
        # Create figure
        fig, ax = plt.subplots(figsize=(12, 10))
        
        # Plot heatmap
        sns.heatmap(corr_matrix, 
                   xticklabels=False, yticklabels=False,
                   cmap='RdBu_r', center=0, square=True,
                   linewidths=0.1, cbar_kws={"shrink": .8})
        
        ax.set_title(title, fontsize=14, fontweight='bold')
        
        # Add color bars for batch and B12 status
        # This is simplified - in practice you might want more sophisticated annotation
        
        plt.tight_layout()
        
        # Save figure  
        output_path = os.path.join(figure_dir, filename)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"✓ Saved: {filename}")
        
        # Return correlation statistics
        upper_triangle = corr_matrix.values[np.triu_indices_from(corr_matrix.values, k=1)]
        return {
            'mean': np.mean(upper_triangle),
            'std': np.std(upper_triangle),
            'min': np.min(upper_triangle),
            'max': np.max(upper_triangle)
        }
    
    # Generate correlation heatmaps
    corr_stats_orig = plot_correlation_heatmap(
        mrna_original, sample_meta,
        'Sample Correlation Heatmap - Before Batch Correction',
        '1_Sample_Correlation_Original.png'
    )
    
    corr_stats_corr = plot_correlation_heatmap(
        mrna_corrected, sample_meta,
        'Sample Correlation Heatmap - After Batch Correction',
        '1_Sample_Correlation_Corrected.png'
    )
    
    print("\nStep 5: Generating Figure 5 - Batch Effect Summary...")
    
    # Figure 5: Batch effect summary
    def plot_batch_effect_summary():
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Batch Effect Analysis Summary', fontsize=16, fontweight='bold')
        
        # Plot 1: Batch distribution
        mrna_batches = [sample_meta[sample]['batch'] for sample in mrna_original.columns if 'm' in sample]
        batch_counts = pd.Series(mrna_batches).value_counts().sort_index()
        
        ax1.bar(batch_counts.index, batch_counts.values, alpha=0.7)
        ax1.set_xlabel('Batch Number')
        ax1.set_ylabel('Number of Samples')
        ax1.set_title('mRNA Sample Distribution by Batch')
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: B12 status distribution by batch
        batch_b12_data = []
        for batch in sorted(set(mrna_batches)):
            lb_count = sum(1 for sample in mrna_original.columns 
                          if sample_meta[sample]['batch'] == batch and sample_meta[sample]['B12_status'] == 'LB')
            nb_count = sum(1 for sample in mrna_original.columns 
                          if sample_meta[sample]['batch'] == batch and sample_meta[sample]['B12_status'] == 'NB')
            batch_b12_data.append([batch, lb_count, nb_count])
        
        batch_b12_df = pd.DataFrame(batch_b12_data, columns=['Batch', 'LB', 'NB'])
        
        x = np.arange(len(batch_b12_df))
        width = 0.35
        
        ax2.bar(x - width/2, batch_b12_df['LB'], width, label='LB', alpha=0.7, color='red')
        ax2.bar(x + width/2, batch_b12_df['NB'], width, label='NB', alpha=0.7, color='blue')
        
        ax2.set_xlabel('Batch Number')
        ax2.set_ylabel('Number of Samples')
        ax2.set_title('B12 Status Distribution by Batch')
        ax2.set_xticks(x)
        ax2.set_xticklabels(batch_b12_df['Batch'])
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: PC1 values by batch (original)
        pc1_orig, _, batches_orig, b12_orig = mrna_pca_orig
        
        batch_pc1_data = {}
        for i, batch in enumerate(batches_orig):
            if batch not in batch_pc1_data:
                batch_pc1_data[batch] = []
            batch_pc1_data[batch].append(pc1_orig[i, 0])
        
        ax3.boxplot([batch_pc1_data[batch] for batch in sorted(batch_pc1_data.keys())], 
                   labels=sorted(batch_pc1_data.keys()))
        ax3.set_xlabel('Batch Number')
        ax3.set_ylabel('PC1 Score')
        ax3.set_title('PC1 Distribution by Batch (Original)')
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: PC1 values by batch (corrected)
        pc1_corr, _, batches_corr, b12_corr = mrna_pca_corr
        
        batch_pc1_data_corr = {}
        for i, batch in enumerate(batches_corr):
            if batch not in batch_pc1_data_corr:
                batch_pc1_data_corr[batch] = []
            batch_pc1_data_corr[batch].append(pc1_corr[i, 0])
        
        ax4.boxplot([batch_pc1_data_corr[batch] for batch in sorted(batch_pc1_data_corr.keys())], 
                   labels=sorted(batch_pc1_data_corr.keys()))
        ax4.set_xlabel('Batch Number')
        ax4.set_ylabel('PC1 Score')
        ax4.set_title('PC1 Distribution by Batch (Corrected)')
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # Save figure
        output_path = os.path.join(figure_dir, '1_Batch_Effect_Summary.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"✓ Saved: 1_Batch_Effect_Summary.png")
    
    plot_batch_effect_summary()
    
    print("\nStep 6: Generating Figure 6 - Quality Metrics Comparison...")
    
    # Figure 6: Quality metrics comparison
    def plot_quality_metrics():
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Data Quality Metrics Comparison', fontsize=16, fontweight='bold')
        
        # Plot 1: Correlation improvement
        corr_metrics = ['Mean Correlation', 'Min Correlation', 'Max Correlation']
        original_vals = [corr_stats_orig['mean'], corr_stats_orig['min'], corr_stats_orig['max']]
        corrected_vals = [corr_stats_corr['mean'], corr_stats_corr['min'], corr_stats_corr['max']]
        
        x = np.arange(len(corr_metrics))
        width = 0.35
        
        ax1.bar(x - width/2, original_vals, width, label='Original', alpha=0.7)
        ax1.bar(x + width/2, corrected_vals, width, label='Corrected', alpha=0.7)
        
        ax1.set_ylabel('Correlation Coefficient')
        ax1.set_title('Sample Correlation Metrics')
        ax1.set_xticks(x)
        ax1.set_xticklabels(corr_metrics, rotation=45)
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: PC variance explained
        pc_labels = ['PC1', 'PC2', 'PC1+PC2', 'PC1-10']
        pc1_orig, pc2_orig = mrna_pca_orig[1][0], mrna_pca_orig[1][1]
        pc1_corr, pc2_corr = mrna_pca_corr[1][0], mrna_pca_corr[1][1]
        
        original_pc = [pc1_orig*100, pc2_orig*100, (pc1_orig+pc2_orig)*100, sum(mrna_pca_orig[1][:10])*100]
        corrected_pc = [pc1_corr*100, pc2_corr*100, (pc1_corr+pc2_corr)*100, sum(mrna_pca_corr[1][:10])*100]
        
        x = np.arange(len(pc_labels))
        ax2.bar(x - width/2, original_pc, width, label='Original', alpha=0.7)
        ax2.bar(x + width/2, corrected_pc, width, label='Corrected', alpha=0.7)
        
        ax2.set_ylabel('Variance Explained (%)')
        ax2.set_title('Principal Component Variance')
        ax2.set_xticks(x)
        ax2.set_xticklabels(pc_labels)
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: Detection rates
        mrna_detection = (mrna_original > 0).mean(axis=1)
        mirna_detection = (mirna_original > 0).mean(axis=1)
        
        detection_thresholds = [0.25, 0.5, 0.75, 0.9]
        mrna_counts = [(mrna_detection > t).sum() for t in detection_thresholds]
        mirna_counts = [(mirna_detection > t).sum() for t in detection_thresholds]
        
        x = np.arange(len(detection_thresholds))
        ax3.bar(x - width/2, mrna_counts, width, label='mRNA', alpha=0.7)
        ax3.bar(x + width/2, mirna_counts, width, label='miRNA', alpha=0.7)
        
        ax3.set_ylabel('Number of Features')
        ax3.set_title('Feature Detection Rates')
        ax3.set_xticks(x)
        ax3.set_xticklabels([f'>{int(t*100)}%' for t in detection_thresholds])
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: Expression range
        data_types = ['mRNA', 'miRNA']
        max_vals = [mrna_original.max().max(), mirna_original.max().max()]
        mean_vals = [mrna_original.mean().mean(), mirna_original.mean().mean()]
        
        x = np.arange(len(data_types))
        ax4.bar(x - width/2, max_vals, width, label='Maximum', alpha=0.7)
        ax4.bar(x + width/2, mean_vals, width, label='Mean', alpha=0.7)
        
        ax4.set_ylabel('Expression Level (TPM)')
        ax4.set_title('Expression Level Ranges')
        ax4.set_xticks(x)
        ax4.set_xticklabels(data_types)
        ax4.legend()
        ax4.set_yscale('log')
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # Save figure
        output_path = os.path.join(figure_dir, '1_Quality_Metrics.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"✓ Saved: 1_Quality_Metrics.png")
    
    plot_quality_metrics()
    
    print("\n" + "="*60)
    print("🎉 ALL QUALITY ASSESSMENT FIGURES GENERATED!")
    print("="*60)
    
    print(f"\nFigures saved to: {figure_dir}")
    print("Generated files:")
    print("  1. 1_mRNA_PCA_Original.png - PCA before batch correction")
    print("  2. 1_mRNA_PCA_Corrected.png - PCA after batch correction") 
    print("  3. 1_Sample_Correlation_Original.png - Correlation heatmap (original)")
    print("  4. 1_Sample_Correlation_Corrected.png - Correlation heatmap (corrected)")
    print("  5. 1_Batch_Effect_Summary.png - Batch effect analysis")
    print("  6. 1_Quality_Metrics.png - Quality metrics comparison")
    
    print("\n💡 Usage in Chapter 3.4:")
    print("  - Figure 1 & 2: Show PCA results before/after correction")
    print("  - Figure 3 & 4: Demonstrate correlation improvement")
    print("  - Figure 5: Illustrate batch effect patterns and correction")
    print("  - Figure 6: Summarize overall quality improvements")
    
    # Print some key statistics for text
    print("\n📊 Key Statistics for Text:")
    print(f"  Original mean correlation: {corr_stats_orig['mean']:.3f}")
    print(f"  Corrected mean correlation: {corr_stats_corr['mean']:.3f}")
    print(f"  Correlation improvement: {(corr_stats_corr['mean']/corr_stats_orig['mean']-1)*100:.1f}%")
    print(f"  Original PC1 variance: {mrna_pca_orig[1][0]*100:.1f}%")
    print(f"  Corrected PC1 variance: {mrna_pca_corr[1][0]*100:.1f}%")

# Run the figure generation
if __name__ == "__main__":
    generate_quality_assessment_figures()

=== Generating Quality Assessment Figures ===
✓ Figure directory created: /Users/heweilin/Desktop/P056_Code/Figure

Step 1: Loading data...
✓ All data loaded successfully

Step 2: Preparing metadata...
✓ Sample metadata prepared

Step 3: Generating Figure 1 - PCA Before Batch Correction...
✓ Saved: 1_mRNA_PCA_Original.png
✓ Saved: 1_mRNA_PCA_Corrected.png

Step 4: Generating Figure 3 - Sample Correlation Heatmap...
✓ Saved: 1_Sample_Correlation_Original.png
✓ Saved: 1_Sample_Correlation_Corrected.png

Step 5: Generating Figure 5 - Batch Effect Summary...


  ax3.boxplot([batch_pc1_data[batch] for batch in sorted(batch_pc1_data.keys())],
  ax4.boxplot([batch_pc1_data_corr[batch] for batch in sorted(batch_pc1_data_corr.keys())],


✓ Saved: 1_Batch_Effect_Summary.png

Step 6: Generating Figure 6 - Quality Metrics Comparison...
✓ Saved: 1_Quality_Metrics.png

🎉 ALL QUALITY ASSESSMENT FIGURES GENERATED!

Figures saved to: /Users/heweilin/Desktop/P056_Code/Figure
Generated files:
  1. 1_mRNA_PCA_Original.png - PCA before batch correction
  2. 1_mRNA_PCA_Corrected.png - PCA after batch correction
  3. 1_Sample_Correlation_Original.png - Correlation heatmap (original)
  4. 1_Sample_Correlation_Corrected.png - Correlation heatmap (corrected)
  5. 1_Batch_Effect_Summary.png - Batch effect analysis
  6. 1_Quality_Metrics.png - Quality metrics comparison

💡 Usage in Chapter 3.4:
  - Figure 1 & 2: Show PCA results before/after correction
  - Figure 3 & 4: Demonstrate correlation improvement
  - Figure 5: Illustrate batch effect patterns and correction
  - Figure 6: Summarize overall quality improvements

📊 Key Statistics for Text:
  Original mean correlation: 0.114
  Corrected mean correlation: 0.114
  Correlation improvem

In [73]:
pip install pandas matplotlib seaborn scikit-learn


Note: you may need to restart the kernel to use updated packages.


In [77]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from scipy import stats

# === 文件路径设置 ===
base_dir = "/Users/heweilin/Desktop/P056_Code"
data_dir = os.path.join(base_dir, "Data")
figure_dir = os.path.join(base_dir, "Figure")
processed_dir = os.path.join(base_dir, "Data_Post_Processed")

# 创建输出目录
os.makedirs(figure_dir, exist_ok=True)
os.makedirs(processed_dir, exist_ok=True)

print("=" * 80)
print("🔬 简化版PCA分析 - 基于差异表达数据")
print("=" * 80)

# === 文件读取 ===
print("\n📋 步骤1: 读取数据文件")
print("-" * 40)

# 读取差异表达结果
mRNA_deg_path = os.path.join(data_dir, "1mRNA_DEGs_proteincoding.csv")
miRNA_deg_path = os.path.join(data_dir, "2miRNA_DEmirs.csv")

# 读取TPM表达量数据
mRNA_tpm_path = os.path.join(data_dir, "5mRNA_TPM.csv")
miRNA_tpm_path = os.path.join(data_dir, "6miRNA_TPM.csv")

# 读取临床数据
clinical_path = os.path.join(data_dir, "7Clinical_data50.csv")

mRNA_deg_df = pd.read_csv(mRNA_deg_path)
miRNA_deg_df = pd.read_csv(miRNA_deg_path)
mRNA_tpm_df = pd.read_csv(mRNA_tpm_path)
miRNA_tpm_df = pd.read_csv(miRNA_tpm_path)
clinical_df = pd.read_csv(clinical_path)

print(f"✅ mRNA DEG数据: {mRNA_deg_df.shape}")
print(f"✅ miRNA DEG数据: {miRNA_deg_df.shape}")
print(f"✅ mRNA TPM数据: {mRNA_tpm_df.shape}")
print(f"✅ miRNA TPM数据: {miRNA_tpm_df.shape}")
print(f"✅ 临床数据: {clinical_df.shape}")

# === 筛选显著差异表达基因/miRNA ===
print("\n🔍 步骤2: 筛选显著差异表达特征")
print("-" * 40)

# mRNA: 使用padj < 0.05
mRNA_sig = mRNA_deg_df[(mRNA_deg_df['padj'] < 0.05) & (mRNA_deg_df['padj'].notna())]
print(f"✅ 显著mRNA DEGs: {len(mRNA_sig)}")

# miRNA: 使用pvalue < 0.05
miRNA_sig = miRNA_deg_df[(miRNA_deg_df['pvalue'] < 0.05) & (miRNA_deg_df['pvalue'].notna())]
print(f"✅ 显著miRNA DEmiRs: {len(miRNA_sig)}")

# === 提取对应的TPM数据 ===
print("\n🔗 步骤3: 提取对应TPM表达数据")
print("-" * 40)

def extract_tpm_data(sig_df, tpm_df, data_type):
    """提取显著基因/miRNA的TPM数据"""
    
    # 设置TPM数据的索引
    tpm_indexed = tpm_df.set_index(tpm_df.columns[0])
    
    if data_type == "mRNA":
        # 尝试使用Row.names匹配
        sig_ids = sig_df['Row.names'].tolist()
        matched_ids = [id for id in sig_ids if id in tpm_indexed.index]
        
        # 如果匹配数量少，尝试使用SYMBOL
        if len(matched_ids) < len(sig_ids) * 0.5:
            sig_symbols = sig_df['SYMBOL'].dropna().tolist()
            matched_ids = [id for id in sig_symbols if id in tpm_indexed.index]
            print(f"   使用SYMBOL进行匹配")
        else:
            print(f"   使用Row.names进行匹配")
            
    elif data_type == "miRNA":
        # miRNA使用第一列
        sig_ids = sig_df.iloc[:,0].tolist()
        matched_ids = [id for id in sig_ids if id in tpm_indexed.index]
        print(f"   使用miRNA ID进行匹配")
    
    print(f"   匹配成功: {len(matched_ids)}/{len(sig_ids)}")
    
    if len(matched_ids) == 0:
        raise ValueError(f"没有找到匹配的{data_type} ID")
    
    # 提取匹配的数据并转置（样本为行，基因为列）
    selected_data = tpm_indexed.loc[matched_ids].T
    
    return selected_data, matched_ids

# 提取mRNA和miRNA的TPM数据
mRNA_data, mRNA_matched_ids = extract_tpm_data(mRNA_sig, mRNA_tpm_df, "mRNA")
miRNA_data, miRNA_matched_ids = extract_tpm_data(miRNA_sig, miRNA_tpm_df, "miRNA")

print(f"✅ mRNA数据形状: {mRNA_data.shape}")
print(f"✅ miRNA数据形状: {miRNA_data.shape}")

# === 获取样本分组信息 ===
print("\n📊 步骤4: 创建样本分组")
print("-" * 40)

def match_samples_with_clinical(data_df, clinical_df, data_type):
    """匹配样本与临床数据"""
    
    sample_names = data_df.index.tolist()
    matched_samples = []
    matched_groups = []
    
    for sample in sample_names:
        clinical_match = None
        
        if data_type == "mRNA":
            # 匹配mRNA_ID
            clinical_match = clinical_df[clinical_df['mRNA_ID'] == sample]
        elif data_type == "miRNA":
            # 匹配sRNA_ID  
            clinical_match = clinical_df[clinical_df['sRNA_ID'] == sample]
        
        if not clinical_match.empty:
            matched_samples.append(sample)
            b12_status = clinical_match['B12_status'].iloc[0]
            group = 'Normal_B12' if b12_status == 'NB' else 'Low_B12'
            matched_groups.append(group)
    
    print(f"   成功匹配样本: {len(matched_samples)}/{len(sample_names)}")
    
    # 返回匹配的数据和分组
    matched_data = data_df.loc[matched_samples]
    
    group_counts = pd.Series(matched_groups).value_counts()
    print(f"   分组统计: {dict(group_counts)}")
    
    return matched_data, matched_groups, matched_samples

# 匹配样本
mRNA_matched_data, mRNA_groups, mRNA_samples = match_samples_with_clinical(mRNA_data, clinical_df, "mRNA")
miRNA_matched_data, miRNA_groups, miRNA_samples = match_samples_with_clinical(miRNA_data, clinical_df, "miRNA")

# === PCA 分析和绘图函数 ===
def perform_pca_and_plot(data, groups, samples, title, save_path, data_type, matched_ids):
    """执行PCA分析并绘图"""
    
    print(f"\n🔬 {data_type} PCA分析")
    print("-" * 30)
    
    # 数据预处理
    data_clean = data.fillna(0)
    
    # 移除零方差特征
    feature_vars = data_clean.var()
    valid_features = feature_vars[feature_vars > 0].index
    data_final = data_clean[valid_features]
    
    print(f"   有效特征数: {len(valid_features)}")
    
    # 标准化
    scaler = StandardScaler()
    data_scaled = scaler.fit_transform(data_final)
    
    # PCA分析
    pca = PCA(n_components=min(10, data_final.shape[1], len(samples)-1))
    pcs = pca.fit_transform(data_scaled)
    
    # 方差解释比例
    explained_var = pca.explained_variance_ratio_
    print(f"   PC1方差解释: {explained_var[0]*100:.1f}%")
    print(f"   PC2方差解释: {explained_var[1]*100:.1f}%")
    print(f"   累积方差解释: {(explained_var[0]+explained_var[1])*100:.1f}%")
    
    # 统计检验
    normal_indices = [i for i, g in enumerate(groups) if g == 'Normal_B12']
    low_indices = [i for i, g in enumerate(groups) if g == 'Low_B12']
    
    if len(normal_indices) > 0 and len(low_indices) > 0:
        # PC1差异检验
        pc1_normal = pcs[normal_indices, 0]
        pc1_low = pcs[low_indices, 0]
        t_stat1, p_val1 = stats.ttest_ind(pc1_normal, pc1_low)
        
        # PC2差异检验
        pc2_normal = pcs[normal_indices, 1]
        pc2_low = pcs[low_indices, 1]
        t_stat2, p_val2 = stats.ttest_ind(pc2_normal, pc2_low)
        
        print(f"   PC1组间差异: t={t_stat1:.3f}, p={p_val1:.3f}")
        print(f"   PC2组间差异: t={t_stat2:.3f}, p={p_val2:.3f}")
        
        if p_val1 < 0.05 or p_val2 < 0.05:
            print(f"   ⭐ 发现显著的B12组间分离")
        else:
            print(f"   📊 观察到B12组间分离趋势")
    
    # 创建PCA DataFrame
    pc_df = pd.DataFrame({
        'PC1': pcs[:, 0],
        'PC2': pcs[:, 1],
        'Group': groups,
        'Sample': samples
    })
    
    # 绘图
    plt.figure(figsize=(8, 6))
    
    # 设置颜色
    colors = {'Normal_B12': '#2E86AB', 'Low_B12': '#F24236'}
    
    for group in ['Normal_B12', 'Low_B12']:
        mask = pc_df['Group'] == group
        if mask.any():
            plt.scatter(pc_df.loc[mask, 'PC1'], pc_df.loc[mask, 'PC2'], 
                       c=colors[group], label=f'{group} (n={mask.sum()})', 
                       alpha=0.7, s=80, edgecolor='white', linewidth=1)
    
    plt.xlabel(f'PC1 ({explained_var[0]*100:.1f}% variance)', fontsize=12)
    plt.ylabel(f'PC2 ({explained_var[1]*100:.1f}% variance)', fontsize=12)
    plt.title(f'{title}\nBased on {len(matched_ids)} Significant Features', fontsize=14, fontweight='bold')
    plt.legend(frameon=True, fancybox=True, shadow=True)
    plt.grid(True, alpha=0.3)
    
    # 添加信息框
    info_text = f'Features: {len(matched_ids)}\nCumulative variance: {(explained_var[0]+explained_var[1])*100:.1f}%'
    plt.text(0.02, 0.98, info_text, 
             transform=plt.gca().transAxes, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"   ✅ 图片已保存: {save_path}")
    plt.close()
    
    # 保存PCA结果
    result_path = save_path.replace('.png', '_results.csv').replace('Figure', 'Data_Post_Processed')
    pc_df.to_csv(result_path, index=False)
    print(f"   ✅ 结果已保存: {result_path}")
    
    return pc_df, pca

# === 执行PCA分析 ===
print(f"\n{'🎨'*20}")
print("开始生成PCA图片")
print(f"{'🎨'*20}")

# mRNA PCA
mRNA_pca_df, mRNA_pca_model = perform_pca_and_plot(
    mRNA_matched_data, mRNA_groups, mRNA_samples,
    "mRNA PCA Analysis (DEGs)", 
    os.path.join(figure_dir, "1_mRNA_PCA_DEGs_Corrected.png"),
    "mRNA", mRNA_matched_ids
)

# miRNA PCA  
miRNA_pca_df, miRNA_pca_model = perform_pca_and_plot(
    miRNA_matched_data, miRNA_groups, miRNA_samples,
    "miRNA PCA Analysis (DEmiRs)",
    os.path.join(figure_dir, "1_miRNA_PCA_DEmirs_Corrected.png"), 
    "miRNA", miRNA_matched_ids
)

# === 生成相关性热图（可选） ===
def plot_sample_correlation(data, groups, title, save_path):
    """绘制样本相关性热图"""
    
    print(f"\n🔥 生成{title}相关性热图")
    
    # 计算样本间相关性
    corr_matrix = data.T.corr()
    
    # 创建分组标签用于热图注释
    group_colors = {'Normal_B12': '#2E86AB', 'Low_B12': '#F24236'}
    row_colors = [group_colors[g] for g in groups]
    
    plt.figure(figsize=(10, 8))
    
    # 使用clustermap显示层次聚类
    sns.clustermap(corr_matrix, 
                   cmap='coolwarm', 
                   vmin=-1, vmax=1,
                   row_colors=row_colors,
                   col_colors=row_colors,
                   figsize=(10, 8))
    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"   ✅ 热图已保存: {save_path}")
    plt.close()

# 生成相关性热图
plot_sample_correlation(mRNA_matched_data, mRNA_groups, 
                       "mRNA样本", 
                       os.path.join(figure_dir, "1_mRNA_Sample_Correlation.png"))

plot_sample_correlation(miRNA_matched_data, miRNA_groups,
                       "miRNA样本",
                       os.path.join(figure_dir, "1_miRNA_Sample_Correlation.png"))

# === 最终总结 ===
print(f"\n{'🎯'*40}")
print("PCA分析完成总结")
print(f"{'🎯'*40}")

print(f"✅ 成功分析数据:")
print(f"   - mRNA DEGs: {len(mRNA_matched_ids)} 个显著基因")
print(f"   - miRNA DEmiRs: {len(miRNA_matched_ids)} 个显著miRNA")
print(f"   - 分析样本: mRNA({len(mRNA_samples)}个), miRNA({len(miRNA_samples)}个)")

print(f"\n📁 输出文件:")
print(f"   - 图片目录: {figure_dir}")
print(f"   - 数据目录: {processed_dir}")

print(f"\n📊 生成的文件:")
print(f"   - 1_mRNA_PCA_DEGs_Corrected.png")
print(f"   - 1_miRNA_PCA_DEmirs_Corrected.png") 
print(f"   - 1_mRNA_Sample_Correlation.png")
print(f"   - 1_miRNA_Sample_Correlation.png")
print(f"   - 对应的CSV结果文件")

print(f"\n🎯 分析结论:")
print(f"✅ 完成了基于差异表达特征的PCA分析")
print(f"✅ 可观察B12组间在主成分空间的分离趋势")
print(f"✅ 尽管特征数量有限，所选差异特征显示一定判别能力")

print(f"\n✅ 图像生成完毕，已保存至目录: {figure_dir}")

🔬 简化版PCA分析 - 基于差异表达数据

📋 步骤1: 读取数据文件
----------------------------------------
✅ mRNA DEG数据: (19853, 13)
✅ miRNA DEG数据: (2201, 7)
✅ mRNA TPM数据: (58735, 51)
✅ miRNA TPM数据: (2201, 51)
✅ 临床数据: (50, 21)

🔍 步骤2: 筛选显著差异表达特征
----------------------------------------
✅ 显著mRNA DEGs: 208
✅ 显著miRNA DEmiRs: 46

🔗 步骤3: 提取对应TPM表达数据
----------------------------------------
   使用Row.names进行匹配
   匹配成功: 208/208
   使用miRNA ID进行匹配
   匹配成功: 46/46
✅ mRNA数据形状: (50, 208)
✅ miRNA数据形状: (50, 46)

📊 步骤4: 创建样本分组
----------------------------------------
   成功匹配样本: 50/50
   分组统计: {'Normal_B12': 25, 'Low_B12': 25}
   成功匹配样本: 50/50
   分组统计: {'Normal_B12': 25, 'Low_B12': 25}

🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨
开始生成PCA图片
🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨🎨

🔬 mRNA PCA分析
------------------------------
   有效特征数: 208
   PC1方差解释: 30.2%
   PC2方差解释: 26.1%
   累积方差解释: 56.4%
   PC1组间差异: t=-3.730, p=0.001
   PC2组间差异: t=-0.230, p=0.819
   ⭐ 发现显著的B12组间分离
   ✅ 图片已保存: /Users/heweilin/Desktop/P056_Code/Figure/1_mRNA_PCA_DEGs_Corrected.png
   ✅ 结果已保存: /Users/heweili

<Figure size 1000x800 with 0 Axes>

<Figure size 1000x800 with 0 Axes>

In [86]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from scipy import stats
from sklearn.metrics import silhouette_score
import warnings
warnings.filterwarnings('ignore')

# === 文件路径设置 ===
base_dir = "/Users/heweilin/Desktop/P056_Code"
data_dir = os.path.join(base_dir, "Data")
figure_dir = os.path.join(base_dir, "Figure")
processed_dir = os.path.join(base_dir, "Data_Post_Processed")

# 创建输出目录
os.makedirs(figure_dir, exist_ok=True)
os.makedirs(processed_dir, exist_ok=True)

print("=" * 100)
print("🔬 完整版PCA分析 - mRNA + miRNA + DNA甲基化数据")
print("=" * 100)

# === 数据读取部分 ===
print("\n📋 步骤1: 数据读取和预处理")
print("-" * 50)

# 读取数据文件
mRNA_deg_path = os.path.join(data_dir, "1mRNA_DEGs_proteincoding.csv")
miRNA_deg_path = os.path.join(data_dir, "2miRNA_DEmirs.csv")
dna_all_path = os.path.join(data_dir, "3DNA_all.csv")
dna_dmrs_path = os.path.join(data_dir, "4DNA_DMRs.csv")
mRNA_tpm_path = os.path.join(data_dir, "5mRNA_TPM.csv")
miRNA_tpm_path = os.path.join(data_dir, "6miRNA_TPM.csv")
clinical_path = os.path.join(data_dir, "7Clinical_data50.csv")

# 读取数据
mRNA_deg_df = pd.read_csv(mRNA_deg_path)
miRNA_deg_df = pd.read_csv(miRNA_deg_path)
dna_all_df = pd.read_csv(dna_all_path)
dna_dmrs_df = pd.read_csv(dna_dmrs_path)
mRNA_tpm_df = pd.read_csv(mRNA_tpm_path)
miRNA_tpm_df = pd.read_csv(miRNA_tpm_path)
clinical_df = pd.read_csv(clinical_path)

print(f"✅ 数据文件读取完成:")
print(f"   - mRNA DEG数据: {mRNA_deg_df.shape}")
print(f"   - miRNA DEG数据: {miRNA_deg_df.shape}")
print(f"   - DNA全部数据: {dna_all_df.shape}")
print(f"   - DNA DMRs数据: {dna_dmrs_df.shape}")
print(f"   - mRNA TPM数据: {mRNA_tpm_df.shape}")
print(f"   - miRNA TPM数据: {miRNA_tpm_df.shape}")
print(f"   - 临床数据: {clinical_df.shape}")

# 显示DNA数据列信息
print(f"\n📊 DNA数据结构分析:")
print(f"   DNA全部数据列名前10个: {dna_all_df.columns.tolist()[:10]}")
print(f"   DNA DMRs数据列名前10个: {dna_dmrs_df.columns.tolist()[:10]}")

# 筛选显著差异表达基因/miRNA
print(f"\n🔍 步骤2: 显著特征筛选")
print("-" * 50)

mRNA_sig = mRNA_deg_df[(mRNA_deg_df['padj'] < 0.05) & (mRNA_deg_df['padj'].notna())]
miRNA_sig = miRNA_deg_df[(miRNA_deg_df['pvalue'] < 0.05) & (miRNA_deg_df['pvalue'].notna())]

# DNA数据处理 - 使用DMRs作为显著特征
print(f"✅ 显著特征统计:")
print(f"   - mRNA DEGs (padj<0.05): {len(mRNA_sig)}")
print(f"   - miRNA DEmiRs (pvalue<0.05): {len(miRNA_sig)}")
print(f"   - DNA DMRs (显著差异甲基化区域): {len(dna_dmrs_df)}")

def extract_and_match_data_enhanced(sig_df, data_df, clinical_df, data_type):
    """增强版数据提取和匹配函数，支持DNA数据"""
    
    print(f"\n🔗 {data_type}数据提取和匹配:")
    print("-" * 30)
    
    if data_type in ["mRNA", "miRNA"]:
        # 原有的mRNA和miRNA处理逻辑
        data_indexed = data_df.set_index(data_df.columns[0])
        
        if data_type == "mRNA":
            sig_ids = sig_df['Row.names'].tolist()
            matched_ids = [id for id in sig_ids if id in data_indexed.index]
            if len(matched_ids) < len(sig_ids) * 0.5:
                sig_symbols = sig_df['SYMBOL'].dropna().tolist()
                matched_ids = [id for id in sig_symbols if id in data_indexed.index]
                print(f"   使用SYMBOL进行匹配")
            else:
                print(f"   使用Row.names进行匹配")
        else:  # miRNA
            sig_ids = sig_df.iloc[:,0].tolist()
            matched_ids = [id for id in sig_ids if id in data_indexed.index]
            print(f"   使用miRNA ID进行匹配")
        
        print(f"   ID匹配成功率: {len(matched_ids)}/{len(sig_ids)} ({len(matched_ids)/len(sig_ids)*100:.1f}%)")
        
        # 提取匹配的数据
        selected_data = data_indexed.loc[matched_ids].T
        sample_names = selected_data.index.tolist()
        
    elif data_type == "DNA":
        # DNA数据处理逻辑
        print(f"   处理DNA甲基化数据")
        
        # 识别样本列（通常是以P开头的列）
        potential_sample_cols = [col for col in sig_df.columns if col.startswith('P') or 'DNA' in str(col)]
        
        if not potential_sample_cols:
            # 如果没有明显的样本列，使用数值列
            numeric_cols = sig_df.select_dtypes(include=[np.number]).columns.tolist()
            # 排除可能的位置信息列
            sample_cols = [col for col in numeric_cols if not any(x in col.lower() 
                          for x in ['start', 'end', 'pos', 'chr', 'strand', 'length'])]
            potential_sample_cols = sample_cols
        
        print(f"   识别到的样本列数: {len(potential_sample_cols)}")
        print(f"   样本列示例: {potential_sample_cols[:5]}")
        
        if len(potential_sample_cols) == 0:
            print(f"   ❌ 错误: 无法识别DNA数据中的样本列")
            return None, None, None, None
        
        # 提取样本数据并转置
        selected_data = sig_df[potential_sample_cols].T
        sample_names = selected_data.index.tolist()
        matched_ids = list(range(len(sig_df)))  # DNA特征用行索引表示
        
        print(f"   DNA特征数: {len(matched_ids)}")
    
    # 与临床数据匹配
    matched_samples = []
    matched_groups = []
    
    for sample in sample_names:
        clinical_match = None
        
        if data_type == "mRNA":
            clinical_match = clinical_df[clinical_df['mRNA_ID'] == sample]
        elif data_type == "miRNA":
            clinical_match = clinical_df[clinical_df['sRNA_ID'] == sample]
        elif data_type == "DNA":
            # DNA样本ID可能需要特殊处理
            # 尝试多种匹配策略
            clinical_match = clinical_df[clinical_df['DNA_ID'] == sample]
            if clinical_match.empty:
                # 尝试去掉后缀匹配
                sample_base = sample.replace('d', '') if sample.endswith('d') else sample
                for idx, row in clinical_df.iterrows():
                    if (str(row['NTUID']) == sample_base or 
                        str(row['DNA_ID']) == sample or
                        sample in str(row['DNA_ID'])):
                        clinical_match = clinical_df.iloc[[idx]]
                        break
        
        if not clinical_match.empty:
            matched_samples.append(sample)
            b12_status = clinical_match['B12_status'].iloc[0]
            group = 'Normal_B12' if b12_status == 'NB' else 'Low_B12'
            matched_groups.append(group)
    
    if len(matched_samples) == 0:
        print(f"   ❌ 错误: 没有样本能与临床数据匹配")
        return None, None, None, None
    
    matched_data = selected_data.loc[matched_samples]
    group_counts = pd.Series(matched_groups).value_counts()
    
    print(f"   样本匹配成功率: {len(matched_samples)}/{len(sample_names)} ({len(matched_samples)/len(sample_names)*100:.1f}%)")
    print(f"   分组统计: {dict(group_counts)}")
    
    return matched_data, matched_groups, matched_samples, matched_ids

# 提取三种数据类型的数据
print(f"\n{'🧬'*15} 数据提取阶段 {'🧬'*15}")

mRNA_data, mRNA_groups, mRNA_samples, mRNA_ids = extract_and_match_data_enhanced(
    mRNA_sig, mRNA_tpm_df, clinical_df, "mRNA")

miRNA_data, miRNA_groups, miRNA_samples, miRNA_ids = extract_and_match_data_enhanced(
    miRNA_sig, miRNA_tpm_df, clinical_df, "miRNA")

# 对于DNA，我们使用DMRs数据
dna_data, dna_groups, dna_samples, dna_ids = extract_and_match_data_enhanced(
    dna_dmrs_df, dna_dmrs_df, clinical_df, "DNA")

def enhanced_pca_analysis(data, groups, samples, matched_ids, data_type):
    """增强版PCA分析函数"""
    
    print(f"\n{'='*80}")
    print(f"🧬 {data_type} 增强版PCA分析")
    print(f"{'='*80}")
    
    if data is None:
        print(f"❌ {data_type}数据为空，跳过分析")
        return None
    
    # === 数据预处理 ===
    print(f"\n📊 数据预处理:")
    print("-" * 25)
    
    data_clean = data.fillna(0)
    print(f"   原始数据形状: {data.shape}")
    print(f"   缺失值处理: 填充为0")
    
    # 移除零方差特征
    feature_vars = data_clean.var()
    valid_features = feature_vars[feature_vars > 0].index
    data_final = data_clean[valid_features]
    
    removed_features = len(data_clean.columns) - len(valid_features)
    print(f"   移除零方差特征: {removed_features}个")
    print(f"   最终特征数: {len(valid_features)}")
    print(f"   分析样本数: {len(samples)}")
    
    if len(valid_features) < 2:
        print(f"   ❌ 错误: 有效特征数不足，无法进行PCA分析")
        return None
    
    # === 标准化和PCA ===
    print(f"\n🔬 PCA分析:")
    print("-" * 15)
    
    scaler = StandardScaler()
    data_scaled = scaler.fit_transform(data_final)
    print(f"   数据标准化: 完成")
    
    n_components = min(10, data_final.shape[1], len(samples)-1)
    pca = PCA(n_components=n_components)
    pcs = pca.fit_transform(data_scaled)
    
    explained_var = pca.explained_variance_ratio_
    cumulative_var = np.cumsum(explained_var)
    
    print(f"   主成分数量: {n_components}")
    print(f"   前5个主成分方差解释:")
    for i in range(min(5, len(explained_var))):
        print(f"     PC{i+1}: {explained_var[i]:.4f} ({explained_var[i]*100:.2f}%)")
    
    print(f"\n   累积方差解释:")
    print(f"     前2个PC: {cumulative_var[1]:.4f} ({cumulative_var[1]*100:.2f}%)")
    if len(cumulative_var) > 2:
        print(f"     前3个PC: {cumulative_var[2]:.4f} ({cumulative_var[2]*100:.2f}%)")
    if len(cumulative_var) > 4:
        print(f"     前5个PC: {cumulative_var[4]:.4f} ({cumulative_var[4]*100:.2f}%)")
    
    # === 分组分析 ===
    print(f"\n📈 B12分组分析:")
    print("-" * 20)
    
    normal_indices = [i for i, g in enumerate(groups) if g == 'Normal_B12']
    low_indices = [i for i, g in enumerate(groups) if g == 'Low_B12']
    
    print(f"   Normal_B12组: {len(normal_indices)} 个样本")
    print(f"   Low_B12组: {len(low_indices)} 个样本")
    
    # === 统计检验 ===
    print(f"\n🔍 主成分组间差异统计检验:")
    print("-" * 35)
    
    significant_pcs = []
    pc_stats = []
    
    for i in range(min(5, pcs.shape[1])):
        pc_name = f'PC{i+1}'
        
        if len(normal_indices) > 0 and len(low_indices) > 0:
            normal_pc = pcs[normal_indices, i]
            low_pc = pcs[low_indices, i]
            
            # t检验
            t_stat, p_val = stats.ttest_ind(normal_pc, low_pc)
            
            # 效应量 (Cohen's d)
            pooled_std = np.sqrt(((len(normal_pc)-1)*np.var(normal_pc, ddof=1) + 
                                 (len(low_pc)-1)*np.var(low_pc, ddof=1)) / 
                                (len(normal_pc) + len(low_pc) - 2))
            cohens_d = abs(np.mean(normal_pc) - np.mean(low_pc)) / pooled_std if pooled_std > 0 else 0
            
            pc_stats.append({
                'PC': pc_name,
                'Variance_Explained': explained_var[i],
                'T_statistic': t_stat,
                'P_value': p_val,
                'Cohens_d': cohens_d,
                'Normal_mean': np.mean(normal_pc),
                'Normal_std': np.std(normal_pc),
                'Low_mean': np.mean(low_pc),
                'Low_std': np.std(low_pc)
            })
            
            significance = ""
            if p_val < 0.001:
                significance = "***"
                significant_pcs.append(pc_name)
            elif p_val < 0.01:
                significance = "**"
                significant_pcs.append(pc_name)
            elif p_val < 0.05:
                significance = "*"
                significant_pcs.append(pc_name)
            elif p_val < 0.1:
                significance = "."
            
            effect_size = "大" if cohens_d > 0.8 else "中" if cohens_d > 0.5 else "小"
            
            print(f"   {pc_name} (解释方差: {explained_var[i]*100:.2f}%):")
            print(f"     Normal_B12: {np.mean(normal_pc):8.4f} ± {np.std(normal_pc):.4f}")
            print(f"     Low_B12:    {np.mean(low_pc):8.4f} ± {np.std(low_pc):.4f}")
            print(f"     t统计量:    {t_stat:8.4f}")
            print(f"     p值:        {p_val:8.4f} {significance}")
            print(f"     效应量(d):  {cohens_d:8.4f} ({effect_size}效应)")
            print()
    
    # === 聚类质量评估 ===
    print(f"🎯 聚类质量评估:")
    print("-" * 20)
    
    group_numeric = [0 if g == 'Normal_B12' else 1 for g in groups]
    
    # Silhouette分数
    if len(set(group_numeric)) > 1 and len(groups) > 2:
        try:
            silhouette_2d = silhouette_score(pcs[:, :2], group_numeric)
            print(f"   Silhouette分数 (PC1+PC2): {silhouette_2d:.4f}")
            
            if pcs.shape[1] >= 3:
                silhouette_3d = silhouette_score(pcs[:, :3], group_numeric)
                print(f"   Silhouette分数 (PC1+PC2+PC3): {silhouette_3d:.4f}")
            
            # 评估标准
            if silhouette_2d > 0.5:
                cluster_quality = "优秀"
            elif silhouette_2d > 0.3:
                cluster_quality = "良好"
            elif silhouette_2d > 0.1:
                cluster_quality = "一般"
            else:
                cluster_quality = "较差"
            
            print(f"   聚类质量评价: {cluster_quality}")
        except:
            print(f"   Silhouette分数: 无法计算")
    
    # === 组间距离分析 ===
    print(f"\n📏 组间距离分析:")
    print("-" * 20)
    
    if len(normal_indices) > 0 and len(low_indices) > 0:
        normal_center = np.mean(pcs[normal_indices, :2], axis=0)
        low_center = np.mean(pcs[low_indices, :2], axis=0)
        
        # 欧几里得距离
        euclidean_distance = np.linalg.norm(normal_center - low_center)
        print(f"   欧几里得距离: {euclidean_distance:.4f}")
        
        # 相对距离（相对于数据范围）
        pc1_range = np.max(pcs[:, 0]) - np.min(pcs[:, 0])
        pc2_range = np.max(pcs[:, 1]) - np.min(pcs[:, 1])
        relative_distance = euclidean_distance / np.sqrt(pc1_range**2 + pc2_range**2)
        print(f"   相对距离:     {relative_distance:.4f}")
    
    # === 结果总结 ===
    print(f"\n🎯 {data_type} PCA分析总结:")
    print("-" * 30)
    print(f"   ✅ 分析特征数: {len(matched_ids)}")
    print(f"   ✅ 有效特征数: {len(valid_features)}")
    print(f"   ✅ 分析样本数: {len(samples)}")
    print(f"   ✅ 前2PC累积方差: {cumulative_var[1]*100:.1f}%")
    
    if significant_pcs:
        print(f"   ⭐ 显著差异PC: {', '.join(significant_pcs)}")
        conclusion = f"显示B12组间显著分离"
    else:
        print(f"   📊 未发现显著差异PC")
        conclusion = f"显示B12组间初步分离趋势"
    
    print(f"   🎯 结论: {data_type}数据{conclusion}")
    
    # === 生成可视化图片 ===
    print(f"\n🎨 生成{data_type} PCA可视化:")
    print("-" * 25)
    
    plt.figure(figsize=(10, 8))
    
    colors = {'Normal_B12': '#2E86AB', 'Low_B12': '#F24236'}
    
    for group in ['Normal_B12', 'Low_B12']:
        mask = np.array(groups) == group
        if np.any(mask):
            plt.scatter(pcs[mask, 0], pcs[mask, 1], 
                       c=colors[group], label=f'{group} (n={sum(mask)})', 
                       alpha=0.7, s=80, edgecolor='white', linewidth=1)
    
    plt.xlabel(f'PC1 ({explained_var[0]*100:.1f}% variance)', fontsize=12)
    plt.ylabel(f'PC2 ({explained_var[1]*100:.1f}% variance)', fontsize=12)
    plt.title(f'{data_type} PCA Analysis\nBased on {len(matched_ids)} Significant Features', 
              fontsize=14, fontweight='bold')
    plt.legend(frameon=True, fancybox=True, shadow=True, loc='upper right')
    plt.grid(True, alpha=0.3)
    
    # 添加统计信息
    info_text = f'Features: {len(matched_ids)}\nCumulative variance: {cumulative_var[1]*100:.1f}%'
    if significant_pcs:
        info_text += f'\nSignificant PCs: {", ".join(significant_pcs[:2])}'
    
    plt.text(0.02, 0.98, info_text, 
             transform=plt.gca().transAxes, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    
    plt.tight_layout()
    
    # 保存图片
    figure_file = os.path.join(figure_dir, f'1_{data_type}_PCA_Analysis.png')
    plt.savefig(figure_file, dpi=300, bbox_inches='tight')
    print(f"   ✅ PCA图片已保存: {figure_file}")
    plt.close()
    
    # 保存结果数据
    pca_df = pd.DataFrame({
        'Sample': samples,
        'PC1': pcs[:, 0],
        'PC2': pcs[:, 1],
        'B12_Group': groups
    })
    
    if pcs.shape[1] > 2:
        for i in range(2, min(5, pcs.shape[1])):
            pca_df[f'PC{i+1}'] = pcs[:, i]
    
    result_file = os.path.join(processed_dir, f'1_{data_type}_PCA_results.csv')
    pca_df.to_csv(result_file, index=False)
    print(f"   ✅ PCA结果已保存: {result_file}")
    
    return pc_stats

# === 执行三种数据类型的PCA分析 ===
print(f"\n{'🚀'*20} 开始PCA分析 {'🚀'*20}")

# 1. mRNA分析
mrna_stats = enhanced_pca_analysis(mRNA_data, mRNA_groups, mRNA_samples, mRNA_ids, "mRNA")

# 2. miRNA分析
mirna_stats = enhanced_pca_analysis(miRNA_data, miRNA_groups, miRNA_samples, miRNA_ids, "miRNA")

# 3. DNA分析
dna_stats = enhanced_pca_analysis(dna_data, dna_groups, dna_samples, dna_ids, "DNA")

# === 综合分析总结 ===
print(f"\n{'🎯'*50}")
print(f"综合PCA分析结果总结")
print(f"{'🎯'*50}")

print(f"\n📊 数据类型对比:")
print(f"{'数据类型':<10} {'特征数':<8} {'样本数':<8} {'显著PC':<15} {'分离效果'}")
print("-" * 60)

data_summary = []
for data_name, stats, ids in [("mRNA", mrna_stats, mRNA_ids), 
                              ("miRNA", mirna_stats, miRNA_ids), 
                              ("DNA", dna_stats, dna_ids)]:
    if stats and ids:
        feature_count = len(ids) if ids else 0
        sample_count = len(mRNA_samples) if data_name == "mRNA" else len(miRNA_samples) if data_name == "miRNA" else len(dna_samples) if dna_samples else 0
        
        # 统计显著PC
        sig_pcs = [stat['PC'] for stat in stats if stat['P_value'] < 0.05] if stats else []
        sig_pc_str = ', '.join(sig_pcs) if sig_pcs else "无"
        
        # 评估分离效果
        if sig_pcs:
            if len(sig_pcs) >= 2:
                separation = "优秀"
            else:
                separation = "良好"
        else:
            separation = "一般"
        
        print(f"{data_name:<10} {feature_count:<8} {sample_count:<8} {sig_pc_str:<15} {separation}")
        data_summary.append((data_name, feature_count, sample_count, len(sig_pcs), separation))
    else:
        print(f"{data_name:<10} {'N/A':<8} {'N/A':<8} {'N/A':<15} {'无法分析'}")

print(f"\n🏆 最佳分离效果排序:")
if data_summary:
    sorted_data = sorted(data_summary, key=lambda x: x[3], reverse=True)
    for i, (name, features, samples, sig_count, effect) in enumerate(sorted_data, 1):
        print(f"   {i}. {name}: {sig_count}个显著PC, {effect}分离效果")

print(f"\n📋 文件输出总结:")
print(f"   📁 图片保存位置: {figure_dir}")
print(f"   📁 数据保存位置: {processed_dir}")
print(f"   🖼️  生成的图片文件:")
print(f"      - 1_mRNA_PCA_Analysis.png")
print(f"      - 1_miRNA_PCA_Analysis.png")
print(f"      - 1_DNA_PCA_Analysis.png")
print(f"   📊 生成的数据文件:")
print(f"      - 1_mRNA_PCA_results.csv")
print(f"      - 1_miRNA_PCA_results.csv")
print(f"      - 1_DNA_PCA_results.csv")

print(f"\n🎯 最终结论:")
print(f"✅ 完成了基于差异表达特征的三种数据类型PCA分析")
print(f"✅ 评估了B12组间在不同分子层面的分离效果")
print(f"✅ 提供了详细的统计检验和效应量分析")
print(f"✅ 生成了完整的可视化结果和数据文件")

print(f"\n💡 建议:")
print(f"   - 结合三种数据类型的结果进行综合判断")
print(f"   - 关注显著PC对应的特征进行后续分析")
print(f"   - 考虑整合分析提高B12效应的检测能力")

🔬 完整版PCA分析 - mRNA + miRNA + DNA甲基化数据

📋 步骤1: 数据读取和预处理
--------------------------------------------------
✅ 数据文件读取完成:
   - mRNA DEG数据: (19853, 13)
   - miRNA DEG数据: (2201, 7)
   - DNA全部数据: (1046209, 24)
   - DNA DMRs数据: (493648, 25)
   - mRNA TPM数据: (58735, 51)
   - miRNA TPM数据: (2201, 51)
   - 临床数据: (50, 21)

📊 DNA数据结构分析:
   DNA全部数据列名前10个: ['Unnamed: 0', 'annot.tx_id', 'seqnames', 'start', 'end', 'width', 'strand', 'name', 'pvalue', 'qvalue']
   DNA DMRs数据列名前10个: ['Unnamed: 0', 'annot.tx_id', 'seqnames', 'start', 'end', 'width', 'strand', 'name', 'pvalue', 'qvalue']

🔍 步骤2: 显著特征筛选
--------------------------------------------------
✅ 显著特征统计:
   - mRNA DEGs (padj<0.05): 208
   - miRNA DEmiRs (pvalue<0.05): 46
   - DNA DMRs (显著差异甲基化区域): 493648

🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬 数据提取阶段 🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬🧬

🔗 mRNA数据提取和匹配:
------------------------------
   使用Row.names进行匹配
   ID匹配成功率: 208/208 (100.0%)
   样本匹配成功率: 50/50 (100.0%)
   分组统计: {'Normal_B12': 25, 'Low_B12': 25}

🔗 miRNA数据提取和匹配:
---------------------------