In [None]:
import scanpy as sc
import scipy
from scipy.sparse import issparse
import pandas as pd
import numpy as np
import anndata as ad
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
import gc
import re
import psutil
from typing import Optional, Tuple, Dict, Any
import logging

# configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# load male and female adata and aggregate
adata_corrected_female = sc.read_h5ad('/ocean/projects/cis240075p/asachan/datasets/TA_muscle/ERCC1_KO_mice/samples_2025/objects/cellbender_corrected_female_2025.h5ad')
adata_corrected_male = sc.read_h5ad('/ocean/projects/cis240075p/asachan/datasets/TA_muscle/ERCC1_KO_mice/samples_2025/objects/cellbender_corrected_male_2025.h5ad')

In [None]:
adata_corrected_male

In [None]:
adata_corrected_female

In [None]:
same_set = set(adata_corrected_female.var_names) == set(adata_corrected_male.var_names)
print("Same gene set?", same_set)
print("A genes:", adata_corrected_female.n_vars, "B genes:", adata_corrected_male.n_vars)

#### Concat the two pools and name their batches

In [None]:
for adx in (adata_corrected_female, adata_corrected_male):
    adx.var_names = adx.var_names.astype(str)
    adx.var_names_make_unique()

In [None]:
adata_cellbender_combined = ad.concat(
    {"male": adata_corrected_male, "female": adata_corrected_female},
    axis=0,
    join="inner",         # use "outer" if you want union of genes (missing filled with 0)
    label="batch",
    index_unique=None     # keep original cell barcodes (ensure they won't collide)
)

In [None]:
adata_cellbender_combined

In [None]:
adata_cellbender_combined.obs['batch'].value_counts()

In [None]:
def remove_duplicates(adata, axis='obs', strategy='first'):
    """
    Remove barcode duplicates with different strategies
    
    Parameters:
    - axis: 'obs' for cells, 'var' for genes
    - strategy: 'first', 'last', 'highest_counts' (for obs only)
    """
    if axis == 'obs':
        index = adata.obs.index
        if strategy == 'first':
            keep_mask = ~index.duplicated(keep='first')
        elif strategy == 'last':
            keep_mask = ~index.duplicated(keep='last')
        elif strategy == 'highest_counts':
            # Keep the cell with highest UMI counts among duplicates
            keep_indices = []
            for name in index.unique():
                name_mask = index == name
                if name_mask.sum() > 1:  # If duplicated
                    if 'nCount_RNA' in adata.obs.columns:
                        best_idx = adata.obs.loc[name_mask, 'nCount_RNA'].idxmax()
                        keep_indices.extend([i for i, idx in enumerate(adata.obs.index) if idx == best_idx])
                    else:
                        # Fall back to first if no count info
                        keep_indices.append(np.where(name_mask)[0][0])
                else:
                    keep_indices.extend(np.where(name_mask)[0])
            keep_mask = np.zeros(adata.n_obs, dtype=bool)
            keep_mask[keep_indices] = True
        
        return adata[keep_mask].copy()
    
    else:  # axis == 'var'
        index = adata.var.index
        if strategy == 'first':
            keep_mask = ~index.duplicated(keep='first')
        elif strategy == 'last':
            keep_mask = ~index.duplicated(keep='last')
        else:
            keep_mask = ~index.duplicated(keep='first')  # Default to first
        
        return adata[:, keep_mask].copy()

In [None]:
print("\n=== Removing Duplicates ===")
original_shape = adata_cellbender_combined.shape
# Remove cell duplicates (keeping first occurrence)
adata_subset = remove_duplicates(adata_cellbender_combined, axis='obs', strategy='first')

In [None]:
dup_mask = adata_subset.obs.index.duplicated(keep=False)

# Extract duplicated barcodes and their counts
duplicates = adata_subset.obs.index[dup_mask]
duplicate_cells = duplicates.value_counts()

print("Number of duplicated barcodes in corrected data:", len(duplicate_cells))
print(duplicate_cells)
#find if common genes is unique in index object
unique_genes_corrected = adata_subset.var.index.unique()
print(f"Unique genes in corrected data: {len(unique_genes_corrected)}")

In [None]:
adata_subset

In [None]:
# save the combined adata
adata_subset.write_h5ad('/ocean/projects/cis240075p/asachan/datasets/TA_muscle/ERCC1_KO_mice/samples_2025/objects/cellbender_corrected_aggr_2025_dedup.h5ad')

## Combine the two adatas

In [None]:
# load datasets
adata_corrected = sc.read_h5ad('/ocean/projects/cis240075p/asachan/datasets/TA_muscle/ERCC1_KO_mice/samples_2025/objects/cellbender_corrected_aggr_2025_dedup.h5ad')
adata_uncorrected = sc.read_h5ad('/ocean/projects/cis240075p/asachan/datasets/TA_muscle/ERCC1_KO_mice/samples_2025/objects/ERCC1_KO_mice_aggr_updated.h5ad')

In [None]:
adata_uncorrected.obs

In [None]:
# Use regex to replace any suffix _2 through _8 with _1
adata_uncorrected.obs.index = adata_uncorrected.obs.index.str.replace(r'-[2-8]$', '-1', regex=True)

In [None]:
# check of cell ranger uncorrected data has duplicates
dup_mask = adata_uncorrected.obs.index.duplicated(keep=False)

# Extract duplicated barcodes and their counts
duplicates = adata_uncorrected.obs.index[dup_mask]
duplicate_cells = duplicates.value_counts()

print("Number of duplicated barcodes in corrected data:", len(duplicate_cells))
print(duplicate_cells)
#find if common genes is unique in index object
unique_genes_corrected = adata_uncorrected.var.index.unique()
print(f"Unique genes in corrected data: {len(unique_genes_corrected)}")

In [None]:
print("\n=== Removing Duplicates ===")
original_shape = adata_uncorrected.shape
# Remove cell duplicates (keeping first occurrence)
adata_uncorrected = remove_duplicates(adata_uncorrected, axis='obs', strategy='first')

In [None]:
# check of cell ranger uncorrected data has duplicates
dup_mask = adata_uncorrected.obs.index.duplicated(keep=False)

# Extract duplicated barcodes and their counts
duplicates = adata_uncorrected.obs.index[dup_mask]
duplicate_cells = duplicates.value_counts()

print("Number of duplicated barcodes in corrected data:", len(duplicate_cells))
print(duplicate_cells)
#find if common genes is unique in index object
unique_genes_corrected = adata_uncorrected.var.index.unique()
print(f"Unique genes in corrected data: {len(unique_genes_corrected)}")

In [None]:
# find overlapping cells
common_cells = adata_uncorrected.obs.index.intersection(adata_corrected.obs.index)
common_genes = adata_uncorrected.var.index.intersection(adata_corrected.var.index)
print(f"Found {len(common_cells)} common cells")
print(f"Found {len(common_genes)} common genes")

In [None]:
adata_orig_subset = adata_uncorrected[common_cells, common_genes].copy()
adata_corrected_subset = adata_corrected[common_cells, common_genes].copy()

In [None]:
# copy over all processed data
adata_combined = adata_orig_subset.copy()

In [None]:
# add cellbender counts as layer
adata_combined.layers['cellbender'] = adata_corrected_subset.X.copy()

In [None]:
print("Transferring cell metadata (obs)...")
for col in adata_corrected_subset.obs.columns:
    if col not in adata_combined.obs.columns:
        # Initialize new column with appropriate dtype
        if adata_corrected_subset.obs[col].dtype == 'object':
            adata_combined.obs[col] = pd.NA
        else:
            adata_combined.obs[col] = np.nan
    
    # Update values for common cells
    adata_combined.obs[col] = adata_corrected_subset.obs[col]

In [None]:
adata_combined.obs

In [None]:
# Transfer obsm with logging
print(f"Transferring {len(adata_corrected_subset.obsm.keys())} obsm entries:")
for key in adata_corrected_subset.obsm.keys():
    print(f"  - {key}: {adata_corrected_subset.obsm[key].shape}")
    adata_combined.obsm[key] = adata_corrected_subset.obsm[key].copy()

# Transfer varm with logging
print(f"Transferring {len(adata_corrected_subset.varm.keys())} varm entries:")
for key in adata_corrected_subset.varm.keys():
    print(f"  - {key}: {adata_corrected_subset.varm[key].shape}")
    adata_combined.varm[key] = adata_corrected_subset.varm[key].copy()

# Transfer uns with logging
print(f"Transferring {len(adata_corrected_subset.uns.keys())} uns entries:")
for key in adata_corrected_subset.uns.keys():
    print(f"  - {key}")
    adata_combined.uns[key] = adata_corrected_subset.uns[key]

In [None]:
adata_combined

# [Pre-proc]

In [None]:
infile = '/ocean/projects/cis240075p/asachan/datasets/TA_muscle/ERCC1_KO_mice/samples_2025/objects/query_combined.h5ad'

In [None]:
adata = sc.read_h5ad(infile)

In [None]:
adata

In [None]:
adata.X = adata.layers['cellbender']

In [None]:
# add the vars you need to get per-gene metrics on (for corrected counts add new var columns)
adata.var['mt_cb'] = adata.var_names.str.startswith('mt-')
adata.var['ribo_cb'] = adata.var_names.str.startswith(('rps', 'rpl'))
adata.var["hb_cb"] = adata.var_names.str.contains("^hb[^(p)]")

sc.pp.calculate_qc_metrics(
    adata, qc_vars=["mt_cb", "ribo_cb", "hb_cb"], inplace=True, percent_top=[20], log1p=True
)
adata

In [None]:
# cut off cells to compare cellbender vs uncorrected counts on the same scale
adata = adata[adata.obs['total_counts_mt_cb'] < 300]
adata = adata[adata.obs['total_counts_mt'] < 300]

In [None]:
sc.pl.violin(
    adata,
    ["total_counts_mt", "total_counts_mt_cb"],
    groupby='sample_id',
    jitter=0.4,
    multi_panel=True,
    rotation=90
)

#### Filter cells based on cell bender probability

In [None]:
# cell bender probability of being cell and not soup
adata = adata[adata.obs['cell_probability'] > 0.99]
# remove the barcodes that are predicted to be doublets
adata = adata[adata.obs['predicted_doublet'] == False]

In [None]:
# get the number of F and M cells from sex obs
adata.obs['sex'].value_counts()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
p1 = sns.histplot(adata.layers["cellbender"].sum(1), bins=100, kde=False, ax=axes[0])
axes[0].set_title("Cellbender counts")
p2 = sns.histplot(adata.layers["counts"].sum(1), bins=100, kde=False, ax=axes[1])
axes[1].set_title("Uncorrected raw counts")
plt.show()

### Normalization

In [None]:
adata.X = adata.layers['cellbender'].copy()
# normalize cellbender counts
scales_counts = sc.pp.normalize_total(adata, target_sum=None, inplace=False)
# log1p transform
adata.layers["log1p_norm_cb"] = sc.pp.log1p(scales_counts["X"], copy=True)

In [None]:
# save the normalized counts
adata.write_h5ad('/ocean/projects/cis240075p/asachan/datasets/TA_muscle/ERCC1_KO_mice/samples_2025/objects/query_norm_v1.h5ad')