In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import AxesGrid
import numpy as np
from scipy.stats import gaussian_kde
import warnings
import matplotlib.colors as clr
from matplotlib.patches import Polygon
import seaborn as sns
from shapely.geometry import Point
from shapely.geometry.polygon import Polygon
import geopandas as gpd
import re

warnings.filterwarnings("ignore")

In [None]:
adata = sc.read_h5ad("../data/adata/timecourse.h5ad")

In [None]:
zissou = [
    "#3A9AB2",
    "#6FB2C1",
    "#91BAB6",
    "#A5C2A3",
    "#BDC881",
    "#DCCB4E",
    "#E3B710",
    "#E79805",
    "#EC7A05",
    "#EF5703",
    "#F11B00",
]

colormap = clr.LinearSegmentedColormap.from_list("Zissou", zissou)
colormap_r = clr.LinearSegmentedColormap.from_list("Zissou", zissou[::-1])

In [None]:
def scatter_with_gaussian_kde(ax, x, y, **kwargs):
    """
    Plots a scatter plot colored by gaussian kde estimates.

    Parameters:
    - ax (matplotlib ax): The ax on which to plot the scatter plot.
    - x (np.array): The x values to perform the gaussian kde and scattering on.
    - y (np.array): The y values to perform the gaussian kde and scattering on.
    - **kwargs: Additional keyword arguments to pass to the scatter function.

    Returns:
    - None
    """
    
    xy = np.vstack([x, y])
    z = gaussian_kde(xy)(xy)

    ax.scatter(x, y, c=z, **kwargs)


# Custom biexponential transformation. Returns x coordinates that have been transformed
def transformation(x, a=0.1, b=0.1, c=0.5, d=2.5, f=4, w=1):
    x = np.array(x)
    return a * np.exp(b * ((x - w))) - c * np.exp(-d * (x - w)) + f

def draw_gates(ax, gates, transformation, type="edge"):
    """
    Draws IMAP gates on a matplotlib ax.

    Parameters:
    - ax (matplotlib ax): The ax on which to draw the gates.
    - gates (dict): A dictionary containing the gates to draw.
    - transformation (function): A function to transform the x values of the gates.
    - type (str): The type of gate to draw. Can be "fill" or "edge".

    Returns:
    - None
    """

    for gate in gates:
        # Apply transformation to x values
        points = [
            [transformation(element[0])] + element[1:]
            for element in gates[gate]["edges"]
        ]

        if type == "fill":
            p = Polygon(points, facecolor=gates[gate]["fill"], edgecolor="none")
            ax.add_patch(p)
        elif type == "edge":
            p = Polygon(points, facecolor="none", edgecolor="#222222")
            ax.add_patch(p)

            ax.text(
                transformation(gates[gate]["label_position"]["x"]),
                gates[gate]["label_position"]["y"],
                gate,
                fontsize=6,
                color="#222222",
            )

def classify_cells(adata, gates, transformation=transformation):
    """
    Classify cells based on the IMAP gates.

    Parameters:
    - adata (anndata): The anndata object containing the cells to classify.
    - gates (dict): A dictionary containing the gates to classify the cells with.
    - transformation (function): A function to transform the x values of the gates.

    Returns:
    - result (geopandas dataframe): A geopandas dataframe containing the classified cells.
    """


    adata.obs["epithelial_distance_transformed"] = transformation(
        adata.obs["epithelial_distance"]
    )
    adata.obs["gate"] = False

    print("Creating polygons")
    polygons = {}
    for gate in gates:
        # Apply transformation to x values
        points = [
            [transformation(element[0])] + element[1:]
            for element in gates[gate]["edges"]
        ]
        polygons[gate] = Polygon(points)
    polygons = gpd.GeoSeries(polygons)
    gpd_poly = gpd.GeoDataFrame({"gates": polygons}, geometry="gates")

    print("Creating cells")
    cells = gpd.GeoSeries.from_xy(
        adata.obs["epithelial_distance_transformed"], adata.obs["crypt_villi_axis"]
    )
    gpd_cells = gpd.GeoDataFrame({"cells": cells}, geometry="cells")

    print("Joining cells and polygons")
    result = gpd.sjoin(
        gpd_cells,
        gpd_poly,
        how="left",
    )
    return result

Define IMAP gates

In [None]:
# Coordinates of the gates
gates = {
    "Top": {
        "edges": [
            [0.15, 0.5],
            [0.6, 0.7],
            [0.8, 0.7],
            [0.8, 1.03],
            [0.15, 1.03],
        ],
        "label_position": {"x": 0.16, "y": 0.9},
        "fill": "#3A9AB244",
        "stroke": "#3A9AB2",
    },
    "Crypt": {
        "edges": [
            [0.15, 0.48],
            [0.6, 0.68],
            [0.8, 0.68],
            [0.8, 0.25],
            [0.2, 0],
            [0.15, 0],
        ],
        "label_position": {"x": 0.16, "y": 0.05},
        "fill": "#F11B0044",
        "stroke": "#F11B00",
    },
    "Muscularis": {
        "edges": [[0.22, 0], [0.8, 0.23], [6, 0.23], [6, 0], [0.22, 0]],
        "label_position": {"x": 0.6, "y": 0.05},
        "fill": "#BDC88155",
        "stroke": "#BDC881",
    },
}

Plotting the IMAPs of P14s across batches

In [None]:
# Create subplots
def plot_imaps(
    adata,
    batches,
    obs,
    values,
    ax_ticks=[0.15, 0.3, 0.6, 1, 6],
    transformation=transformation,
    gates=gates,
    dpi=600,
):
    """
    Plots IMAPs for a given set of batches and cells belonging to the categories in "values".

    Parameters:
    - adata (anndata): The anndata object containing the cells to plot.
    - batches (list): A list of batches to plot.
    - obs (str): The observation column from the adata to plot.
    - values (list): Values in the observation column to subset the plotting to.
    - ax_ticks (list): A list of x-axis ticks.
    - transformation (function): A function to transform the x values of the gates.
    - gates (dict): A dictionary containing the gates to plot.
    - dpi (int): The dpi of the plot.

    Returns:
    - None
    """

    fig = plt.figure(figsize=(3 * len(values), 3 * len(batches)), dpi=dpi)

    # Apply transformation
    adata.obs["epithelial_distance_transformed"] = transformation(
        adata.obs["epithelial_distance"]
    )

    for col, value in enumerate(values):
        print("Plotting value: " + str(value))
        # Iterate over batches
        for i, bt in enumerate(batches):
            sub_adata = adata[adata.obs["batch"] == bt]
            sub_adata = sub_adata[sub_adata.obs[obs] == value]
            ax = fig.add_subplot(len(batches), len(values), i * len(values) + 1 + col)

            # Draw gates filled in background
            draw_gates(ax, gates=gates, transformation=transformation, type="fill")

            # Draw the density lines
            sns.kdeplot(
                data=sub_adata.obs,
                x="epithelial_distance_transformed",
                y="crypt_villi_axis",
                ax=ax,
                color="#444444",
                linewidths=0.5,
            )

            # Colored scatter plot
            scatter_with_gaussian_kde(
                ax=ax,
                x=sub_adata.obs["epithelial_distance_transformed"],
                y=sub_adata.obs["crypt_villi_axis"],
                s=5,
            )

            # Transform the tick labels and set them
            ax.set_xticks(transformation(ax_ticks))
            ax.set_xticklabels(ax_ticks)

            # Label the axes
            ax.set_xlabel("Epithelial Axis")
            ax.set_ylabel(f"{bt}\nCrypt-Villi Axis")

            ax.set_ylim(-0.02, 1.05)

            # Add a title
            if i == 0:
                ax.set_title(f"{value}")
            else:
                ax.set_title(f"")
            draw_gates(ax, gates=gates, transformation=transformation)

    fig.tight_layout()

In [None]:
batches = {
    #  "day6_SI": {"x": 6200, "y": 6200},
    "day6_SI_r2": {"x": 5800, "y": 5500},
    # "day8_SI_Ctrl": {"x": 2400, "y": 2400},
    "day8_SI_r2": {"x": 3200, "y": 1500},
    # "day30_SI": {"x": 6400, "y": 2400},
    "day30_SI_r2": {"x": 6200, "y": 6200},
    "day90_SI": {"x": 2400, "y": 2400},
    # "day90_SI_r2": {"x": 1200, "y": 6200},
}

In [None]:
plot_imaps(adata, batches, obs="Subtype", values=["Cd8_T-Cell_P14"], dpi=100)

Quantification of P14s in each gate

In [None]:
adata_p14 = adata[adata.obs["Subtype"] == "Cd8_T-Cell_P14"]

In [None]:
classification = classify_cells(adata_p14, gates)

adata_p14.obs["gate"] = classification["index_right"]
df = adata_p14.obs

df_abs = df.groupby(["batch", "gate"]).size().unstack()
df_abs["Day"] = [re.findall(r"\d+", index)[0] for index in df_abs.index]
df_abs

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))

sns_df = df_abs
sns_df = sns_df.set_index("Day")

sns.lineplot(sns_df, dashes=False, ax=ax)
sns.scatterplot(sns_df, ax=ax, markers="o", alpha=0.4, legend=False)
ax.set_yscale("log")
plt.xlabel("Day")
plt.ylabel("Absolute cell numbers")
plt.title("Absolute cell numbers")