#### This particular notebook calculates and plots the MECR values for the following datasets:
* Dataset 1: Xenium (rep1)
* Dataset 1: CosMx
* External reference scRNAseq data (GSE250487)

#### Required input files:
* Xenium (rep1) filtered transcripts file
* CosMx filtered transcripts file
* External reference scRNAseq cell-based data object

Note: r denotes filtered transcript file (filtering removal)

Environment: Please create and activate the conda environment provided in default_env.yaml before running this notebook

In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
import squidpy as sq

import gzip
import anndata

import os

import scipy
import scipy.sparse

## CosMx and Xenium MECR Plots

In [None]:
## Load in filtered transcript files

Xenium_transcripts_r = pd.read_csv('/path/Xenium_transcripts_r.csv')
CosMx_transcripts_r = pd.read_csv('/path/CosMx_transcripts_r.csv')

Functions

In [None]:
## Compute mean and max transcript counts per gene

def compute_gene_counts(
    transcripts_df,
    genes,
    platform_name=None,
    gene_col='feature_name',
    cell_col='cell_id',
    transcript_col='transcript_id'
):
    """
    For each gene in `genes`, compute the number of unique transcripts per cell.
    Prints mean and max per cell and returns:
      - gene_counts: dict[gene] -> per-cell count dataframe
      - summary_df: dataframe with mean/max per gene (and platform, if provided)
    """
    gene_counts = {}
    summary_rows = []

    for gene in genes:
        # Subset to rows for this gene
        df_gene = transcripts_df[transcripts_df[gene_col] == gene]

        # Group by cell and count unique transcripts
        counts = (
            df_gene.groupby(cell_col)[transcript_col]
            .nunique()
            .reset_index()
        )

        count_col = f"{gene}_transcript_count"
        counts.columns = [cell_col, count_col]

        # Store per-gene df
        gene_counts[gene] = counts

        # Compute mean and max
        mean_val = counts[count_col].mean()
        max_val = counts[count_col].max()

        # Pretty label for printing
        label = f"{platform_name} " if platform_name is not None else ""

        print(f"Mean {label}{gene} transcript counts per cell: {mean_val}")
        print(f"Max  {label}{gene} transcript counts per cell: {max_val}\n")

        row = {
            "gene": gene,
            "mean_transcripts_per_cell": mean_val,
            "max_transcripts_per_cell": max_val,
        }
        if platform_name is not None:
            row["platform"] = platform_name

        summary_rows.append(row)

    summary_df = pd.DataFrame(summary_rows)
    return gene_counts, summary_df

In [None]:
## Compute MECR value for gene pairs

def compute_mecr_from_counts(
    gene_counts,
    gene_pairs,
    platform_name=None,
    cell_col='cell_id'
):
    """
    Compute MECR for each pair in gene_pairs using per-gene count dataframes.

    gene_counts: dict[gene] -> dataframe with columns [cell_id, <gene>_transcript_count]
    gene_pairs: list of (gene1, gene2) tuples
    """
    mecr_rows = []

    for g1, g2 in gene_pairs:
        df1 = gene_counts[g1]
        df2 = gene_counts[g2]

        col1 = f"{g1}_transcript_count"
        col2 = f"{g2}_transcript_count"

        # Outer merge in case genes are expressed in different cell subsets
        merged = df1.merge(df2, on=cell_col, how='outer').fillna(0)

        both_genes = ((merged[col1] >= 1.0) & (merged[col2] >= 1.0)).sum()
        at_least_one_gene = ((merged[col1] >= 1.0) | (merged[col2] >= 1.0)).sum()

        mecr = both_genes / at_least_one_gene if at_least_one_gene > 0 else float("nan")

        label = f"{platform_name} " if platform_name is not None else ""
        print(f"MECR value for {label}{g1}_{g2}: {mecr}")

        row = {
            "gene1": g1,
            "gene2": g2,
            "both_genes": both_genes,
            "at_least_one_gene": at_least_one_gene,
            "MECR": mecr,
        }
        if platform_name is not None:
            row["platform"] = platform_name

        mecr_rows.append(row)

    mecr_df = pd.DataFrame(mecr_rows)
    return mecr_df

In [None]:
## Save gene and gene pairs

genes = [
    "COL3A1", "PIGR", "CD8A", "OSM",
    "LYVE1", "MZB1", "CPA3", "MS4A1",
    "CD19", "IL1B"
]

gene_pairs = [
    ("COL3A1", "PIGR"),
    ("CD8A", "OSM"),
    ("LYVE1", "MZB1"),
    ("CPA3", "MS4A1"),
    ("CD19", "IL1B")
]

Xenium

In [None]:
X_gene_counts, X_gene_summary = compute_gene_counts(
    Xenium_transcripts_r,
    genes=genes,
    platform_name="Xenium"
)

X_mecr_results = compute_mecr_from_counts(
    X_gene_counts,
    gene_pairs=gene_pairs,
    platform_name="Xenium"
)

CosMx

In [None]:
C_gene_counts, C_gene_summary = compute_gene_counts(
    CosMx_transcripts_r,
    genes=genes,
    gene_col="target", # Important
    platform_name="CosMx"
)

C_mecr_results = compute_mecr_from_counts(
    C_gene_counts,
    gene_pairs=gene_pairs,
    platform_name="CosMx"
)

Plot

In [None]:
def plot_gene_pair(
    gene_counts,
    gene1,
    gene2,
    platform_name="Platform",
    mecr_results=None,
    cell_col="cell_id",
    figsize=(6, 6),
    xlim=None,
    ylim=None,
    xticks=None,
    yticks=None,
    save_path=None
):
    """
    Plot scatter of transcript counts for (gene1, gene2) using gene_counts dict.

    Parameters
    ----------
    gene_counts : dict
        dict[gene] -> dataframe with columns [cell_id, <gene>_transcript_count]
        (e.g. X_gene_counts or C_gene_counts)

    gene1, gene2 : str
        Gene names to plot.

    platform_name : str
        Title label, e.g. "Xenium" or "CosMx".

    mecr_results : pd.DataFrame or None
        If provided, will look up MECR for (gene1, gene2).
        If None, MECR will be computed directly from the merged data.

    cell_col : str
        Name of the cell id column (default "cell_id").
    """

    # --- Merge per-gene count dataframes ---
    df1 = gene_counts[gene1]
    df2 = gene_counts[gene2]

    col1 = f"{gene1}_transcript_count"
    col2 = f"{gene2}_transcript_count"

    merged_df = df1.merge(df2, on=cell_col, how="outer").fillna(0)

    # --- Get or compute MECR ---
    if mecr_results is not None:
        match = mecr_results.loc[
            (mecr_results["gene1"] == gene1) &
            (mecr_results["gene2"] == gene2),
            "MECR"
        ]
        if not match.empty:
            mecr_value = match.iloc[0]
        else:
            # fallback: compute if not found
            both = ((merged_df[col1] >= 1) & (merged_df[col2] >= 1)).sum()
            either = ((merged_df[col1] >= 1) | (merged_df[col2] >= 1)).sum()
            mecr_value = both / either if either > 0 else np.nan
    else:
        both = ((merged_df[col1] >= 1) & (merged_df[col2] >= 1)).sum()
        either = ((merged_df[col1] >= 1) | (merged_df[col2] >= 1)).sum()
        mecr_value = both / either if either > 0 else np.nan

    # --- Plot ---
    plt.figure(figsize=figsize)

    plt.scatter(merged_df[col1], merged_df[col2], s=20)

    plt.xlabel(f"{gene1} transcript counts per cell", fontsize=12)
    plt.ylabel(f"{gene2} transcript counts per cell", fontsize=12)
    plt.title(platform_name, fontsize=14)

    # Axis limits
    if xlim is not None:
        plt.xlim(xlim)
    if ylim is not None:
        plt.ylim(ylim)

    # Custom ticks
    if xticks is not None:
        plt.xticks(xticks)
    if yticks is not None:
        plt.yticks(yticks)

    plt.tick_params(axis="both", labelsize=11)

    # MECR annotation
    textstr = f"MECR = {mecr_value:.2f}"
    plt.text(
        0.95, 0.95, textstr,
        transform=plt.gca().transAxes,
        fontsize=12,
        verticalalignment="top",
        horizontalalignment="right",
        bbox=dict(facecolor="white", alpha=0.6)
    )

    plt.tight_layout()

    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")

    plt.show()

Xenium

In [None]:
plot_gene_pair(
    gene_counts=X_gene_counts,
    mecr_results=X_mecr_results,
    gene1="COL3A1",
    gene2="PIGR",
    platform_name="Xenium",
    xlim=(-5, 250),
    ylim=(-5, 250),
    xticks=[0, 50, 100, 150, 200, 250],
    yticks=[0, 50, 100, 150, 200, 250],
    # save_path="/path/MECR_Xenium_COL3A1_PIGR.pdf"
)

In [None]:
plot_gene_pair(
    gene_counts=X_gene_counts,
    mecr_results=X_mecr_results,
    gene1="CD8A",
    gene2="OSM",
    platform_name="Xenium",
    xlim=(-1, 20),
    ylim=(-1, 20),
    xticks=[0, 5, 10, 15, 20],
    yticks=[0, 5, 10, 15, 20],
    # save_path="/path/MECR_Xenium_CD8A_OSM.pdf"
)

In [None]:
plot_gene_pair(
    gene_counts=X_gene_counts,
    mecr_results=X_mecr_results,
    gene1="LYVE1",
    gene2="MZB1",
    platform_name="Xenium",
    xlim=(-1, 60),
    ylim=(-1, 60),
    xticks=[0, 15, 30, 45, 60],
    yticks=[0, 15, 30, 45, 60],
    # save_path="/path/MECR_Xenium_LYVE1_MZB1.pdf"
)

In [None]:
plot_gene_pair(
    gene_counts=X_gene_counts,
    mecr_results=X_mecr_results,
    gene1="CPA3",
    gene2="MS4A1",
    platform_name="Xenium",
    xlim=(-1, 40),
    ylim=(-1, 40),
    xticks=[0, 10, 20, 30, 40],
    yticks=[0, 10, 20, 30, 40],
    # save_path="/path/MECR_Xenium_CPA3_MS4A1.pdf"
)

In [None]:
plot_gene_pair(
    gene_counts=X_gene_counts,
    mecr_results=X_mecr_results,
    gene1="CD19",
    gene2="IL1B",
    platform_name="Xenium",
    xlim=(-1, 40),
    ylim=(-1, 40),
    xticks=[0, 10, 20, 30, 40],
    yticks=[0, 10, 20, 30, 40],
    # save_path="/path/MECR_Xenium_CD19_IL1B.pdf"
)

CosMx

In [None]:
plot_gene_pair(
    gene_counts=C_gene_counts,
    mecr_results=C_mecr_results,
    gene1="COL3A1",
    gene2="PIGR",
    platform_name="CosMx",
    xlim=(-5, 250),
    ylim=(-5, 250),
    xticks=[0, 50, 100, 150, 200, 250],
    yticks=[0, 50, 100, 150, 200, 250],
    # save_path="/path/MECR_CosMx_COL3A1_PIGR.pdf"
)

In [None]:
plot_gene_pair(
    gene_counts=C_gene_counts,
    mecr_results=C_mecr_results,
    gene1="CD8A",
    gene2="OSM",
    platform_name="CosMx",
    xlim=(-1, 20),
    ylim=(-1, 20),
    xticks=[0, 5, 10, 15, 20],
    yticks=[0, 5, 10, 15, 20],
    # save_path="/path/MECR_CosMx_CD8A_OSM.pdf"
)

In [None]:
plot_gene_pair(
    gene_counts=C_gene_counts,
    mecr_results=C_mecr_results,
    gene1="LYVE1",
    gene2="MZB1",
    platform_name="CosMx",
    xlim=(-1, 60),
    ylim=(-1, 60),
    xticks=[0, 15, 30, 45, 60],
    yticks=[0, 15, 30, 45, 60],
    # save_path="/path/MECR_CosMx_LYVE1_MZB1.pdf"
)

In [None]:
plot_gene_pair(
    gene_counts=C_gene_counts,
    mecr_results=C_mecr_results,
    gene1="CPA3",
    gene2="MS4A1",
    platform_name="CosMx",
    xlim=(-1, 40),
    ylim=(-1, 40),
    xticks=[0, 10, 20, 30, 40],
    yticks=[0, 10, 20, 30, 40],
    # save_path="/path/MECR_CosMx_CPA3_MS4A1.pdf"
)

In [None]:
plot_gene_pair(
    gene_counts=C_gene_counts,
    mecr_results=C_mecr_results,
    gene1="CD19",
    gene2="IL1B",
    platform_name="CosMx",
    xlim=(-1, 40),
    ylim=(-1, 40),
    xticks=[0, 10, 20, 30, 40],
    yticks=[0, 10, 20, 30, 40],
    # save_path="/path/MECR_CosMx_CD19_IL1B.pdf"
)

## scRNAseq MECR Plots

In [None]:
## Load in data -- Get this file from GEO
RNAseq = sc.read_h5ad('/path/XAUT1_biopsy_scRNAseq.h5ad')

# View
RNAseq

In [None]:
## Extract expression for each gene

genes = ["COL3A1", "PIGR", "CD8A", "OSM", "LYVE1", "MZB1", "CPA3", "MS4A1", "CD19", "IL1B"]

RNA_expr = {}

for gene in genes:
    arr = RNAseq[:, gene].X
    if scipy.sparse.issparse(arr):
        arr = arr.toarray()
    arr = arr.flatten()

    RNA_expr[gene] = pd.DataFrame({
        "cell_id": RNAseq.obs_names,
        f"{gene}_expression": arr
    })

In [None]:
## Compute MECR for each gene pair

gene_pairs = [
    ("COL3A1", "PIGR"),
    ("CD8A", "OSM"),
    ("LYVE1", "MZB1"),
    ("CPA3", "MS4A1"),
    ("CD19", "IL1B")
]

RNA_mecr_list = []

for g1, g2 in gene_pairs:
    col1 = f"{g1}_expression"
    col2 = f"{g2}_expression"

    merged = RNA_expr[g1].merge(RNA_expr[g2], on="cell_id", how="outer").fillna(0)

    both = ((merged[col1] > 0) & (merged[col2] > 0)).sum()
    either = ((merged[col1] > 0) | (merged[col2] > 0)).sum()

    mecr = both / either if either > 0 else np.nan

    print(f"MECR for RNAseq {g1}_{g2}: {mecr}")

    RNA_mecr_list.append({
        "gene1": g1,
        "gene2": g2,
        "MECR": mecr
    })

RNA_mecr = pd.DataFrame(RNA_mecr_list)

In [None]:
def plot_rnaseq_gene_pair(
    RNA_expr,
    RNA_mecr,
    gene1,
    gene2,
    platform_name="RNAseq",
    xlim=None,
    ylim=None,
    xticks=None,
    yticks=None,
    figsize=(6,6),
    save_path=None
):
    """
    Minimal plotting function for RNAseq MECR scatterplots.
    Merges expression data, retrieves MECR, and plots with full axis control.
    """

    col1 = f"{gene1}_expression"
    col2 = f"{gene2}_expression"

    # Merge expression DataFrames for g1 and g2
    merged_df = RNA_expr[gene1].merge(
        RNA_expr[gene2],
        on="cell_id",
        how="outer"
    ).fillna(0)

    # Look up MECR value from RNA_mecr DataFrame
    mecr_val = RNA_mecr.loc[
        (RNA_mecr["gene1"] == gene1) & (RNA_mecr["gene2"] == gene2),
        "MECR"
    ].iloc[0]

    # Plotting
    plt.figure(figsize=figsize)

    plt.scatter(merged_df[col1], merged_df[col2], s=20)

    plt.xlabel(f"{gene1} expression per cell", fontsize=12)
    plt.ylabel(f"{gene2} expression per cell", fontsize=12)
    plt.title(platform_name, fontsize=14)

    # Axis limits & ticks
    if xlim is not None:
        plt.xlim(xlim)
    if ylim is not None:
        plt.ylim(ylim)
    if xticks is not None:
        plt.xticks(xticks)
    if yticks is not None:
        plt.yticks(yticks)

    plt.tick_params(axis="both", labelsize=11)

    # MECR annotation box
    plt.text(
        0.95, 0.95,
        f"MECR = {mecr_val:.2f}",
        transform=plt.gca().transAxes,
        fontsize=12,
        ha="right", va="top",
        bbox=dict(facecolor="white", alpha=0.6)
    )

    plt.tight_layout()

    # Optional saving
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")

    plt.show()

In [None]:
plot_rnaseq_gene_pair(
    RNA_expr=RNA_expr,
    RNA_mecr=RNA_mecr,
    gene1="COL3A1",
    gene2="PIGR",
    xlim=(-1, 10),
    ylim=(-1, 10),
    xticks=[0, 2.5, 5, 7.5, 10],
    yticks=[0, 2.5, 5, 7.5, 10],
    # save_path="/path/MECR_RNAseq_COL3A1_PIGR.pdf"
)

In [None]:
plot_rnaseq_gene_pair(
    RNA_expr=RNA_expr,
    RNA_mecr=RNA_mecr,
    gene1="CD8A",
    gene2="OSM",
    xlim=(-1, 10),
    ylim=(-1, 10),
    xticks=[0, 2.5, 5, 7.5, 10],
    yticks=[0, 2.5, 5, 7.5, 10],
    # save_path="/path/MECR_RNAseq_CD8A_OSM.pdf"
)

In [None]:
plot_rnaseq_gene_pair(
    RNA_expr=RNA_expr,
    RNA_mecr=RNA_mecr,
    gene1="LYVE1",
    gene2="MZB1",
    xlim=(-1, 10),
    ylim=(-1, 10),
    xticks=[0, 2.5, 5, 7.5, 10],
    yticks=[0, 2.5, 5, 7.5, 10],
    # save_path="/path/MECR_RNAseq_LYVE1_MZB1.pdf"
)

In [None]:
plot_rnaseq_gene_pair(
    RNA_expr=RNA_expr,
    RNA_mecr=RNA_mecr,
    gene1="CPA3",
    gene2="MS4A1",
    xlim=(-1, 10),
    ylim=(-1, 10),
    xticks=[0, 2.5, 5, 7.5, 10],
    yticks=[0, 2.5, 5, 7.5, 10],
    # save_path="/path/MECR_RNAseq_CPA3_MS4A1.pdf"
)

In [None]:
plot_rnaseq_gene_pair(
    RNA_expr=RNA_expr,
    RNA_mecr=RNA_mecr,
    gene1="CD19",
    gene2="IL1B",
    xlim=(-1, 10),
    ylim=(-1, 10),
    xticks=[0, 2.5, 5, 7.5, 10],
    yticks=[0, 2.5, 5, 7.5, 10],
    # save_path="/path/MECR_RNAseq_CD19_IL1B.pdf"
)