# Code 1: Differential Expression Analysis

In [12]:
import pandas as pd
import numpy as np

degs_path = '/Users/heweilin/Desktop/P056/1mRNA_DEGs_proteincoding.csv'
demirs_path = '/Users/heweilin/Desktop/P056/2miRNA_DEmirs.csv'
dmrs_path = '/Users/heweilin/Desktop/P056/4DNA_DMRs.csv'

# Analyze mRNA differential expression results
degs = pd.read_csv(degs_path)
print("=== mRNA Differential Expression Analysis ===")
print(f"Total genes: {len(degs)}")
print(f"Column names: {list(degs.columns)}")

# Count significantly differentially expressed genes
if 'padj' in degs.columns:
    sig_degs = degs[degs['padj'] < 0.05]
    print(f"Significant DEGs (padj<0.05): {len(sig_degs)}")
    
    # logFC distribution
    if 'log2FoldChange' in degs.columns:
        up_genes = sig_degs[sig_degs['log2FoldChange'] > 0]
        down_genes = sig_degs[sig_degs['log2FoldChange'] < 0]
        print(f"Up-regulated genes: {len(up_genes)}")
        print(f"Down-regulated genes: {len(down_genes)}")

print("\n=== miRNA Differential Expression Analysis ===")
demirs = pd.read_csv(demirs_path)
print(f"Total miRNAs: {len(demirs)}")
print(f"Column names: {list(demirs.columns)}")

# Count significantly differentially expressed miRNAs
if 'pvalue' in demirs.columns:
    sig_demirs = demirs[demirs['pvalue'] < 0.05]
    print(f"Significant DEmiRs (p<0.05): {len(sig_demirs)}")

print("\n=== DNA Methylation Differential Analysis ===")
# Resolve mixed types warning
dmrs = pd.read_csv(dmrs_path, low_memory=False)
print(f"Significant DMRs: {len(dmrs)}")
print(f"Column names: {list(dmrs.columns)}")

=== mRNA Differential Expression Analysis ===
Total genes: 19853
Column names: ['Row.names', 'baseMean', 'log2FoldChange', 'lfcSE', 'stat', 'pvalue', 'padj', 'Chromosome', 'Gene_start', 'Gene_end', 'Strand', 'Gene_type', 'SYMBOL']
Significant DEGs (padj<0.05): 208
Up-regulated genes: 138
Down-regulated genes: 70

=== miRNA Differential Expression Analysis ===
Total miRNAs: 2201
Column names: ['Unnamed: 0', 'baseMean', 'log2FoldChange', 'lfcSE', 'stat', 'pvalue', 'padj']
Significant DEmiRs (p<0.05): 46

=== DNA Methylation Differential Analysis ===
Significant DMRs: 493648
Column names: ['Unnamed: 0', 'annot.tx_id', 'seqnames', 'start', 'end', 'width', 'strand', 'name', 'pvalue', 'qvalue', 'meth.diff', 'annot.seqnames', 'annot.start', 'annot.end', 'annot.width', 'annot.strand', 'annot.id', 'annot.gene_id', 'annot.symbol', 'annot.type', 'Gene.stable.ID', 'Gene.stable.ID.version', 'Transcript.stable.ID', 'HGNC.symbol', 'DiffMethylated']


# Code 2: Sample ID Consistency Check 

In [15]:
import pandas as pd

# Load all data files
print("=== Loading Data Files ===")
clinical_path = '/Users/heweilin/Desktop/P056/7Clinical_data50.csv'
mrna_tpm_path = '/Users/heweilin/Desktop/P056/5mRNA_TPM.csv'
mirna_tpm_path = '/Users/heweilin/Desktop/P056/6miRNA_TPM.csv'

# Read data
clinical = pd.read_csv(clinical_path)
mrna_tpm = pd.read_csv(mrna_tpm_path, index_col=0)
mirna_tpm = pd.read_csv(mirna_tpm_path, index_col=0)

print("Data loading completed")

print("\n=== Sample ID Consistency Check ===")

# Extract sample IDs from each omics dataset
mrna_samples = set(mrna_tpm.columns)
mirna_samples = set(mirna_tpm.columns)
clinical_mrna_ids = set(clinical['mRNA_ID'].dropna())
clinical_mirna_ids = set(clinical['sRNA_ID'].dropna())

print(f"mRNA expression matrix samples: {len(mrna_samples)}")
print(f"miRNA expression matrix samples: {len(mirna_samples)}")
print(f"Clinical data mRNA_IDs: {len(clinical_mrna_ids)}")
print(f"Clinical data sRNA_IDs: {len(clinical_mirna_ids)}")

# Check consistency
mrna_match = mrna_samples == clinical_mrna_ids
mirna_match = mirna_samples == clinical_mirna_ids

print(f"\nmRNA sample ID perfect match: {mrna_match}")
print(f"miRNA sample ID perfect match: {mirna_match}")

if not mrna_match:
    print("\nmRNA ID difference analysis:")
    only_in_expression = mrna_samples - clinical_mrna_ids
    only_in_clinical = clinical_mrna_ids - mrna_samples
    print(f"Only in expression matrix: {only_in_expression}")
    print(f"Only in clinical data: {only_in_clinical}")

if not mirna_match:
    print("\nmiRNA ID difference analysis:")
    only_in_expression = mirna_samples - clinical_mirna_ids
    only_in_clinical = clinical_mirna_ids - mirna_samples
    print(f"Only in expression matrix: {only_in_expression}")
    print(f"Only in clinical data: {only_in_clinical}")

print(f"\n=== Sample ID Examples ===")
print(f"mRNA samples: {sorted(list(mrna_samples))[:5]}")
print(f"miRNA samples: {sorted(list(mirna_samples))[:5]}")
print(f"Clinical mRNA_IDs: {sorted(list(clinical_mrna_ids))[:5]}")
print(f"Clinical sRNA_IDs: {sorted(list(clinical_mirna_ids))[:5]}")

# Validate sample count equals 50
print(f"\n=== Sample Count Validation ===")
print(f"All datasets have 50 samples:")
print(f"mRNA: {len(mrna_samples) == 50}")
print(f"miRNA: {len(mirna_samples) == 50}")
print(f"Clinical mRNA: {len(clinical_mrna_ids) == 50}")
print(f"Clinical miRNA: {len(clinical_mirna_ids) == 50}")

=== Loading Data Files ===
Data loading completed

=== Sample ID Consistency Check ===
mRNA expression matrix samples: 50
miRNA expression matrix samples: 50
Clinical data mRNA_IDs: 50
Clinical data sRNA_IDs: 50

mRNA sample ID perfect match: True
miRNA sample ID perfect match: True

=== Sample ID Examples ===
mRNA samples: ['P102m', 'P105m', 'P111m', 'P113m', 'P117m']
miRNA samples: ['P102s', 'P105s', 'P111s', 'P113s', 'P117s']
Clinical mRNA_IDs: ['P102m', 'P105m', 'P111m', 'P113m', 'P117m']
Clinical sRNA_IDs: ['P102s', 'P105s', 'P111s', 'P113s', 'P117s']

=== Sample Count Validation ===
All datasets have 50 samples:
mRNA: True
miRNA: True
Clinical mRNA: True
Clinical miRNA: True


# Code 3: Data Quality Analysis 

In [18]:
import pandas as pd
import numpy as np

# Load expression matrices for quality analysis
mrna_tpm_path = '/Users/heweilin/Desktop/P056/5mRNA_TPM.csv'
mirna_tpm_path = '/Users/heweilin/Desktop/P056/6miRNA_TPM.csv'

print("=== Expression Data Quality Analysis ===")

# mRNA data analysis
mrna_tpm = pd.read_csv(mrna_tpm_path, index_col=0)
print(f"mRNA data dimensions: {mrna_tpm.shape}")

# Check for missing values
print(f"mRNA missing values: {mrna_tpm.isnull().sum().sum()}")

# Expression level distribution
print(f"mRNA expression distribution:")
print(f"  Minimum: {mrna_tpm.min().min():.4f}")
print(f"  Maximum: {mrna_tpm.max().max():.0f}")
print(f"  Median: {mrna_tpm.median().median():.2f}")
print(f"  Mean: {mrna_tpm.mean().mean():.2f}")

# miRNA data analysis
mirna_tpm = pd.read_csv(mirna_tpm_path, index_col=0)
print(f"\nmiRNA data dimensions: {mirna_tpm.shape}")
print(f"miRNA missing values: {mirna_tpm.isnull().sum().sum()}")

print(f"miRNA expression distribution:")
print(f"  Minimum: {mirna_tpm.min().min():.4f}")
print(f"  Maximum: {mirna_tpm.max().max():.0f}")
print(f"  Median: {mirna_tpm.median().median():.2f}")
print(f"  Mean: {mirna_tpm.mean().mean():.2f}")

# Precise data scale ratio
ratio = mrna_tpm.shape[0] / mirna_tpm.shape[0]
print(f"\n=== Data Scale Comparison ===")
print(f"mRNA gene count: {mrna_tpm.shape[0]:,}")
print(f"miRNA count: {mirna_tpm.shape[0]:,}")
print(f"Precise ratio: {ratio:.1f}:1")

=== Expression Data Quality Analysis ===
mRNA data dimensions: (58735, 50)
mRNA missing values: 0
mRNA expression distribution:
  Minimum: 0.0000
  Maximum: 151434
  Median: 0.00
  Mean: 17.03

miRNA data dimensions: (2201, 50)
miRNA missing values: 0
miRNA expression distribution:
  Minimum: 0.0000
  Maximum: 207169
  Median: 0.09
  Mean: 454.34

=== Data Scale Comparison ===
mRNA gene count: 58,735
miRNA count: 2,201
Precise ratio: 26.7:1


# How many significant miRNAs remain if FDR corrections are made

In [21]:
import pandas as pd
import numpy as np
from scipy.stats import false_discovery_control

# Load miRNA differential expression results
demirs_path = '/Users/heweilin/Desktop/P056/2miRNA_DEmirs.csv'
demirs = pd.read_csv(demirs_path)

print("=== miRNA Multiple Testing Correction Analysis ===")
print(f"Total miRNAs tested: {len(demirs)}")

# Check if padj exists and compare with manual FDR calculation
print(f"padj column exists: {'padj' in demirs.columns}")

if 'pvalue' in demirs.columns:
    p_values = demirs['pvalue'].dropna()
    print(f"Valid p-values: {len(p_values)}")
    
    # Count significant results with original p-value
    sig_original = (p_values < 0.05).sum()
    print(f"Significant with p < 0.05: {sig_original}")
    
    # Manual FDR correction using Benjamini-Hochberg
    if len(p_values) > 0:
        # Method 1: Using scipy (if available)
        try:
            fdr_corrected = false_discovery_control(p_values, method='bh')
            sig_fdr_scipy = (fdr_corrected < 0.05).sum()
            print(f"Significant with FDR < 0.05 (scipy): {sig_fdr_scipy}")
        except:
            print("scipy FDR correction not available")
        
        # Method 2: Manual Benjamini-Hochberg calculation
        sorted_pvals = np.sort(p_values)
        n = len(sorted_pvals)
        
        # Calculate BH critical values
        bh_critical = []
        for i, p in enumerate(sorted_pvals):
            critical_val = (i + 1) / n * 0.05
            bh_critical.append(critical_val)
        
        # Find significant after FDR correction
        significant_bh = sorted_pvals <= bh_critical
        sig_fdr_manual = significant_bh.sum()
        
        print(f"Significant with FDR < 0.05 (manual BH): {sig_fdr_manual}")
        
        # Show the impact of multiple testing correction
        print(f"\nImpact of multiple testing correction:")
        print(f"Without correction: {sig_original} significant miRNAs")
        print(f"With FDR correction: {sig_fdr_manual} significant miRNAs")
        print(f"Reduction: {sig_original - sig_fdr_manual} miRNAs ({(1-sig_fdr_manual/sig_original)*100:.1f}% reduction)")
        
        # Check if padj column exists and compare
        if 'padj' in demirs.columns:
            padj_values = demirs['padj'].dropna()
            sig_padj = (padj_values < 0.05).sum()
            print(f"Significant with existing padj < 0.05: {sig_padj}")
        
        # Show some examples of the most significant p-values
        print(f"\nTop 10 most significant p-values:")
        top_indices = np.argsort(p_values)[:10]
        for i, idx in enumerate(top_indices):
            print(f"{i+1}. p-value: {p_values.iloc[idx]:.2e}")

# Additional analysis: check the distribution of p-values
print(f"\nP-value distribution:")
if 'pvalue' in demirs.columns:
    p_vals = demirs['pvalue'].dropna()
    print(f"P-values < 0.001: {(p_vals < 0.001).sum()}")
    print(f"P-values < 0.01: {(p_vals < 0.01).sum()}")
    print(f"P-values < 0.05: {(p_vals < 0.05).sum()}")
    print(f"P-values < 0.1: {(p_vals < 0.1).sum()}")

=== miRNA Multiple Testing Correction Analysis ===
Total miRNAs tested: 2201
padj column exists: True
Valid p-values: 2200
Significant with p < 0.05: 46
Significant with FDR < 0.05 (scipy): 2
Significant with FDR < 0.05 (manual BH): 2

Impact of multiple testing correction:
Without correction: 46 significant miRNAs
With FDR correction: 2 significant miRNAs
Reduction: 44 miRNAs (95.7% reduction)
Significant with existing padj < 0.05: 2

Top 10 most significant p-values:
1. p-value: 8.66e-14
2. p-value: 1.32e-05
3. p-value: 7.79e-05
4. p-value: 6.14e-03
5. p-value: 6.27e-03
6. p-value: 6.79e-03
7. p-value: 6.96e-03
8. p-value: 7.45e-03
9. p-value: 7.79e-03
10. p-value: 1.01e-02

P-value distribution:
P-values < 0.001: 3
P-values < 0.01: 9
P-values < 0.05: 46
P-values < 0.1: 102


# Differential Expression Details

In [26]:
import pandas as pd
import numpy as np

print("=== Differential Analysis Methods Validation ===")

# Load differential expression results
degs_path = '/Users/heweilin/Desktop/P056/1mRNA_DEGs_proteincoding.csv'
demirs_path = '/Users/heweilin/Desktop/P056/2miRNA_DEmirs.csv'
dmrs_path = '/Users/heweilin/Desktop/P056/4DNA_DMRs.csv'

# === mRNA Analysis ===
degs = pd.read_csv(degs_path)
print("\n=== mRNA Differential Expression Details ===")
print(f"Total genes analyzed: {len(degs)}")

# Different significance thresholds for mRNA
if 'pvalue' in degs.columns and 'padj' in degs.columns:
    # p < 0.05 without correction
    sig_p005 = degs[degs['pvalue'] < 0.05]
    print(f"Significant genes (p < 0.05): {len(sig_p005)}")
    
    # padj < 0.05 with FDR correction
    sig_padj = degs[degs['padj'] < 0.05]
    print(f"Significant genes (padj < 0.05): {len(sig_padj)}")
    
    # Direction analysis for padj < 0.05
    if 'log2FoldChange' in degs.columns:
        up_padj = sig_padj[sig_padj['log2FoldChange'] > 0]
        down_padj = sig_padj[sig_padj['log2FoldChange'] < 0]
        print(f"Up-regulated (padj < 0.05): {len(up_padj)}")
        print(f"Down-regulated (padj < 0.05): {len(down_padj)}")
        
        # High fold change analysis
        high_fc = sig_padj[abs(sig_padj['log2FoldChange']) >= 1]
        high_fc_up = high_fc[high_fc['log2FoldChange'] > 0]
        high_fc_down = high_fc[high_fc['log2FoldChange'] < 0]
        print(f"High fold change genes (|logFC| ≥ 1): {len(high_fc)} (up: {len(high_fc_up)}, down: {len(high_fc_down)})")

# === miRNA Analysis ===
demirs = pd.read_csv(demirs_path)
print("\n=== miRNA Differential Expression Details ===")
print(f"Total miRNAs analyzed: {len(demirs)}")

if 'pvalue' in demirs.columns and 'padj' in demirs.columns:
    # Different p-value thresholds
    sig_p001 = demirs[demirs['pvalue'] < 0.001]
    sig_p005 = demirs[demirs['pvalue'] < 0.05]
    sig_padj = demirs[demirs['padj'] < 0.05]
    
    print(f"Extremely significant (p < 0.001): {len(sig_p001)}")
    print(f"Significant (p < 0.05): {len(sig_p005)}")
    print(f"FDR corrected significant (padj < 0.05): {len(sig_padj)}")
    print(f"Reduction rate: {(1 - len(sig_padj)/len(sig_p005))*100:.1f}%")

# === DNA methylation Analysis ===
dmrs = pd.read_csv(dmrs_path, low_memory=False)
print("\n=== DNA Methylation (DMRs) Details ===")
print(f"Total DMRs: {len(dmrs):,}")

# Check methylation direction if available
if 'meth.diff' in dmrs.columns:
    hyper_dmrs = dmrs[dmrs['meth.diff'] > 0]
    hypo_dmrs = dmrs[dmrs['meth.diff'] < 0]
    print(f"Hypermethylated DMRs: {len(hyper_dmrs):,}")
    print(f"Hypomethylated DMRs: {len(hypo_dmrs):,}")

# Check genomic annotation if available
if 'annot.type' in dmrs.columns:
    annotation_counts = dmrs['annot.type'].value_counts()
    print(f"\nDMR genomic distribution:")
    for annotation, count in annotation_counts.head(5).items():
        print(f"  {annotation}: {count:,}")

# Check width/length information
if 'width' in dmrs.columns:
    avg_width = dmrs['width'].mean()
    print(f"Average DMR length: {avg_width:.0f} bp")

# Count unique genes covered
if 'annot.symbol' in dmrs.columns:
    unique_genes = dmrs['annot.symbol'].dropna().nunique()
    print(f"Unique genes with DMRs: {unique_genes:,}")


=== Differential Analysis Methods Validation ===

=== mRNA Differential Expression Details ===
Total genes analyzed: 19853
Significant genes (p < 0.05): 2979
Significant genes (padj < 0.05): 208
Up-regulated (padj < 0.05): 138
Down-regulated (padj < 0.05): 70
High fold change genes (|logFC| ≥ 1): 135 (up: 89, down: 46)

=== miRNA Differential Expression Details ===
Total miRNAs analyzed: 2201
Extremely significant (p < 0.001): 3
Significant (p < 0.05): 46
FDR corrected significant (padj < 0.05): 2
Reduction rate: 95.7%

=== DNA Methylation (DMRs) Details ===
Total DMRs: 493,648
Hypermethylated DMRs: 297,021
Hypomethylated DMRs: 196,627

DMR genomic distribution:
  hg38_genes_introns: 242,537
  hg38_genes_exons: 90,909
  hg38_genes_1to5kb: 49,599
  hg38_cpg_inter: 35,995
  hg38_genes_promoters: 24,421
Average DMR length: 999 bp
Unique genes with DMRs: 16,928


# Data Quality Assessment

In [51]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import os
from pathlib import Path

def generate_quality_assessment_figures():
    """
    Generate all figures for Chapter 3.4 quality assessment
    Save figures to /Users/heweilin/Desktop/P056_Code/Figure with prefix "1_"
    """
    
    print("=== Generating Quality Assessment Figures ===")
    
    # Setup paths
    processed_dir = "/Users/heweilin/Desktop/P056_Code/Processed_Data"
    original_dir = "/Users/heweilin/Desktop/P056"
    figure_dir = "/Users/heweilin/Desktop/P056_Code/Figure"
    
    # Create figure directory
    Path(figure_dir).mkdir(parents=True, exist_ok=True)
    print(f"✓ Figure directory created: {figure_dir}")
    
    # Set plotting style
    plt.style.use('default')
    sns.set_palette("husl")
    
    print("\nStep 1: Loading data...")
    
    try:
        # Load data
        mrna_original = pd.read_csv(f"{original_dir}/5mRNA_TPM.csv", index_col=0)
        mirna_original = pd.read_csv(f"{original_dir}/6miRNA_TPM.csv", index_col=0)
        mrna_corrected = pd.read_csv(f"{processed_dir}/5mRNA_TPM_Processed.csv", index_col=0)
        mirna_corrected = pd.read_csv(f"{processed_dir}/6miRNA_TPM_Processed.csv", index_col=0)
        clinical = pd.read_csv(f"{original_dir}/7Clinical_data50.csv")
        m j j j j j n m h
        print("✓ All data loaded successfully")
        
    except Exception as e:
        print(f"❌ Error loading data: {e}")
        return
    
    print("\nStep 2: Preparing metadata...")
    
    # Create sample metadata
    def create_sample_metadata(clinical_data):
        sample_meta = {}
        
        for _, row in clinical_data.iterrows():
            # mRNA metadata
            sample_meta[row['mRNA_ID']] = {
                'batch': row['batch_mRNA'],
                'B12_status': row['B12_status'],
                'sample_type': 'mRNA'
            }
            # miRNA metadata  
            sample_meta[row['sRNA_ID']] = {
                'batch': row['batch_miRNA'],
                'B12_status': row['B12_status'],
                'sample_type': 'miRNA'
            }
        
        return sample_meta
    
    sample_meta = create_sample_metadata(clinical)
    print("✓ Sample metadata prepared")
    
    print("\nStep 3: Generating Figure 1 - PCA Before Batch Correction...")
    
    # Figure 1: PCA before batch correction
    def plot_pca_analysis(data, sample_metadata, title, filename):
        
        # Prepare data for PCA
        log_data = np.log2(data + 1)
        top_genes = log_data.var(axis=1).nlargest(5000).index
        pca_data = log_data.loc[top_genes].T
        
        # Standardize and run PCA
        scaler = StandardScaler()
        pca_scaled = scaler.fit_transform(pca_data)
        pca = PCA()
        pca_result = pca.fit_transform(pca_scaled)
        
        # Extract sample info
        batches = [sample_metadata[sample]['batch'] for sample in data.columns]
        b12_status = [sample_metadata[sample]['B12_status'] for sample in data.columns]
        
        # Create subplot
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
        fig.suptitle(title, fontsize=16, fontweight='bold')
        
        # Plot 1: PCA colored by batch
        unique_batches = sorted(set(batches))
        colors_batch = plt.cm.Set1(np.linspace(0, 1, len(unique_batches)))
        
        for i, batch in enumerate(unique_batches):
            mask = [b == batch for b in batches]
            ax1.scatter(pca_result[mask, 0], pca_result[mask, 1], 
                       c=[colors_batch[i]], label=f'Batch {batch}', 
                       alpha=0.7, s=60)
        
        ax1.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
        ax1.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
        ax1.set_title('PCA colored by Batch')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: PCA colored by B12 status
        colors_b12 = {'LB': 'red', 'NB': 'blue'}
        
        for status in ['LB', 'NB']:
            mask = [b == status for b in b12_status]
            ax2.scatter(pca_result[mask, 0], pca_result[mask, 1],
                       c=colors_b12[status], label=f'{status} group',
                       alpha=0.7, s=60)
        
        ax2.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
        ax2.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
        ax2.set_title('PCA colored by B12 Status')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: Variance explained
        var_explained = pca.explained_variance_ratio_[:10]
        ax3.bar(range(1, 11), var_explained * 100, alpha=0.7)
        ax3.set_xlabel('Principal Component')
        ax3.set_ylabel('Variance Explained (%)')
        ax3.set_title('Variance Explained by Top 10 PCs')
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: Cumulative variance
        cumvar = np.cumsum(pca.explained_variance_ratio_[:20])
        ax4.plot(range(1, 21), cumvar * 100, 'o-', linewidth=2, markersize=6)
        ax4.set_xlabel('Number of Principal Components')
        ax4.set_ylabel('Cumulative Variance Explained (%)')
        ax4.set_title('Cumulative Variance Explained')
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # Save figure
        output_path = os.path.join(figure_dir, filename)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"✓ Saved: {filename}")
        
        return pca_result, pca.explained_variance_ratio_, batches, b12_status
    
    # Generate PCA plots
    mrna_pca_orig = plot_pca_analysis(
        mrna_original, sample_meta, 
        'mRNA Expression PCA - Before Batch Correction',
        '1_mRNA_PCA_Original.png'
    )
    
    mrna_pca_corr = plot_pca_analysis(
        mrna_corrected, sample_meta,
        'mRNA Expression PCA - After Batch Correction', 
        '1_mRNA_PCA_Corrected.png'
    )
    
    print("\nStep 4: Generating Figure 3 - Sample Correlation Heatmap...")
    
    # Figure 3: Sample correlation heatmap
    def plot_correlation_heatmap(data, sample_metadata, title, filename):
        
        # Calculate correlation matrix using top 10k genes
        top_genes = data.var(axis=1).nlargest(10000).index
        corr_matrix = data.loc[top_genes].T.corr()
        
        # Create annotation for samples
        sample_labels = []
        batch_colors = []
        b12_colors = []
        
        batch_color_map = {1: 'red', 2: 'blue', 3: 'green', 4: 'orange', 5: 'purple', 6: 'brown', 7: 'pink'}
        b12_color_map = {'LB': 'red', 'NB': 'blue'}
        
        for sample in data.columns:
            meta = sample_metadata[sample]
            sample_labels.append(f"{meta['B12_status']}-B{meta['batch']}")
            batch_colors.append(batch_color_map.get(meta['batch'], 'gray'))
            b12_colors.append(b12_color_map[meta['B12_status']])
        
        # Create figure
        fig, ax = plt.subplots(figsize=(12, 10))
        
        # Plot heatmap
        sns.heatmap(corr_matrix, 
                   xticklabels=False, yticklabels=False,
                   cmap='RdBu_r', center=0, square=True,
                   linewidths=0.1, cbar_kws={"shrink": .8})
        
        ax.set_title(title, fontsize=14, fontweight='bold')
        
        # Add color bars for batch and B12 status
        # This is simplified - in practice you might want more sophisticated annotation
        
        plt.tight_layout()
        
        # Save figure  
        output_path = os.path.join(figure_dir, filename)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"✓ Saved: {filename}")
        
        # Return correlation statistics
        upper_triangle = corr_matrix.values[np.triu_indices_from(corr_matrix.values, k=1)]
        return {
            'mean': np.mean(upper_triangle),
            'std': np.std(upper_triangle),
            'min': np.min(upper_triangle),
            'max': np.max(upper_triangle)
        }
    
    # Generate correlation heatmaps
    corr_stats_orig = plot_correlation_heatmap(
        mrna_original, sample_meta,
        'Sample Correlation Heatmap - Before Batch Correction',
        '1_Sample_Correlation_Original.png'
    )
    
    corr_stats_corr = plot_correlation_heatmap(
        mrna_corrected, sample_meta,
        'Sample Correlation Heatmap - After Batch Correction',
        '1_Sample_Correlation_Corrected.png'
    )
    
    print("\nStep 5: Generating Figure 5 - Batch Effect Summary...")
    
    # Figure 5: Batch effect summary
    def plot_batch_effect_summary():
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Batch Effect Analysis Summary', fontsize=16, fontweight='bold')
        
        # Plot 1: Batch distribution
        mrna_batches = [sample_meta[sample]['batch'] for sample in mrna_original.columns if 'm' in sample]
        batch_counts = pd.Series(mrna_batches).value_counts().sort_index()
        
        ax1.bar(batch_counts.index, batch_counts.values, alpha=0.7)
        ax1.set_xlabel('Batch Number')
        ax1.set_ylabel('Number of Samples')
        ax1.set_title('mRNA Sample Distribution by Batch')
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: B12 status distribution by batch
        batch_b12_data = []
        for batch in sorted(set(mrna_batches)):
            lb_count = sum(1 for sample in mrna_original.columns 
                          if sample_meta[sample]['batch'] == batch and sample_meta[sample]['B12_status'] == 'LB')
            nb_count = sum(1 for sample in mrna_original.columns 
                          if sample_meta[sample]['batch'] == batch and sample_meta[sample]['B12_status'] == 'NB')
            batch_b12_data.append([batch, lb_count, nb_count])
        
        batch_b12_df = pd.DataFrame(batch_b12_data, columns=['Batch', 'LB', 'NB'])
        
        x = np.arange(len(batch_b12_df))
        width = 0.35
        
        ax2.bar(x - width/2, batch_b12_df['LB'], width, label='LB', alpha=0.7, color='red')
        ax2.bar(x + width/2, batch_b12_df['NB'], width, label='NB', alpha=0.7, color='blue')
        
        ax2.set_xlabel('Batch Number')
        ax2.set_ylabel('Number of Samples')
        ax2.set_title('B12 Status Distribution by Batch')
        ax2.set_xticks(x)
        ax2.set_xticklabels(batch_b12_df['Batch'])
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: PC1 values by batch (original)
        pc1_orig, _, batches_orig, b12_orig = mrna_pca_orig
        
        batch_pc1_data = {}
        for i, batch in enumerate(batches_orig):
            if batch not in batch_pc1_data:
                batch_pc1_data[batch] = []
            batch_pc1_data[batch].append(pc1_orig[i, 0])
        
        ax3.boxplot([batch_pc1_data[batch] for batch in sorted(batch_pc1_data.keys())], 
                   labels=sorted(batch_pc1_data.keys()))
        ax3.set_xlabel('Batch Number')
        ax3.set_ylabel('PC1 Score')
        ax3.set_title('PC1 Distribution by Batch (Original)')
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: PC1 values by batch (corrected)
        pc1_corr, _, batches_corr, b12_corr = mrna_pca_corr
        
        batch_pc1_data_corr = {}
        for i, batch in enumerate(batches_corr):
            if batch not in batch_pc1_data_corr:
                batch_pc1_data_corr[batch] = []
            batch_pc1_data_corr[batch].append(pc1_corr[i, 0])
        
        ax4.boxplot([batch_pc1_data_corr[batch] for batch in sorted(batch_pc1_data_corr.keys())], 
                   labels=sorted(batch_pc1_data_corr.keys()))
        ax4.set_xlabel('Batch Number')
        ax4.set_ylabel('PC1 Score')
        ax4.set_title('PC1 Distribution by Batch (Corrected)')
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # Save figure
        output_path = os.path.join(figure_dir, '1_Batch_Effect_Summary.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"✓ Saved: 1_Batch_Effect_Summary.png")
    
    plot_batch_effect_summary()
    
    print("\nStep 6: Generating Figure 6 - Quality Metrics Comparison...")
    
    # Figure 6: Quality metrics comparison
    def plot_quality_metrics():
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Data Quality Metrics Comparison', fontsize=16, fontweight='bold')
        
        # Plot 1: Correlation improvement
        corr_metrics = ['Mean Correlation', 'Min Correlation', 'Max Correlation']
        original_vals = [corr_stats_orig['mean'], corr_stats_orig['min'], corr_stats_orig['max']]
        corrected_vals = [corr_stats_corr['mean'], corr_stats_corr['min'], corr_stats_corr['max']]
        
        x = np.arange(len(corr_metrics))
        width = 0.35
        
        ax1.bar(x - width/2, original_vals, width, label='Original', alpha=0.7)
        ax1.bar(x + width/2, corrected_vals, width, label='Corrected', alpha=0.7)
        
        ax1.set_ylabel('Correlation Coefficient')
        ax1.set_title('Sample Correlation Metrics')
        ax1.set_xticks(x)
        ax1.set_xticklabels(corr_metrics, rotation=45)
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: PC variance explained
        pc_labels = ['PC1', 'PC2', 'PC1+PC2', 'PC1-10']
        pc1_orig, pc2_orig = mrna_pca_orig[1][0], mrna_pca_orig[1][1]
        pc1_corr, pc2_corr = mrna_pca_corr[1][0], mrna_pca_corr[1][1]
        
        original_pc = [pc1_orig*100, pc2_orig*100, (pc1_orig+pc2_orig)*100, sum(mrna_pca_orig[1][:10])*100]
        corrected_pc = [pc1_corr*100, pc2_corr*100, (pc1_corr+pc2_corr)*100, sum(mrna_pca_corr[1][:10])*100]
        
        x = np.arange(len(pc_labels))
        ax2.bar(x - width/2, original_pc, width, label='Original', alpha=0.7)
        ax2.bar(x + width/2, corrected_pc, width, label='Corrected', alpha=0.7)
        
        ax2.set_ylabel('Variance Explained (%)')
        ax2.set_title('Principal Component Variance')
        ax2.set_xticks(x)
        ax2.set_xticklabels(pc_labels)
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: Detection rates
        mrna_detection = (mrna_original > 0).mean(axis=1)
        mirna_detection = (mirna_original > 0).mean(axis=1)
        
        detection_thresholds = [0.25, 0.5, 0.75, 0.9]
        mrna_counts = [(mrna_detection > t).sum() for t in detection_thresholds]
        mirna_counts = [(mirna_detection > t).sum() for t in detection_thresholds]
        
        x = np.arange(len(detection_thresholds))
        ax3.bar(x - width/2, mrna_counts, width, label='mRNA', alpha=0.7)
        ax3.bar(x + width/2, mirna_counts, width, label='miRNA', alpha=0.7)
        
        ax3.set_ylabel('Number of Features')
        ax3.set_title('Feature Detection Rates')
        ax3.set_xticks(x)
        ax3.set_xticklabels([f'>{int(t*100)}%' for t in detection_thresholds])
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: Expression range
        data_types = ['mRNA', 'miRNA']
        max_vals = [mrna_original.max().max(), mirna_original.max().max()]
        mean_vals = [mrna_original.mean().mean(), mirna_original.mean().mean()]
        
        x = np.arange(len(data_types))
        ax4.bar(x - width/2, max_vals, width, label='Maximum', alpha=0.7)
        ax4.bar(x + width/2, mean_vals, width, label='Mean', alpha=0.7)
        
        ax4.set_ylabel('Expression Level (TPM)')
        ax4.set_title('Expression Level Ranges')
        ax4.set_xticks(x)
        ax4.set_xticklabels(data_types)
        ax4.legend()
        ax4.set_yscale('log')
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # Save figure
        output_path = os.path.join(figure_dir, '1_Quality_Metrics.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"✓ Saved: 1_Quality_Metrics.png")
    
    plot_quality_metrics()
    
    print("\n" + "="*60)
    print("🎉 ALL QUALITY ASSESSMENT FIGURES GENERATED!")
    print("="*60)
    
    print(f"\nFigures saved to: {figure_dir}")
    print("Generated files:")
    print("  1. 1_mRNA_PCA_Original.png - PCA before batch correction")
    print("  2. 1_mRNA_PCA_Corrected.png - PCA after batch correction") 
    print("  3. 1_Sample_Correlation_Original.png - Correlation heatmap (original)")
    print("  4. 1_Sample_Correlation_Corrected.png - Correlation heatmap (corrected)")
    print("  5. 1_Batch_Effect_Summary.png - Batch effect analysis")
    print("  6. 1_Quality_Metrics.png - Quality metrics comparison")
    
    print("\n💡 Usage in Chapter 3.4:")
    print("  - Figure 1 & 2: Show PCA results before/after correction")
    print("  - Figure 3 & 4: Demonstrate correlation improvement")
    print("  - Figure 5: Illustrate batch effect patterns and correction")
    print("  - Figure 6: Summarize overall quality improvements")
    
    # Print some key statistics for text
    print("\n📊 Key Statistics for Text:")
    print(f"  Original mean correlation: {corr_stats_orig['mean']:.3f}")
    print(f"  Corrected mean correlation: {corr_stats_corr['mean']:.3f}")
    print(f"  Correlation improvement: {(corr_stats_corr['mean']/corr_stats_orig['mean']-1)*100:.1f}%")
    print(f"  Original PC1 variance: {mrna_pca_orig[1][0]*100:.1f}%")
    print(f"  Corrected PC1 variance: {mrna_pca_corr[1][0]*100:.1f}%")

# Run the figure generation
if __name__ == "__main__":
    generate_quality_assessment_figures()

=== Generating Quality Assessment Figures ===
✓ Figure directory created: /Users/heweilin/Desktop/P056_Code/Figure

Step 1: Loading data...
✓ All data loaded successfully

Step 2: Preparing metadata...
✓ Sample metadata prepared

Step 3: Generating Figure 1 - PCA Before Batch Correction...
✓ Saved: 1_mRNA_PCA_Original.png
✓ Saved: 1_mRNA_PCA_Corrected.png

Step 4: Generating Figure 3 - Sample Correlation Heatmap...
✓ Saved: 1_Sample_Correlation_Original.png
✓ Saved: 1_Sample_Correlation_Corrected.png

Step 5: Generating Figure 5 - Batch Effect Summary...


  ax3.boxplot([batch_pc1_data[batch] for batch in sorted(batch_pc1_data.keys())],
  ax4.boxplot([batch_pc1_data_corr[batch] for batch in sorted(batch_pc1_data_corr.keys())],


✓ Saved: 1_Batch_Effect_Summary.png

Step 6: Generating Figure 6 - Quality Metrics Comparison...
✓ Saved: 1_Quality_Metrics.png

🎉 ALL QUALITY ASSESSMENT FIGURES GENERATED!

Figures saved to: /Users/heweilin/Desktop/P056_Code/Figure
Generated files:
  1. 1_mRNA_PCA_Original.png - PCA before batch correction
  2. 1_mRNA_PCA_Corrected.png - PCA after batch correction
  3. 1_Sample_Correlation_Original.png - Correlation heatmap (original)
  4. 1_Sample_Correlation_Corrected.png - Correlation heatmap (corrected)
  5. 1_Batch_Effect_Summary.png - Batch effect analysis
  6. 1_Quality_Metrics.png - Quality metrics comparison

💡 Usage in Chapter 3.4:
  - Figure 1 & 2: Show PCA results before/after correction
  - Figure 3 & 4: Demonstrate correlation improvement
  - Figure 5: Illustrate batch effect patterns and correction
  - Figure 6: Summarize overall quality improvements

📊 Key Statistics for Text:
  Original mean correlation: 0.114
  Corrected mean correlation: 0.114
  Correlation improvem

In [63]:
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy import stats

# 设置输出目录
output_data_dir = "/Users/heweilin/Desktop/P056_Code/Data_Post_Processed"
output_figure_dir = "/Users/heweilin/Desktop/P056_Code/Figure"
data_dir = "/Users/heweilin/Desktop/P056_Code/Data"

# 创建输出目录
os.makedirs(output_data_dir, exist_ok=True)
os.makedirs(output_figure_dir, exist_ok=True)

# 读取临床数据
print("=" * 80)
print("正在读取临床数据...")
clinical_data = pd.read_csv(os.path.join(data_dir, "7Clinical_data50.csv"))
print(f"临床数据形状: {clinical_data.shape}")
print(f"临床数据列名: {clinical_data.columns.tolist()}")

# 寻找B12相关列
b12_columns = [col for col in clinical_data.columns if 'B12' in str(col) or 'b12' in str(col) or 'B-12' in str(col)]
print(f"B12相关列: {b12_columns}")

# 查看前几行数据了解结构
print("\n临床数据前5行:")
print(clinical_data.head())

def get_b12_groups(clinical_data):
    """
    从临床数据中提取B12分组信息
    """
    # 寻找B12列
    b12_columns = [col for col in clinical_data.columns if 'B12' in str(col) or 'b12' in str(col)]
    
    if not b12_columns:
        # 如果没有找到明确的B12列，查看所有数值列
        print("未找到明确的B12列，显示所有列供参考:")
        for i, col in enumerate(clinical_data.columns):
            print(f"  {i}: {col}")
        
        # 尝试使用第一个数值列
        numeric_cols = clinical_data.select_dtypes(include=[np.number]).columns.tolist()
        if numeric_cols:
            b12_column = numeric_cols[0]
            print(f"使用数值列 '{b12_column}' 作为B12数据")
        else:
            raise ValueError("无法找到合适的B12数据列")
    else:
        b12_column = b12_columns[0]
        print(f"使用B12列: {b12_column}")
    
    return b12_column

def perform_degs_pca_analysis(data_file, data_type, tpm_file):
    """
    使用差异表达数据进行PCA分析
    
    Parameters:
    - data_file: 差异表达结果文件 (DEGs或DEmiRs)
    - data_type: 数据类型 (mRNA 或 miRNA)
    - tpm_file: 对应的TPM表达量文件
    """
    print("=" * 80)
    print(f"开始 {data_type} 差异表达数据 PCA 分析")
    print("=" * 80)
    
    # 读取差异表达结果
    deg_data = pd.read_csv(data_file)
    print(f"差异表达数据形状: {deg_data.shape}")
    print(f"差异表达数据列名: {deg_data.columns.tolist()}")
    
    # 读取TPM数据
    if tpm_file.endswith('.xlsx'):
        tpm_data = pd.read_excel(tpm_file)
    else:
        tpm_data = pd.read_csv(tpm_file)
    
    print(f"TPM数据形状: {tpm_data.shape}")
    print(f"TPM数据列名前10个: {tpm_data.columns.tolist()[:10]}")
    
    # 识别显著差异表达的基因/miRNA
    if data_type == "mRNA":
        # 对于mRNA，通常使用padj < 0.05
        if 'padj' in deg_data.columns:
            significant_mask = (deg_data['padj'] < 0.05) & (deg_data['padj'].notna())
            print(f"使用padj < 0.05筛选显著DEGs")
        elif 'pvalue' in deg_data.columns:
            significant_mask = (deg_data['pvalue'] < 0.05) & (deg_data['pvalue'].notna())
            print(f"使用pvalue < 0.05筛选显著DEGs")
        else:
            # 寻找p值相关列
            p_cols = [col for col in deg_data.columns if 'p' in col.lower()]
            print(f"可用的p值列: {p_cols}")
            if p_cols:
                p_col = p_cols[0]
                significant_mask = (deg_data[p_col] < 0.05) & (deg_data[p_col].notna())
                print(f"使用 {p_col} < 0.05筛选显著DEGs")
            else:
                print("未找到p值列，使用所有基因")
                significant_mask = pd.Series([True] * len(deg_data))
    
    elif data_type == "miRNA":
        # 对于miRNA，根据readme使用pvalue
        if 'pvalue' in deg_data.columns:
            significant_mask = (deg_data['pvalue'] < 0.05) & (deg_data['pvalue'].notna())
            print(f"使用pvalue < 0.05筛选显著DEmiRs")
        elif 'padj' in deg_data.columns:
            significant_mask = (deg_data['padj'] < 0.05) & (deg_data['padj'].notna())
            print(f"使用padj < 0.05筛选显著DEmiRs")
        else:
            # 寻找p值相关列
            p_cols = [col for col in deg_data.columns if 'p' in col.lower()]
            print(f"可用的p值列: {p_cols}")
            if p_cols:
                p_col = p_cols[0]
                significant_mask = (deg_data[p_col] < 0.05) & (deg_data[p_col].notna())
                print(f"使用 {p_col} < 0.05筛选显著DEmiRs")
            else:
                print("未找到p值列，使用所有miRNA")
                significant_mask = pd.Series([True] * len(deg_data))
    
    # 获取显著差异表达的基因/miRNA列表
    significant_degs = deg_data[significant_mask]
    print(f"显著差异表达的{data_type}数量: {len(significant_degs)}")
    
    if len(significant_degs) == 0:
        print(f"警告: 没有找到显著的差异表达{data_type}!")
        return None, None
    
    # 获取基因/miRNA ID列
    if data_type == "mRNA":
        # 通常第一列是基因ID，或者寻找包含'gene'的列
        id_cols = [col for col in significant_degs.columns if any(x in col.lower() for x in ['gene', 'id', 'symbol', 'ensembl'])]
        if id_cols:
            gene_id_col = id_cols[0]
        else:
            gene_id_col = significant_degs.columns[0]
        
        significant_ids = significant_degs[gene_id_col].tolist()
        print(f"使用列 '{gene_id_col}' 作为基因ID")
        
    elif data_type == "miRNA":
        # miRNA通常第一列是miRNA ID
        mirna_id_col = significant_degs.columns[0]
        significant_ids = significant_degs[mirna_id_col].tolist()
        print(f"使用列 '{mirna_id_col}' 作为miRNA ID")
    
    print(f"前5个显著{data_type} IDs: {significant_ids[:5]}")
    
    # 从TPM数据中提取显著差异表达的基因/miRNA
    # TPM数据通常第一列是ID，其余列是样本
    tpm_data_indexed = tpm_data.set_index(tpm_data.columns[0])
    
    # 匹配显著差异表达的基因/miRNA
    matched_ids = []
    for sig_id in significant_ids:
        if sig_id in tpm_data_indexed.index:
            matched_ids.append(sig_id)
    
    print(f"在TPM数据中找到的显著{data_type}数量: {len(matched_ids)}")
    
    if len(matched_ids) == 0:
        print(f"错误: TPM数据中没有找到任何显著差异表达的{data_type}!")
        return None, None
    
    # 提取匹配的TPM数据
    selected_tpm_data = tpm_data_indexed.loc[matched_ids]
    
    # 转置数据，使样本为行，基因/miRNA为列
    pca_input_data = selected_tpm_data.T
    sample_names = pca_input_data.index.tolist()
    
    print(f"PCA输入数据形状 (样本 x 特征): {pca_input_data.shape}")
    print(f"样本名称前5个: {sample_names[:5]}")
    
    # 获取B12分组信息
    b12_column = get_b12_groups(clinical_data)
    
    # 匹配样本与临床数据
    matched_samples = []
    matched_clinical_data = []
    
    for sample in sample_names:
        # 尝试不同的匹配策略
        clinical_match = None
        
        # 直接匹配
        for idx, row in clinical_data.iterrows():
            if sample == str(row.iloc[0]) or sample in str(row.iloc[0]) or str(row.iloc[0]) in sample:
                clinical_match = row
                break
        
        # 尝试在所有列中匹配
        if clinical_match is None:
            for idx, row in clinical_data.iterrows():
                if any(sample == str(val) or sample in str(val) or str(val) in sample 
                       for val in row if pd.notna(val)):
                    clinical_match = row
                    break
        
        if clinical_match is not None:
            matched_samples.append(sample)
            matched_clinical_data.append(clinical_match)
    
    print(f"成功匹配临床数据的样本数: {len(matched_samples)} / {len(sample_names)}")
    
    if len(matched_samples) == 0:
        print("错误: 没有样本能够与临床数据匹配!")
        return None, None
    
    # 获取匹配样本的数据
    matched_pca_data = pca_input_data.loc[matched_samples]
    matched_clinical_df = pd.DataFrame(matched_clinical_data, index=matched_samples)
    
    # 获取B12值和分组
    b12_values = matched_clinical_df[b12_column].values
    print(f"B12值统计: min={np.min(b12_values):.2f}, max={np.max(b12_values):.2f}, mean={np.mean(b12_values):.2f}, std={np.std(b12_values):.2f}")
    
    # 基于中位数创建B12分组
    b12_median = np.median(b12_values)
    b12_groups = ['High_B12' if x >= b12_median else 'Low_B12' for x in b12_values]
    group_counts = pd.Series(b12_groups).value_counts()
    
    print(f"B12分组 (基于中位数 {b12_median:.2f}):")
    for group, count in group_counts.items():
        print(f"  {group}: {count} 样本")
    
    # 数据预处理和PCA
    print(f"\n使用{len(matched_ids)}个显著差异表达的{data_type}进行PCA分析...")
    
    # 处理缺失值
    pca_data_clean = matched_pca_data.fillna(0)
    
    # 移除方差为0的特征
    feature_vars = pca_data_clean.var()
    valid_features = feature_vars[feature_vars > 0].index
    pca_data_final = pca_data_clean[valid_features]
    
    print(f"移除零方差特征后的数据形状: {pca_data_final.shape}")
    
    # 标准化
    scaler = StandardScaler()
    data_scaled = scaler.fit_transform(pca_data_final)
    
    # PCA分析
    pca = PCA()
    pca_result = pca.fit_transform(data_scaled)
    
    # 方差解释比例
    explained_variance_ratio = pca.explained_variance_ratio_
    cumulative_variance = np.cumsum(explained_variance_ratio)
    
    print(f"\n{data_type} PCA方差解释结果:")
    print("-" * 40)
    for i in range(min(10, len(explained_variance_ratio))):
        print(f"  PC{i+1}: {explained_variance_ratio[i]:.4f} ({explained_variance_ratio[i]*100:.2f}%)")
    
    print(f"\n累积方差解释:")
    print(f"  前2个主成分: {cumulative_variance[1]:.4f} ({cumulative_variance[1]*100:.2f}%)")
    print(f"  前5个主成分: {cumulative_variance[4]:.4f} ({cumulative_variance[4]*100:.2f}%)")
    
    # 统计检验: B12组间在主成分上的差异
    print(f"\n{data_type} B12组间主成分差异检验:")
    print("-" * 50)
    
    significant_pcs = []
    for i in range(min(5, pca_result.shape[1])):
        pc_name = f'PC{i+1}'
        high_b12_pc = pca_result[np.array(b12_groups) == 'High_B12', i]
        low_b12_pc = pca_result[np.array(b12_groups) == 'Low_B12', i]
        
        # t检验
        t_stat, p_value = stats.ttest_ind(high_b12_pc, low_b12_pc)
        
        print(f"{pc_name} (解释方差: {explained_variance_ratio[i]*100:.2f}%):")
        print(f"  High_B12组: mean={np.mean(high_b12_pc):.4f}, std={np.std(high_b12_pc):.4f} (n={len(high_b12_pc)})")
        print(f"  Low_B12组:  mean={np.mean(low_b12_pc):.4f}, std={np.std(low_b12_pc):.4f} (n={len(low_b12_pc)})")
        print(f"  t检验: t={t_stat:.4f}, p={p_value:.4f}")
        
        if p_value < 0.05:
            print(f"  *** {pc_name} 显示B12组间显著分离 (p < 0.05) ***")
            significant_pcs.append(pc_name)
        elif p_value < 0.1:
            print(f"  ** {pc_name} 显示B12组间边际显著分离 (p < 0.1) **")
        else:
            print(f"  {pc_name} B12组间无显著分离 (p >= 0.1)")
        print()
    
    # 创建结果DataFrame
    pca_df = pd.DataFrame(
        pca_result[:, :10],  # 保存前10个主成分
        columns=[f'PC{i+1}' for i in range(min(10, pca_result.shape[1]))],
        index=matched_samples
    )
    pca_df['B12_Group'] = b12_groups
    pca_df['B12_Value'] = b12_values
    pca_df[f'Significant_{data_type}_Count'] = len(matched_ids)
    
    # 分析总结
    print(f"\n{data_type} PCA分析总结:")
    print("=" * 50)
    print(f"✅ 使用了 {len(matched_ids)} 个显著差异表达的{data_type}")
    print(f"✅ 分析了 {len(matched_samples)} 个样本")
    print(f"✅ B12分组: High_B12 ({group_counts.get('High_B12', 0)}个) vs Low_B12 ({group_counts.get('Low_B12', 0)}个)")
    print(f"✅ 前2个主成分解释了 {cumulative_variance[1]*100:.1f}% 的方差")
    
    if significant_pcs:
        print(f"✅ 发现B12组间显著分离的主成分: {', '.join(significant_pcs)}")
        print("🎯 结论: DEGs/DEmiRs显示B12组间存在分离趋势")
    else:
        print("ℹ️  主成分分析未显示显著的B12组间分离")
        print("🎯 结论: 需要更多数据或其他方法来观察B12效应")
    
    # 保存结果
    output_file = os.path.join(output_data_dir, f"1_{data_type}_DEGs_PCA_results.csv")
    pca_df.to_csv(output_file)
    print(f"\n📁 PCA结果已保存: {output_file}")
    
    # 保存方差解释
    variance_df = pd.DataFrame({
        'Principal_Component': [f'PC{i+1}' for i in range(len(explained_variance_ratio))],
        'Explained_Variance_Ratio': explained_variance_ratio,
        'Cumulative_Variance': cumulative_variance
    })
    variance_file = os.path.join(output_data_dir, f"1_{data_type}_DEGs_PCA_variance.csv")
    variance_df.to_csv(variance_file, index=False)
    print(f"📁 方差解释已保存: {variance_file}")
    
    # 生成可视化代码模板
    print(f"\n📊 {data_type} PCA可视化代码模板:")
    print("-" * 40)
    print(f"""
# {data_type} DEGs PCA 可视化
plt.figure(figsize=(12, 8))

# 设置颜色
colors = {{'High_B12': '#FF6B6B', 'Low_B12': '#4ECDC4'}}

# PCA散点图
for group in ['High_B12', 'Low_B12']:
    mask = np.array(b12_groups) == group
    plt.scatter(pca_result[mask, 0], pca_result[mask, 1], 
               c=colors[group], label=f'{{group}} (n={{sum(mask)}})', 
               alpha=0.7, s=60, edgecolor='white', linewidth=0.5)

plt.xlabel(f'PC1 ({explained_variance_ratio[0]*100:.1f}% variance)')
plt.ylabel(f'PC2 ({explained_variance_ratio[1]*100:.1f}% variance)')
plt.title(f'{data_type} PCA Analysis\\nBased on {{len(matched_ids)}} Significant DE{data_type}')
plt.legend(frameon=True, fancybox=True, shadow=True)
plt.grid(True, alpha=0.3)

# 添加统计信息
plt.text(0.02, 0.98, f'Total variance explained: {{cumulative_variance[1]*100:.1f}}%', 
         transform=plt.gca().transAxes, verticalalignment='top',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

plt.tight_layout()
plt.savefig(os.path.join(output_figure_dir, f'1_{{data_type}}_DEGs_PCA.png'), 
           dpi=300, bbox_inches='tight')
plt.close()
    """)
    
    return pca_df, variance_df

# 执行分析
print("开始执行基于差异表达数据的PCA分析...")
print("🔬 使用显著DEGs和DEmiRs观察B12组间分离趋势")

# 1. mRNA DEGs PCA分析
print("\n" + "🧬" * 20)
try:
    deg_file = os.path.join(data_dir, "1mRNA_DEGs_proteincoding.csv")
    tpm_file = os.path.join(data_dir, "5mRNA_TPM.csv")
    
    mrna_pca, mrna_variance = perform_degs_pca_analysis(deg_file, "mRNA", tpm_file)
    
except Exception as e:
    print(f"❌ mRNA DEGs PCA分析出错: {e}")
    import traceback
    traceback.print_exc()

# 2. miRNA DEmiRs PCA分析  
print("\n" + "🧬" * 20)
try:
    demirs_file = os.path.join(data_dir, "2miRNA_DEmirs.csv")
    mirna_tpm_file = os.path.join(data_dir, "6miRNA_TPM.csv")
    
    mirna_pca, mirna_variance = perform_degs_pca_analysis(demirs_file, "miRNA", mirna_tpm_file)
    
except Exception as e:
    print(f"❌ miRNA DEmiRs PCA分析出错: {e}")
    import traceback
    traceback.print_exc()

# 最终总结
print("\n" + "🎯" * 30)
print("PCA分析完成总结")
print("🎯" * 30)
print("✅ 完成了基于差异表达数据的PCA分析")
print("✅ 分别分析了mRNA DEGs和miRNA DEmiRs")
print("✅ 评估了B12高低组在主成分空间的分离情况")
print("✅ 进行了统计检验验证组间差异显著性")
print(f"📁 所有结果保存在: {output_data_dir}")
print(f"🖼️  图片将保存在: {output_figure_dir}")
print("\n📊 分析结果包括:")
print("   - PCA坐标矩阵")
print("   - 方差解释比例")
print("   - B12分组信息")
print("   - 统计检验结果")
print("   - 可视化代码模板")

正在读取临床数据...
临床数据形状: (50, 21)
临床数据列名: ['NTUID', 'DNA_ID', 'mRNA_ID', 'sRNA_ID', 'age', 'BMI', 'BMI_cat', 'B12_status', 'B12_mol', 'B12supplem', 'ethnicity', 'parity', 'V1_se_EmploymentStatus', 'V1_se_HouseholdIncome', 'v1p_MultivitTab', 'v1p_FolicAcid', 'smoking', 'v3n_Gender', 'batch_miRNA', 'batch_mRNA', 'batch_DNA']
B12相关列: ['B12_status', 'B12_mol', 'B12supplem']

临床数据前5行:
   NTUID DNA_ID mRNA_ID sRNA_ID   age        BMI     BMI_cat B12_status  \
0    102  P102d   P102m   P102s  36.5  20.703735     healthy         NB   
1    105  P105d   P105m   P105s  33.5  38.404033      obese2         NB   
2    111  P111d   P111m   P111s  32.6  22.519433     healthy         NB   
3    113  P113d   P113m   P113s  31.5  25.970116  overweight         NB   
4    117  P117d   P117m   P117s  37.0  26.370238  overweight         LB   

    B12_mol  B12supplem  ... parity  V1_se_EmploymentStatus  \
0  292.6170         1.0  ...      1                     NaN   
1  362.2104         1.0  ...      1          

Traceback (most recent call last):
  File "/var/folders/6r/vh9hm9y50sng_kp683fc7t8m0000gn/T/ipykernel_13045/3724079657.py", line 394, in <module>
    mirna_pca, mirna_variance = perform_degs_pca_analysis(demirs_file, "miRNA", mirna_tpm_file)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/6r/vh9hm9y50sng_kp683fc7t8m0000gn/T/ipykernel_13045/3724079657.py", line 223, in perform_degs_pca_analysis
    print(f"B12值统计: min={np.min(b12_values):.2f}, max={np.max(b12_values):.2f}, mean={np.mean(b12_values):.2f}, std={np.std(b12_values):.2f}")
                           ^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Unknown format code 'f' for object of type 'str'
