In [2]:
import scanpy as sc
import pandas as pd
import numpy as np

from typing import List, Optional, Tuple
try:
    from pydeseq2.dds import DeseqDataSet
    from pydeseq2.ds import DeseqStats
except ImportError:
    raise ImportError("Please install pyDESeq2: pip install pyDESeq2")

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

In [3]:
def pseudobulk_and_deseq(
    adata,
    pseudobulk_group_cols: List[str],
    condition_col: str,
    comparisons: Optional[List[Tuple[str, str]]] = None,
    design_factors: Optional[List[str]] = None,
    min_samples_per_group: int = 5,
    override_min_samples: bool = False,
    pseudobulk_func: str = "sum",
    min_cells_in_group: int = 1,
    # Additional parameters passed to DESeqDataSet constructor
    dds_kwargs: dict = {},
    # Additional parameters passed to DeseqStats constructor
    dds_stats_kwargs: dict = {}
):
    """
    Pseudobulk single-cell data from an AnnData object, then run pyDESeq2
    differential expression. This version automatically adds 'log2_ncells'
    as a covariate in the DESeq2 model. Now with added logging/printing
    for debugging and clarity.
    Parameters
    ----------
    adata : anndata.AnnData
        Single-cell data. Must have raw counts in `adata.X`.
    pseudobulk_group_cols : List[str]
        Columns in `adata.obs` that define each pseudobulk group. For example,
        ['sample_id', 'cell_type', 'condition'].
    condition_col : str
        The metadata column (in `adata.obs`) that defines the primary condition
        for differential analysis. Must be included in `pseudobulk_group_cols`
        if it is part of the grouping.
    comparisons : Optional[List[Tuple[str, str]]]
        List of (conditionA, conditionB) pairs to compare. If None,
        you must specify a comparison or handle it otherwise.
    design_factors : Optional[List[str]]
        Columns in the pseudobulk-level metadata to use in the DESeq design,
        e.g. ['condition', 'batch']. If None, defaults to [condition_col].
        Note that 'log2_ncells' is automatically added.
    min_samples_per_group : int
        Minimum required samples in each condition group for DESeq to run well.
        DESeq2 typically recommends at least 5. If you have fewer, you can
        set override_min_samples=True to proceed anyway.
    override_min_samples : bool
        If False and if any group has fewer than `min_samples_per_group` samples,
        raises an error. If True, prints a warning and proceeds.
    pseudobulk_func : str
        Aggregation function to produce pseudobulk. Options: "sum", "mean", etc.
        Typically "sum" is recommended for raw counts. Must be a valid pandas
        groupby function.
    min_cells_in_group : int
        If a group has fewer than this number of cells, we discard it from
        the pseudobulk matrix.
    dds_kwargs : dict
        Additional keyword arguments passed to `DeseqDataSet`.
    dds_stats_kwargs : dict
        Additional keyword arguments passed to `DeseqStats`.
    Returns
    -------
    pd.DataFrame
        A DataFrame containing DE results for all comparisons. Includes columns:
        log2FoldChange, pvalue, padj, etc. Also has a 'comparison' column
        indicating which conditionA_vs_conditionB the result corresponds to.
    """

    print("========== Starting pseudobulk_and_deseq ==========")
    print(f"Grouping columns: {pseudobulk_group_cols}")
    print(f"Condition column: {condition_col}")
    print(f"Comparisons: {comparisons}")
    print(f"Design factors (initial): {design_factors}")
    print(f"min_samples_per_group: {min_samples_per_group}, override_min_samples: {override_min_samples}")
    print(f"Aggregation function: {pseudobulk_func}")
    print(f"min_cells_in_group: {min_cells_in_group}")
    print("---------------------------------------------------")
    # 0. Basic checks

    print("Step 0: Checking columns in adata.obs...")
    for col in pseudobulk_group_cols:
        if col not in adata.obs.columns:
            raise ValueError(f"'{col}' not found in adata.obs.")
    if condition_col not in adata.obs.columns:
        raise ValueError(f"'{condition_col}' not found in adata.obs.")
    if design_factors is not None:
        for dfactor in design_factors:
            if dfactor not in adata.obs.columns:
                raise ValueError(f"Design factor '{dfactor}' not found in adata.obs.")
    else:
        design_factors = [condition_col]
    if comparisons is None:
        raise ValueError("No comparisons provided. Please provide a list of (condA, condB) tuples.")
    print(f"Total cells in adata: {adata.n_obs}")
    print("---------------------------------------------------")

    # 1. Create a single grouping key
    grouping_key = "pseudobulk_group_key"
    while grouping_key in adata.obs.columns:
        grouping_key += "_dup"
    adata.obs[grouping_key] = (
        adata.obs[pseudobulk_group_cols]
        .astype(str)
        .apply(lambda x: "--".join(x), axis=1)
    )
    print(f"Created grouping key '{grouping_key}' in adata.obs.")
    print("Example grouping key values:")
    print(adata.obs[grouping_key].head(5))
    print("---------------------------------------------------")

    # 2. Convert to a cell-level counts DataFrame
    print("Step 2: Creating cell_counts_df from adata.X...")
    if not hasattr(adata.X, "toarray"):
        counts_matrix = adata.layers['raw']  # already dense
        print("Data is already dense.")
    else:
        counts_matrix = adata.layers['raw'].toarray()  # convert sparse to dense
        print("Converted sparse matrix to dense.")

    # Round only the non-integer values to the next highest integer
    non_integer_mask = (counts_matrix % 1 != 0)
    counts_matrix[non_integer_mask] = np.ceil(counts_matrix[non_integer_mask])

    # Verify that all values are now integers
    assert np.all(counts_matrix % 1 == 0), "There are still non-integer values!"
    print("Non-integer values rounded up successfully.")

    gene_names = adata.var_names.tolist()
    cell_counts_df = pd.DataFrame(
        counts_matrix,
        columns=gene_names,
        index=adata.obs.index
    )
    cell_counts_df[grouping_key] = adata.obs[grouping_key].values
    print(f"cell_counts_df shape: {cell_counts_df.shape}")
    print("---------------------------------------------------")
    
    # 3. Group and aggregate
    print("Step 3: Grouping and aggregating...")
    grouped = cell_counts_df.groupby(grouping_key)
    group_sizes = grouped.size()
    valid_groups = group_sizes[group_sizes >= min_cells_in_group].index
    print(f"Number of total groups: {len(group_sizes)}")
    print(f"Number of valid groups (>= {min_cells_in_group} cells): {len(valid_groups)}")
    if len(valid_groups) == 0:
        raise ValueError("No groups found with at least min_cells_in_group cells.")
    cell_counts_df = cell_counts_df[cell_counts_df[grouping_key].isin(valid_groups)]
    grouped = cell_counts_df.groupby(grouping_key)
    if not hasattr(pd.core.groupby.generic.DataFrameGroupBy, pseudobulk_func):
        raise ValueError(f"Invalid pseudobulk_func '{pseudobulk_func}'. "
                         "Must be a valid pandas groupby agg method.")
    pseudobulk_df = grouped.aggregate(pseudobulk_func)
    print("Aggregated pseudobulk dataframe shape:", pseudobulk_df.shape)
    print("---------------------------------------------------")
    
    # 3.1 Also store the actual number of cells in each group
    n_cells_per_group = grouped.size()
    if grouping_key in pseudobulk_df.columns:
        pseudobulk_df.drop(columns=grouping_key, inplace=True, errors="ignore")

    # 4. Build pseudobulk-level metadata
    print("Step 4: Building metadata for pseudobulk samples...")
    meta_list = []
    for idx in pseudobulk_df.index:
        split_vals = idx.split("--")
        meta_dict = {}
        for col, val in zip(pseudobulk_group_cols, split_vals):
            meta_dict[col] = val
        meta_list.append(meta_dict)
    pseudo_metadata = pd.DataFrame(meta_list, index=pseudobulk_df.index)
    print(f"pseudo_metadata shape: {pseudo_metadata.shape}")
    print("pseudo_metadata head:")
    print(pseudo_metadata.head(3))
    print("---------------------------------------------------")
    
    # 4.1 Add number of cells and log2 number of cells
    pseudo_metadata["n_cells"] = n_cells_per_group.reindex(pseudo_metadata.index)
    pseudo_metadata["log2_ncells"] = np.log2(pseudo_metadata["n_cells"].astype(float))
    print("Added 'n_cells' and 'log2_ncells' to pseudo_metadata.")
    print(pseudo_metadata[["n_cells", "log2_ncells"]].head(3))
    print("---------------------------------------------------")

    # 5. For each comparison, run DESeq
    results_list = []
    print("Step 5: Running DESeq comparisons...")
    for condA, condB in comparisons:
        print(f"  -> Comparison: {condB} vs {condA}")
        keep_idx = pseudo_metadata[condition_col].isin([condA, condB])
        sub_counts = pseudobulk_df[keep_idx].copy()
        sub_meta = pseudo_metadata[keep_idx].copy()
        group_count_condA = (sub_meta[condition_col] == condA).sum()
        group_count_condB = (sub_meta[condition_col] == condB).sum()
        print(f"     Condition '{condA}' sample count: {group_count_condA}")
        print(f"     Condition '{condB}' sample count: {group_count_condB}")
        if not override_min_samples:
            if group_count_condA < min_samples_per_group:
                raise ValueError(
                    f"Condition '{condA}' has only {group_count_condA} samples. "
                    f"Minimum required is {min_samples_per_group}. "
                    f"Set override_min_samples=True to bypass."
                )
            if group_count_condB < min_samples_per_group:
                raise ValueError(
                    f"Condition '{condB}' has only {group_count_condB} samples. "
                    f"Minimum required is {min_samples_per_group}. "
                    f"Set override_min_samples=True to bypass."
                )
        else:
            if (group_count_condA < min_samples_per_group) or (group_count_condB < min_samples_per_group):
                print(
                    f"     WARNING: {condA} has {group_count_condA} samples, "
                    f"{condB} has {group_count_condB} samples. "
                    f"Continuing because override_min_samples=True."
                )
        # Build the local design
        if condition_col not in design_factors:
            design_factors_local = [condition_col] + design_factors
        else:
            design_factors_local = design_factors[:]
        if "log2_ncells" not in design_factors_local:
            design_factors_local.append("log2_ncells")

        reference_level = [condition_col, condA]  # sets condA as the reference
        print("     Using design factors:", f"~ {' + '.join(design_factors_local)}")
        print("     sub_counts shape:", sub_counts.shape)
        print("     sub_meta shape:", sub_meta.shape)
        print("     reference level:", reference_level)

        # Initialize DeseqDataSet
        print("     Initializing DeseqDataSet...")
        dds = DeseqDataSet(
            counts=sub_counts,
            metadata=sub_meta,
            design_factors = design_factors_local,
            ref_level=reference_level, 
            **dds_kwargs
        )
        print("     Running dds.deseq2()...")
        dds.deseq2()
        print("     Initializing DeseqStats and computing results...")
        stat_res = DeseqStats(dds, **dds_stats_kwargs)
        stat_res.summary()
        res_df = stat_res.results_df.copy()
        res_df["comparison"] = f"{condB}_vs_{condA}"
        res_df["condA"] = condA
        res_df["condB"] = condB
        res_df["gene"] = res_df.index
        print("     DE results shape:", res_df.shape)
        print("     DE results head:")
        print(res_df.head(3))
        results_list.append(res_df)
    final_results = pd.concat(results_list, axis=0)
    print("All comparisons finished. final_results shape:", final_results.shape)
    print("========== pseudobulk_and_deseq Completed ==========")
    return final_results

In [4]:
adata = sc.read_h5ad('/nfs/data/COST_IBD/versions/IBD/03_00_00/03_00_03_sub/adata_core_cellxgene_cleaned_coarse_annotation.h5ad')
adata

AnnData object with n_obs × n_vars = 921278 × 42811
    obs: 'sample', 'total_counts', 'n_genes', 'batch', 'n_counts', 'cell_type', 'n_genes_by_counts', 'pct_counts_mt', 'total_counts_mt', 'label:scanvi', 'scvi-global-0.5_leiden', 'scvi-global-1.0_leiden', 'scvi-global-2_leiden', 'scvi-global-1.5_leiden', 'scanvi-global-0.5_leiden', 'scanvi-global-2_leiden', 'scanvi-global-1.0_leiden', 'scanvi-global-1.5_leiden', 'dataset', 'patient', 'sex', 'condition', 'tissue', 'n_cells (after filtering)', 'disease status', 'disease location', 'sequencing protocol', 'biopsy type', 'tissue depth', 'montreal classification_age at diagnosis', 'age group (at sample collection)', 'montreal classification_cd location ', 'montreal classification_ cd behavior', 'ses-cd ', 'montreal classification_ uc extensity', 'montreal classification_ uc severity', 'mayo score-uc', 'medication_antidiarrheal (lomotil, imodium, dto)', 'medication_5_asa', 'medication_biologics ', 'medication_biologics (monoclonal antibody m

In [4]:
#random_indices = np.random.choice(adata.n_obs, size=2000, replace=False)
#adata_subset = adata[random_indices].copy()

adata_subset = sc.read_h5ad('test_subset_ibd.h5ad')

In [5]:
# Suppose your AnnData has:
#   adata.obs['sample_id'] for sample
#   adata.obs['cell_type'] for cell types
#   adata.obs['condition'] for conditions
#   adata.X is raw UMI counts
# Example: pseudobulk by sample + cell_type + condition,
# run DESeq comparing multiple condition pairs,
# with some additional design factor (like batch).
design_factors = ['condition', 'smoking']  # batch must exist in adata.obs
comparisons = [
    ("HC", "UC"),
    ("HC", "CD")
]
res = pseudobulk_and_deseq(
    adata=adata_subset,
    pseudobulk_group_cols=["sample", "annotation:coarse:cleaned", "condition", "smoking"],
    condition_col="condition",
    comparisons=comparisons,
    design_factors=design_factors,
    min_samples_per_group=5,
    override_min_samples=False,
    pseudobulk_func="sum",          # sum is typical for raw counts
    min_cells_in_group=3,           # minimum number of cells per group before aggregating
    dds_kwargs={},  # example advanced parameter
    dds_stats_kwargs={"alpha": 0.05} # e.g. FDR threshold
)
# Check the output
print(res.head())

Grouping columns: ['sample', 'annotation:coarse:cleaned', 'condition', 'smoking']
Condition column: condition
Comparisons: [('HC', 'UC'), ('HC', 'CD')]
Design factors (initial): ['condition', 'smoking']
min_samples_per_group: 5, override_min_samples: False
Aggregation function: sum
min_cells_in_group: 3
---------------------------------------------------
Step 0: Checking columns in adata.obs...
Total cells in adata: 10000
---------------------------------------------------
Created grouping key 'pseudobulk_group_key' in adata.obs.
Example grouping key values:
original_index
smillie_N7.LPB.TAACTCACTGCTAG       Smillie_N7_B--B_plasma--UC--Never
kong_N119540_L1-CAGCAATTCCTTATAC       N119540_1--T_NK_ILC--CD--Never
kong_N175041_N1-GAAGCAGGTTGATTGC        N175041_1--stromal--CD--Never
smillie_N24.LPA.TGCTACCTCCCGACTT    Smillie_N24_A--stromal--UC--Never
kong_N176196_L2-TACCCGTGTATGGAGC       N176196_2--T_NK_ILC--CD--Never
Name: pseudobulk_group_key, dtype: object
----------------------------

                be converted to hyphens ('-').
  dds = DeseqDataSet(
Fitting size factors...


     Running dds.deseq2()...


... done in 0.54 seconds.

Fitting dispersions...
... done in 102.53 seconds.

Fitting dispersion trend curve...
  self._fit_parametric_dispersion_trend(vst)
... done in 0.64 seconds.

Fitting MAP dispersions...
... done in 152.08 seconds.

Fitting LFCs...
... done in 293.06 seconds.

Calculating cook's distance...
... done in 1.03 seconds.

Replacing 24878 outlier genes.

Fitting dispersions...
... done in 91.23 seconds.

Fitting MAP dispersions...
... done in 135.91 seconds.

Fitting LFCs...
... done in 281.06 seconds.



     Initializing DeseqStats and computing results...


Running Wald tests...
... done in 20.05 seconds.



Log2 fold change & Wald test p-value: log2-ncells 2.0 vs 1.584962500721156
                     baseMean  log2FoldChange     lfcSE      stat    pvalue  \
5S_rRNA              0.000000             NaN       NaN       NaN       NaN   
7SK                  0.414558       -0.158692  6.667469 -0.023801  0.981011   
7SK.1                0.006479       -0.127310  6.668549 -0.019091  0.984768   
7SK.2                1.876246        0.997783  6.642888  0.150203  0.880604   
7SK_ENSG00000232512  0.000000             NaN       NaN       NaN       NaN   
...                       ...             ...       ...       ...       ...   
hsa-mir-490          0.000000             NaN       NaN       NaN       NaN   
hsa-mir-6080         0.030987       -0.220023  6.667754 -0.032998  0.973676   
hsa-mir-8072         0.199855       -0.210286  5.197445 -0.040459  0.967727   
pL63_blaR            0.030112       -0.135003  6.668162 -0.020246  0.983847   
snoU13               0.021866       -0.123285  6.668961 

                be converted to hyphens ('-').
  dds = DeseqDataSet(
Fitting size factors...


     Running dds.deseq2()...


... done in 0.57 seconds.

Fitting dispersions...
... done in 116.40 seconds.

Fitting dispersion trend curve...
  self._fit_parametric_dispersion_trend(vst)
... done in 0.95 seconds.

Fitting MAP dispersions...
... done in 169.10 seconds.

Fitting LFCs...
... done in 270.88 seconds.

Calculating cook's distance...
... done in 1.53 seconds.

Replacing 26420 outlier genes.

Fitting dispersions...
... done in 100.23 seconds.

Fitting MAP dispersions...
... done in 156.52 seconds.

Fitting LFCs...
... done in 242.31 seconds.



     Initializing DeseqStats and computing results...


Running Wald tests...
... done in 19.31 seconds.



Log2 fold change & Wald test p-value: log2-ncells 2.0 vs 1.584962500721156
                     baseMean  log2FoldChange     lfcSE      stat    pvalue  \
5S_rRNA              0.000000             NaN       NaN       NaN       NaN   
7SK                  0.064560       -0.364848  5.622763 -0.064888  0.948263   
7SK.1                0.000000             NaN       NaN       NaN       NaN   
7SK.2                0.000000             NaN       NaN       NaN       NaN   
7SK_ENSG00000232512  0.000000             NaN       NaN       NaN       NaN   
...                       ...             ...       ...       ...       ...   
hsa-mir-490          0.000000             NaN       NaN       NaN       NaN   
hsa-mir-6080         0.008352       -0.364808  5.622797 -0.064880  0.948269   
hsa-mir-8072         0.074415       -0.372757  1.866777 -0.199679  0.841731   
pL63_blaR            0.011997       -0.365023  5.622749 -0.064919  0.948239   
snoU13               0.005510       -0.367616  5.622781 

In [8]:
res.to_csv('output.csv')

In [13]:
adata_subset.obs[["sample", "annotation:coarse:cleaned", "condition", "smoking"]].head()

pseudobulk_group_cols=["sample", "annotation:coarse:cleaned", "condition", "smoking"]
condition_col="condition"
design_factors = ['condition', 'smoking'] 

Unnamed: 0_level_0,sample,annotation:coarse:cleaned,condition,smoking
original_index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
smillie_N7.LPB.TAACTCACTGCTAG,Smillie_N7_B,B_plasma,UC,Never
kong_N119540_L1-CAGCAATTCCTTATAC,N119540_1,T_NK_ILC,CD,Never
kong_N175041_N1-GAAGCAGGTTGATTGC,N175041_1,stromal,CD,Never
smillie_N24.LPA.TGCTACCTCCCGACTT,Smillie_N24_A,stromal,UC,Never
kong_N176196_L2-TACCCGTGTATGGAGC,N176196_2,T_NK_ILC,CD,Never


In [41]:
# 1. Create a single grouping key
grouping_key = "pseudobulk_group_key"
while grouping_key in adata_subset.obs.columns:
    grouping_key += "_dup"
adata_subset.obs[grouping_key] = (
    adata_subset.obs[pseudobulk_group_cols]
    .astype(str)
    .apply(lambda x: "--".join(x), axis=1)
)
print(f"Created grouping key '{grouping_key}' in adata.obs.")
print("Example grouping key values:")
print(adata_subset.obs[grouping_key].head(5))

True

In [None]:
# 2. Convert to a cell-level counts DataFrame
print("Step 2: Creating cell_counts_df from adata.X...")
if not hasattr(adata_subset.X, "toarray"):
    counts_matrix = adata_subset.layers['raw']  # already dense
    print("Data is already dense.")
else:
    counts_matrix = adata_subset.layers['raw'].toarray()  # convert sparse to dense
    print("Converted sparse matrix to dense.")

# Round only the non-integer values to the next highest integer
non_integer_mask = (counts_matrix % 1 != 0)
counts_matrix[non_integer_mask] = np.ceil(counts_matrix[non_integer_mask])

# Verify that all values are now integers
assert np.all(counts_matrix % 1 == 0), "There are still non-integer values!"
print("Non-integer values rounded up successfully.")

gene_names = adata_subset.var_names.tolist()
cell_counts_df = pd.DataFrame(
    counts_matrix,
    columns=gene_names,
    index=adata_subset.obs.index
)
cell_counts_df[grouping_key] = adata_subset.obs[grouping_key].values
print(f"cell_counts_df shape: {cell_counts_df.shape}")
print("---------------------------------------------------")

In [None]:
# 3. Group and aggregate
print("Step 3: Grouping and aggregating...")
grouped = cell_counts_df.groupby(grouping_key)
group_sizes = grouped.size()
valid_groups = group_sizes[group_sizes >= min_cells_in_group].index
print(f"Number of total groups: {len(group_sizes)}")
print(f"Number of valid groups (>= {min_cells_in_group} cells): {len(valid_groups)}")
if len(valid_groups) == 0:
    raise ValueError("No groups found with at least min_cells_in_group cells.")
cell_counts_df = cell_counts_df[cell_counts_df[grouping_key].isin(valid_groups)]
grouped = cell_counts_df.groupby(grouping_key)
if not hasattr(pd.core.groupby.generic.DataFrameGroupBy, pseudobulk_func):
    raise ValueError(f"Invalid pseudobulk_func '{pseudobulk_func}'. "
                        "Must be a valid pandas groupby agg method.")
pseudobulk_df = grouped.aggregate(pseudobulk_func)
print("Aggregated pseudobulk dataframe shape:", pseudobulk_df.shape)
print("---------------------------------------------------")

# 3.1 Also store the actual number of cells in each group
n_cells_per_group = grouped.size()
if grouping_key in pseudobulk_df.columns:
    pseudobulk_df.drop(columns=grouping_key, inplace=True, errors="ignore")

In [None]:
# 4. Build pseudobulk-level metadata
print("Step 4: Building metadata for pseudobulk samples...")
meta_list = []
for idx in pseudobulk_df.index:
    split_vals = idx.split("--")
    meta_dict = {}
    for col, val in zip(pseudobulk_group_cols, split_vals):
        meta_dict[col] = val
    meta_list.append(meta_dict)
pseudo_metadata = pd.DataFrame(meta_list, index=pseudobulk_df.index)
print(f"pseudo_metadata shape: {pseudo_metadata.shape}")
print("pseudo_metadata head:")
print(pseudo_metadata.head(3))
print("---------------------------------------------------")

# 4.1 Add number of cells and log2 number of cells
pseudo_metadata["n_cells"] = n_cells_per_group.reindex(pseudo_metadata.index)
pseudo_metadata["log2_ncells"] = np.log2(pseudo_metadata["n_cells"].astype(float))
print("Added 'n_cells' and 'log2_ncells' to pseudo_metadata.")
print(pseudo_metadata[["n_cells", "log2_ncells"]].head(3))
print("---------------------------------------------------")

In [None]:
# 5. For each comparison, run DESeq
results_list = []
print("Step 5: Running DESeq comparisons...")
for condA, condB in comparisons:
    print(f"  -> Comparison: {condB} vs {condA}")
    keep_idx = pseudo_metadata[condition_col].isin([condA, condB])
    sub_counts = pseudobulk_df[keep_idx].copy()
    sub_meta = pseudo_metadata[keep_idx].copy()
    group_count_condA = (sub_meta[condition_col] == condA).sum()
    group_count_condB = (sub_meta[condition_col] == condB).sum()
    print(f"     Condition '{condA}' sample count: {group_count_condA}")
    print(f"     Condition '{condB}' sample count: {group_count_condB}")
    if not override_min_samples:
        if group_count_condA < min_samples_per_group:
            raise ValueError(
                f"Condition '{condA}' has only {group_count_condA} samples. "
                f"Minimum required is {min_samples_per_group}. "
                f"Set override_min_samples=True to bypass."
            )
        if group_count_condB < min_samples_per_group:
            raise ValueError(
                f"Condition '{condB}' has only {group_count_condB} samples. "
                f"Minimum required is {min_samples_per_group}. "
                f"Set override_min_samples=True to bypass."
            )
    else:
        if (group_count_condA < min_samples_per_group) or (group_count_condB < min_samples_per_group):
            print(
                f"     WARNING: {condA} has {group_count_condA} samples, "
                f"{condB} has {group_count_condB} samples. "
                f"Continuing because override_min_samples=True."
            )
    # Build the local design
    if condition_col not in design_factors:
        design_factors_local = [condition_col] + design_factors
    else:
        design_factors_local = design_factors[:]
    if "log2_ncells" not in design_factors_local:
        design_factors_local.append("log2_ncells")

    reference_level = [condition_col, condA]  # sets condA as the reference
    print("     Using design factors:", f"~ {' + '.join(design_factors_local)}")
    print("     sub_counts shape:", sub_counts.shape)
    print("     sub_meta shape:", sub_meta.shape)
    print("     reference level:", reference_level)

    # Initialize DeseqDataSet
    print("     Initializing DeseqDataSet...")
    dds = DeseqDataSet(
        counts=sub_counts,
        metadata=sub_meta,
        design_factors = design_factors_local,
        ref_level=reference_level, 
        **dds_kwargs
    )
    print("     Running dds.deseq2()...")
    dds.deseq2()
    print("     Initializing DeseqStats and computing results...")
    stat_res = DeseqStats(dds, **dds_stats_kwargs)
    stat_res.summary()
    res_df = stat_res.results_df.copy()
    res_df["comparison"] = f"{condB}_vs_{condA}"
    res_df["condA"] = condA
    res_df["condB"] = condB
    res_df["gene"] = res_df.index
    print("     DE results shape:", res_df.shape)
    print("     DE results head:")
    print(res_df.head(3))
    results_list.append(res_df)
final_results = pd.concat(results_list, axis=0)

In [44]:
mask = (adata_subset.layers['raw'].toarray() % 1 != 0)

rows, cols = np.where(mask)

sample_ids = adata_subset.obs_names[rows]  # Sample names
gene_names = adata_subset.var_names[cols]  # Gene names
non_integer_values = adata_subset.layers['raw'].toarray()[rows, cols]

df_non_integer = pd.DataFrame({
    "sample_id": sample_ids,
    "gene": gene_names,
    "value": non_integer_values
})

# Display the first few cases
print(df_non_integer.head())

                              sample_id      gene     value
0  parikh_GSM3214209_TCAGGATAGAAGGTGA-1  C18orf32  0.333333
1  parikh_GSM3214209_TCAGGATAGAAGGTGA-1     C1QBP  1.500000
2  parikh_GSM3214209_TCAGGATAGAAGGTGA-1     CLDN7  1.500000
3  parikh_GSM3214209_TCAGGATAGAAGGTGA-1    CORO1B  0.500000
4  parikh_GSM3214209_TCAGGATAGAAGGTGA-1     DGAT1  3.500000


In [29]:
random_indices2 = np.random.choice(adata.n_obs, size=1000, replace=False)
adata_subset2 = adata[random_indices].copy()

In [32]:
adata_subset2.write_h5ad('test_subset_ibd.h5ad')