In [None]:
import os

# important for gpd.sjoin
os.environ["USE_PYGEOS"] = "0"

import scanpy as sc
import scvelo as scv
import geopandas as gpd

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.gridspec import GridSpecFromSubplotSpec
from mpl_toolkits.axes_grid1 import AxesGrid
from matplotlib.patches import Rectangle

# import mpl_scatter_density # adds projection='scatter_density'
import numpy as np
import warnings
import seaborn as sns
import igraph
import random

warnings.filterwarnings("ignore")

##### Making the Visium figures - gene expression trends and spatial axes

In [None]:
import matplotlib.colors as clr

zissou = [
    "#3A9AB2",
    "#6FB2C1",
    "#91BAB6",
    "#A5C2A3",
    "#BDC881",
    "#DCCB4E",
    "#E3B710",
    "#E79805",
    "#EC7A05",
    "#EF5703",
    "#F11B00",
]

colormap = clr.LinearSegmentedColormap.from_list("Zissou", zissou)
colormap_r = clr.LinearSegmentedColormap.from_list("Zissou", zissou[::-1])

In [None]:
adata = sc.concat(
    [
        sc.read_h5ad("visium_with_axis_distal.h5ad"),
        sc.read_h5ad("visium_with_axis_proximal.h5ad"),
    ]
)

In [None]:
sc.pp.filter_genes(adata, min_cells=10)

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
# list from https://www.genome.jp/pathway/mmu04060
df_cytokines = pd.read_csv("kegg_cytokines.csv")
df_cytokines = df_cytokines[df_cytokines["type"] == "ligand"]
df_cytokines

In [None]:
cytokine_genes = df_cytokines["gene"][
    df_cytokines["gene"].isin(adata.var_names)
].to_list()

In [None]:
from typing import List


def filter_adata_expressed_in_n_cells(adata, percent=0.05):
    bin_Layer = adata.X > 0
    gene_expressed_in_percent_cells = np.mean(bin_Layer, axis=0)
    keep = gene_expressed_in_percent_cells > percent
    adata = adata[:, keep]
    return adata


def scvelo_heatmap(
    adata: sc.AnnData,
    batches: List[str],
    sortby: str,
    highlight: List[str],
    n_bins: int = 5,
):
    """
    Create a heatmap to visualize gene expression trends in single-cell RNA-seq data,
    with options for subsetting, sorting, and highlighting genes.

    Parameters:
    - adata (sc.AnnData): Annotated data object containing single-cell RNA-seq data.
    - batches (List[str]): List of batch identifiers to subset the data.
    - sortby (str): Variable to sort the heatmap by (e.g., "crypt_villi_axis").
    - highlight (List[str]): List of labels to highlight on the heatmap.
    - n_bins (int, optional): Integer specifying the number of bins to use for convolution (default: 5).

    Returns:
    - s (seaborn.matrix.ClusterGrid): Matplotlib figure object representing the heatmap.

    This function subsets the input data based on specified batches, and creates a heatmap
    to visualize gene expression trends along a specified variable. The function also allows
    highlighting specific labels on the y-axis.

    Example:
    ```
    scvelo_heatmap(adata, batches=list(batches.keys())[0:2],
               sortby="crypt_villi_axis",
               highlight=highlight,
               n_bins=20)
    ```
    """
    print("Creating Heatmap for batches", " + ".join(batches))
    # Subset batches
    adata = adata[adata.obs["batch"].isin(batches)]
    # Filter to include only genes that are expressed in 5% of the cells
    # adata = filter_adata_expressed_in_n_cells(adata)
    adata = adata.copy()

    n_convolve = len(adata) // n_bins
    print(f"Setting `n_convolve` to {n_convolve} ({n_bins} bins, {len(adata)} cells) ")
    # Plot

    s = scv.pl.heatmap(
        adata,
        var_names=adata.var_names,
        sortby="crypt_villi_axis",
        n_convolve=n_convolve,
        show=False,
        yticklabels=True,
        rasterized=True,
        color_map=colormap,
        figsize=(8, 16),
    )
    ax = s.ax_heatmap

    ids = [i for i, e in enumerate(adata.var_names) if e in in_xenium]
    ax.tick_params(axis="both", labelsize=10)  # Adjust font size of tick labels

    # Loop through the x-axis tick labels and show/hide based on the 'highlight' list
    for i, label in enumerate(ax.get_yticklabels()):
        if label.get_text() not in highlight:
            label.set_visible(False)
            ax.get_yticklines()[2 * i + 1].set_visible(False)
        ax.get_yticklines()[2 * i].set_visible(False)
        if label.get_text() in in_xenium:
            label.set_color("red")  # Set color to red

    ax.set_xlabel("")
    ax.set_title(f"Visium Gene Expression Trends Along {sortby}")
    ax.grid(False)
    s.savefig("Visium_Cytokines.pdf")
    return s

##### Get the list of genes imaged in the Xenium data

In [None]:
in_xenium = sc.read("downsampled_mouse.h5ad").var.index.values

##### Plot the gene trends of all cytokines in the proximal gut

In [None]:
s = scvelo_heatmap(
    adata[:, cytokine_genes],
    batches=["proximal"],
    sortby="crypt_villi_axis",
    highlight=cytokine_genes,
    n_bins=3,
)

##### Plot the crypt-villus axis on the Visium data

In [None]:
sc.set_figure_params(dpi=300)
for batch in np.unique(adata.obs["batch"]):
    fig = sc.pl.embedding(
        adata[adata.obs["batch"] == batch],
        basis="spatial",
        color="crypt_villi_axis",
        title=f"{batch} Visium Crypt Villi Axis",
        return_fig=True,
    )
    ax = fig.gca()
    ax.axis("equal")
    fig.savefig(f"{batch}_visium_cvaxis.pdf")