In [None]:
import scanpy as sc
import scvelo as scv
import pandas as pd
import numpy as np
import matplotlib.colors as clr
import warnings
from typing import List

warnings.filterwarnings("ignore")

In [None]:
zissou = [
    "#3A9AB2",
    "#6FB2C1",
    "#91BAB6",
    "#A5C2A3",
    "#BDC881",
    "#DCCB4E",
    "#E3B710",
    "#E79805",
    "#EC7A05",
    "#EF5703",
    "#F11B00",
]

colormap = clr.LinearSegmentedColormap.from_list("Zissou", zissou)
colormap_r = clr.LinearSegmentedColormap.from_list("Zissou", zissou[::-1])

In [None]:
adata = sc.read_h5ad("../data/adata/timecourse.h5ad")

In [None]:
batches = {
    "day 6": [
        "day6_SI",
        "day6_SI_r2",
    ],
    "day 8": [
        "day8_SI_Ctrl",
        "day8_SI_r2",
    ],
    "day 30": [
        "day30_SI",
        "day30_SI_r2",
    ],
    "day 90": [
        "day90_SI",
        "day90_SI_r2",
    ],
}
all_batches = np.concatenate(list(batches.values())).tolist()

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
# list from https://www.genome.jp/pathway/mmu04060
df_cytokines = pd.read_csv("../data/kegg_cytokines.csv")
df_cytokines = df_cytokines[df_cytokines["type"] == "ligand"]
df_cytokines.head()

In [None]:
cytokine_genes = df_cytokines["gene"][
    df_cytokines["gene"].isin(adata.var_names)
].to_list()

In [None]:
def filter_adata_expressed_in_n_cells(adata, percent=0.05):
    """
    Filter the adata to only include genes expressed in more than a certain percentage of cells

    Parameters:
    - adata (anndata): The anndata object containing the gene expression
    - percent (float): The percentage of cells a gene must be expressed in to be included

    Return:
    - adata (anndata): The filtered anndata object
    """

    bin_Layer = adata.X > 0
    gene_expressed_in_percent_cells = np.mean(bin_Layer, axis=0)
    keep = gene_expressed_in_percent_cells > percent
    adata = adata[:, keep]
    return adata


def scvelo_heatmap(
    adata: sc.AnnData,
    batches: List[str],
    sortby: str,
    highlight: List[str],
    n_bins: int = 5,
):
    """
    Create a heatmap to visualize gene expression trends in single-cell RNA-seq data,
    with options for subsetting, sorting, and highlighting genes.

    Parameters:
    - adata (sc.AnnData): Annotated data object containing single-cell RNA-seq data.
    - batches (List[str]): List of batch identifiers to subset the data.
    - sortby (str): Variable to sort the heatmap by (e.g., "crypt_villi_axis").
    - highlight (List[str]): List of labels to highlight on the heatmap.
    - n_bins (int, optional): Integer specifying the number of bins to use for convolution (default: 5).

    Returns:
    - s (seaborn.matrix.ClusterGrid): Matplotlib figure object representing the heatmap.

    This function subsets the input data based on specified batches, and creates a heatmap
    to visualize gene expression trends along a specified variable. The function also allows
    highlighting specific labels on the y-axis.

    Example:
    ```
    scvelo_heatmap(adata, batches=list(batches.keys())[0:2],
               sortby="crypt_villi_axis",
               highlight=highlight,
               n_bins=20)
    ```
    """
    print("Creating Heatmap for batches", " + ".join(batches))
    # Subset batches
    adata = adata[adata.obs["batch"].isin(batches)]
    # Filter to include only genes that are expressed in 5% of the cells
    # adata = filter_adata_expressed_in_n_cells(adata)
    adata = adata.copy()

    n_convolve = len(adata) // n_bins
    print(f"Setting `n_convolve` to {n_convolve} ({n_bins} bins, {len(adata)} cells) ")
    # Plot
    s = scv.pl.heatmap(
        adata,
        var_names=adata.var_names,
        sortby="crypt_villi_axis",
        n_convolve=n_convolve,
        show=False,
        yticklabels=True,
        rasterized=True,
        color_map=colormap,
    )
    ax = s.ax_heatmap

    # Loop through the x-axis tick labels and show/hide based on the 'highlight' list
    for i, label in enumerate(ax.get_yticklabels()):
        if label.get_text() not in highlight:
            label.set_visible(False)
            ax.get_yticklines()[2 * i + 1].set_visible(False)
        ax.get_yticklines()[2 * i].set_visible(False)

    ax.set_xlabel("")
    ax.set_title(f"Gene Expression Trends Along {sortby}")
    return s

In [None]:
s = scvelo_heatmap(
    adata[:, cytokine_genes],
    batches=all_batches,
    sortby="crypt_villi_axis",
    highlight=cytokine_genes,
    n_bins=20,
)