# Doublet Detection

**Pinned Environment:** [`envs/sc-scrublet.yaml`](../../envs/sc-scrublet.yaml)  

In [None]:
import os
from pathlib import Path
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scrublet as scr
import sys

In [None]:
plt.rcParams["figure.figsize"] = (3, 3)

In [None]:
sys.path.append(str(Path.cwd().resolve().parents[1]))

from config.paths import BASE_DIR

data_dir = BASE_DIR / "data/h5ad/export_01/03_peyers-removed"
scrublet_dir = BASE_DIR / "scrublet"
output_dir = BASE_DIR / "data/h5ad/export_01/04_scrublet"

scrublet_dir.mkdir(parents=True, exist_ok=True)
output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# Import AnnData objects
sample_files = [
    os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".h5ad")
]
adata_list = [sc.read_h5ad(f) for f in sample_files] 

# Summary print statements
for i, adata in enumerate(adata_list, start=1):
    sample_id = adata.obs["sample_id"].unique()[0] if "sample_id" in adata.obs.columns else f"Sample_{i}"
    print(f"Sample {i}: {sample_id}")
    print(f"  n_obs: {adata.n_obs}")
    print(f"  n_vars: {adata.n_vars}")
    print("-" * 40)

In [None]:
# Function for doublet scoring and visualization

def run_scrublet(adata_list, scrublet_dir):
    os.makedirs(scrublet_dir, exist_ok=True)

    for i, adata in enumerate(adata_list):
        sample_id = adata.obs["sample_id"].unique()[0]
        print(f"Running Scrublet for {sample_id}...")

        # Run Scrublet with expected doublet rate (5% recommended for Xenium)
        scrub = scr.Scrublet(adata.X, expected_doublet_rate=0.05)
        doublet_scores, predicted_doublets = scrub.scrub_doublets(
            min_counts=2, min_cells=3, min_gene_variability_pctl=85, n_prin_comps=30
        )

        # **Auto-set threshold, but provide a backup in case it fails**
        if scrub.threshold_ is not None:
            doublet_threshold = scrub.threshold_
            print(
                f"Scrublet auto-set threshold for {sample_id}: {doublet_threshold:.3f}"
            )
        else:
            # **Fallback threshold: Use mean + 2*std from simulated doublet scores**
            doublet_threshold = (
                scrub.doublet_scores_sim_.mean() + 2 * scrub.doublet_scores_sim_.std()
            )
            print(
                f"Warning: Scrublet auto-threshold failed for {sample_id}. Using fallback threshold: {doublet_threshold:.3f}"
            )

        # Apply the determined threshold
        predicted_doublets = scrub.call_doublets(threshold=doublet_threshold)

        # Store results in AnnData object
        adata.obs["doublet_scores"] = doublet_scores
        adata.obs["predicted_doublets"] = predicted_doublets

        # Plot Scrublet score histograms
        plt.figure(figsize=(10, 6))
        sns.histplot(
            scrub.doublet_scores_obs_, bins=30, color="blue", label="Observed", kde=True
        )
        sns.histplot(
            scrub.doublet_scores_sim_, bins=30, color="red", label="Simulated", kde=True
        )
        plt.axvline(
            doublet_threshold,
            color="green",
            linestyle="--",
            label=f"Threshold: {doublet_threshold:.3f}",
        )
        plt.title(f"Scrublet Doublet Score Distribution for {sample_id}")
        plt.xlabel("Doublet Score")
        plt.ylabel("Density")
        plt.legend()
        plt.show()

        # Prepare Scrublet DataFrame for this sample
        scrublet_rows = [
            {
                "Sample_ID": sample_id,
                "Cell_Barcode": barcode,
                "Observed_Score": obs_score,
                "Simulated_Score": sim_score,
                "Doublet_Score": doublet_scores,
                "Predicted_Doublet": pred_doublet,
            }
            for barcode, obs_score, sim_score, pred_doublet in zip(
                adata.obs.index,
                scrub.doublet_scores_obs_,
                scrub.doublet_scores_sim_,
                predicted_doublets,
            )
        ]

        # Convert to DataFrame and save per sample
        scrublet_df = pd.DataFrame(scrublet_rows)
        scrublet_csv_path = os.path.join(
            scrublet_dir, f"{sample_id}_scrublet_results-test.csv"
        )
        scrublet_df.to_csv(scrublet_csv_path, index=False)

        print(f"Saved Scrublet results for {sample_id} to {scrublet_csv_path}")

In [None]:
scrublet_result = run_scrublet(adata_list, scrublet_dir)

## Export

In [None]:
for i, adata in enumerate(adata_list):
    if "sample_id" in adata.obs.columns:
        sample_name = adata.obs["sample_id"].iloc[0]
    else:
        sample_name = f"sample_{i+1}"

    output_file = os.path.join(output_dir, f"{sample_name}.h5ad")
    adata.write(output_file)
    print(f"Saved: {output_file}")