In [None]:
%pip install openpyxl

In [None]:
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import scipy
import seaborn as sns
import warnings
import matplotlib.colors as clr

warnings.filterwarnings("ignore")

import data

In [None]:
adata = sc.read_h5ad("../data/adata/timecourse.h5ad")

In [None]:
# Normalize
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
# Read in TRM signature from Milner et al, 2017
trm_signature = pd.read_csv(
    "../data/signatures/Core Trm signature_Milner et al Nature 2017_vIL.txt",
    header=None,
)[0].to_list()

Helper functions

In [None]:
def df_correlation(adata, var):
    """
    Calculate the correlation between an observation and gene expression across the cells in an adata
    
    Parameters:
    - adata (anndata): The anndata object containing the gene expression
    - var (str): The observation name to correlate with

    Return:
    - results (pd.DataFrame): A dataframe containing the correlation coefficient and p-value for each gene
    """

    mask = ~np.isnan(adata.obs[var])
    # print("Remove na values", np.sum(~mask))
    rho, p = scipy.stats.spearmanr(adata.obs[var][mask], adata.X[mask])
    results = pd.DataFrame(
        {"gene": adata.var_names, "rho": rho[0, 1:], "p": p[0, 1:], "var": var}
    )
    return results

def filter_adata_expressed_in_n_cells(adata, percent=0.05):
    """
    Filter the adata to only include genes expressed in more than a certain percentage of cells

    Parameters:
    - adata (anndata): The anndata object containing the gene expression
    - percent (float): The percentage of cells a gene must be expressed in to be included

    Return:
    - adata (anndata): The filtered anndata object
    """
    adata = adata.copy()
    adata.layers["bin"] = adata.X > 0
    gene_expressed_in_percent_cells = np.mean(adata.layers["bin"], axis=0)
    keep = gene_expressed_in_percent_cells > percent
    adata = adata[:, keep]
    return adata

def correlation_violin(
    adata,
    title,
    vars=["crypt_villi_axis", "epithelial_distance_clipped", "predicted_longitudinal"],
    highlight=trm_signature,
    rho_cutoff=0.2,
):
    """
    Create a violin plot of the correlation between gene expression and observations

    Parameters:
    - adata (anndata): The anndata object containing the gene expression
    - title (str): The title of the plot
    - vars (list): The observations to correlate with
    - highlight (list): The genes to highlight in the plot
    - rho_cutoff (float): The cutoff for the correlation coefficient to be considered significant

    Return:
    - fig, ax: The matplotlib figure and axis
    """
    df = pd.concat(
        [df_correlation(adata, var) for var in vars],
        ignore_index=True,
    )

    df["color"] = (
        df["rho"]
        .apply(
            lambda x: (
                "positive correlation"
                if x > rho_cutoff
                else "negative correlation" if x < rho_cutoff * -1 else "no correlation"
            )
        )
        .astype("category")
        .cat.reorder_categories(
            ["positive correlation", "no correlation", "negative correlation"],
            ordered=True,
        )
    )

    df["size"] = df["rho"].apply(
        lambda x: 2 if x > rho_cutoff else 2 if x < rho_cutoff * -1 else 2
    )

    fig, ax = plt.subplots()
    ax.axhline(y=rho_cutoff, linestyle=":", color="#AAAAAA", linewidth=0.5)
    ax.axhline(y=rho_cutoff * -1, linestyle=":", color="#AAAAAA", linewidth=0.5)
    # sns.violinplot(df, x="var", y="rho", scale="width", ax=ax, order=vars)
    sns.swarmplot(
        df,
        x="var",
        y="rho",
        ax=ax,
        order=vars,
        s=2,
        hue="color",
        palette=["#E15759", "#EDC94884", "#4E79A7"],
    )
    ax.set_xlabel("")
    ax.legend(bbox_to_anchor=(1.02, 0.55), loc="upper left", borderaxespad=0)

    rho_max = np.max(df.rho)
    rho_min = np.min(df.rho)

    for i, var in enumerate(vars):
        bigger = np.sum((df["var"] == var) & (df["rho"] > rho_cutoff))
        lower = np.sum((df["var"] == var) & (df["rho"] < rho_cutoff * -1))
        ax.text(i, 0.9 * rho_max, f"n = {bigger}", ha="center")
        ax.text(i, 0.9 * rho_min, f"n = {lower}", ha="center")

    filtered_points = df[
        (np.abs(df["rho"]) > rho_cutoff) & (df["gene"].isin(highlight))
    ]

    x_offset = 0.3
    y_offset = 0.04
    for i, var in enumerate(vars):
        # Annotate points with positive correlation
        filtered_points_gt = filtered_points[
            (filtered_points["var"] == var) & (filtered_points["rho"] > 0)
        ]
        filtered_points_gt = filtered_points_gt.sort_values(by=["rho"])
        x = 0
        for _, point in filtered_points_gt.iterrows():

            if point["rho"] - x < y_offset:
                x = x + y_offset
            else:
                x = point["rho"]
            ax.plot(
                (i + 0.1, i + 0.15, i + x_offset),
                (point["rho"], point["rho"], x),
                color="black",
                linewidth=0.5,
            )
            ax.text(
                i + x_offset + 0.02,
                x,
                f'{point["gene"]}',
                fontsize=10,
                color="black",
                va="center_baseline",
            )

        # Annotate points with negative correlation
        filtered_points_lt = filtered_points[
            (filtered_points["var"] == var) & (filtered_points["rho"] < 0)
        ]
        filtered_points_lt = filtered_points_lt.sort_values(by=["rho"], ascending=False)
        x = 0
        for _, point in filtered_points_lt.iterrows():

            if point["rho"] - x > -y_offset:
                x = x - y_offset
            else:
                x = point["rho"]
            ax.plot(
                (i + 0.1, i + 0.15, i + x_offset),
                (point["rho"], point["rho"], x),
                color="black",
                linewidth=0.5,
            )
            ax.text(
                i + x_offset + 0.02,
                x,
                f'{point["gene"]}',
                fontsize=10,
                color="black",
                va="center_baseline",
            )

    ax.set_title(title)
    ax.set_ylabel(r"Spearman correlation: $\rho$")

    return (fig, ax)

Figure 2e

In [None]:
fig, ax = correlation_violin(
    adata=filter_adata_expressed_in_n_cells(
        adata[adata.obs["Subtype"] == "Cd8_T-Cell_P14"]
    ),
    title="Spearman correlation of all P14 cells (Subtype)",
    rho_cutoff=0.05,
    highlight=trm_signature
    + [
        "Itgae",
        "Ifng",
        "Il18r1",
        "Il7r",
        "Klrg1",
        "Klf2",
        "Ldlr",
        "Slamf6",
        "Tgfbr2",
        "Tcf7",
    ],
)

Figure 2f

In [None]:
def correlation_heatmap(
    adata,
    obs,
    types=None,
    vars=["crypt_villi_axis", "epithelial_distance_clipped", "predicted_longitudinal"],
    rho_cutoff=0.2,
    file=None,
):
    """
    Create a heatmap of the correlation between gene expression and observations

    Parameters:
    - adata (anndata): The anndata object containing the gene expression
    - obs (str): The observation to correlate with
    - types (list): The celltypes to include in the heatmap
    - vars (list): The observations to correlate with
    - rho_cutoff (float): The cutoff for the correlation coefficient to be considered significant
    - file (str): The file to save the correlation data to
    """
    print("Making heatmap for ", obs)
    adata = adata[~adata.obs[obs].isnull()]
    if types is None:
        types = adata.obs[obs].unique()
    result = pd.DataFrame(columns=vars)
    for t in types:
        print("... calculating correlation for ", t)
        a = adata[adata.obs[obs] == t]
        a = filter_adata_expressed_in_n_cells(a)

        df_xlsx = []
        row = []
        for var in vars:
            df = df_correlation(a, var)
            df_xlsx.append(df)

            n_genes = len(df)
            df = df[df["p"] < rho_cutoff]
            if len(df) == 0:
                row.append(0)
            else:
                correlated_genes = np.sum(np.abs(df.rho) > rho_cutoff)
                percent_correlated_genes = correlated_genes / n_genes * 100
                row.append(percent_correlated_genes)
        result.loc[t] = row

        df_xlsx = pd.concat(
            df_xlsx,
            ignore_index=True,
        )
        if file:
            df_xlsx.to_excel(file, sheet_name=t)

    cg = sns.clustermap(
        result.transpose(),
        annot=False,
        fmt=".1f",
        linewidth=0.5,
        cmap=clr.LinearSegmentedColormap.from_list(
            "ZissouBlues",
            ["#FFFFFF", "#edcfc7", "#d8a191", "#bf735f", "#a34630", "#830f00"],
        ),
        dendrogram_ratio=(1 / (len(types) + 1), 0.2),
        # cbar_pos=None,
        figsize=(3 + len(types) * 0.5, 4),
    )
    plt.setp(cg.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)
    cg.fig.suptitle(
        f"Percent of significantly correlated genes\nrho={rho_cutoff}, annotation={obs}",
        y=1.1,
    )
    return cg

In [None]:
types = [
    "Goblet",
    "Enterocyte_1",
    "Monocyte",
    "Cd8_T-Cell_P14",
    "Cd8_T-Cell_aa+",
    "Cd8_T-Cell_ab+",
    "MAIT",
    "T-Cell gd",
    "Enterocyte_2",
    "Macrophage",
    "ILC",
    "Cd4_T-Cell",
    "B-Cell",
    "Enteroendocrine",
    "cDC1",
    "Early_Enterocyte",
    "Enterocyte_3",
    "DC2",
    "Lymphatic",
    "Tuft",
    "NK-Cell",
    "Fibroblast",
    "Transit_Amplifying",
    "Fibroblast_Pdgfrb+ ",
    "Vascular Endothelial",
    "ISC",
    "Paneth",
    "Neuron",
    "Fibroblast_Ncam1",
    "Fibroblast_Pdgfra+",
]

# Save the correlation results as an excelt.
# with pd.ExcelWriter("tables/correlation_subtype.xlsx") as writer:
#     fig = correlation_heatmap(adata, obs="Subtype", rho_cutoff=0.05, types=types, file=writer)
fig = correlation_heatmap(adata, obs="Subtype", rho_cutoff=0.05, types=types, file=None)