In [10]:
import numpy as np
import pandas as pd
import scanpy as sc
import skbio
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import euclidean_distances
from scipy.spatial.distance import squareform
import gc # For garbage collection
from typing import Union, Dict, Optional

In [7]:
adata_corrected_sparse = sc.read_h5ad("/home/minhang/mds_project/data/cohort_adata/multiVI_model/adata_multivi_corrected_rna_final_sparseRNA.h5ad")

In [8]:
adata_corrected_sparse

AnnData object with n_obs × n_vars = 192149 × 36601
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'library', 'exp.ID', 'percent.mt', 'nCount_ATAC', 'nFeature_ATAC', 'nCount_dsb', 'nFeature_dsb', 'nCount_ADT', 'nFeature_ADT', 'hash.ID', 'scDblFinder.score', 'scDblFinder.weighted', 'scDblFinder.cxds_score', 'Lane', 'patient', 'marker', 'Time', 'batch', 'Tech', 'sample', 'source', 'soup.singlet_posterior', '_indices', '_scvi_batch', '_scvi_labels', 'CN.label', 'predicted.annotation.score', 'predicted.annotation', 'predicted.pseudotime.score', 'predicted.pseudotime', 'timepoint_type'
    var: 'ID', 'modality', 'chr', 'start', 'end'
    uns: 'CN.label_colors', 'Tech_colors', 'neighbors', 'patient_colors', 'predicted.annotation_colors', 'sample_colors', 'timepoint_type_colors', 'umap'
    obsm: 'X_multivi', 'X_umap'
    obsp: 'connectivities', 'distances'

In [2]:
adata = sc.read_h5ad('/home/minhang/mds_project/data/cohort_adata/multiVI_model/adata_multivi_corrected_rna.h5ad')

In [3]:
list(adata.obs['predicted.annotation'].unique())

['MEP-MkP',
 'HSC/MPP1',
 'HSC/MPP2',
 'MEP-EP',
 'LMPP',
 'EBM',
 'cDC',
 'NP1',
 'NP2',
 'CLP',
 'pDC2',
 'pDC1',
 'ncMono']

In [6]:
original_adata = sc.read_h5ad('/home/minhang/mds_project/data/cohort_adata/multiVI_model/adata.h5ad')

In [3]:
# --- Helper Function to run PERMANOVA ---
def run_permanova_analysis(
    adata_input: sc.AnnData,
    grouping_column: str,
    n_subsample: Optional[int] = None,
    random_seed: int = 42,
    permutations: int = 999
) -> pd.Series:
    """
    Runs PERMANOVA on AnnData, optionally subsampling first.

    Args:
        adata_input: AnnData object.
        grouping_column: Name of the column in adata_input.obs to use for grouping.
        n_subsample: Number of cells to subsample to. If None or >= adata_input.n_obs,
                     all cells in adata_input are used.
        random_seed: Random seed for subsampling.
        permutations: Number of permutations for PERMANOVA.

    Returns:
        A pandas Series with PERMANOVA results (test statistic, p-value, etc.)
        or a Series with error information if it fails.
    """
    print(f"    --- Analyzing grouping: '{grouping_column}' ---")
    
    current_adata = adata_input.copy() # Work on a copy

    if n_subsample is not None and current_adata.n_obs > n_subsample:
        print(f"    Subsampling from {current_adata.n_obs} to {n_subsample} cells...")
        sc.pp.subsample(current_adata, n_obs=n_subsample, random_state=random_seed)
        print(f"    Subsampled to {current_adata.n_obs} cells.")
    else:
        print(f"    Using all {current_adata.n_obs} cells.")

    if current_adata.n_obs == 0:
        print("    No cells in subset. Skipping PERMANOVA.")
        return pd.Series({"error": "No cells in subset", "p-value": np.nan, "test statistic": np.nan, "n_groups_in_test": 0})

    if grouping_column not in current_adata.obs.columns:
        print(f"    Error: Grouping column '{grouping_column}' not found. Skipping.")
        return pd.Series({"error": f"Grouping column '{grouping_column}' not found", "p-value": np.nan, "test statistic": np.nan, "n_groups_in_test": 0})
        
    n_groups_in_test = current_adata.obs[grouping_column].nunique()
    if n_groups_in_test < 2:
        print(f"    Error: Grouping column '{grouping_column}' has {n_groups_in_test} unique groups (needs >= 2). Skipping PERMANOVA.")
        return pd.Series({"error": f"Grouping '{grouping_column}' has < 2 unique groups", "p-value": np.nan, "test statistic": np.nan, "n_groups_in_test": n_groups_in_test})
    
    group_counts = current_adata.obs[grouping_column].value_counts()
    if any(count < 1 for count in group_counts): # Should be at least 1, PERMANOVA might need more for some groups
        print(f"    Warning: At least one group in '{grouping_column}' is empty or has very few samples after subsetting/filtering.")
        print(group_counts[group_counts < 2])


    print("    Extracting multiVI embeddings...")
    embeddings = current_adata.obsm['X_multivi']
    if embeddings.dtype != np.float32:
        embeddings = embeddings.astype(np.float32)

    print("    Calculating Euclidean distance matrix...")
    distance_array_full = euclidean_distances(embeddings)
    
    print("    Converting to condensed distance matrix...")
    condensed_distance_array = squareform(distance_array_full, force='tovector', checks=False)
    del distance_array_full, embeddings # Free memory
    gc.collect()

    print("    Creating skbio.DistanceMatrix object...")
    cell_ids = current_adata.obs_names.tolist()
    results = pd.Series({"error": "Unknown error", "p-value": np.nan, "test statistic": np.nan, "n_groups_in_test": n_groups_in_test}) # Default error
    try:
        skbio_dm = skbio.stats.distance.DistanceMatrix(condensed_distance_array, ids=cell_ids)
        del condensed_distance_array # Free memory
        gc.collect()
        
        print(f"    Running PERMANOVA with {permutations} permutations...")
        # skbio expects a list or array for grouping, not a pandas Series directly for some versions/setups
        grouping_variable_list = current_adata.obs[grouping_column].tolist()
        
        results = skbio.stats.distance.permanova(
            skbio_dm,
            grouping_variable_list, # Use the list
            permutations=permutations
        )
        # Add n_groups to the results Series
        results['n_groups_in_test'] = n_groups_in_test
        print("    PERMANOVA completed.")
    except MemoryError:
        print("    MemoryError during skbio.DistanceMatrix creation or PERMANOVA.")
        results = pd.Series({"error": "MemoryError", "p-value": np.nan, "test statistic": np.nan, "n_groups_in_test": n_groups_in_test})
    except Exception as e:
        print(f"    An error occurred during PERMANOVA: {e}")
        results = pd.Series({"error": str(e), "p-value": np.nan, "test statistic": np.nan, "n_groups_in_test": n_groups_in_test})
    
    return results

In [4]:
grouping_cols_to_test = ['Tech', 'sample', 'batch']
cell_sources_to_test = ['donor', 'recipient']
all_results_summary = []

# Determine donor cell count for subsampling recipient cells
# Ensure 'source' column exists
if 'source' not in adata.obs.columns:
    raise ValueError("'source' column not found in adata_full.obs. Please ensure it's correctly populated.")

n_donor_cells_total = adata[adata.obs['source'] == 'donor'].n_obs
n_recipient_cells_total = adata[adata.obs['source'] == 'recipient'].n_obs

print(f"Total donor cells found: {n_donor_cells_total}")
print(f"Total recipient cells found: {n_recipient_cells_total}")

if n_donor_cells_total == 0:
    print("Warning: No donor cells found in the dataset based on 'source' column. PERMANOVA for donor cells will be skipped.")
if n_recipient_cells_total == 0:
     print("Warning: No recipient cells found in the dataset based on 'source' column. PERMANOVA for recipient cells will be skipped.")

# Target for subsampling recipient cells is the total number of donor cells
recipient_subsample_target_count = n_donor_cells_total
if n_donor_cells_total == 0 and n_recipient_cells_total > 0: # Edge case: no donors, but want to run for recipients
    print("No donor cells to match count. If you want to run for recipients, set a specific subsample size or use all.")
    # Defaulting to a placeholder if no donors, user might want to adjust
    recipient_subsample_target_count = min(50000, n_recipient_cells_total) # Example fallback


for source_type in cell_sources_to_test:
    print(f"\n===== PROCESSING CELL SOURCE: {source_type.upper()} =====")
    
    if source_type == 'donor' and n_donor_cells_total == 0:
        continue
    if source_type == 'recipient' and n_recipient_cells_total == 0:
        continue

    adata_filtered_by_source = adata[adata.obs['source'] == source_type].copy()
    gc.collect()

    current_subsample_size = None # Use all cells by default
    if source_type == 'recipient':
        if n_recipient_cells_total > recipient_subsample_target_count and recipient_subsample_target_count > 0 :
            current_subsample_size = recipient_subsample_target_count
            print(f"  Recipient cells will be subsampled to: {current_subsample_size} (matching donor count)")
        else:
            print(f"  Using all {n_recipient_cells_total} recipient cells (target subsample size not applicable or smaller).")
    elif source_type == 'donor':
        print(f"  Using all {n_donor_cells_total} donor cells.")


    for group_col_name in grouping_cols_to_test:
        # Check if grouping column actually exists in the AnnData object
        if group_col_name not in adata_filtered_by_source.obs.columns:
            print(f"  Skipping '{group_col_name}' for '{source_type}': column not found in .obs")
            all_results_summary.append({
                'cell_source': source_type,
                'grouping_variable': group_col_name,
                'n_cells_in_test': 0,
                'n_groups_in_test': 0,
                'F_statistic': np.nan,
                'p_value': np.nan,
                'notes': 'Grouping column not found in AnnData'
            })
            continue
            
        permanova_result_series = run_permanova_analysis(
            adata_filtered_by_source,
            grouping_column=group_col_name,
            n_subsample=current_subsample_size
        )
        
        # Determine actual number of cells used in the test
        num_cells_tested = current_subsample_size if current_subsample_size is not None and current_subsample_size < adata_filtered_by_source.n_obs else adata_filtered_by_source.n_obs
        
        all_results_summary.append({
            'cell_source': source_type,
            'grouping_variable': group_col_name,
            'n_cells_in_test': num_cells_tested if 'error' not in permanova_result_series else 0,
            'n_groups_in_test': permanova_result_series.get('n_groups_in_test',0),
            'F_statistic': permanova_result_series.get('test statistic', np.nan),
            'p_value': permanova_result_series.get('p-value', np.nan),
            'notes': permanova_result_series.get('error', '')
        })
        gc.collect() # Clean up memory after each PERMANOVA run

# --- Display results in a table ---
results_summary_df = pd.DataFrame(all_results_summary)
print("\n\n===== PERMANOVA Results Summary Table =====")
print(results_summary_df)

Total donor cells found: 71450
Total recipient cells found: 112797

===== PROCESSING CELL SOURCE: DONOR =====
  Using all 71450 donor cells.
    --- Analyzing grouping: 'Tech' ---
    Using all 71450 cells.
    Extracting multiVI embeddings...
    Calculating Euclidean distance matrix...
    Converting to condensed distance matrix...
    Creating skbio.DistanceMatrix object...
    Running PERMANOVA with 999 permutations...
    PERMANOVA completed.
    --- Analyzing grouping: 'sample' ---
    Using all 71450 cells.
    Extracting multiVI embeddings...
    Calculating Euclidean distance matrix...
    Converting to condensed distance matrix...
    Creating skbio.DistanceMatrix object...
    Running PERMANOVA with 999 permutations...
    PERMANOVA completed.
    --- Analyzing grouping: 'batch' ---
    Using all 71450 cells.
    Extracting multiVI embeddings...
    Calculating Euclidean distance matrix...
    Converting to condensed distance matrix...
    Creating skbio.DistanceMatrix objec

In [5]:
results_summary_df

Unnamed: 0,cell_source,grouping_variable,n_cells_in_test,n_groups_in_test,F_statistic,p_value,notes
0,donor,Tech,71450,2,1534.666593,0.001,
1,donor,sample,71450,41,139.174032,0.001,
2,donor,batch,71450,21,224.845215,0.001,
3,recipient,Tech,71450,2,256.422908,0.001,
4,recipient,sample,71450,46,295.191421,0.001,
5,recipient,batch,71450,20,413.941852,0.001,
