In [None]:
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import warnings
import seaborn as sns

warnings.filterwarnings("ignore")

In [None]:
adata = sc.read_h5ad("../data/adata/timecourse.h5ad")

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
import matplotlib.colors as clr

zissou = [
    "#3A9AB2",
    "#6FB2C1",
    "#91BAB6",
    "#A5C2A3",
    "#BDC881",
    "#DCCB4E",
    "#E3B710",
    "#E79805",
    "#EC7A05",
    "#EF5703",
    "#F11B00",
]

colormap = clr.LinearSegmentedColormap.from_list("Zissou", zissou)
colormap_r = clr.LinearSegmentedColormap.from_list("Zissou", zissou[::-1])

In [None]:
batches = {
    "day6_SI": {"x": 6200, "y": 6200},
    "day6_SI_r2": {"x": 5800, "y": 5500},
    "day8_SI_Ctrl": {"x": 2400, "y": 2400},
    "day8_SI_r2": {"x": 3200, "y": 1500},
    "day30_SI": {"x": 6400, "y": 2400},
    "day30_SI_r2": {"x": 6200, "y": 6200},
    "day90_SI": {"x": 2400, "y": 2400},
    "day90_SI_r2": {"x": 1200, "y": 6200},
}

In [None]:
# Functions to help with the IMAPS
def scatter_with_gaussian_kde_weights(ax, x, y, weights, **kwargs):
    from scipy.stats import gaussian_kde

    xy = np.vstack([x, y])
    z = gaussian_kde(xy, weights=weights)(xy)

    ax.scatter(x, y, c=z, **kwargs)


# Custom biexponential transformation. Maybe not needed for IF data
def transformation(x, a=0.1, b=0.1, c=0.5, d=2.5, f=4, w=1):
    x = np.array(x)
    return a * np.exp(b * ((x - w))) - c * np.exp(-d * (x - w)) + f

In [None]:
# Coordinates of the gates
gates = {
    "Top": {
        "edges": [
            [0.15, 0.5],
            [0.6, 0.7],
            [0.8, 0.7],
            [0.8, 1.03],
            [0.15, 1.03],
        ],
        "label_position": {"x": 0.16, "y": 0.9},
        "fill": "#3A9AB224",
        "stroke": "#3A9AB2",
    },
    "Crypt": {
        "edges": [
            [0.15, 0.48],
            [0.6, 0.68],
            [0.8, 0.68],
            [0.8, 0.25],
            [0.2, 0],
            [0.15, 0],
        ],
        "label_position": {"x": 0.16, "y": 0.05},
        "fill": "#F11B0024",
        "stroke": "#F11B00",
    },
    "Muscularis": {
        "edges": [[0.22, 0], [0.8, 0.23], [6, 0.23], [6, 0], [0.22, 0]],
        "label_position": {"x": 0.6, "y": 0.05},
        "fill": "#BDC88135",
        "stroke": "#BDC881",
    },
}


def draw_gates(ax, gates, transformation, type="edge"):
    from matplotlib.patches import Polygon

    for gate in gates:
        # Apply transformation to x values
        points = [
            [transformation(element[0])] + element[1:]
            for element in gates[gate]["edges"]
        ]

        if type == "fill":
            p = Polygon(points, facecolor=gates[gate]["fill"], edgecolor="none")
            ax.add_patch(p)
        elif type == "edge":
            p = Polygon(points, facecolor="none", edgecolor="#222222")
            ax.add_patch(p)

            ax.text(
                transformation(gates[gate]["label_position"]["x"]),
                gates[gate]["label_position"]["y"],
                gate,
                fontsize=6,
                color="#222222",
            )

In [None]:
# Create subplots
def plot_imaps(
    adata,
    batches,
    genes,
    ax_ticks=[0.15, 0.3, 0.6, 1, 6],
    transformation=transformation,
    gates=gates,
    dpi=100,
):
    fig = plt.figure(figsize=(3 * len(genes), 3 * len(batches)), dpi=dpi)

    # Apply transformation
    adata.obs["epithelial_distance_transformed"] = transformation(
        adata.obs["epithelial_distance"]
    )

    for col, gene in enumerate(genes):
        print("Plotting value: " + str(gene))
        # Iterate over batches
        for i, bt in enumerate(batches):
            sub_adata = adata[adata.obs["batch"] == bt]

            if gene == "Distribution":
                gene_expression = np.ones(len(sub_adata))
            else:
                gene_expression = np.array(
                    sub_adata[:, sub_adata.var.index == gene].X.flatten()
                )

            ax = fig.add_subplot(len(batches), len(genes), i * len(genes) + 1 + col)

            # Draw gates filled in background
            draw_gates(ax, gates=gates, transformation=transformation, type="fill")

            # Draw the density lines
            sns.kdeplot(
                data=sub_adata.obs,
                x="epithelial_distance_transformed",
                y="crypt_villi_axis",
                ax=ax,
                weights=gene_expression,
                color="#444444",
                linewidths=0.5,
            )

            # Colored scatter plot
            scatter_with_gaussian_kde_weights(
                ax=ax,
                x=sub_adata.obs["epithelial_distance_transformed"],
                y=sub_adata.obs["crypt_villi_axis"],
                weights=gene_expression,
                s=5,
                cmap="viridis" if gene == "Distribution" else colormap,
            )

            # Transform the tick labels and set them
            ax.set_xticks(transformation(ax_ticks))
            ax.set_xticklabels(ax_ticks)

            # Label the axes
            ax.set_xlabel("Epithelial Axis")
            ax.set_ylabel("Crypt-Villi Axis")

            ax.set_ylim(-0.02, 1.05)

            # Add a title
            if i == 0:
                ax.set_title(f"{gene}")
            else:
                ax.set_title(f"")
            draw_gates(ax, gates=gates, transformation=transformation)

    fig.tight_layout()

## Figure 2 h

In [None]:
# Distribution plots the IMAP without weights
plot_imaps(
    adata[adata.obs["Subtype"] == "Cd8_T-Cell_P14"],
    ["day90_SI"],
    genes=["Gzma", "Gzmb", "Itgae", "Tcf7"],
)

## Figure 2 i

In [None]:
adata_p14 = adata[adata.obs["Subtype"] == "Cd8_T-Cell_P14"]

In [None]:
def classify_cells(adata, gates, transformation=transformation):
    """
    Classify cells based on the gates.
    """
    from shapely.geometry import Point
    from shapely.geometry.polygon import Polygon
    import geopandas as gpd

    adata.obs["epithelial_distance_transformed"] = transformation(
        adata.obs["epithelial_distance"]
    )
    adata.obs["gate"] = False

    print("Creating polygons")
    polygons = {}
    for gate in gates:
        # Apply transformation to x values
        points = [
            [transformation(element[0])] + element[1:]
            for element in gates[gate]["edges"]
        ]
        polygons[gate] = Polygon(points)
    polygons = gpd.GeoSeries(polygons)
    gpd_poly = gpd.GeoDataFrame({"gates": polygons}, geometry="gates")

    print("Creating cells")
    cells = gpd.GeoSeries.from_xy(
        adata.obs["epithelial_distance_transformed"], adata.obs["crypt_villi_axis"]
    )
    gpd_cells = gpd.GeoDataFrame({"cells": cells}, geometry="cells")

    print("Joining cells and polygons")
    result = gpd.sjoin(
        gpd_cells,
        gpd_poly,
        how="left",
    )
    return result

In [None]:
def get_mean_expression(adata, genes):
    """
    Calculate the mean expression of specified genes for each batch and gate in the provided AnnData object.

    Parameters:
    - adata (AnnData): Annotated data matrix with observations (rows) and variables (columns).
    - genes (list): A list of gene names for which mean expression is calculated.

    Returns:
    - pd.DataFrame: A DataFrame containing mean expression values for each gene, batch, and gate.
    """
    import re

    keys = genes + ["batch", "gate"]
    df = sc.get.obs_df(adata, keys=keys)
    grouped = (
        df.groupby(["batch", "gate"])
        .mean()
        .join(df.groupby(["batch", "gate"]).size().rename("group_size"))
        .reset_index()
    )
    grouped["Day"] = [re.findall(r"\d+", b)[0] for b in grouped["batch"]]

    # Melt into a longer form
    grouped = pd.melt(
        grouped,
        id_vars=["batch", "Day", "gate", "group_size"],
        var_name="gene",
        value_name="expression",
    )
    return grouped


def get_scaled_mean_expression(adata, genes):
    """
    Calculate the scaled mean expression of specified genes for each batch and gate in the provided AnnData object.

    Parameters:
    - adata (AnnData): Annotated data matrix with observations (rows) and variables (columns).
    - genes (list): A list of gene names for which scaled mean expression is calculated.

    Returns:
    - pd.DataFrame: A DataFrame containing scaled mean expression values for each gene, batch, and gate.
    """
    grouped = get_mean_expression(adata=adata, genes=genes)
    # Group the DataFrame by gene
    gene_groups = grouped.groupby("gene")

    # Apply the scaling function to each group
    def scale_group(group):
        gene_min = group["expression"].min()
        gene_max = group["expression"].max()
        group["expression"] = (group["expression"] - gene_min) / (gene_max - gene_min)
        return group

    df_scaled = gene_groups.apply(scale_group)

    return df_scaled

In [None]:
classification = classify_cells(adata_p14, gates)
adata_p14.obs["gate"] = classification["index_right"]

In [None]:
df = get_scaled_mean_expression(adata_p14, ["Tcf7", "Itgae", "Gzma", "Gzmb"])
df = df[df["group_size"] >= 50]
df["gate"] = (
    df["gate"].astype("category").cat.reorder_categories(["Top", "Crypt", "Muscularis"])
)
g = sns.FacetGrid(
    df,
    hue="gate",
    col="gene",
    col_wrap=4,
    palette={k: gates[k]["stroke"] for k in gates},
)
g.map(sns.lineplot, "Day", "expression", err_style="bars")
g.add_legend()