# Imports

In [9]:
"""
EmptyDrops - Python implementation of the EmptyDrops algorithm for droplet-based single-cell RNA sequencing.
Modified for Apple Silicon (M-series) compatibility with enhanced performance optimizations.

This module provides functionality to distinguish between droplets containing cells 
and ambient RNA in droplet-based single-cell RNA sequencing experiments.

Based on:
Lun A, Riesenfeld S, Andrews T, et al. (2019).
Distinguishing cells from empty droplets in droplet-based single-cell RNA sequencing data.
Genome Biol. 20, 63.
"""

import scanpy as sc
import numpy as np
from numba import typeof, jit, prange, cuda
from scipy.sparse import csr_matrix, issparse, spmatrix
from scipy.optimize import minimize, minimize_scalar
from scipy.special import gamma, gammaln, factorial, loggamma
import scipy.stats as ss
from scipy import stats
from statsmodels.stats.multitest import multipletests
from collections import Counter
from tqdm import tqdm
from typing import Tuple, Dict, Union, Optional, List
import pandas as pd
import warnings
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
import multiprocessing as mp
from functools import partial
from tqdm.auto import tqdm
import time
from datetime import timedelta
import os
import pickle
import hashlib
import matplotlib.pyplot as plt
import seaborn as sns
import psutil
import gc

# for running the emptydrops function
from sklearn.metrics import confusion_matrix, classification_report
import shutil

# system monitor

In [10]:
def get_system_stats():
    """Get current system resource usage."""
    process = psutil.Process(os.getpid())
    return {
        'memory_percent': process.memory_percent(),
        'cpu_percent': process.cpu_percent(),
        'num_threads': process.num_threads()
    }

def check_system_resources():
    """Check if system resources are too high."""
    stats = get_system_stats()
    if stats['memory_percent'] > MAX_MEMORY_PERCENT or stats['cpu_percent'] > MAX_CPU_PERCENT:
        print(f"\nHigh resource usage detected - Memory: {stats['memory_percent']:.1f}%, CPU: {stats['cpu_percent']:.1f}%")
        print("Cooling down...")
        gc.collect()
        time.sleep(COOLING_PERIOD)
        return True
    return False

def get_memory_usage():
    """Get current memory usage as a percentage."""
    process = psutil.Process(os.getpid())
    return process.memory_percent()

def check_memory_usage():
    """Check if memory usage is too high."""
    usage = get_memory_usage()
    if usage > MAX_MEMORY_PERCENT:
        gc.collect()  # Force garbage collection
        return True
    return False

In [11]:
try:
    import cupy as cp
    HAS_GPU = True
except ImportError:
    HAS_GPU = False

CHUNK_SIZE = 100  # Increased for better vectorization
MAX_WORKERS = mp.cpu_count()  # Use all available cores
CACHE_DIR = "empty_drops_cache"
N_PROCESSES = mp.cpu_count()  # Use all cores
MAX_MEMORY_PERCENT = 80  # Slightly increased but still safe
COOLING_PERIOD = 1  # Reduced cooling period
MAX_CPU_PERCENT = 90  # Increased CPU usage threshold
SAVE_FREQUENCY = 5  # More frequent checkpoints

# caching stuff

In [12]:
# Create caching directory if it doesn't exist
os.makedirs(CACHE_DIR, exist_ok=True)

def save_intermediate_results(key: str, data: dict):
    """Save intermediate results to prevent data loss."""
    temp_file = os.path.join(CACHE_DIR, f"intermediate_{key}.pkl")
    with open(temp_file, 'wb') as f:
        pickle.dump(data, f)


def _get_cache_key(data_hash: str, params: dict) -> str:
    """Generate a unique cache key for the given data and parameters."""
    # Sort parameters to ensure consistent keys
    sorted_params = sorted(params.items())
    param_str = str(sorted_params)
    
    # Combine data hash and parameters
    key_str = f"{data_hash}_{param_str}"
    
    # Generate a hash of the combined string
    return hashlib.md5(key_str.encode()).hexdigest()

def _save_to_cache(key: str, data: dict):
    """Save data to cache."""
    cache_file = os.path.join(CACHE_DIR, f"{key}.pkl")
    with open(cache_file, 'wb') as f:
        pickle.dump(data, f)

def _load_from_cache(key: str) -> Optional[dict]:
    """Load data from cache if it exists."""
    cache_file = os.path.join(CACHE_DIR, f"{key}.pkl")
    if os.path.exists(cache_file):
        with open(cache_file, 'rb') as f:
            return pickle.load(f)
    return None

# multiprocessing

In [13]:
def _process_chunk(args):
    """Process a chunk of data for parallel computation."""
    chunk, prop, alpha = args
    result = _compute_multinom_prob_chunk(chunk, prop, alpha)
    # Ensure result is an array
    if np.isscalar(result):
        result = np.array([result])
    return result

def _process_chunk_mp(args):
    """Process a chunk of data using multiprocessing."""
    chunk, prop, alpha = args
    return _compute_multinom_prob_chunk(chunk, prop, alpha)

def debug_mp_info():
    """Print debug information about multiprocessing setup."""
    print("\nMultiprocessing Debug Info:")
    print(f"Number of CPU cores available: {mp.cpu_count()}")
    print(f"Number of processes to be used: {N_PROCESSES}")
    print(f"Chunk size: {CHUNK_SIZE}")
    print("Process Pool starting...\n")


@jit(nopython=True, parallel=True)
def _simulate_batch(totals: np.ndarray, ambient: np.ndarray, n_iter: int) -> np.ndarray:
    """Simulate multiple count vectors in parallel using multinomial distribution."""
    n_cells = len(totals)
    results = np.zeros((n_cells, n_iter))
    
    for i in prange(n_cells):
        if totals[i] <= 0:
            continue
        for j in range(n_iter):
            results[i, j] = np.random.multinomial(totals[i], ambient).sum()
    
    return results

if HAS_GPU:
    @cuda.jit
    def _cuda_simulate_kernel(totals, ambient, results):
        """CUDA kernel for parallel simulation on GPU."""
        idx = cuda.grid(1)
        if idx < len(totals):
            if totals[idx] > 0:
                # Use shared memory for ambient profile
                shared_ambient = cuda.shared.array(shape=(ambient.shape[0],), dtype=np.float64)
                if cuda.threadIdx.x == 0:
                    for i in range(ambient.shape[0]):
                        shared_ambient[i] = ambient[i]
                cuda.syncthreads()
                
                # Simulate multinomial using cuRAND
                for j in range(results.shape[1]):
                    results[idx, j] = cuda.random.multinomial(totals[idx], shared_ambient).sum()

def optimize_sparse_matrix(X):
    """Optimize sparse matrix operations."""
    if issparse(X):
        X = X.tocsr()
        X.sort_indices()
        X.sum_duplicates()
        return X
    return X

# numba gammaln

In [14]:
# Numba-compatible gammaln function for scalar inputs
@jit(nopython=True)
def _numba_gammaln_scalar(x):
    """Numba-compatible version of gammaln for scalar inputs."""
    # Lanczos approximation for log gamma
    c = np.array([
        76.18009172947146,
        -86.50532032941677,
        24.01409824083091,
        -1.231739572450155,
        0.1208650973866179e-2,
        -0.5395239384953e-5
    ])
    
    y = x
    tmp = x + 5.5
    tmp = (x + 0.5) * np.log(tmp) - tmp
    ser = 1.000000000190015
    for j in range(6):
        y = y + 1
        ser = ser + c[j] / y
    return tmp + np.log(2.5066282746310005 * ser / x)

# Vectorized version for array inputs
@jit(nopython=True)
def _numba_gammaln_array(x):
    """Numba-compatible version of gammaln for array inputs."""
    result = np.zeros_like(x, dtype=np.float64)
    for i in range(x.size):
        result.flat[i] = _numba_gammaln_scalar(x.flat[i])
    return result

# estimate & show remaining time

In [15]:
class TimeEstimator:
    """Helper class to estimate remaining time for long-running operations."""
    def __init__(self, total, desc, position=0):
        self.pbar = tqdm(
            total=total, 
            desc=desc, 
            position=position, 
            leave=True,
            bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
        )
        self.start_time = time.time()
        self.current = 0
        self.total = total
        
    def update(self, n=1):
        self.current += n
        self.pbar.update(n)
        
    def close(self):
        self.pbar.close()
        
    def get_elapsed_time(self):
        return time.time() - self.start_time
        
    def reset(self):
        """Reset the progress bar to the beginning."""
        self.pbar.reset()
        self.current = 0
        self.start_time = time.time()

def format_time(seconds):
    """Format time in seconds to a human-readable string."""
    return str(timedelta(seconds=int(seconds)))

class ProcessingRateMonitor:
    """Monitor processing rate and predict remaining time."""
    def __init__(self, total_size_mb):
        self.start_time = time.time()
        self.total_size_mb = total_size_mb
        self.processed_mb = 0
        self.last_update = self.start_time
        self.rate_mb_s = 0
        
    def update(self, size_mb):
        current_time = time.time()
        elapsed = current_time - self.last_update
        self.processed_mb += size_mb
        
        # Update processing rate (MB/s)
        if elapsed > 0:
            self.rate_mb_s = size_mb / elapsed
        
        self.last_update = current_time
        
    def get_stats(self):
        if self.rate_mb_s > 0:
            remaining_mb = self.total_size_mb - self.processed_mb
            remaining_seconds = remaining_mb / self.rate_mb_s
            return {
                'rate_mb_s': self.rate_mb_s,
                'processed_mb': self.processed_mb,
                'total_mb': self.total_size_mb,
                'remaining_seconds': remaining_seconds,
                'percent_complete': (self.processed_mb / self.total_size_mb) * 100
            }
        return None

# good touring

In [16]:
@jit(nopython=True)
def _compute_good_turing_freqs(counts: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute Good-Turing frequency estimates for observed counts.
    
    Parameters
    ----------
    counts : array-like
        Observed counts
        
    Returns
    -------
    tuple
        (Smoothed frequencies of frequencies, Good-Turing estimates)
    """
    # Count frequencies of frequencies
    max_count = int(counts.max())
    freq_counts = np.zeros(max_count + 1)
    
    # Sequential counting of frequencies
    for i in range(len(counts)):
        c = int(counts[i])
        if c <= max_count:
            freq_counts[c] += 1
            
    # Smooth frequencies using linear interpolation in log-space
    log_counts = np.log1p(np.arange(max_count + 1))
    log_freqs = np.log1p(freq_counts)
    
    # Replace zeros with interpolated values
    for i in range(1, max_count + 1):
        if freq_counts[i] == 0:
            # Find next non-zero frequency
            next_nonzero = i + 1
            while next_nonzero <= max_count and freq_counts[next_nonzero] == 0:
                next_nonzero += 1
                
            if next_nonzero <= max_count:
                # Linear interpolation in log space
                prev_nonzero = i - 1
                while prev_nonzero >= 0 and freq_counts[prev_nonzero] == 0:
                    prev_nonzero -= 1
                    
                if prev_nonzero >= 0:
                    slope = (log_freqs[next_nonzero] - log_freqs[prev_nonzero]) / (log_counts[next_nonzero] - log_counts[prev_nonzero])
                    log_freqs[i] = log_freqs[prev_nonzero] + slope * (log_counts[i] - log_counts[prev_nonzero])
                    freq_counts[i] = np.exp(log_freqs[i]) - 1
    
    # Calculate Good-Turing estimates
    gt_estimates = np.zeros_like(freq_counts)
    for i in range(max_count):
        if freq_counts[i] > 0:
            gt_estimates[i] = (i + 1) * freq_counts[i + 1] / freq_counts[i]
            
    return freq_counts, gt_estimates

def good_turing_ambient_pool(
    data: sc.AnnData,
    low_count_gene_sums: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Calculate the ambient RNA profile using Good-Turing estimation.
    
    This implementation follows the method described in Lun et al. (2019) for estimating
    the ambient RNA profile from empty droplets using Good-Turing frequency estimation.
    
    Parameters
    ----------
    data : AnnData
        The input data matrix where rows are cells and columns are genes
    low_count_gene_sums : array-like
        Sum of counts for each gene in low-count barcodes
        
    Returns
    -------
    tuple
        (gene_names, ambient_proportions, gene_expectations)
        - gene_names: names of genes
        - ambient_proportions: Good-Turing estimates of gene proportions in ambient RNA
        - gene_expectations: Expected counts for each gene
    """
    gene_names = data.var_names.values
    
    # Convert to dense array for processing
    if issparse(low_count_gene_sums):
        counts = low_count_gene_sums.A1
    else:
        counts = low_count_gene_sums
    
    # Remove genes with zero counts
    nonzero_mask = counts > 0
    nonzero_counts = counts[nonzero_mask]
    nonzero_genes = gene_names[nonzero_mask]
    
    if len(nonzero_counts) == 0:
        warnings.warn("No non-zero counts found in ambient set")
        return gene_names, np.zeros_like(counts), np.zeros_like(counts)
    
    # Compute Good-Turing estimates
    freq_counts, gt_estimates = _compute_good_turing_freqs(nonzero_counts)
    
    # Calculate proportions using Good-Turing estimates
    total_counts = np.sum(nonzero_counts)
    ambient_props = np.zeros_like(counts)
    
    for i, count in enumerate(nonzero_counts):
        if count < len(gt_estimates):
            ambient_props[nonzero_mask][i] = gt_estimates[int(count)] / total_counts
        else:
            # For high counts, use the observed proportion
            ambient_props[nonzero_mask][i] = count / total_counts
            
    # Normalize proportions
    sum_props = np.sum(ambient_props)
    if sum_props > 0:
        ambient_props = ambient_props / sum_props
    else:
        warnings.warn("Sum of ambient proportions is zero, using uniform distribution")
        ambient_props = np.ones_like(ambient_props) / len(ambient_props)
    
    # Calculate expectations
    gene_expectations = ambient_props * total_counts
    
    return gene_names, ambient_props, gene_expectations

# multinomial probabilities

In [17]:
@jit(nopython=True)
def _compute_multinom_prob_chunk(data_chunk: np.ndarray, prop: np.ndarray, alpha: float) -> np.ndarray:
    """Compute multinomial probabilities for a chunk of data.
    
    Args:
        data_chunk: Array of counts for the chunk
        prop: Array of proportions
        alpha: Concentration parameter
        
    Returns:
        Array of log probabilities
    """
    # Ensure consistent dtype
    data_chunk = data_chunk.astype(np.float64)
    prop = prop.astype(np.float64)
    
    # Initialize array for log probabilities
    log_probs = np.zeros(data_chunk.shape[0], dtype=np.float64)
    
    # Compute log probabilities for each row
    for i in range(data_chunk.shape[0]):
        row_sum = 0.0
        for j in range(data_chunk.shape[1]):
            if data_chunk[i, j] > 0:
                log_probs[i] += _numba_gammaln_scalar(data_chunk[i, j] + 1)
                row_sum += data_chunk[i, j]
        
        # Add log probability of proportions
        if alpha > 0:
            log_probs[i] += np.log(prop).dot(data_chunk[i])
        else:
            log_probs[i] += np.log(prop + 1e-10).dot(data_chunk[i])
    
    return log_probs


def compute_multinom_prob(
    data: Union[np.ndarray, spmatrix],
    prop: np.ndarray,
    alpha: float = np.inf,
    progress: bool = True
) -> np.ndarray:
    """Calculate multinomial probabilities using multiprocessing."""
    if issparse(data):
        data = data.tocsr().toarray()
    
    if progress:
        debug_mp_info()
    
    # Calculate total data size in MB
    data_size_mb = data.nbytes / (1024 * 1024)
    rate_monitor = ProcessingRateMonitor(data_size_mb)
    
    # Split data into chunks
    n_chunks = max(1, len(data) // CHUNK_SIZE)
    chunks = np.array_split(data, n_chunks)
    chunk_size_mb = data_size_mb / n_chunks
    
    if progress:
        print(f"Total data size: {data_size_mb:.2f} MB")
        print(f"Number of chunks: {n_chunks}")
        print(f"Chunk size: {chunk_size_mb:.2f} MB\n")
    
    # Prepare arguments for multiprocessing
    chunk_args = [(chunk, prop, alpha) for chunk in chunks]
    
    if progress:
        pbar = tqdm(total=n_chunks, desc="Computing probabilities", 
                   bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}] {postfix}')
    
    # Process chunks using multiprocessing
    with mp.Pool(processes=N_PROCESSES) as pool:
        if progress:
            print(f"Process pool created with {N_PROCESSES} workers")
        
        results = []
        for i, result in enumerate(pool.imap(_process_chunk_mp, chunk_args)):
            results.append(result)
            if progress:
                rate_monitor.update(chunk_size_mb)
                stats = rate_monitor.get_stats()
                if stats:
                    pbar.set_postfix({
                        'Rate': f"{stats['rate_mb_s']:.1f} MB/s",
                        'Remaining': f"{stats['remaining_seconds']/60:.1f}min",
                        'Active_Processes': N_PROCESSES
                    })
                pbar.update(1)
                
                # Print periodic process status
                if i % 10 == 0:
                    print(f"\nProcessed {i}/{n_chunks} chunks using {N_PROCESSES} processes")
    
    if progress:
        pbar.close()
        print("\nMultiprocessing completed")
    
    return np.concatenate(results)

# estimate alpha

In [18]:
def estimate_alpha(
    mat: Union[np.ndarray, spmatrix],
    prop: np.ndarray,
    totals: np.ndarray,
    interval: Tuple[float, float] = (0.01, 10000),
    progress: bool = True
) -> float:
    """
    Estimate the Dirichlet-multinomial alpha parameter.
    
    Parameters
    ----------
    mat : array-like or sparse matrix
        The count data
    prop : array-like
        The proportion vector
    totals : array-like
        The total counts per barcode
    interval : tuple, optional (default: (0.01, 10000))
        The interval to search for alpha
    progress : bool, optional (default: True)
        Whether to show progress bars
        
    Returns
    -------
    float
        The estimated alpha parameter
    """
    if issparse(mat):
        mat = mat.tocsr()
    
    def loglik(alpha):
        prop_alpha = prop * alpha
        return (gammaln(alpha) * len(totals) -
                np.sum(gammaln(totals + alpha)) +
                np.sum(gammaln(mat.data + prop_alpha)) -
                np.sum(gammaln(prop_alpha)))
    
    # Use multiple starting points for better optimization
    best_alpha = None
    best_loglik = -np.inf
    starting_points = [0.01, 0.1, 1.0, 10.0, 100.0, 1000.0]
    
    if progress:
        pbar = TimeEstimator(len(starting_points), "Optimizing alpha")
    
    for i, start in enumerate(starting_points):
        result = minimize_scalar(
            lambda x: -loglik(x),
            bounds=interval,
            method='bounded',
            x0=start
        )
        
        if -result.fun > best_loglik:
            best_loglik = -result.fun
            best_alpha = result.x
            
        if progress:
            pbar.update(1)
            
            # Print estimated completion time every 2 starting points
            if (i + 1) % 2 == 0:
                elapsed = pbar.get_elapsed_time()
                rate = (i + 1) / elapsed
                remaining = (len(starting_points) - (i + 1)) / rate
                print(f"\nEstimated time remaining: {format_time(remaining)}")
    
    if progress:
        pbar.close()
    
    return best_alpha 

# EmptyDrops()

In [19]:
def empty_drops(
    data: sc.AnnData,
    lower: int = 100,
    retain: Optional[int] = None,
    barcode_args: Optional[dict] = None,
    test_ambient: bool = False,
    niters: int = 1000,
    ignore: Optional[int] = None,
    alpha: Optional[float] = np.inf,
    round: bool = True,
    by_rank: Optional[int] = None,
    known_empty: Optional[np.ndarray] = None,
    progress: bool = True,
    adaptive: bool = True,
    min_iters: int = 100,
    max_iters: int = 1000,
    early_stopping: bool = True,
    batch_size: int = CHUNK_SIZE,
    use_cache: bool = True,
    visualize: bool = True,
    confidence_level: float = 0.95,
    n_processes: int = N_PROCESSES,
    use_gpu: bool = HAS_GPU
) -> pd.DataFrame:
    """
    Enhanced EmptyDrops implementation with performance optimizations.
    """
    start_time = time.time()
    
    if progress:
        print("Starting EmptyDrops analysis...")
    
    # Generate a hash of the data
    data_hash = hashlib.md5(data.X.data.tobytes()).hexdigest()
    
    # Create a dictionary of parameters
    params = {
        'lower': lower,
        'retain': retain,
        'test_ambient': test_ambient,
        'niters': niters,
        'ignore': ignore,
        'alpha': alpha,
        'round': round,
        'by_rank': by_rank,
        'adaptive': adaptive,
        'min_iters': min_iters,
        'max_iters': max_iters,
        'early_stopping': early_stopping,
        'batch_size': batch_size,
        'confidence_level': confidence_level,
        'n_processes': n_processes,
        'use_gpu': use_gpu
    }
    
    # Generate cache key
    cache_key = _get_cache_key(data_hash, params)
    
    # Check if cached results exist
    if use_cache:
        cached_results = _load_from_cache(cache_key)
        if cached_results is not None:
            if progress:
                print("Loading cached results...")
            return cached_results['results']
    
    # Create directory for visualizations if needed
    if visualize:
        os.makedirs('empty_drops_visualizations', exist_ok=True)
    
    # Filter genes with zero counts
    if progress:
        print("Filtering genes...")
    print(f"{(data.X.sum(axis=0).A1 == 0).sum()} genes filtered out since sum(counts) over the gene was 0.")
    sc.pp.filter_genes(data, min_counts=1)
    
    # Get total counts per barcode
    if progress:
        print("Calculating total counts per barcode...")
    totals = data.X.sum(axis=1).A1
    
    # Visualize total counts distribution
    if visualize:
        plt.figure(figsize=(10, 6))
        sns.histplot(totals, bins=100, log_scale=True)
        plt.title('Distribution of Total UMI Counts')
        plt.xlabel('Total UMI Counts (log scale)')
        plt.ylabel('Count')
        plt.savefig('empty_drops_visualizations/total_counts_distribution.png')
        plt.close()
        print("Total counts distribution plot saved to 'empty_drops_visualizations/total_counts_distribution.png'")
    
    # Identify putative empty droplets
    if progress:
        print("Identifying empty droplets...")
    if by_rank is not None:
        # Implement rank-based empty droplet identification
        pass
    else:
        empty_mask = totals <= lower
    
    # Print statistics about empty droplets
    n_empty = np.sum(empty_mask)
    n_total = len(totals)
    print(f"Identified {n_empty} empty droplets out of {n_total} total droplets ({n_empty/n_total*100:.2f}%)")
    
    # Visualize empty droplet threshold
    if visualize:
        plt.figure(figsize=(10, 6))
        sns.histplot(totals, bins=100, log_scale=True)
        plt.axvline(x=lower, color='r', linestyle='--', label=f'Lower threshold ({lower})')
        if retain is not None:
            plt.axvline(x=retain, color='g', linestyle='--', label=f'Retain threshold ({retain})')
        plt.title('Empty Droplet Threshold')
        plt.xlabel('Total UMI Counts (log scale)')
        plt.ylabel('Count')
        plt.legend()
        plt.savefig('empty_drops_visualizations/empty_droplet_threshold.png')
        plt.close()
        print("Empty droplet threshold plot saved to 'empty_drops_visualizations/empty_droplet_threshold.png'")
    
    # Get ambient profile from empty droplets
    if progress:
        print("Calculating ambient profile...")
    ambient_data = data[empty_mask]
    ambient_totals = totals[empty_mask]
    
    # Calculate ambient proportions using Good-Turing estimation
    gene_names, ambient_proportions, gene_expectations = good_turing_ambient_pool(
        data, ambient_data.X.sum(axis=0).A1
    )
    
    # Visualize ambient proportions
    if visualize:
        plt.figure(figsize=(10, 6))
        top_genes = np.argsort(ambient_proportions)[-20:]  # Top 20 genes
        sns.barplot(x=ambient_proportions[top_genes], y=gene_names[top_genes])
        plt.title('Top 20 Genes in Ambient Profile')
        plt.xlabel('Proportion')
        plt.ylabel('Gene')
        plt.tight_layout()
        plt.savefig('empty_drops_visualizations/ambient_profile.png')
        plt.close()
        print("Ambient profile plot saved to 'empty_drops_visualizations/ambient_profile.png'")
    
    # Estimate alpha if not specified
    if alpha is None:
        if progress:
            print("Estimating alpha parameter...")
        alpha = estimate_alpha(
            ambient_data.X, 
            ambient_proportions,
            ambient_totals,
            progress=progress
        )
        print(f"Estimated alpha parameter: {alpha:.4f}")
    
    # Calculate probabilities for non-empty droplets
    if progress:
        print("Calculating probabilities for non-empty droplets...")
    non_empty_mask = ~empty_mask
    if ignore is not None:
        non_empty_mask &= totals > ignore
        
    obs_data = data[non_empty_mask]
    obs_totals = totals[non_empty_mask]
    
    # Calculate multinomial probabilities
    obs_probs = compute_multinom_prob(
        obs_data.X,
        ambient_proportions,
        alpha,
        progress=progress
    )
    
    # Visualize log probabilities
    if visualize:
        plt.figure(figsize=(10, 6))
        sns.histplot(obs_probs, bins=50)
        plt.title('Distribution of Log Probabilities')
        plt.xlabel('Log Probability')
        plt.ylabel('Count')
        plt.savefig('empty_drops_visualizations/log_probabilities.png')
        plt.close()
        print("Log probabilities plot saved to 'empty_drops_visualizations/log_probabilities.png'")
    
    # Convert arrays to correct types for Cython function
    print("\nConverting arrays to correct types...")
    print(f"obs_totals dtype before: {obs_totals.dtype}")
    obs_totals = obs_totals.astype(np.int64)
    print(f"obs_totals dtype after: {obs_totals.dtype}")
    
    print(f"obs_probs dtype before: {obs_probs.dtype}")
    obs_probs = obs_probs.astype(np.float64)
    print(f"obs_probs dtype after: {obs_probs.dtype}")
    
    print(f"ambient_proportions dtype before: {ambient_proportions.dtype}")
    ambient_proportions = ambient_proportions.astype(np.float64)
    print(f"ambient_proportions dtype after: {ambient_proportions.dtype}")
    print(f"ambient_proportions shape: {ambient_proportions.shape}")
    print(f"ambient_proportions sum: {ambient_proportions.sum()}")
    
    # Perform Monte Carlo testing
    if progress:
        print("\nPerforming Monte Carlo testing...")
    n_above = permute_counter(
        obs_totals,
        obs_probs,
        ambient_proportions,
        niters,
        alpha,
        progress=progress,
        adaptive=adaptive,
        min_iters=min_iters,
        max_iters=max_iters,
        early_stopping=early_stopping,
        batch_size=batch_size
    )
    
    # Calculate p-values
    if progress:
        print("Calculating p-values and FDR...")
    pvals = (n_above + 1) / (niters + 1)
    limited = n_above == 0
    
    # Create a full DataFrame with all barcodes
    full_results = pd.DataFrame(index=data.obs_names)
    full_results['Total'] = totals
    full_results['IsEmpty'] = empty_mask
    
    # Initialize columns with NaN
    full_results['LogProb'] = np.nan
    full_results['PValue'] = np.nan
    full_results['FDR'] = np.nan
    full_results['Limited'] = np.nan
    
    # Fill in results for tested cells
    non_empty_barcodes = data[non_empty_mask].obs_names
    full_results.loc[non_empty_barcodes, 'LogProb'] = obs_probs
    full_results.loc[non_empty_barcodes, 'PValue'] = pvals
    full_results.loc[non_empty_barcodes, 'Limited'] = limited
    
    # Handle retain threshold
    if retain is not None:
        retain_mask = totals >= retain
        full_results.loc[data[retain_mask].obs_names, 'PValue'] = 0
    
    # Apply FDR correction only to cells that were tested
    tested_mask = ~full_results['PValue'].isna()
    if tested_mask.any():
        fdr = multipletests(full_results.loc[tested_mask, 'PValue'], method='fdr_bh')[1]
        full_results.loc[tested_mask, 'FDR'] = fdr
    
    # Visualize p-value distribution
    if visualize:
        plt.figure(figsize=(10, 6))
        sns.histplot(full_results.loc[tested_mask, 'PValue'], bins=50)
        plt.title('Distribution of P-values')
        plt.xlabel('P-value')
        plt.ylabel('Count')
        plt.savefig('empty_drops_visualizations/p_values.png')
        plt.close()
        print("P-value distribution plot saved to 'empty_drops_visualizations/p_values.png'")
    
    # Visualize FDR distribution
    if visualize:
        plt.figure(figsize=(10, 6))
        sns.histplot(full_results.loc[tested_mask, 'FDR'], bins=50)
        plt.title('Distribution of FDR')
        plt.xlabel('FDR')
        plt.ylabel('Count')
        plt.savefig('empty_drops_visualizations/fdr.png')
        plt.close()
        print("FDR distribution plot saved to 'empty_drops_visualizations/fdr.png'")
    
    # Create scatter plot of FDR vs Total UMI Counts
    full_results['FDR < 0.05'] = full_results['FDR'] < 0.05
    
    plt.figure(figsize=(10, 6))
    sns.scatterplot(data=full_results[tested_mask], x='Total', y='FDR', 
                   hue='FDR < 0.05', palette=['red', 'blue'], alpha=0.5)
    plt.title('FDR vs Total UMI Counts')
    plt.xlabel('Total UMI Counts')
    plt.ylabel('False Discovery Rate')
    plt.savefig('empty_drops_visualizations/final_results.png')
    plt.close()
    
    # Cache the results
    if use_cache:
        _save_to_cache(cache_key, {'results': full_results})
    
    if progress:
        total_time = time.time() - start_time
        print(f"EmptyDrops analysis complete! Total time: {format_time(total_time)}")
    
    return full_results

# RUN EmptyDrops()

In [20]:
def load_matrices():
    """Load raw and filtered feature-barcode matrices."""
    print("Loading matrices...")
    start_time = time.time()
    
    # Load raw matrix
    raw_adata = sc.read_10x_h5('raw_feature_bc_matrix.h5')
    print(f"Raw matrix loaded: {raw_adata.shape}")
    
    # Load filtered matrix for validation
    filtered_adata = sc.read_10x_h5('filtered_feature_bc_matrix.h5')
    print(f"Filtered matrix loaded: {filtered_adata.shape}")
    
    load_time = time.time() - start_time
    print(f"Loading completed in {load_time:.2f} seconds")
    
    return raw_adata, filtered_adata

def run_empty_drops_analysis(raw_adata, lower=150, retain=1000):
    """Run EmptyDrops analysis with parameters optimized for background removal.
    
    Parameters chosen to minimize false positives (empty droplets called as cells):
    - lower=150: Moderate threshold for empty droplets
    - retain=1000: Automatically keep only very high-count droplets
    - alpha=0.5: Stricter ambient profile modeling
    - FDR=0.001: Very stringent false discovery rate
    """
    print("\nRunning EmptyDrops analysis with background-removal optimized parameters...")
    print(f"lower threshold: {lower}")
    print(f"retain threshold: {retain}")
    start_time = time.time()
    
    # Clear the cache directory
    if os.path.exists("empty_drops_cache"):
        print("Clearing cache directory...")
        shutil.rmtree("empty_drops_cache")
        os.makedirs("empty_drops_cache")
    
    # Run EmptyDrops with strict parameters
    results = empty_drops(
        raw_adata,
        lower=lower,          # Moderate threshold for empty droplets
        retain=retain,        # Only very high-count droplets automatically kept
        test_ambient=True,
        niters=1000,         # Full Monte Carlo iterations
        alpha=0.5,           # Stricter ambient profile modeling
        progress=True,
        use_cache=False,
        min_iters=100,
        max_iters=1000,
        early_stopping=True,
        batch_size=25,       # Small batch size for M3
        confidence_level=0.99 # Very high confidence requirement
    )
    
    analysis_time = time.time() - start_time
    print(f"Analysis completed in {analysis_time:.2f} seconds")
    
    return results

def validate_results(raw_adata, filtered_adata, results):
    """Validate the results of EmptyDrops by comparing with the filtered data."""
    print("\nValidating results...")
    
    # Get the barcodes that are in the filtered data (ground truth)
    filtered_barcodes = set(filtered_adata.obs_names)
    
    # Get the barcodes that EmptyDrops identified as cells using stricter FDR
    empty_drops_barcodes = set(results[results['FDR'] < 0.001].index)  # Much stricter FDR threshold
    
    # Calculate metrics
    true_positives = len(filtered_barcodes.intersection(empty_drops_barcodes))
    false_positives = len(empty_drops_barcodes - filtered_barcodes)
    true_negatives = len(set(raw_adata.obs_names) - filtered_barcodes - empty_drops_barcodes)
    false_negatives = len(filtered_barcodes - empty_drops_barcodes)
    
    # Create confusion matrix
    cm = np.array([[true_negatives, false_positives],
                   [false_negatives, true_positives]])
    
    # Calculate metrics
    total = true_positives + false_positives + true_negatives + false_negatives
    accuracy = (true_positives + true_negatives) / total if total > 0 else 0
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    # Print validation results
    print("\nValidation Results:")
    print(f"True Positives: {true_positives}")
    print(f"False Positives: {false_positives}")
    print(f"True Negatives: {true_negatives}")
    print(f"False Negatives: {false_negatives}")
    print(f"\nPrecision: {precision:.3f}")
    print(f"Recall: {recall:.3f}")
    print(f"F1 Score: {f1_score:.3f}")
    print(f"Accuracy: {accuracy:.3f}")
    
    # Save confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Empty', 'Cell'],
                yticklabels=['Empty', 'Cell'])
    plt.title('Confusion Matrix (Strict Background Removal)')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    # Additional visualization: UMI count distribution
    plt.figure(figsize=(10, 6))
    sns.histplot(data=pd.DataFrame({
        'UMI_Counts': raw_adata.X.sum(axis=1).A1,
        'Category': ['Cell' if bc in empty_drops_barcodes else 'Empty' for bc in raw_adata.obs_names]
    }), x='UMI_Counts', hue='Category', bins=100, log_scale=True)
    plt.title('UMI Count Distribution by Category')
    plt.savefig('umi_distribution.png')
    plt.close()
    
    return {
        'true_positives': true_positives,
        'false_positives': false_positives,
        'true_negatives': true_negatives,
        'false_negatives': false_negatives,
        'precision': precision,
        'recall': recall,
        'f1_score': f1_score,
        'accuracy': accuracy,
        'confusion_matrix': cm
    }

In [21]:
"""Main function to run the analysis pipeline."""
# Load matrices
raw_adata, filtered_adata = load_matrices()

# Run EmptyDrops analysis with strict parameters
results = run_empty_drops_analysis(raw_adata)

# Validate results
metrics = validate_results(raw_adata, filtered_adata, results)

# Save results
results.to_csv('empty_drops_results.csv')

print("\nAnalysis complete! Results saved to 'empty_drops_results.csv'")
print("Confusion matrix saved as 'confusion_matrix.png'")
print("UMI distribution plot saved as 'umi_distribution.png'")

Loading matrices...


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


Raw matrix loaded: (722431, 22040)


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


Filtered matrix loaded: (2701, 22040)
Loading completed in 1.66 seconds

Running EmptyDrops analysis with background-removal optimized parameters...
lower threshold: 150
retain threshold: 1000
Clearing cache directory...
Starting EmptyDrops analysis...
Filtering genes...
2014 genes filtered out since sum(counts) over the gene was 0.


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


Calculating total counts per barcode...
Total counts distribution plot saved to 'empty_drops_visualizations/total_counts_distribution.png'
Identifying empty droplets...
Identified 710061 empty droplets out of 722431 total droplets (98.29%)
Empty droplet threshold plot saved to 'empty_drops_visualizations/empty_droplet_threshold.png'
Calculating ambient profile...




Ambient profile plot saved to 'empty_drops_visualizations/ambient_profile.png'
Calculating probabilities for non-empty droplets...

Multiprocessing Debug Info:
Number of CPU cores available: 8
Number of processes to be used: 8
Chunk size: 100
Process Pool starting...

Total data size: 944.98 MB
Number of chunks: 123
Chunk size: 7.68 MB



Computing probabilities:   0%|          | 0/123 [00:00<?] 

Process pool created with 8 workers


Process SpawnPoolWorker-2:
Traceback (most recent call last):
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/queues.py", line 368, in get
    return _ForkingPickler.loads(res)
AttributeError: Can't get attribute '_process_chunk_mp' on <module '__main__' (built-in)>
Process SpawnPoolWorker-1:
Traceback (most recent call last):
  File "/Library/Developer/CommandLineTools/Libr

KeyboardInterrupt: 