In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore")

In [None]:
colors = {
    "Epithelial_Secretory": "#AA9228",
    "Epithelial_Absorptive": "#E3C300",
    "Monocyte": "#C37698",
    "T-Cell": "#008E74",
    "MAIT": "#63ABB9",
    "Myeloid": "#EF9684",
    "ILC": "#A0C6D3",
    "B-Cell": "#E2CEAB",
    "DC": "#FE757D",
    "Fibroblast": "#E17300",
    "Endothelial": "#E30133",
    "NK": "#4A7B89",
    "Epithelial_Progenitor": "#F7BC00",
    "Neuron": "#2A2446",
    "Erythroid": "#A5021D",
    "Eosinophil": "#782c4e",
    "Unknown": "#AAAAAA",
}

In [None]:
def plot_celltype_with_zoom(
    adata: sc.AnnData,
    batches,
    celltype: str,
    dpi=600,
    width=500,
    palette=None,
    basis=["mde", "spatial", "zoom"],
):
    sc.set_figure_params(vector_friendly=True, dpi_save=dpi)

    fig, axes = plt.subplots(
        nrows=len(basis), ncols=len(batches), figsize=(4 * len(batches), 3 * len(basis))
    )

    ax = axes.flat
    from itertools import product

    for i, (bas, batch) in enumerate(product(basis, batches.keys())):
        # Plot MDE, Spatial and Zoom
        if bas in ["mde", "spatial", "zoom"]:
            if bas == "zoom":
                b = "spatial"
            else:
                b = bas

            sc.pl.embedding(
                adata[adata.obs["batch"] == batch],
                basis=b,
                ax=ax[i],
                show=False,
                title=batch,
                color=celltype,
                frameon=True,
                palette=palette,
                size=40 if bas == "zoom" else None,
                legend_loc="on data" if bas == "mde" else "right margin",
                legend_fontweight="medium",
                legend_fontsize="small",
            )

            if bas == "zoom":
                ax[i].set_xlim(batches[batch]["x"], batches[batch]["x"] + width)
                ax[i].set_ylim(batches[batch]["y"], batches[batch]["y"] + width)
            elif bas == "mde":
                ax[i].set_xlim(-3, 3)
                ax[i].set_ylim(-3, 3)
            elif bas == "spatial":
                xmin, xmax = ax[i].get_xlim()
                ymin, ymax = ax[i].get_ylim()
                diff = (xmax - xmin) - (ymax - ymin)
                if diff > 0:
                    # make y axis bigger
                    ax[i].set_ylim(ymin - diff / 2, ymax + diff / 2)
                else:
                    # make x axis bigger
                    ax[i].set_xlim(xmin + diff / 2, xmax - diff / 2)
        # Plot 14
        elif bas == "P14":
            adata_batch = adata[adata.obs["batch"] == batch].copy()
            adata_batch.obs["isP14"] = [
                "P14" if cell == "Cd8_T-Cell_P14" else "notP14"
                for cell in adata_batch.obs["Subtype"]
            ]
            sc.pl.embedding(
                adata_batch,
                basis="spatial",
                ax=ax[i],
                show=False,
                title=batch,
                # Make the P14s red
                color="isP14",
                frameon=True,
                palette={"P14": "red", "notP14": "gray"},
                size=[60 if cell == "P14" else 40 for cell in adata_batch.obs["isP14"]],
                legend_loc=None,
                legend_fontweight="medium",
                legend_fontsize="small",
            )
            ax[i].set_xlim(batches[batch]["x"], batches[batch]["x"] + width)
            ax[i].set_ylim(batches[batch]["y"], batches[batch]["y"] + width)
        # Leiden
        elif bas == "leiden":
            adata_batch = adata[adata.obs["batch"] == batch].copy()
            sc.pl.embedding(
                adata_batch,
                basis="spatial",
                ax=ax[i],
                show=False,
                color="leiden",
                frameon=True,
                size=40,
                legend_loc="right margin",
                legend_fontweight="medium",
                legend_fontsize="small",
            )
            ax[i].set_xlim(batches[batch]["x"], batches[batch]["x"] + width)
            ax[i].set_ylim(batches[batch]["y"], batches[batch]["y"] + width)

        ax[i].set_aspect("equal", "box")
        ax[i].set_xlabel("")
        ax[i].set_ylabel("")
        if i >= len(batches):
            ax[i].set_title("")

    handles, labels = ax[len(batches)].get_legend_handles_labels()
    for a in ax:
        if a.get_legend():
            a.get_legend().remove()
    fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))
    fig.tight_layout()
    plt.show()

In [None]:
adata = sc.read_h5ad("../data/adata/tgfb.h5ad")

In [None]:
batches = {
    "WT": {"x": 1800, "y": 2200},
    "KO": {"x": 9850, "y": 9680},
}

In [None]:
plot_celltype_with_zoom(adata, batches, "Type", palette=colors)