# Set-up

In [None]:
# imports
import os
import sys
import numpy as np
import pandas as pd
import scanpy as sc
import mudata

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.container import BarContainer

from matplotlib_venn import venn3

In [None]:
path_mdata = "/cellar/users/aklie/data/datasets/tf_perturb_seq/datasets/Hon_WTC11-benchmark_TF-Perturb-seq/results/1_CRISPR_pipeline/2026_01_05/inference_mudata.h5mu"

path_out = "/cellar/users/aklie/data/datasets/tf_perturb_seq/datasets/Hon_WTC11-benchmark_TF-Perturb-seq/results/1_CRISPR_pipeline/2026_01_05"

os.makedirs(path_out, exist_ok=True)

# Load MuData

In [None]:
mdata = mudata.read_h5mu(path_mdata)
mdata

In [None]:
# Plot the number of cells per batch and color each bar individually
with sns.plotting_context("talk"):
    fig, ax = plt.subplots(figsize=(10, 5))
    sns.countplot(x="batch", data=mdata.obs, ax=ax, palette="tab10")
    ax.set_title(f"Number of cells per lane (N={mdata.n_obs})")

    # rotate x labels
    plt.xlabel("Lane")
    plt.xticks(rotation=45)
    plt.show()
    
    plt.tight_layout()

# Transcriptome QC

In [None]:
gene = mdata["gene"]
gene

In [None]:
# Like these better: plot the distribution of UMI counts and annotate with median (red dashed line and text in upper right corner)
with sns.plotting_context("talk", font_scale=1.2):
    fig, ax = plt.subplots(3, 1, figsize=(12, 10))

    # Plotting total counts
    sns.histplot(gene.obs["total_gene_umis"], bins=100, ax=ax[0])
    ax[0].axvline(gene.obs["total_gene_umis"].median(), color="red", linestyle="--")
    ax[0].text(0.95, 0.95, f"Median: {gene.obs['total_gene_umis'].median():.0f} UMIs", ha="right", va="top", transform=ax[0].transAxes)
    ax[0].set_title("Gene UMI counts")

    # Plotting percentage of mitochondrial genes
    sns.histplot(gene.obs["num_expressed_genes"], bins=100, ax=ax[1])
    ax[1].axvline(gene.obs["num_expressed_genes"].median(), color="red", linestyle="--")
    ax[1].text(0.95, 0.95, f"Median: {gene.obs['num_expressed_genes'].median():.0f} Genes", ha="right", va="top", transform=ax[1].transAxes)
    ax[1].set_title("Gene detected (> 0 UMI counts)")

    # Plotting percentage of mitochondrial genes
    sns.histplot(gene.obs["percent_mito"], bins=100, ax=ax[2])
    ax[2].axvline(gene.obs["percent_mito"].median(), color="red", linestyle="--")
    ax[2].text(0.95, 0.95, f"Median: {gene.obs['percent_mito'].median():.2f} %", ha="right", va="top", transform=ax[2].transAxes)
    ax[2].set_title("Mitochondrial %")

    plt.tight_layout()
    plt.show()

In [None]:
# Split the above by batch with different colors, add dotted lines for medians, one legend to right of plots
batch_colors = sns.color_palette("tab10", n_colors=gene.obs["batch"].nunique())
batch_color_dict = dict(zip(gene.obs["batch"].unique(), batch_colors))

with sns.plotting_context("talk", font_scale=1):
    fig, ax = plt.subplots(3, 1, figsize=(10, 10))

    # Plotting total counts
    sns.histplot(data=gene.obs, x="total_gene_umis", bins=100, hue="batch", palette=batch_color_dict, ax=ax[0], element="step", stat="density", common_norm=False)
    for batch in gene.obs["batch"].unique():
        median_val = gene.obs.loc[gene.obs["batch"] == batch, "total_gene_umis"].median()
        ax[0].axvline(median_val, color=batch_color_dict[batch], linestyle="--")
        ax[0].text(0.6, 0.9 - 0.1 * list(gene.obs["batch"].unique()).index(batch), f"Median: {median_val:.0f} UMIs", ha="right", va="top", transform=ax[0].transAxes, color=batch_color_dict[batch], fontsize=14)
    ax[0].set_title("Gene UMI counts")

    # Plotting number of expressed genes
    sns.histplot(data=gene.obs, x="num_expressed_genes", bins=100, hue="batch", palette=batch_color_dict, ax=ax[1], element="step", stat="density", common_norm=False)
    for batch in gene.obs["batch"].unique():
        median_val = gene.obs.loc[gene.obs["batch"] == batch, "num_expressed_genes"].median()
        ax[1].axvline(median_val, color=batch_color_dict[batch], linestyle="--")
        ax[1].text(0.8, 0.9 - 0.1 * list(gene.obs["batch"].unique()).index(batch), f"Median: {median_val:.0f} Genes", ha="right", va="top", transform=ax[1].transAxes, color=batch_color_dict[batch], fontsize=14)
    ax[1].set_title("Gene detected (> 0 UMI counts)")
    ax[1].legend_.remove()

    # Plotting percentage of mitochondrial genes
    sns.histplot(data=gene.obs, x="percent_mito", bins=100, hue="batch", palette=batch_color_dict, ax=ax[2], element="step", stat="density", common_norm=False)
    for batch in gene.obs["batch"].unique():
        median_val = gene.obs.loc[gene.obs["batch"] == batch, "percent_mito"].median()
        ax[2].axvline(median_val, color=batch_color_dict[batch], linestyle="--")
        ax[2].text(0.6, 0.9 - 0.1 * list(gene.obs["batch"].unique()).index(batch), f"Median: {median_val:.2f} %", ha="right", va="top", transform=ax[2].transAxes, color=batch_color_dict[batch], fontsize=14)
    ax[2].set_title("Mitochondrial %")
    ax[2].legend_.remove()

    plt.tight_layout()
    plt.show()

# Guide QC

In [None]:
guide = mdata["guide"]
guide

In [None]:
guide.obs

In [None]:
guide.var["label"].value_counts()

In [None]:
# Calculate n_guides_per_cell and n_cells_per_guide from layers["guide_assignment"]
guide.obs["n_guides_per_cell"] = np.sum(guide.layers["guide_assignment"] > 0, axis=1).A1
guide.var["n_cells_per_guide"] = np.sum(guide.layers["guide_assignment"] > 0, axis=0).A1

In [None]:
# Plot the distribution of UMI counts and annotate with median (red dashed line and text in upper right corner)
guide_label_colors = sns.color_palette("tab10", n_colors=guide.var["label"].nunique())
guide_label_color_dict = dict(zip(guide.var["label"].unique(), guide_label_colors))

with sns.plotting_context("talk", font_scale=1):
    

    fig, ax = plt.subplots(3, 1, figsize=(10, 10))

    # Guide umis per cell
    sns.histplot(data=guide.obs, x="total_guide_umis", bins=100, hue="batch", palette=batch_color_dict, ax=ax[0], element="step", stat="density", common_norm=False)
    for batch in guide.obs["batch"].unique():
        median_val = guide.obs.loc[guide.obs["batch"] == batch, "total_guide_umis"].median()
        ax[0].axvline(median_val, color=batch_color_dict[batch], linestyle="--")
        ax[0].text(0.5, 0.9 - 0.1 * list(guide.obs["batch"].unique()).index(batch), f"Median: {median_val:.0f} UMIs", ha="right", va="top", transform=ax[0].transAxes, color=batch_color_dict[batch], fontsize=14)
    ax[0].set_title("Total guide UMI counts")
        
    # n_guides_per_cell
    sns.histplot(data=guide.obs, x="n_guides_per_cell", bins=30, hue="batch", palette=batch_color_dict, ax=ax[1], element="step", stat="density", common_norm=False)
    for batch in guide.obs["batch"].unique():
        mean_val = guide.obs.loc[guide.obs["batch"] == batch, "n_guides_per_cell"].mean()
        ax[1].axvline(mean_val, color=batch_color_dict[batch], linestyle="--")
        ax[1].text(0.5, 0.9 - 0.1 * list(guide.obs["batch"].unique()).index(batch), f"Mean: {mean_val:.2f} guides", ha="right", va="top", transform=ax[1].transAxes, color=batch_color_dict[batch], fontsize=14)
    ax[1].set_title("Number of guides ASSIGNED per cell")
    ax[1].legend_.remove()

    # n_cells_per_guide
    sns.histplot(data=guide.var, x="n_cells_per_guide", bins=30, hue="label", palette=guide_label_color_dict, ax=ax[2], element="step", stat="density", common_norm=False)
    median_val = guide.var["n_cells_per_guide"].median()
    for label in guide.var["label"].unique():
        median_val = guide.var.loc[guide.var["label"] == label, "n_cells_per_guide"].median()
        ax[2].axvline(median_val, color=guide_label_color_dict[label], linestyle="--")
        ax[2].text(0.6, 0.9 - 0.1 * list(guide.var["label"].unique()).index(label), f"Median: {median_val:.0f} Cells", ha="right", va="top", transform=ax[2].transAxes, color=guide_label_color_dict[label], fontsize=14)
    ax[2].set_title("Number of cells per guide assignment")

    plt.tight_layout()
    plt.show()

# More metrics

In [None]:
batch_order = ["IGVFDS9332KWPJ", "IGVFDS8721BKRO", "IGVFDS9613DDRB", "IGVFDS6244NAXC"]
batch_colors = {
    'IGVFDS6244NAXC': '#1f77b4',
    'IGVFDS8721BKRO': '#ff7f0e',
    'IGVFDS9332KWPJ': '#2ca02c',
    'IGVFDS9613DDRB': '#d62728',
}

In [None]:
# Get the median UMI/cell in gene
for batch in batch_order:
    median_umi = gene.obs.loc[gene.obs["batch"] == batch, "total_gene_umis"].median()
    print(f"Batch {batch}: Median UMI/cell in gene: {median_umi:.0f}")
for batch in batch_order:
    median_umi = gene.obs.loc[gene.obs["batch"] == batch, "total_gene_umis"].median()
    print(f"{median_umi:.0f}")

In [None]:
# get the median UMI/cell in guide
for batch in batch_order:
    median_umi = guide.obs.loc[guide.obs["batch"] == batch, "total_guide_umis"].median()
    print(f"Batch {batch}: Median UMI/cell in guide: {median_umi:.0f}")
for batch in batch_order:
    median_umi = guide.obs.loc[guide.obs["batch"] == batch, "total_guide_umis"].median()
    print(f"{median_umi:.0f}")

In [None]:
# get the number of cells per lane
for batch in batch_order:
    n_cells = mdata.obs.loc[mdata.obs["batch"] == batch].shape[0]
    print(f"Batch {batch}: Number of cells: {n_cells}")
for batch in batch_order:
    n_cells = mdata.obs.loc[mdata.obs["batch"] == batch].shape[0]
    print(f"{n_cells}")

Batch IGVFDS9332KWPJ: Number of cells: 20528
Batch IGVFDS8721BKRO: Number of cells: 23190
Batch IGVFDS9613DDRB: Number of cells: 25384
Batch IGVFDS6244NAXC: Number of cells: 23451
20528
23190
25384
23451


In [21]:
# Get a the number of cells that have exactly 1 guide in each batch
for batch in batch_order:
    n_cells_1_guide = np.sum((guide.obs["n_guides_per_cell"] == 1) & (guide.obs["batch"] == batch))
    print(f"Batch {batch}: {n_cells_1_guide} cells with exactly 1 guide assigned")
for batch in batch_order:
    n_cells_1_guide = np.sum((guide.obs["n_guides_per_cell"] == 1) & (guide.obs["batch"] == batch))
    print(f"{n_cells_1_guide}")

Batch IGVFDS9332KWPJ: 8652 cells with exactly 1 guide assigned
Batch IGVFDS8721BKRO: 11543 cells with exactly 1 guide assigned
Batch IGVFDS9613DDRB: 11675 cells with exactly 1 guide assigned
Batch IGVFDS6244NAXC: 12060 cells with exactly 1 guide assigned
8652
11543
11675
12060


In [22]:
# Get the number that have >0 guides in each batch
for batch in batch_order:
    n_cells_gt0_guide = np.sum((guide.obs["n_guides_per_cell"] > 0) & (guide.obs["batch"] == batch))
    print(f"Batch {batch}: {n_cells_gt0_guide} cells with >0 guides assigned")
for batch in batch_order:
    n_cells_gt0_guide = np.sum((guide.obs["n_guides_per_cell"] > 0) & (guide.obs["batch"] == batch))
    print(f"{n_cells_gt0_guide}")

Batch IGVFDS9332KWPJ: 20363 cells with >0 guides assigned
Batch IGVFDS8721BKRO: 23089 cells with >0 guides assigned
Batch IGVFDS9613DDRB: 25347 cells with >0 guides assigned
Batch IGVFDS6244NAXC: 23361 cells with >0 guides assigned
20363
23089
25347
23361


In [23]:
# get mean guides assigned per cell in each batch
for batch in batch_order:
    mean_guides_per_cell = guide.obs.loc[guide.obs["batch"] == batch, "n_guides_per_cell"].mean()
    print(f"Batch {batch}: Mean guides assigned per cell: {mean_guides_per_cell:.2f}")
for batch in batch_order:
    mean_guides_per_cell = guide.obs.loc[guide.obs["batch"] == batch, "n_guides_per_cell"].mean()
    print(f"{mean_guides_per_cell:.2f}")

Batch IGVFDS9332KWPJ: Mean guides assigned per cell: 3.66
Batch IGVFDS8721BKRO: Mean guides assigned per cell: 2.10
Batch IGVFDS9613DDRB: Mean guides assigned per cell: 1.92
Batch IGVFDS6244NAXC: Mean guides assigned per cell: 1.76
3.66
2.10
1.92
1.76


In [24]:
# Get max guides assigned per cell in each batch
for batch in guide.obs["batch"].unique():
    max_guides_per_cell = guide.obs.loc[guide.obs["batch"] == batch, "n_guides_per_cell"].max()
    print(f"Batch {batch}: Max guides assigned per cell: {max_guides_per_cell}")
for batch in batch_order:
    max_guides_per_cell = guide.obs.loc[guide.obs["batch"] == batch, "n_guides_per_cell"].max()
    print(f"{max_guides_per_cell}")

Batch IGVFDS6244NAXC: Max guides assigned per cell: 14
Batch IGVFDS8721BKRO: Max guides assigned per cell: 39
Batch IGVFDS9332KWPJ: Max guides assigned per cell: 31
Batch IGVFDS9613DDRB: Max guides assigned per cell: 20
31
39
20
14


# DONE!

---