In [None]:
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import warnings
import seaborn as sns

warnings.filterwarnings("ignore")

In [None]:
import matplotlib.colors as clr

zissou = [
    "#3A9AB2",
    "#6FB2C1",
    "#91BAB6",
    "#A5C2A3",
    "#BDC881",
    "#DCCB4E",
    "#E3B710",
    "#E79805",
    "#EC7A05",
    "#EF5703",
    "#F11B00",
]

colormap = clr.LinearSegmentedColormap.from_list("Zissou", zissou)
colormap_r = clr.LinearSegmentedColormap.from_list("Zissou", zissou[::-1])

In [None]:
adata = sc.read_h5ad("../data/adata/tgfb.h5ad")

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
ax = sns.histplot(np.sum(adata.X > 0, axis=1)[adata.obs["Subtype"] == "Cd8_T-Cell_P14"])
ax.set_xlabel("Number of expressed genes")

In [None]:
signatures = {
    "TRM": pd.read_csv(
        "../data/signatures/Core Trm signature_Milner et al Nature 2017_vIL.txt",
        header=None,
    )[0].to_list(),
    "TGFb": pd.read_csv("../data/signatures/TGFbeta.txt", header=None)[0].to_list(),
}
import ucell

ucell.add_scores(adata, signatures, maxRank=100, seed=42)

In [None]:
def scatter_with_gaussian_kde_weights(ax, x, y, weights, exp, **kwargs):
    from scipy.stats import gaussian_kde

    xy = np.vstack([x, y])
    z = gaussian_kde(xy, weights=weights**exp)(xy)

    ax.scatter(x, y, c=z, **kwargs)


# Custom biexponential transformation. Maybe not needed for IF data
def transformation(x, a=0.1, b=0.1, c=0.5, d=2.5, f=4, w=1):
    x = np.array(x)
    return a * np.exp(b * ((x - w))) - c * np.exp(-d * (x - w)) + f


def classify_cells(adata, gates, transformation=transformation):
    """
    Classify cells based on the IMAP gates.

    Parameters:
    - adata (anndata): The anndata object containing the cells to classify.
    - gates (dict): A dictionary containing the gates to classify the cells with.
    - transformation (function): A function to transform the x values of the gates.

    Returns:
    - result (geopandas dataframe): A geopandas dataframe containing the classified cells.
    """
    from shapely.geometry import Point
    from shapely.geometry.polygon import Polygon
    import geopandas as gpd

    adata.obs["epithelial_distance_transformed"] = transformation(
        adata.obs["epithelial_distance"]
    )
    adata.obs["gate"] = False

    print("Creating polygons")
    polygons = {}
    for gate in gates:
        # Apply transformation to x values
        points = [
            [transformation(element[0])] + element[1:]
            for element in gates[gate]["edges"]
        ]
        polygons[gate] = Polygon(points)
    polygons = gpd.GeoSeries(polygons)
    gpd_poly = gpd.GeoDataFrame({"gates": polygons}, geometry="gates")

    print("Creating cells")
    cells = gpd.GeoSeries.from_xy(
        adata.obs["epithelial_distance_transformed"], adata.obs["crypt_villi_axis"]
    )
    gpd_cells = gpd.GeoDataFrame({"cells": cells}, geometry="cells")

    print("Joining cells and polygons")
    result = gpd.sjoin(
        gpd_cells,
        gpd_poly,
        how="left",
    )
    return result

In [None]:
# Coordinates of the gates
gates = {
    "Top-IE": {
        "edges": [
            [0.05, 0.5],
            [0.19, 0.5],
            [0.19, 1.03],
            [0.05, 1.03],
        ],
        "label_position": {"x": 0.16, "y": 0.9},
        "fill": "#3A9AB244",
    },
    "Top-LP": {
        "edges": [
            [0.20, 0.5],
            [0.8, 0.57],
            [0.8, 1.03],
            [0.20, 1.03],
        ],
        "label_position": {"x": 1, "y": 0.8},
        "fill": "#3A9AB244",
    },
    "Crypt-IE": {
        "edges": [
            [0.05, 0.48],
            [0.19, 0.48],
            [0.19, 0],
            [0.05, 0],
        ],
        "label_position": {"x": 0.16, "y": 0.05},
        "fill": "#F11B0044",
    },
    "Crypt-LP": {
        "edges": [
            [0.20, 0.48],
            [0.8, 0.55],
            [0.8, 0.3],
            [0.20, 0.3],
        ],
        "label_position": {"x": 1, "y": 0.25},
        "fill": "#F11B0044",
    },
    "Muscularis": {
        "edges": [
            [0.2, 0.28],
            [0.8, 0.28],
            [6, 0.28],
            [6, 0],
            [0.2, 0],
        ],
        "label_position": {"x": 0.6, "y": 0.05},
        "fill": "#BDC88155",
    },
}


def draw_gates(ax, gates, transformation, type="edge"):
    from matplotlib.patches import Polygon

    for gate in gates:
        # Apply transformation to x values
        points = [
            [transformation(element[0])] + element[1:]
            for element in gates[gate]["edges"]
        ]

        if type == "fill":
            p = Polygon(points, facecolor=gates[gate]["fill"], edgecolor="none")
            ax.add_patch(p)
        elif type == "edge":
            p = Polygon(points, facecolor="none", edgecolor="#222222")
            ax.add_patch(p)

            ax.text(
                transformation(gates[gate]["label_position"]["x"]),
                gates[gate]["label_position"]["y"],
                gate,
                fontsize=6,
                color="#222222",
            )

In [None]:
import numpy as np
import anndata as ad


def get_expression(adata: ad.AnnData, key: str) -> np.ndarray:
    """
    Retrieves expression values for a given gene or observation annotation from an AnnData object.

    Args:
        adata: An AnnData object containing expression data.
        key: The name of the gene or observation annotation to retrieve.

    Returns:
        A NumPy array containing the expression values.

    Raises:
        ValueError: If the key is not found in either the var_names or obs columns of the AnnData object.
    """

    if key in adata.var_names:
        return np.array(adata[:, key].X.flatten())
    elif key in adata.obs.columns:
        return np.array(adata.obs[key])
    else:
        raise ValueError(f"{key} not found in object")

In [None]:
# Create subplots
def plot_expression_imaps(
    adata,
    batches,
    genes,
    ax_ticks=[0.15, 0.3, 0.6, 1, 6],
    transformation=transformation,
    gates=gates,
    dpi=60,
    exp=1,
    normalize_batches=True,
):
    fig = plt.figure(figsize=(3 * len(genes), 3 * len(batches)), dpi=dpi)

    # Apply transformation
    adata.obs["epithelial_distance_transformed"] = transformation(
        adata.obs["epithelial_distance"]
    )

    for col, gene in enumerate(genes):
        print("Plotting value: " + str(gene))
        # Iterate over batches
        if normalize_batches:
            zs = []
            for i, bt in enumerate(batches):
                sub_adata = adata[adata.obs["batch"] == bt]

                if gene == "Distribution":
                    gene_expression = np.ones(len(sub_adata))
                else:
                    gene_expression = get_expression(sub_adata, gene)
                from scipy.stats import gaussian_kde

                xy = np.vstack(
                    [
                        sub_adata.obs["epithelial_distance_transformed"],
                        sub_adata.obs["crypt_villi_axis"],
                    ]
                )
                z = gaussian_kde(xy, weights=gene_expression**exp)(xy)
                zs.append(z)
            vmax = np.concatenate(zs).max()
            vmin = np.concatenate(zs).min()
        else:
            vmax = None
            vmin = None

        for i, bt in enumerate(batches):
            sub_adata = adata[adata.obs["batch"] == bt]

            if gene == "Distribution":
                gene_expression = np.ones(len(sub_adata))
            else:
                gene_expression = get_expression(sub_adata, gene)

            ax = fig.add_subplot(len(batches), len(genes), i * len(genes) + 1 + col)

            # Draw gates filled in background
            draw_gates(ax, gates=gates, transformation=transformation, type="fill")

            # Draw the density lines
            sns.kdeplot(
                data=sub_adata.obs,
                x="epithelial_distance_transformed",
                y="crypt_villi_axis",
                ax=ax,
                weights=gene_expression,
                color="#444444",
                linewidths=0.5,
            )

            # Colored scatter plot
            scatter_with_gaussian_kde_weights(
                ax=ax,
                x=sub_adata.obs["epithelial_distance_transformed"],
                y=sub_adata.obs["crypt_villi_axis"],
                weights=gene_expression,
                exp=exp,
                s=5,
                cmap="viridis" if gene == "Distribution" else colormap,
                vmax=vmax,
                vmin=vmin,
            )

            # Transform the tick labels and set them
            ax.set_xticks(transformation(ax_ticks))
            ax.set_xticklabels(ax_ticks)

            # Label the axes
            ax.set_xlabel("Epithelial Axis")
            ax.set_ylabel(f"{bt}\nCrypt-Villi Axis")

            ax.set_ylim(-0.02, 1.05)

            # Add a title
            if i == 0:
                ax.set_title(f"{gene}")
            else:
                ax.set_title(f"")
            draw_gates(ax, gates=gates, transformation=transformation)

    fig.tight_layout()

# Figure 4b & 4e

In [None]:
plot_expression_imaps(
    adata[adata.obs["Subtype"] == "Cd8_T-Cell_P14"],
    ["WT", "KO"],
    genes=["Distribution", "Itgae", "Mki67"],
    dpi=100,
)

In [None]:
adata_p14 = adata[adata.obs["Subtype"] == "Cd8_T-Cell_P14"]

In [None]:
classification = classify_cells(adata_p14, gates)

adata_p14.obs["gate"] = classification["index_right"]
df = adata_p14.obs

In [None]:
df_abs = df.groupby(["batch", "gate"]).size().unstack()
df_rel = df_abs.div(df_abs.sum(axis=1), axis=0) * 100
df_rel = df_rel.iloc[:, [2, 0, 1, 3, 4]]
df_rel

In [None]:
df_rel.plot(
    kind="bar",
    stacked="True",
    color={
        "Muscularis": "#e9edd5",
        "Crypt-IE": "#fcc6c1",
        "Crypt-LP": "#fcc6c1",
        "Top-LP": "#cbe4eb",
        "Top-IE": "#cbe4eb",
    },
    figsize=(2, 5),
)

# Extetnded data figure 6f

In [None]:
plot_expression_imaps(
    adata[adata.obs["Subtype"] == "Cd8_T-Cell_P14"],
    ["WT", "KO"],
    genes=["UCell_TRM", "UCell_TGFb"],
    dpi=100,
    exp=2,
)