In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.gridspec import GridSpecFromSubplotSpec
from mpl_toolkits.axes_grid1 import AxesGrid
from matplotlib.patches import Rectangle
import warnings

warnings.filterwarnings("ignore")

In [None]:
adata = sc.read_h5ad("../data/adata/timecourse.h5ad")

In [None]:
import matplotlib.colors as clr

zissou = [
    "#3A9AB2",
    "#6FB2C1",
    "#91BAB6",
    "#A5C2A3",
    "#BDC881",
    "#DCCB4E",
    "#E3B710",
    "#E79805",
    "#EC7A05",
    "#EF5703",
    "#F11B00",
]

colormap = clr.LinearSegmentedColormap.from_list("Zissou", zissou)
colormap_r = clr.LinearSegmentedColormap.from_list("Zissou", zissou[::-1])

In [None]:
batches = {
    "day6_SI": {"x": 6200, "y": 6200},
    # "day6_SI_r2": {"x": 5800, "y": 5500},
    "day8_SI_Ctrl": {"x": 2400, "y": 2400},
    # "day8_SI_r2": {"x": 3200, "y": 1500},
    "day30_SI": {"x": 6400, "y": 2400},
    # "day30_SI_r2": {"x": 6200, "y": 6200},
    "day90_SI": {"x": 2400, "y": 2400},
    # "day90_SI_r2": {"x": 1200, "y": 6200},
}

In [None]:
def plot_gradient(
    adata,
    gs,
    batches,
    gradient,
    labels,
    basis=["mde", "spatial"],
    cmap=None,
    dpi=600,
    width=500,
):
    sc.set_figure_params(vector_friendly=True, dpi_save=dpi)

    gs_inner = GridSpecFromSubplotSpec(
        len(batches) + 1,
        len(basis),
        subplot_spec=gs,
        height_ratios=[8 for _ in batches] + [1],
        wspace=0.02,
        hspace=0.02,
    )

    ax = plt.Subplot(fig, gs)
    ax.set_title(gradient)
    ax.axis("off")
    fig.add_subplot(ax)

    for row, batch in enumerate(batches):
        for col, b in enumerate(basis):
            ax = fig.add_subplot(gs_inner[row, col])
            sc.pl.embedding(
                adata[adata.obs["batch"] == batch],
                basis="spatial" if b == "zoom" else b,
                ax=ax,
                show=False,
                title=batch,
                color=gradient,
                colorbar_loc=None,
                frameon=True,
                cmap=cmap,
                # sort_order=False,
                size=40 if b == "zoom" else None,
            )
            ax.axis("equal")
            ax.set_xlabel("")
            ax.set_ylabel("")
            ax.set_title("")
            if b == "zoom":
                ax.set_xlim(batches[batch]["x"], batches[batch]["x"] + width)
                ax.set_ylim(batches[batch]["y"], batches[batch]["y"] + width)
            elif b == "mde":
                ax.set_xlim(-3, 3)
                ax.set_ylim(-3, 3)
            if ("zoom" in basis) & (b == "spatial"):
                fov = Rectangle(
                    (batches[batch]["x"], batches[batch]["y"]),
                    width,
                    width,
                    edgecolor="black",
                    facecolor="none",
                )
                ax.add_patch(fov)

    last_ax = fig.axes[-1]
    mappable = last_ax.collections[0]
    cb_ax = fig.add_subplot(gs_inner[len(batches), 0 : len(basis)])
    cbar = fig.colorbar(mappable, cax=cb_ax, orientation="horizontal")
    ticklabels = cbar.ax.get_xbound()
    cbar.ax.set_xticks(ticklabels)
    cbar.ax.set_xticklabels(labels, ha="left")
    ticklabels = cbar.ax.get_xticklabels()
    ticklabels[-1].set_ha("right")

In [None]:
fig = plt.figure(figsize=(24, 12))
gs = GridSpec(1, 3, figure=fig, width_ratios=[3, 2, 3], wspace=0.05)

plot_gradient(
    adata,
    gs[0],
    batches,
    "crypt_villi_axis",
    cmap=colormap,
    basis=["mde", "spatial", "zoom"],
    labels=["top", "bottom"],
)

plot_gradient(
    adata,
    gs[1],
    batches,
    "predicted_longitudinal",
    cmap=colormap,
    labels=["proximal\n", "distal"],
)

plot_gradient(
    adata,
    gs[2],
    batches,
    "epithelial_distance_clipped",
    cmap=colormap,
    basis=["mde", "spatial", "zoom"],
    labels=["epithel", "lamina\npropria"],
)

# fig.savefig("axis.pdf")
plt.show()