# Ambient RNA Removal Methods Comparison
## Quick Demo: DecontX, FastCAR, and CellBender

**Meeting prep for geometric data analysis research**

Methods covered:
- ✅ SoupX (already done by partner)
- 🆕 DecontX (cluster-based contamination)
- 🆕 FastCAR (sample-specific, DGE-optimized)
- 📋 CellBender (reference implementation)

---

## Setup and Imports

In [None]:
# Core libraries
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.sparse import csr_matrix, issparse
from scipy.stats import pearsonr
import warnings
warnings.filterwarnings('ignore')

# Set plotting parameters
sc.settings.verbosity = 1
sc.settings.set_figure_params(dpi=80, facecolor='white', frameon=False)
sns.set_style('whitegrid')

print("✅ Libraries imported successfully!")
print(f"Scanpy version: {sc.__version__}")

## 1. Generate Synthetic Data with Known Ground Truth

This is **critical** for validation - we know the true signal and true contamination!

In [None]:
def generate_synthetic_contaminated(
    n_cells=1000,
    n_genes=2000,
    n_cell_types=5,
    contamination_rate=0.15,
    seed=42
):
    """
    Generate synthetic scRNA-seq data with known ambient RNA contamination.
    
    Returns AnnData with:
    - X: Observed counts (signal + noise)
    - uns['true_signal']: Ground truth signal
    - uns['ambient_profile']: True ambient profile
    - obs['true_contamination']: True contamination fraction
    """
    np.random.seed(seed)
    
    print(f"🧬 Generating synthetic data...")
    print(f"   Cells: {n_cells}, Genes: {n_genes}, Cell types: {n_cell_types}")
    print(f"   Contamination rate: {contamination_rate:.1%}")
    
    # Assign cell types
    cell_types = np.random.choice(n_cell_types, n_cells)
    
    # Generate distinct cell type expression profiles
    # Using Gamma distribution for realistic gene expression
    profiles = np.zeros((n_cell_types, n_genes))
    for ct in range(n_cell_types):
        # Each cell type has marker genes
        marker_genes = slice(ct * (n_genes // n_cell_types), (ct + 1) * (n_genes // n_cell_types))
        profiles[ct, marker_genes] = np.random.gamma(5, 2, n_genes // n_cell_types)  # High expression
        profiles[ct, :] += np.random.gamma(1, 0.5, n_genes)  # Background expression
    
    # True biological signal for each cell
    S_true = profiles[cell_types]
    
    # Add cell-to-cell variability
    cell_size_factors = np.random.lognormal(0, 0.3, n_cells)
    S_true = S_true * cell_size_factors[:, np.newaxis]
    
    # Ambient profile (mixture from all cell types, weighted by abundance)
    cell_type_weights = np.bincount(cell_types, minlength=n_cell_types) / n_cells
    ambient = (profiles.T @ cell_type_weights)
    
    # Add ambient contamination
    # Variable contamination per cell
    contamination_fractions = np.random.beta(2, 10, n_cells) * contamination_rate * 2
    
    # Observed counts = True signal + Ambient noise
    observed = np.zeros_like(S_true)
    for i in range(n_cells):
        cell_total = S_true[i].sum()
        true_counts = np.random.poisson(S_true[i])
        ambient_counts = np.random.poisson(contamination_fractions[i] * cell_total * ambient / ambient.sum())
        observed[i] = true_counts + ambient_counts
    
    # Also generate empty droplets for ambient estimation
    n_empty = 200
    empty_counts = np.random.poisson(
        np.random.uniform(10, 100, (n_empty, 1)) * ambient / ambient.sum(),
        size=(n_empty, n_genes)
    )
    
    # Combine cells and empty droplets
    all_counts = np.vstack([observed, empty_counts])
    
    # Create AnnData object
    adata = sc.AnnData(X=csr_matrix(all_counts.astype(int)))
    adata.obs['cell_type'] = [f'Type_{ct}' for ct in cell_types] + ['Empty'] * n_empty
    adata.obs['is_cell'] = [True] * n_cells + [False] * n_empty
    adata.obs['true_contamination'] = list(contamination_fractions) + [1.0] * n_empty
    adata.var_names = [f'Gene_{i}' for i in range(n_genes)]
    
    # Store ground truth
    adata.uns['true_signal'] = S_true
    adata.uns['ambient_profile'] = ambient
    adata.uns['true_profiles'] = profiles
    
    print(f"✅ Generated {n_cells} cells + {n_empty} empty droplets")
    print(f"   Mean contamination: {contamination_fractions.mean():.1%}")
    
    return adata

# Generate synthetic dataset
adata_synthetic = generate_synthetic_contaminated(
    n_cells=1000,
    n_genes=2000,
    n_cell_types=5,
    contamination_rate=0.15
)

print(f"\n📊 Data shape: {adata_synthetic.shape}")
print(f"   Cell types: {adata_synthetic.obs['cell_type'].unique()[:5]}")

## 2. Load Real Datasets

In [None]:
# Load PBMC3K (preprocessed)
print("📥 Loading PBMC3K dataset...")
adata_pbmc = sc.datasets.pbmc3k_processed()
print(f"   Shape: {adata_pbmc.shape}")
print(f"   Cell types: {adata_pbmc.obs['louvain'].nunique()}")

# For demonstration, we'll work with raw counts
# In practice, you'd load the raw matrix before filtering
print("\n⚠️  Note: PBMC3K is pre-filtered. For real analysis, use raw counts!")

# Optional: Load bone marrow data
try:
    print("\n📥 Loading bone marrow (Paul15) dataset...")
    adata_bm = sc.datasets.paul15()
    print(f"   Shape: {adata_bm.shape}")
    print(f"   Cell types: {adata_bm.obs['paul15_clusters'].nunique()}")
except:
    print("   ⚠️  Paul15 dataset not available, skipping")
    adata_bm = None

## 3. Method Implementation: DecontX

**Key idea**: Contamination comes from OTHER cell populations, weighted by cluster size.

Mathematical model:
$$Y_n \sim \text{Multinomial}(N_n, \theta_n)$$
$$\theta_n = (1 - \phi_n)\pi_{c(n)} + \phi_n \eta_n$$

where $\eta_n = \sum_{k \neq c(n)} w_k \pi_k$ (weighted mixture from other clusters)

In [None]:
def decontX_simple(adata, cluster_key='cell_type', max_iter=50, convergence_threshold=1e-4):
    """
    Simplified DecontX implementation.
    
    Based on: Yang et al. (2020) Genome Biology 21:57
    
    Parameters:
    -----------
    adata : AnnData
        Input data with raw counts
    cluster_key : str
        Key in adata.obs for cluster labels
    
    Returns:
    --------
    adata_corrected : AnnData
        Corrected counts with contamination estimates
    """
    print("🔬 Running DecontX...")
    
    # Get counts matrix
    if issparse(adata.X):
        X = adata.X.toarray()
    else:
        X = adata.X.copy()
    
    # Only use cells (not empty droplets)
    if 'is_cell' in adata.obs.columns:
        cell_mask = adata.obs['is_cell'].values
        X_cells = X[cell_mask]
        clusters = adata.obs.loc[cell_mask, cluster_key].values
    else:
        X_cells = X
        clusters = adata.obs[cluster_key].values
        cell_mask = np.ones(adata.n_obs, dtype=bool)
    
    n_cells, n_genes = X_cells.shape
    unique_clusters = np.unique(clusters)
    n_clusters = len(unique_clusters)
    
    print(f"   Cells: {n_cells}, Genes: {n_genes}, Clusters: {n_clusters}")
    
    # Initialize cluster expression profiles
    cluster_profiles = np.zeros((n_clusters, n_genes))
    cluster_sizes = np.zeros(n_clusters)
    
    for i, cluster in enumerate(unique_clusters):
        mask = clusters == cluster
        cluster_sizes[i] = mask.sum()
        cluster_profiles[i] = X_cells[mask].sum(axis=0)
        cluster_profiles[i] /= cluster_profiles[i].sum() + 1e-10  # Normalize
    
    # Initialize contamination fractions
    phi = np.random.uniform(0.01, 0.2, n_cells)
    
    # EM algorithm
    for iteration in range(max_iter):
        phi_old = phi.copy()
        
        # E-step: Estimate contamination for each cell
        for i in range(n_cells):
            cell_cluster_idx = np.where(unique_clusters == clusters[i])[0][0]
            
            # Native profile
            native = cluster_profiles[cell_cluster_idx]
            
            # Contamination profile (weighted by OTHER cluster sizes)
            other_mask = np.arange(n_clusters) != cell_cluster_idx
            weights = cluster_sizes[other_mask] / cluster_sizes[other_mask].sum()
            contamination = (cluster_profiles[other_mask].T @ weights)
            
            # Cell counts
            y = X_cells[i]
            total = y.sum()
            
            if total == 0:
                continue
            
            # Estimate phi using simple moment matching
            # phi = correlation with contamination profile
            obs_freq = y / (total + 1e-10)
            
            # Score how well native vs contamination explains data
            native_score = np.corrcoef(obs_freq, native)[0, 1] if native.sum() > 0 else 0
            contam_score = np.corrcoef(obs_freq, contamination)[0, 1] if contamination.sum() > 0 else 0
            
            # Estimate contamination fraction
            if native_score + contam_score > 0:
                phi[i] = max(0, min(0.5, contam_score / (native_score + contam_score + 1e-10)))
            else:
                phi[i] = 0.1
        
        # M-step: Update cluster profiles
        for i, cluster in enumerate(unique_clusters):
            mask = clusters == cluster
            if mask.sum() == 0:
                continue
            
            # Weight by (1 - phi) to down-weight contaminated cells
            weights = (1 - phi[mask])[:, np.newaxis]
            cluster_profiles[i] = (X_cells[mask] * weights).sum(axis=0)
            cluster_profiles[i] /= cluster_profiles[i].sum() + 1e-10
        
        # Check convergence
        change = np.abs(phi - phi_old).mean()
        if change < convergence_threshold:
            print(f"   ✅ Converged at iteration {iteration + 1}")
            break
    
    # Correct counts
    X_corrected = X.copy()
    cell_idx = 0
    
    for i in range(adata.n_obs):
        if not cell_mask[i]:
            continue
        
        cluster_idx = np.where(unique_clusters == clusters[cell_idx])[0][0]
        other_mask = np.arange(n_clusters) != cluster_idx
        weights = cluster_sizes[other_mask] / cluster_sizes[other_mask].sum()
        contamination = (cluster_profiles[other_mask].T @ weights)
        
        # Subtract contamination
        total = X[i].sum()
        correction = phi[cell_idx] * total * contamination
        X_corrected[i] = np.maximum(X[i] - correction, 0)
        
        cell_idx += 1
    
    # Create output
    adata_corrected = adata.copy()
    adata_corrected.X = csr_matrix(X_corrected)
    adata_corrected.obs['decontX_contamination'] = 0.0
    adata_corrected.obs.loc[cell_mask, 'decontX_contamination'] = phi
    
    print(f"   Mean contamination: {phi.mean():.1%}")
    print(f"   Range: {phi.min():.1%} - {phi.max():.1%}")
    
    return adata_corrected

# Test on synthetic data
adata_decontx = decontX_simple(adata_synthetic, cluster_key='cell_type')

## 4. Method Implementation: FastCAR

**Key idea**: Sample-specific ambient correction optimized for differential expression.

Mathematical model:
$$S^{(s)}_{ng} = \max(0, Y^{(s)}_{ng} - \rho^{(s)}_n \cdot A^{(s)}_g)$$

Reference: Muskovic & Powell (2023) BMC Genomics 24:1

In [None]:
def fastCAR_correction(adata, empty_droplet_cutoff=100, marker_genes=None):
    """
    FastCAR: Sample-specific ambient RNA correction.
    
    Based on: Muskovic & Powell (2023) BMC Genomics 24:1
    
    Parameters:
    -----------
    adata : AnnData
        Input data with raw counts
    empty_droplet_cutoff : int
        UMI threshold to identify empty droplets
    marker_genes : list, optional
        Known non-expressed genes for contamination estimation
    
    Returns:
    --------
    adata_corrected : AnnData
        Corrected counts with contamination estimates
    """
    print("🚗 Running FastCAR...")
    
    # Get counts matrix
    if issparse(adata.X):
        X = adata.X.toarray()
    else:
        X = adata.X.copy()
    
    n_cells, n_genes = X.shape
    
    # Identify empty droplets
    total_counts = X.sum(axis=1)
    
    if 'is_cell' in adata.obs.columns:
        empty_mask = ~adata.obs['is_cell'].values
    else:
        empty_mask = total_counts < empty_droplet_cutoff
    
    print(f"   Identified {empty_mask.sum()} empty droplets")
    
    # Estimate ambient profile from empty droplets
    if empty_mask.sum() > 0:
        ambient_profile = X[empty_mask].sum(axis=0)
        ambient_profile = ambient_profile / (ambient_profile.sum() + 1e-10)
    else:
        print("   ⚠️  No empty droplets found, using lowest 10% of cells")
        low_umi_mask = total_counts < np.percentile(total_counts, 10)
        ambient_profile = X[low_umi_mask].sum(axis=0)
        ambient_profile = ambient_profile / (ambient_profile.sum() + 1e-10)
    
    # Estimate contamination fraction per cell
    contamination_fractions = np.zeros(n_cells)
    corrected_counts = X.copy()
    cell_mask = ~empty_mask
    
    for i in range(n_cells):
        if empty_mask[i]:
            contamination_fractions[i] = 1.0
            continue
        
        cell_counts = X[i]
        total = cell_counts.sum()
        
        if total == 0:
            continue
        
        # Method 1: Use correlation with ambient profile
        cell_freq = cell_counts / (total + 1e-10)
        
        # Genes with high ambient but low cell expression = contamination markers
        if marker_genes is not None:
            # Use provided marker genes
            marker_idx = [i for i, g in enumerate(adata.var_names) if g in marker_genes]
        else:
            # Identify high-ambient genes
            high_ambient = ambient_profile > np.percentile(ambient_profile, 90)
            low_cell = cell_freq < np.percentile(cell_freq, 10)
            marker_idx = np.where(high_ambient & low_cell)[0]
        
        if len(marker_idx) > 10:
            # Estimate rho from marker genes
            observed_marker = cell_counts[marker_idx].sum()
            expected_marker = total * ambient_profile[marker_idx].sum()
            rho = observed_marker / (expected_marker + 1e-10)
            rho = max(0, min(rho, 0.5))  # Bound between 0-50%
        else:
            # Fallback: correlation-based estimate
            corr = np.corrcoef(cell_freq, ambient_profile)[0, 1]
            rho = max(0, min(corr, 0.5))
        
        contamination_fractions[i] = rho
        
        # Correct counts: S = max(0, Y - rho * A * total)
        correction = rho * total * ambient_profile
        corrected_counts[i] = np.maximum(cell_counts - correction, 0)
    
    # Create output
    adata_corrected = adata.copy()
    adata_corrected.X = csr_matrix(corrected_counts)
    adata_corrected.obs['fastcar_contamination'] = contamination_fractions
    adata_corrected.uns['ambient_profile'] = ambient_profile
    
    valid_cells = cell_mask
    print(f"   Mean contamination (cells only): {contamination_fractions[valid_cells].mean():.1%}")
    print(f"   Range: {contamination_fractions[valid_cells].min():.1%} - {contamination_fractions[valid_cells].max():.1%}")
    
    return adata_corrected

# Test on synthetic data
adata_fastcar = fastCAR_correction(adata_synthetic)

## 5. Compare Methods on Synthetic Data

**This is the gold standard** - we know the true signal!

In [None]:
def evaluate_correction(adata_corrected, adata_original, method_name):
    """
    Evaluate correction quality against ground truth.
    """
    # Only evaluate on cells (not empty droplets)
    cell_mask = adata_original.obs['is_cell'].values
    
    # Get matrices
    if issparse(adata_corrected.X):
        corrected = adata_corrected.X.toarray()[cell_mask]
    else:
        corrected = adata_corrected.X[cell_mask]
    
    if issparse(adata_original.X):
        observed = adata_original.X.toarray()[cell_mask]
    else:
        observed = adata_original.X[cell_mask]
    
    true_signal = adata_original.uns['true_signal']
    
    # Calculate metrics
    # 1. RMSE (lower is better)
    rmse_corrected = np.sqrt(((corrected - true_signal) ** 2).mean())
    rmse_observed = np.sqrt(((observed - true_signal) ** 2).mean())
    
    # 2. Pearson correlation (higher is better)
    corr_corrected = pearsonr(corrected.flatten(), true_signal.flatten())[0]
    corr_observed = pearsonr(observed.flatten(), true_signal.flatten())[0]
    
    # 3. Contamination estimation accuracy
    true_contam = adata_original.obs.loc[cell_mask, 'true_contamination'].values
    
    if 'decontX_contamination' in adata_corrected.obs.columns:
        est_contam = adata_corrected.obs.loc[cell_mask, 'decontX_contamination'].values
    elif 'fastcar_contamination' in adata_corrected.obs.columns:
        est_contam = adata_corrected.obs.loc[cell_mask, 'fastcar_contamination'].values
    else:
        est_contam = None
    
    contam_corr = pearsonr(true_contam, est_contam)[0] if est_contam is not None else None
    
    results = {
        'Method': method_name,
        'RMSE (Corrected)': rmse_corrected,
        'RMSE (Observed)': rmse_observed,
        'RMSE Improvement': (rmse_observed - rmse_corrected) / rmse_observed * 100,
        'Correlation (Corrected)': corr_corrected,
        'Correlation (Observed)': corr_observed,
        'Contamination Est. Corr': contam_corr
    }
    
    return results

# Evaluate both methods
print("📊 Evaluating methods on synthetic data...\n")

results_decontx = evaluate_correction(adata_decontx, adata_synthetic, 'DecontX')
results_fastcar = evaluate_correction(adata_fastcar, adata_synthetic, 'FastCAR')

# Create comparison DataFrame
comparison_df = pd.DataFrame([results_decontx, results_fastcar])
print(comparison_df.to_string(index=False))
print("\n✅ Higher correlation and lower RMSE = better performance")

## 6. Visualize Results

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Only plot cells (not empty droplets)
cell_mask = adata_synthetic.obs['is_cell'].values

# Row 1: PCA visualizations
datasets = [
    (adata_synthetic[cell_mask].copy(), 'Raw (with contamination)'),
    (adata_decontx[cell_mask].copy(), 'DecontX corrected'),
    (adata_fastcar[cell_mask].copy(), 'FastCAR corrected')
]

for idx, (adata_plot, title) in enumerate(datasets):
    sc.pp.normalize_total(adata_plot, target_sum=1e4)
    sc.pp.log1p(adata_plot)
    sc.pp.pca(adata_plot, n_comps=20)
    
    ax = axes[0, idx]
    sc.pl.pca(adata_plot, color='cell_type', ax=ax, show=False, title=title)

# Row 2: Contamination estimates vs truth
true_contam = adata_synthetic.obs.loc[cell_mask, 'true_contamination'].values

ax = axes[1, 0]
ax.hist(true_contam, bins=30, alpha=0.7, edgecolor='black')
ax.set_xlabel('True Contamination Fraction')
ax.set_ylabel('Count')
ax.set_title('Ground Truth Contamination')
ax.axvline(true_contam.mean(), color='red', linestyle='--', label=f'Mean: {true_contam.mean():.1%}')
ax.legend()

ax = axes[1, 1]
est_contam_decontx = adata_decontx.obs.loc[cell_mask, 'decontX_contamination'].values
ax.scatter(true_contam, est_contam_decontx, alpha=0.5, s=20)
ax.plot([0, 0.5], [0, 0.5], 'r--', label='Perfect estimate')
ax.set_xlabel('True Contamination')
ax.set_ylabel('DecontX Estimate')
ax.set_title(f'DecontX (r={pearsonr(true_contam, est_contam_decontx)[0]:.3f})')
ax.legend()
ax.set_xlim([0, 0.5])
ax.set_ylim([0, 0.5])

ax = axes[1, 2]
est_contam_fastcar = adata_fastcar.obs.loc[cell_mask, 'fastcar_contamination'].values
ax.scatter(true_contam, est_contam_fastcar, alpha=0.5, s=20, color='orange')
ax.plot([0, 0.5], [0, 0.5], 'r--', label='Perfect estimate')
ax.set_xlabel('True Contamination')
ax.set_ylabel('FastCAR Estimate')
ax.set_title(f'FastCAR (r={pearsonr(true_contam, est_contam_fastcar)[0]:.3f})')
ax.legend()
ax.set_xlim([0, 0.5])
ax.set_ylim([0, 0.5])

plt.tight_layout()
plt.savefig('/home/claude/method_comparison_synthetic.png', dpi=300, bbox_inches='tight')
print("\n💾 Saved: method_comparison_synthetic.png")
plt.show()

## 7. Signal Recovery Analysis

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

cell_mask = adata_synthetic.obs['is_cell'].values

# Get data
true_signal = adata_synthetic.uns['true_signal']
observed = adata_synthetic.X.toarray()[cell_mask] if issparse(adata_synthetic.X) else adata_synthetic.X[cell_mask]
decontx = adata_decontx.X.toarray()[cell_mask] if issparse(adata_decontx.X) else adata_decontx.X[cell_mask]
fastcar = adata_fastcar.X.toarray()[cell_mask] if issparse(adata_fastcar.X) else adata_fastcar.X[cell_mask]

# Subsample for visualization
n_points = min(10000, true_signal.size)
idx = np.random.choice(true_signal.size, n_points, replace=False)

datasets = [
    (observed.flatten()[idx], 'Raw Observed'),
    (decontx.flatten()[idx], 'DecontX'),
    (fastcar.flatten()[idx], 'FastCAR')
]

for i, (data, name) in enumerate(datasets):
    ax = axes[i]
    
    true_sub = true_signal.flatten()[idx]
    
    # Scatter plot
    ax.scatter(true_sub, data, alpha=0.1, s=1)
    ax.plot([0, true_sub.max()], [0, true_sub.max()], 'r--', linewidth=2, label='Perfect recovery')
    
    corr = pearsonr(true_sub, data)[0]
    ax.set_xlabel('True Signal')
    ax.set_ylabel(f'{name} Counts')
    ax.set_title(f'{name}\n(r = {corr:.3f})')
    ax.legend()
    ax.set_xlim([0, np.percentile(true_sub, 99)])
    ax.set_ylim([0, np.percentile(data, 99)])

plt.tight_layout()
plt.savefig('/home/claude/signal_recovery.png', dpi=300, bbox_inches='tight')
print("💾 Saved: signal_recovery.png")
plt.show()

## 8. Apply to Real Data (PBMC3K)

In [None]:
print("🔬 Applying methods to PBMC3K...\n")

# Note: PBMC3K is already filtered, so this is for demonstration
# In practice, you'd load raw counts before any filtering

# Apply DecontX
adata_pbmc_decontx = decontX_simple(adata_pbmc.copy(), cluster_key='louvain')
print()

# Apply FastCAR
adata_pbmc_fastcar = fastCAR_correction(adata_pbmc.copy())
print()

print("✅ Methods applied to PBMC3K")

In [None]:
# Visualize PBMC results
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

datasets_pbmc = [
    (adata_pbmc, 'Original PBMC3K'),
    (adata_pbmc_decontx, 'DecontX corrected'),
    (adata_pbmc_fastcar, 'FastCAR corrected')
]

for idx, (adata_plot, title) in enumerate(datasets_pbmc):
    ax = axes[idx]
    sc.pl.umap(adata_plot, color='louvain', ax=ax, show=False, title=title)

plt.tight_layout()
plt.savefig('/home/claude/pbmc_comparison.png', dpi=300, bbox_inches='tight')
print("💾 Saved: pbmc_comparison.png")
plt.show()

## 9. Summary Table for Meeting

In [None]:
# Create comprehensive summary
summary_data = {
    'Method': ['SoupX', 'DecontX', 'FastCAR', 'CellBender'],
    'Status': ['✅ Done (Partner)', '✅ Implemented', '✅ Implemented', '📋 Reference only'],
    'Speed': ['Fast', 'Medium', 'Very Fast', 'Slow (GPU)'],
    'Approach': [
        'Gene-specific propensity',
        'Cluster-weighted mixture',
        'Per-sample DGE-optimized',
        'Deep generative model'
    ],
    'Best For': [
        'Filtered data',
        'Cluster-based analysis',
        'Disease vs. control',
        'Heavy contamination'
    ],
    'Key Innovation': [
        'MT/ribo genes contaminate more',
        'Contamination from OTHER clusters',
        'Sample-specific correction',
        'Neural net learns manifold'
    ]
}

summary_df = pd.DataFrame(summary_data)
print("\n" + "="*80)
print("📋 METHOD SUMMARY FOR MEETING")
print("="*80 + "\n")
print(summary_df.to_string(index=False))

# Performance on synthetic data
print("\n" + "="*80)
print("📊 PERFORMANCE ON SYNTHETIC DATA (Ground Truth Known)")
print("="*80 + "\n")
print(comparison_df.to_string(index=False))

# Save to CSV
summary_df.to_csv('/home/claude/method_summary.csv', index=False)
comparison_df.to_csv('/home/claude/performance_comparison.csv', index=False)
print("\n💾 Saved: method_summary.csv, performance_comparison.csv")

## 10. CellBender Reference (Command-line)

In [None]:
print("📋 CellBender Usage Reference\n")
print("="*60)
print("CellBender must be run via command line (requires GPU):\n")

cellbender_cmd = """
# Installation
pip install cellbender

# Basic usage
cellbender remove-background \\
    --input raw_feature_bc_matrix.h5 \\
    --output cellbender_output.h5 \\
    --expected-cells 3000 \\
    --total-droplets-included 5000 \\
    --epochs 150 \\
    --cuda

# Load results in Python
import scanpy as sc
adata = sc.read_10x_h5('cellbender_output_filtered.h5')
"""

print(cellbender_cmd)
print("="*60)
print("\n⚠️  CellBender is the most sophisticated but requires:")
print("   - GPU (CUDA compatible)")
print("   - Raw unfiltered matrix")
print("   - 30-60 minutes runtime")
print("\n✅ DecontX and FastCAR run in seconds on CPU!")

## 11. Next Steps & Research Directions

## 12. Export All Results

In [None]:
# Save all corrected datasets
print("💾 Exporting results...\n")

adata_synthetic.write_h5ad('/home/claude/synthetic_raw.h5ad')
print("✅ Saved: synthetic_raw.h5ad")

adata_decontx.write_h5ad('/home/claude/synthetic_decontx.h5ad')
print("✅ Saved: synthetic_decontx.h5ad")

adata_fastcar.write_h5ad('/home/claude/synthetic_fastcar.h5ad')
print("✅ Saved: synthetic_fastcar.h5ad")

adata_pbmc_decontx.write_h5ad('/home/claude/pbmc_decontx.h5ad')
print("✅ Saved: pbmc_decontx.h5ad")

adata_pbmc_fastcar.write_h5ad('/home/claude/pbmc_fastcar.h5ad')
print("✅ Saved: pbmc_fastcar.h5ad")

print("\n🎉 All done! Ready for your meeting!")
print("\nFiles created:")
print("  - method_comparison_synthetic.png")
print("  - signal_recovery.png")
print("  - pbmc_comparison.png")
print("  - method_summary.csv")
print("  - performance_comparison.csv")
print("  - *.h5ad files (all datasets)")