In [None]:
%pip install squidpy==1.3.1

In [None]:
import os

# important for gpd.sjoin
os.environ["USE_PYGEOS"] = "0"

import scanpy as sc
import squidpy as sq
import numpy as np
import warnings
import seaborn as sns
import matplotlib.pyplot as plt
import igraph
import random
import math
from sklearn.preprocessing import MinMaxScaler


warnings.filterwarnings("ignore")

In [None]:
import matplotlib.colors as clr

zissou = [
    "#3A9AB2",
    "#6FB2C1",
    "#91BAB6",
    "#A5C2A3",
    "#BDC881",
    "#DCCB4E",
    "#E3B710",
    "#E79805",
    "#EC7A05",
    "#EF5703",
    "#F11B00",
]

colormap = clr.LinearSegmentedColormap.from_list("Zissou", zissou)
colormap_r = clr.LinearSegmentedColormap.from_list("Zissou", zissou[::-1])

In [None]:
def get_interaction_matrix(adata, cluster_key, spatial_key, normalized):
    """
    Get the interaction matrix for a given batch of cells using Squidpy.

    Parameters
    - adata (anndata): The anndata object containing the cells to find interactions between.
    - cluster_key (str): The key in adata.obs to use for defining discrete interacting groups.
    - spatial_key (str): The key in adata.obsm to use for spatial coordinates.
    - normalized (bool): Whether to normalize the interaction matrix.

    Returns
    - interaction_matrix (np.ndarray): The interaction matrix.
    """

    adata = adata.copy()
    sq.gr.spatial_neighbors(adata, spatial_key=spatial_key)
    sq.gr.interaction_matrix(adata, cluster_key=cluster_key, normalized=normalized)

    return adata.uns[f"{cluster_key}_interactions"]


def mean_interaction_matrix(
    adata, batches, cluster_key, spatial_key="X_spatial", normalized=True
):
    """
    Get the mean interaction matrix for a set of batches of cells using Squidpy.

    Parameters
    - adata (anndata): The anndata object containing the cells to find interactions between.
    - batches (list): The list of batch names to use in the interaction analysis.
    - cluster_key (str): The key in adata.obs to use for defining discrete interacting groups.
    - spatial_key (str): The key in adata.obsm to use for spatial coordinates.
    - normalized (bool): Whether to normalize the interaction matrix.

    Returns
    - mean_interactions (np.ndarray): The mean interaction matrix.
    """

    interactions = [
        get_interaction_matrix(
            adata=adata[adata.obs["batch"] == b],
            spatial_key=spatial_key,
            cluster_key=cluster_key,
            normalized=normalized,
        )
        for b in batches
    ]

    mean_interactions = sum(interactions) / len(batches)

    return mean_interactions


def create_mean_interaction_graph(
    adata, batches, cluster_key, spatial_key="X_spatial", interaction_cutoff=0.1
):
    """
    Create an interaction graph for a set of batches of cells using Squidpy.

    Parameters
    - adata (anndata): The anndata object containing the cells to find interactions between.
    - batches (list): The list of batch names to use in the interaction analysis.
    - cluster_key (str): The key in adata.obs to use for defining discrete interacting groups.
    - spatial_key (str): The key in adata.obsm to use for spatial coordinates.
    - interaction_cutoff (float): The cutoff for interactions to be included in the graph.

    Returns
    - g (igraph.Graph): The interaction graph.
    """

    M = mean_interaction_matrix(
        adata=adata, batches=batches, cluster_key=cluster_key, spatial_key=spatial_key
    )

    M[M < interaction_cutoff] = 0
    g = igraph.Graph.Weighted_Adjacency(M)
    g.vs["label"] = adata.obs[cluster_key].cat.categories
    return g


def get_mean_expression(adata, batches, gene, cluster_key):
    """
    Calculate the mean expression of specified genes for each batch and gate in the provided AnnData object.

    Parameters:
    - adata (AnnData): Annotated data matrix with observations (rows) and variables (columns).
    - gene (str): A gene name for which mean expression is calculated.

    Returns:
    - pd.DataFrame: A DataFrame containing mean expression values for the gene, batch.
    """
    import re

    adata = adata[adata.obs.batch.isin(batches)]
    keys = [gene, "batch", cluster_key]
    df = sc.get.obs_df(adata, keys=keys)
    grouped = df.groupby(["batch", cluster_key]).mean().reset_index()
    grouped = grouped[[cluster_key, gene]].groupby(cluster_key).mean().reset_index()
    expression = grouped[gene].tolist()
    expression = expression / np.max(expression)
    return expression

In [None]:
def plot_graph(g, vertex_colors, ax, layout="kk", highlight="Cd8_T-Cell_P14"):
    """
    Plot an igraph graph with specified vertex colors and layout.

    Parameters
    - g (igraph.Graph): The graph to plot.
    - vertex_colors (list): The list of colors to use for the vertices.
    - ax (matplotlib.axes.Axes): The axes to plot the graph on.
    - layout (str): The layout to use for the graph.
    - highlight (str): The vertex to highlight in the graph.

    Returns
    - None
    """

    random.seed(42)
    try:
        node = g.vs["label"].index(highlight)
        highlight_edges = g.incident(node, "all")
    except:
        highlight_edges = []

    edge_color = [
        "90,10,0" if i in highlight_edges else "0,0,0" for i in range(len(g.es))
    ]
    igraph.plot(
        g,
        target=ax,
        layout=layout,
        edge_color=[
            f"rgba({c}, {w})" for c, w in zip(edge_color, scale_numbers(g.es["weight"]))
        ],
        edge_arrow_size=0.005,
        edge_width=1,
        vertex_color=vertex_colors,
        vertex_label_dist=-1,
        vertex_label_size=8,
    )


def scale_numbers(input_list, target_min=0.4, target_max=0.9):
    """
    Scale a list of numbers to a target range.

    Parameters
    - input_list (list): The list of numbers to scale.
    - target_min (float): The minimum value of the target range.
    - target_max (float): The maximum value of the target range.

    Returns
    - scaled_list (list): The list of scaled numbers.
    """

    # Find the minimum and maximum values in the input list
    min_value = min(input_list)
    max_value = max(input_list)

    # Scale each number in the input list to the target range
    scaled_list = [
        ((x - min_value) / (max_value - min_value)) * (target_max - target_min)
        + target_min
        for x in input_list
    ]

    return scaled_list


def ceil_division(numerator, denominator):
    return int(math.ceil(numerator / denominator))


def global_layout(adata, cluster_key, batches, layout="kk"):
    """
    Create a global layout for the cells in an AnnData object using an interaction graph.

    Parameters
    - adata (anndata): The AnnData object containing the cells to create a layout for.
    - cluster_key (str): The key in adata.obs to use for defining discrete interacting groups.
    - batches (list): The list of batch names to use in the interaction analysis.
    - layout (str): The layout to use for the graph.

    Returns
    - layout (igraph.Layout): The global layout for the cells in the AnnData object.
    """

    g = create_mean_interaction_graph(adata, batches, cluster_key)
    random.seed(42)
    layout = g.layout(layout)
    return layout

Figure 3b

In [None]:
adata = sc.read_h5ad("../data/adata/timecourse.h5ad")

In [None]:
gl = global_layout(
    adata=adata,
    cluster_key="Subtype",
    batches=[
        "day8_SI_Ctrl",
        "day8_SI_r2",
    ],
)
gl

In [None]:
batches = {
    "day 6": [
        "day6_SI",
        "day6_SI_r2",
    ],
    "day 8": [
        "day8_SI_Ctrl",
        "day8_SI_r2",
    ],
    "day 30": [
        "day30_SI",
        "day30_SI_r2",
    ],
    "day 90": [
        "day90_SI",
        "day90_SI_r2",
    ],
}

all_batches = list(np.concatenate(list(batches.values())))

Make the interaction graph, colored by celltype

In [None]:
colors = {
    "Epithelial_Secretory": "#AA9228",
    "Epithelial_Absorptive": "#E3C300",
    "Monocyte": "#C37698",
    "T-Cell": "#008E74",
    "MAIT": "#63ABB9",
    "Myeloid": "#EF9684",
    "ILC": "#A0C6D3",
    "B-Cell": "#E2CEAB",
    "DC": "#FE757D",
    "Fibroblast": "#E17300",
    "Endothelial": "#E30133",
    "NK": "#4A7B89",
    "Epithelial_Progenitor": "#F7BC00",
    "Neuron": "#2A2446",
    "Erythroid": "#A5021D",
    "Eosinophil": "#782c4e",
}

In [None]:
# We have defined colors for the Type but not the Subtype annotation. Hence, we will create a assignment
# from subtype to type to pick the colors.
subtype_to_type = (
    adata.obs.groupby(["Subtype", "Type"])
    .size()
    .reset_index()
    .rename(columns={0: "count"})
)
subtype_to_type = subtype_to_type[subtype_to_type["count"] > 0].set_index("Subtype")

In [None]:
ncols = len(batches.keys())
fig, axes = plt.subplots(1, ncols, figsize=(9 * ncols, 6 * 1))

for col, batch in enumerate(batches.keys()):
    print(f"Creating graph for {batch}")
    g = create_mean_interaction_graph(
        adata=adata, batches=batches[batch], cluster_key="Subtype"
    )

    ax = axes[col]

    node_colors = [
        colors[subtype_to_type["Type"].loc[t]]
        for t in adata.obs["Subtype"].cat.categories
    ]
    plot_graph(
        g=g,
        vertex_colors=node_colors,
        ax=ax,
        layout=gl,
    )
    ax.set_title(f"{batch}: Subtype")
fig.tight_layout()

Figure 3c

In [None]:
gates = {
    "Top": {
        "edges": [
            [0.15, 0.5],
            [0.6, 0.7],
            [0.8, 0.7],
            [0.8, 1.03],
            [0.15, 1.03],
        ],
        "label_position": {"x": 0.16, "y": 0.9},
        "fill": "#3A9AB244",
        "stroke": "#3A9AB2",
    },
    "Crypt": {
        "edges": [
            [0.15, 0.48],
            [0.6, 0.68],
            [0.8, 0.68],
            [0.8, 0.25],
            [0.2, 0],
            [0.15, 0],
        ],
        "label_position": {"x": 0.16, "y": 0.05},
        "fill": "#F11B0044",
        "stroke": "#F11B00",
    },
    "Muscularis": {
        "edges": [[0.22, 0], [0.8, 0.23], [6, 0.23], [6, 0], [0.22, 0]],
        "label_position": {"x": 0.6, "y": 0.05},
        "fill": "#BDC88155",
        "stroke": "#BDC881",
    },
}

In [None]:
# Custom biexponential transformation.
def transformation(x, a=0.1, b=0.1, c=0.5, d=2.5, f=4, w=1):
    x = np.array(x)
    return a * np.exp(b * ((x - w))) - c * np.exp(-d * (x - w)) + f


def classify_cells(adata, gates, transformation=transformation):
    """
    Classify cells based on the IMAP gates.

    Parameters:
    - adata (anndata): The anndata object containing the cells to classify.
    - gates (dict): A dictionary containing the gates to classify the cells with.
    - transformation (function): A function to transform the x values of the gates.

    Returns:
    - result (geopandas dataframe): A geopandas dataframe containing the classified cells.
    """
    from shapely.geometry import Point
    from shapely.geometry.polygon import Polygon
    import geopandas as gpd

    adata.obs["epithelial_distance_transformed"] = transformation(
        adata.obs["epithelial_distance"]
    )
    adata.obs["gate"] = False

    print("Creating polygons")
    polygons = {}
    for gate in gates:
        # Apply transformation to x values
        points = [
            [transformation(element[0])] + element[1:]
            for element in gates[gate]["edges"]
        ]
        polygons[gate] = Polygon(points)
    polygons = gpd.GeoSeries(polygons)
    gpd_poly = gpd.GeoDataFrame({"gates": polygons}, geometry="gates")

    print("Creating cells")
    cells = gpd.GeoSeries.from_xy(
        adata.obs["epithelial_distance_transformed"], adata.obs["crypt_villi_axis"]
    )
    gpd_cells = gpd.GeoDataFrame({"cells": cells}, geometry="cells")

    print("Joining cells and polygons")
    result = gpd.sjoin(
        gpd_cells,
        gpd_poly,
        how="left",
    )
    return result


classification = classify_cells(adata, gates)
classification

In [None]:
adata.obs["gate"] = classification["index_right"]

In [None]:
def make_name(gate, cell):
    if cell == "Cd8_T-Cell_P14":
        if gate == "Top":
            return "P14 top"
        elif gate == "Crypt":
            return "P14 crypt"
        elif gate == "Muscularis":
            return "P14 muscularis"
        else:
            return "P14 undeterminded"
    else:
        return cell


adata.obs["Subtype_gate"] = [
    make_name(gate, cell) for gate, cell in zip(adata.obs["gate"], adata.obs["Subtype"])
]
adata = adata[~(adata.obs["Subtype_gate"] == "P14 undeterminded")]
adata.obs["Subtype_gate"] = adata.obs["Subtype_gate"].astype("category")

Heatmap for locations

In [None]:
# heatmap
m = mean_interaction_matrix(adata, batches=all_batches, cluster_key="Subtype_gate")

In [None]:
population_of_intertest = ["P14 top", "P14 crypt", "P14 muscularis"]
positions = [
    list(adata.obs["Subtype_gate"].cat.categories).index(element)
    for element in population_of_intertest
]
m_p14 = m[positions]

In [None]:
# Scale each row from min to max
scaler = MinMaxScaler()
normalized_array = scaler.fit_transform(m_p14.T).T

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 1))
sns.heatmap(
    normalized_array,
    xticklabels=adata.obs["Subtype_gate"].cat.categories,
    yticklabels=population_of_intertest,
    ax=ax,
    cmap=colormap,
    linecolor="white",
    linewidths=0.5,
)