In [None]:
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import AxesGrid
import numpy as np
import scipy
import seaborn as sns
from typing import List
import os
import scvelo as scv

In [None]:
import matplotlib.colors as clr

# slightly modified to make it less red
zissou = [
    "#3A9AB2",
    "#6FB2C1",
    "#91BAB6",
    "#A5C2A3",
    "#BDC881",
    "#DCCB4E",
    "#E3B710",
    "#EC7A05",
    "#F11B00",
]


colormap = clr.LinearSegmentedColormap.from_list("Zissou", zissou)
colormap_r = clr.LinearSegmentedColormap.from_list("Zissou", zissou[::-1])

In [None]:
adata = sc.read("../data/adata/human.h5ad")

batches = {
    "human_05_r1": {"x": 2400, "y": 2400},
    "human_05_r2": {"x": 6400, "y": 2400},
    "human_09_r1": {"x": 2400, "y": 6400},
    "human_09_r2": {"x": 6400, "y": 6400},
}

# subset anndata
adata = adata[adata.obs.batch.isin(batches.keys())]
adata = adata[adata.obs.peyers == 0]
adata

In [None]:
# Normalize
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata

In [None]:
def filter_adata_expressed_in_n_cells(adata, percent=0.05):
    bin_Layer = adata.X > 0
    gene_expressed_in_percent_cells = np.mean(bin_Layer, axis=0)
    keep = gene_expressed_in_percent_cells > percent
    adata = adata[:, keep]
    return adata

In [None]:
def scvelo_heatmap(
    adata: sc.AnnData,
    batches: List[str],
    key_name: str,
    key_value: str,
    sortby: str,
    highlight: List[str],
    n_bins: int = 5,
):
    """
    Create a heatmap to visualize gene expression trends in single-cell RNA-seq data,
    with options for subsetting, sorting, and highlighting genes.

    Parameters:
    - adata (sc.AnnData): Annotated data object containing single-cell RNA-seq data.
    - batches (List[str]): List of batch identifiers to subset the data.
    - key_name (str): String representing the key in `adata.obs` to use for subsetting cells.
    - key_value (str): String representing the value of `key_name` to subset to.
    - sortby (str): Variable to sort the heatmap by (e.g., "crypt_villi_axis").
    - highlight (List[str]): List of labels to highlight on the heatmap.
    - n_bins (int, optional): Integer specifying the number of bins to use for convolution (default: 5).

    Returns:
    - s (seaborn.matrix.ClusterGrid): Matplotlib figure object representing the heatmap.

    This function subsets the input data based on specified batches and key-value pairs,
    filters genes expressed in a minimum percentage of cells, and creates a heatmap
    to visualize gene expression trends along a specified variable. The function also allows
    highlighting specific labels on the y-axis.

    Example:
    ```
    scvelo_heatmap(adata, batches=list(batches.keys())[0:2],
               key_name="Subtype",
               key_value="Cd8_T-Cell_P14",
               sortby="crypt_villi_axis",
               highlight=highlight,
               n_bins=20)
    ```
    """
    print("Creating Heatmap for batches", " + ".join(batches))
    print(f"Subset to '{key_name}'=='{key_value}'")
    # Subset batches
    adata = adata[adata.obs["batch"].isin(batches)]
    # Subset to key
    adata = adata[adata.obs[key_name] == key_value]
    # Filter to include only genes that are expressed in 5% of the cells
    adata = filter_adata_expressed_in_n_cells(adata)
    adata = adata.copy()

    n_convolve = len(adata) // n_bins
    print(f"Setting `n_convolve` to {n_convolve} ({n_bins} bins, {len(adata)} cells) ")
    # Plot
    s = scv.pl.heatmap(
        adata,
        var_names=adata.var_names,
        sortby=sortby,
        n_convolve=n_convolve,
        show=False,
        yticklabels=True,
        rasterized=True,
        color_map=colormap,
    )
    ax = s.ax_heatmap

    # Loop through the x-axis tick labels and show/hide based on the 'highlight' list
    for i, label in enumerate(ax.get_yticklabels()):
        if label.get_text() not in highlight:
            label.set_visible(False)
            ax.get_yticklines()[2 * i + 1].set_visible(False)
        ax.get_yticklines()[2 * i].set_visible(False)

    ax.set_xlabel("")
    ax.set_title(f"Human Gene Expression Trends")
    return s

In [None]:
highlight = ["ITGAE", "GZMA", "MKI67", "KLRG1", "KLF2", "SLAMF6", "TCF7"]

# 5e

In [None]:
s = scvelo_heatmap(
    adata,
    batches=list(batches.keys()),
    key_name="CD8_column",
    key_value=1,
    sortby="crypt_villi_axis",
    highlight=highlight,
    n_bins=10,
)

# 5f

In [None]:
s = scvelo_heatmap(
    adata,
    batches=list(batches.keys()),
    key_name="CD8_column",
    key_value=1,
    sortby="epithelial_distance_clipped",
    highlight=highlight,
    n_bins=10,
)