In [None]:
import os
import numpy as np
import pandas as pd
from scipy.stats import median_abs_deviation

import matplotlib.pyplot as plt
import seaborn as sns

import anndata as ad
import scanpy as sc
import doubletdetection

# scanpy settingsdid y
sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=150)
sc.logging.print_header()

# preferred seaborn plotting
sns.set_context('notebook', font_scale=1)
sns.set_style('ticks')

blues = sns.color_palette('Blues', 9)
reds = sns.color_palette('Reds', 9)
greens = sns.color_palette('Greens', 9)
grays = sns.color_palette('Greys', 9)

In [None]:
base_path = '/raw_data/'

samples = ["list all samples here"]

In [None]:
adatas = []
for sample in samples:
    adata_path = os.path.join(base_path,sample,'cellbender_with_metadata.h5ad')
    adata = sc.read_h5ad(adata_path)
    print(f'Sample {sample}: loaded {adata.shape[0]} barcodes, {adata.shape[1]} genes')
    adatas.append(adata)

In [None]:
total_raw_cells = 0
total_low_counts = 0

for i, adata in enumerate(adatas):
    print("")
    sample_name = adata.obs['sample'].iloc[0]
    print(f"sample: {sample_name}")
    
    n_cells = adata.shape[0]
    total_raw_cells += n_cells
    low_counts = adata.obs['total_counts'] < 500
    
    n_low_counts = sum(low_counts)
    total_low_counts += n_low_counts
    print(f"mumber of cells with fewer than 500 counts: {n_low_counts} of {n_cells} total cells ({n_low_counts/n_cells:.0%})")
    sc.pp.filter_cells(adata, min_counts = 500)
print("****************")
print(f"Removed {total_low_counts} of {total_raw_cells} total raw cells ({n_low_counts/n_cells:.0%})")

In [None]:
metric = 'pct_counts_mito'

# Set the mitochondrial threshold
max_mito_pct = 8.0

# Number of samples
n_samples = len(adatas)

# Create a figure with 4x4 subplots (one sample per subplot)
fig, axs = plt.subplots(4, 4, figsize=(16, 16))

# Flatten the axes for easier iteration
axs = axs.flatten()

# Loop over each adata object and make QC plots
for i, adata in enumerate(adatas):
    # Extract the sample name
    sample_name = adata.obs['sample'].iloc[0]
    
    # Set mitochondrial threshold for this adata object
    adata.uns[metric + '_min'] = 0.0
    adata.uns[metric + '_max'] = max_mito_pct
    
    adata.obs[metric + '_outlier'] = adata.obs[metric] > max_mito_pct
    
    # Plot in the ith subplot
    ax = axs[i]
    
    # Plot 1: Histogram of mitochondrial proportions
    sns.histplot(
        adata.obs[metric],
        bins=100,
        color=blues[6],
        ax=ax
    )
    ax.axvline(
        x=adata.uns[metric + '_max'],
        color=reds[3],
        linestyle='--'
    )
    ax.set_xlabel('% mitochondrial reads')
    ax.set_yscale('log')
    ax.set_title(f"{sample_name}", fontsize=14)

# Hide unused subplots
for j in range(len(adatas), len(axs)):
    axs[j].axis('off')
    
#plt.savefig(f"./qc_figs/{modality}_filtered_all_mito_percent.png", dpi=300, bbox_inches="tight")

# Adjust layout
plt.tight_layout()
plt.show()

In [None]:
def log1p(x):
    y = np.log10(x + 1)
    return y

def inv_log1p(y):
    x = np.power(10,y) - 1
    return x

def is_outlier(adata, metric: str, nmads: int):
    """
    Flag cells/barcodes for which `metric` is outside
    of `nmads` median absolute deviations.
    Store the bounds used here in `adata.uns` for future filtering
    """
    
    print(f'checking for outliers with {metric}')
    
    M = adata.obs[metric]
    if (metric == 'total_counts') | (metric == 'n_genes_by_counts') | (metric == 'pct_counts_mito'):
        print(f'log1p transforming {metric}')
        M = log1p(M)
            
    median_val = np.nanmedian(M)
    print(f"Median: {median_val}")
    mad_val = median_abs_deviation(M, nan_policy='omit')
    print(f"MAD: {mad_val}")


    # If both median and MAD are zero, check for the pct_counts_mito case and recompute on nonzero values
    if median_val == 0 and mad_val == 0:
        print(f"Warning: Median and MAD are both zero for {metric}.")
        if metric == 'pct_counts_mito':
            # Subset M to nonzero values
            nonzero_M = M[M > 0]
            if len(nonzero_M) > 0:
                median_val = np.median(nonzero_M)
                mad_val = median_abs_deviation(nonzero_M)
                print("Recomputed on nonzero values:")
                print(f"New Median: {median_val}")
                print(f"New MAD: {mad_val}")
            else:
                print("No nonzero values found for pct_counts_mito.")
    elif mad_val == 0:
        print(f"Warning: MAD is zero for {metric}. Using only the median.")
#        return M > median_val + nmads or M < median_val - nmads
    
    lower_bound = median_val - nmads * mad_val
    upper_bound = median_val + nmads * mad_val
    
    print(f"Lower bound: {lower_bound}")
    print(f"Upper bound: {upper_bound}")
    
    # Compute outliers based on MAD
    outlier = (M < lower_bound) | (M > upper_bound)
    
    if (metric == 'total_counts') | (metric == 'n_genes_by_counts') | (metric == 'pct_counts_mito'):
        lower_bound = inv_log1p(lower_bound)
        upper_bound = inv_log1p(upper_bound)
        
        
    adata.uns[metric + '_min'] = lower_bound
    adata.uns[metric + '_max'] = upper_bound

    return outlier

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(16, 16))
axs = axs.flatten()

metric = 'pct_counts_intronic'

n_bad_cells = 0
total_cells = 0

# loop over all adata objects
for i, adata in enumerate(adatas):
    sample_name = adata.obs['sample'].iloc[0]
    print(f"Sample {sample_name}")
    outliers = is_outlier(adata, metric, 5)
    adata.obs[metric + '_outlier'] = outliers
    n_bad_cells += sum(outliers)
    total_cells += adata.shape[0]

    # Histogram of the selected metric
    sns.histplot(
        adata.obs[metric],
        bins=50, color=blues[6], log_scale=False, ax=axs[i], kde=True
    )
    
    adatas[i] = adata.copy()
    adata.uns[metric + '_max'] = 100
    
    # Add bounds
    lower_bound = adata.uns.get(f'{metric}_min', None)
    
    upper_bound = adata.uns.get(f'{metric}_max', None)
        
    if lower_bound is not None:
        axs[i].axvline(
            x=lower_bound,
            color=reds[5], linestyle='--'
        )
    if upper_bound is not None:
        axs[i].axvline(
            x=upper_bound,
            color=reds[5], linestyle='--'
        )
    
    # Add labels and title
    axs[i].set_xlabel(metric.replace('_', ' ').title())
    axs[i].set_ylabel('# Barcodes')
    axs[i].set_title(sample_name, fontsize=14)

    # Optional: Add grid lines
    axs[i].grid(visible=True, which='both', linestyle='--', linewidth=0.5)

# Hide unused subplots
for j in range(len(adatas), len(axs)):
    axs[j].axis('off')

# Save the plot
plt.tight_layout()
plt.show()
##plt.savefig(f"./qc_figs/{modality}_filtered_count_depth_plots.png", dpi=300, bbox_inches="tight")

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(16, 16))
axs = axs.flatten()

metric = 'total_counts'

n_bad_cells = 0
total_cells = 0

# Loop over all adata objects
for i, adata in enumerate(adatas):
    sample_name = adata.obs['sample'].iloc[0]
    print(f"Sample {sample_name}")
    outliers = is_outlier(adata, metric, 3)
    adata.obs[metric + '_outlier'] = outliers
    n_bad_cells += sum(outliers)
    total_cells += adata.shape[0]

    # Histogram of the selected metric
    sns.histplot(
        adata.obs[metric],
        bins=50, color=blues[6], log_scale=True, ax=axs[i], kde=True
    )
    
    adatas[i] = adata.copy()
    
    # Add bounds
    lower_bound = adata.uns.get(f'{metric}_min', None)
    upper_bound = adata.uns.get(f'{metric}_max', None)
    
    if lower_bound is not None:
        axs[i].axvline(
            x=lower_bound,
            color=reds[5], linestyle='--'
        )
    if upper_bound is not None:
        axs[i].axvline(
            x=upper_bound,
            color=reds[5], linestyle='--'
        )
    
    # Add labels and title
    axs[i].set_xlabel(metric.replace('_', ' ').title())
    axs[i].set_ylabel('# Barcodes')
    axs[i].set_title(sample_name, fontsize=14)

    # Optional: Add grid lines
    axs[i].grid(visible=True, which='both', linestyle='--', linewidth=0.5)

# Hide unused subplots
for j in range(len(adatas), len(axs)):
    axs[j].axis('off')

# Save the plot
plt.tight_layout()
plt.show()
##plt.savefig(f"./qc_figs/{modality}_filtered_count_depth_plots.png", dpi=300, bbox_inches="tight")

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(16, 16))
axs = axs.flatten()

metric = 'n_genes_by_counts'

n_bad_cells = 0
total_cells = 0

# Loop over all adata objects
for i, adata in enumerate(adatas):
    sample_name = adata.obs['sample'].iloc[0]
    print(f"Sample {sample_name}")
    outliers = is_outlier(adata, metric, 3)
    adata.obs[metric + '_outlier'] = outliers
    n_bad_cells += sum(outliers)
    total_cells += adata.shape[0]

    # Histogram of the selected metric
    sns.histplot(
        adata.obs[metric],
        bins=50, color=blues[6], log_scale=True, ax=axs[i], kde=True
    )
    
    # Add bounds
    lower_bound = adata.uns.get(f'{metric}_min', None)
#    lower_bound = 500
    upper_bound = adata.uns.get(f'{metric}_max', None)
    
    if lower_bound is not None:
        axs[i].axvline(
            x=lower_bound,
            color=reds[5], linestyle='--'
        )
    if upper_bound is not None:
        axs[i].axvline(
            x=upper_bound,
            color=reds[5], linestyle='--'
        )
    
    # Add labels and title
    axs[i].set_xlabel(metric.replace('_', ' ').title())
    axs[i].set_ylabel('# Barcodes')
    axs[i].set_title(sample_name, fontsize=14)

    # Optional: Add grid lines
    axs[i].grid(visible=True, which='both', linestyle='--', linewidth=0.5)

# Hide unused subplots
for j in range(len(adatas), len(axs)):
    axs[j].axis('off')

# Save the plot
plt.tight_layout()
#plt.savefig(f"./qc_figs/{modality}_filtered_count_depth_plots.png", dpi=300, bbox_inches="tight")

In [None]:
def filter_outliers(adata, metric):
    good_cells = ~adata.obs[metric + '_outlier']
    n_good_cells = sum(good_cells)
    print(f"Removing {n_good_cells} outliers for {metric}.")
    return adata[good_cells].copy()

In [None]:
metrics = ['pct_counts_mito', 'total_counts', 'n_genes_by_counts']
for i, adata in enumerate(adatas):
    n_cells = adata.shape[0]
    sample_name = adata.obs['sample'].iloc[0]
    print(f"Sample {sample_name}")
    print(f"Total cells before filtering: {n_cells}")
    
    for metric in metrics:
        adata = filter_outliers(adata, metric)
        
    n_cells_left = adata.shape[0]
    print(f"After filtering, {n_cells_left} remain ({((n_cells - n_cells_left) / n_cells):.0%} removed).")
    
    adatas[i] = adata.copy()
    
    print("")

In [None]:
for i, adata in enumerate(adatas):
    sample = adata.obs['sample'].iloc[0]
    print("")
    print("*************")
    print(f"Running doubletdetection on {sample} ({i+1}/{len(adatas)})")
    
    clf = doubletdetection.BoostClassifier(n_iters=40,
                                       clustering_algorithm="leiden",
                                       standard_scaling=True,
                                       n_top_var_genes=3000,
                                       n_jobs=8)  # change n_jobs to your number of cores if you can parallelize
    X = adata.raw.X
    doubletdetection_preds = clf.fit(X).predict(p_thresh=1e-7, voter_thresh=0.5)
    convergence_plot = doubletdetection.plot.convergence(clf,
                                                         show=True,
                                                         p_thresh=1e-7,
                                                         voter_thresh=0.5,
                                                         save = f'./figures/qc/doublet_convergence/{sample}_doublet_convergence.pdf')
    
    adata.obs['doubletdetection_predicted_doublet'] = doubletdetection_preds.astype(bool)
    adata.obs['doubletdetection_doublet_score'] = clf.doublet_score()
        
    adatas[i] = adata.copy()

In [None]:
adatas_copy = adatas.copy()

In [None]:
for i, adata in enumerate(adatas):
    adata = adatas[i].copy()
    sample = adata.obs['sample'].iloc[0]
    print(f"removing doublets from {sample}")
    adata = adata[adata.obs['doubletdetection_predicted_doublet'] == False].copy()
    adatas[i] = adata.copy()
    outpath = os.path.join(base_path,sample,'cellbender_qc_applied.h5ad')
    print(f"saving {sample} to {outpath}")
    adata.write_h5ad(outpath)