# Setup

## Imports & Constants

May need to change things here if, for instance, default directories for data storage or metadata export column names change.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import re
import functools
import matplotlib.pyplot as plt
import seaborn as sb
import anndata
import scanpy as sc
import spatialdata
import spatialdata_io as sdio
import numpy as np
import pandas as pd
import corescpy as cr

ddm = "/mnt/cho_lab" if os.path.exists("/mnt/cho_lab") else "/mnt"  # Spark?
ddl = f"{ddm}/disk2/{os.getlogin()}/data/shared-xenium-library" if (
    "cho" in ddm) else os.path.join(ddu, "shared-xenium-library")
file_mdf = os.path.join(ddl, "samples.csv")  # metadata
path_marker_maps = os.path.join(
    os.path.expanduser("~"), "corescpy/examples/markers_lineages.csv")
col_inflamed, col_stricture = "Inflamed", "Stricture"  # in metadata file
key_uninfl, key_infl, key_stric = "Uninflamed", "Inflamed", "Stricture"
col_sample_id_o = "Sample ID"  # in metadata file
col_condition = "Condition"  # constructed from col_inflamed & col_stricture
col_sample_id, col_condition, col_subject = "Sample", "Condition", "Patient"

Downloading data from `https://omnipathdb.org/queries/enzsub?format=json`
Downloading data from `https://omnipathdb.org/queries/interactions?format=json`
Downloading data from `https://omnipathdb.org/queries/complexes?format=json`
Downloading data from `https://omnipathdb.org/queries/annotations?format=json`
Downloading data from `https://omnipathdb.org/queries/intercell?format=json`
Downloading data from `https://omnipathdb.org/about?format=text`


## Options & Data

You may make changes here because these are variable options, depending on what you want to analyze.

In [4]:
# Settings/Options
sc.settings.set_figure_params(dpi_save=200, dpi=200)
panel = "TUQ97N"
out_dir = str(ddl, f"outputs/{panel}/nebraska")
samps = ["Uninflamed-50006B", "Inflamed-50006A", "Stricture-50006C",
         "Uninflamed-50217B", "Inflamed-50217A", "Stricture-50217C",
         "Uninflamed-50336C", "Inflamed-50336B", "Stricture-50336A",
         "Uninflamed-50452A", "Inflamed-50452B", "Stricture-50452C"]
col_cell_type = "bucket_res1pt5_dist0_npc30"
hue_order = ["Uninflamed", "Inflamed", "Stricture"]
palette = ["blue", "red", "yellow"]
layer = "counts"
threshold = 0

# Quantification Files/Options
c_m = "bucket"
c_ann = ["bucket", "bin", "annotation"]
directory = "/home/elizabeth/elizabeth/projects/senescence/analysis"
label = "snc" if "senescence" in directory else "csf"
fff_tx = os.path.join(directory, f"quantification_{label}_tx_cts")
fff_ct = os.path.join(directory, f"quantification_ncells_{label}")

# Gene Categories
assign = pd.read_csv(path_marker_maps).set_index("gene").rename_axis("Gene")
snc_genes = {
    "Cell Cycle Control": ["CDKN2A", "CDKN1A", "PLAUR", "TP53"],
    "SASP": ["IL6", "IL6ST", "IL1A", "CXCL8", "CCL2", "CEBPB",
             "NFKB1", "IGFBP7", "TGFB1"],
    # "DNA Repair": ["ERCC1", "ERCC4"],
    "Fibrosis": ["PLAUR", "IL1", "IL4", "IL6", "IL11", "IL12", "IL13", "IL17",
                 "IL22", "IL23", "IL33", "IL34", "IL36", "TL1A", "TGFβ",
                 "CXCR4", "CCL8", "CCL11", "CHMP1A", "TBX3",
                 "RNF168", "SUCNR1"]}

# Data
adatas = {}
for s in samps:
    adatas[s] = sc.read(os.path.join(out_dir, s + ".h5ad"))

# Constructed Variables
marker_genes_dict = dict(assign["Bucket"].reset_index().groupby(
    "Bucket").apply(lambda x: list(pd.unique(list(set(
        x.Gene).intersection(adatas[s].var_names))))))  # cell type markers
snc_genes_dict = dict(pd.Series(snc_genes).apply(
    lambda x: list(set(x).intersection(adatas[s].var_names))))
hue = dict(hue=col_condition, hue_order=hue_order, palette=palette)
c_l = "leiden_" + "_".join(col_cell_type.split("_")[1:])  # Leiden version

# Annotations
fmrs = {}
for x in samps:
    fann = os.path.join(out_dir, "annotation_dictionaries", str(
        f"{x}___{c_l}_dictionary.xlsx"))
    if os.path.exists(fann):
        fmr = pd.read_excel(fann).astype(str).replace("nan", np.nan)
        fmrs[x] = fmr.set_index(fmr.columns[0]).rename_axis(c_l)
fmrs = pd.concat(fmrs, names=["Sample", c_l])

## Samples

In [15]:
# Load Metadata
m_d = (pd.read_excel if file_mdf[-4:] == "xlsx" else pd.read_csv)(
    file_mdf, dtype={"Slide ID": str}).rename({
        "Name": col_subject, "Inflammation": col_inflamed}, axis=1)
m_d.loc[:, col_condition] = m_d.apply(lambda x: "Stricture" if x[
    col_stricture].lower() in ["stricture", "yes"] else x[
        col_inflamed].capitalize(), axis=1)  # inflamation/stricture condition
m_d.loc[:, col_sample_id] = m_d[[col_condition, col_sample_id_o]].apply(
    "-".join, axis=1)
m_d = m_d.set_index(col_sample_id).loc[samps]
print(m_d[[col_subject, col_condition]].reset_index(0)[
    col_condition].value_counts())
m_d.reset_index().set_index([col_subject, col_condition]).sort_index()[
    ["Age", "Sex", "Race", "Hispanic", "Diagnosis", "Disease_Status"]]

# Add Metadata & Annotations to AnnData Objects & Concatenate
for s in adatas:
    adatas[s].obs.loc[:, col_condition] = m_d.loc[s][col_condition]
    adatas[s].obs.loc[:, col_subject] = m_d.loc[s][col_subject]
    adatas[s].obs.loc[:, col_cell_type] = adatas[s].obs[c_l].astype(
        str).replace(dict(fmrs.loc[s][c_m]))  # annotations
adata = anndata.AnnData.concatenate(
    *[adatas[x] for x in adatas], join="outer", batch_key=col_sample_id,
    batch_categories=list(adatas.keys()), index_unique="-", fill_value=None,
    uns_merge="same")  # concatenate adata

Condition
Uninflamed    4
Inflamed      4
Stricture     4
Name: count, dtype: int64


# UMAPs

In [None]:
fig, axes = plt.subplots(len(m_d[col_subject].unique()),
                         len(m_d[col_condition].unique()),
                         dpi=200, figsize=(60, 60))
for r, s in enumerate(m_d[col_subject].unique()):
    for c, g in enumerate(m_d[col_condition].unique()):
        lib = m_d.reset_index().set_index([col_subject, col_condition]).loc[
            s].loc[g][col_sample_id]
        sc.pl.umap(adatas[lib], color=col_cell_type, title="",
                   frameon=False, show=False, ax=axes[r, c],
                   legend_loc="on data")
        axes.flatten()[i].set_title(x)
        axes[r, c].set_title(f"{lib} (Age {m_d['Age'].loc[lib]})")

# Spatial Plots

In [None]:
for s in selves:
    m_d = metadata.loc[s._library_id]
    fig = s.plot_spatial(color="bucket_res1pt5_dist0_npc30")
    fig.set_title(f"{s._library_id} (Age {m_d['Age']})")
    fig.fig.set_dpi(200)

# Quantifications

## Native Plotting

### Facet=Cell Type; Rows=Condition

In [None]:
for j, k in enumerate(["var", "group"]):
    fig, axes = plt.subplots(*cr.pl.square_grid(len(in_all_cond)),
                             figsize=(40, 20))
    for i, x in enumerate(in_all_cond):
        sc.pl.dotplot(
            adata[adata.obs[col_cell_type] == x], snc_genes_dict, layer=layer,
            groupby=col_condition, title=x, ax=axes.flatten()[i],
            use_raw=False, show=False, expression_cutoff=threshold,
            figsize=(40, 20), standard_scale=k)
        # d_p.add_totals().style(dot_edge_color="black", dot_edge_lw=0.5).show()
        # axes.flatten()[i].set_title(x)
    fig.suptitle(f"GEX (Normalized by {['Gene', 'Condition'][j]})")
    plt.subplots_adjust(hspace=1, wspace=1, top=0.85)
    fig.show()

### Facet=Condition; Rows=Cell Type

In [None]:
for j, k in enumerate(["var", "group"]):
    fig, axes = plt.subplots(*cr.pl.square_grid(
        len(adata.obs[col_condition].unique())), figsize=(40, 20))
    for i, x in enumerate(adata.obs[col_condition].unique()):
        sc.pl.dotplot(
            adata[adata.obs[col_condition] == x], snc_genes_dict, layer=layer,
            groupby=col_cell_type, ax=axes.flatten()[i],
            use_raw=False, show=False, expression_cutoff=threshold,
            # title=x,
            figsize=(40, 20), standard_scale=k)
        # d_p.add_totals().style(dot_edge_color="black", dot_edge_lw=0.5).show()
        axes.flatten()[i].set_title(x)
    fig.suptitle(f"GEX (Normalized by {['Gene', 'Cluster'][j]})")
    plt.subplots_adjust(hspace=1, wspace=0.5, top=0.85)
    fig.show()

### By Sample

In [None]:
for j, k in enumerate(["var", "group"]):
    fig, axes = plt.subplots(len(m_d[col_subject].unique()),
                             len(m_d[col_condition].unique()),
                             dpi=200, figsize=(60, 60))
    for r, s in enumerate(m_d[col_subject].unique()):
        for c, g in enumerate(m_d[col_condition].unique()):
            lib = m_d.reset_index().set_index([
                col_subject, col_condition]).loc[s].loc[g][col_sample_id]
            sc.pl.dotplot(
                adatas[lib], snc_genes_dict, layer=layer,
                groupby=col_cell_type, ax=axes[r, c],
                use_raw=False, show=False, expression_cutoff=threshold,
                figsize=(40, 20), standard_scale=k)
            axes[r, c].set_title(f"{lib} (Age {m_d['Age'].loc[lib]})")
    fig.suptitle(f"GEX (Normalized by {['Gene', 'Cluster'][j]})")
    plt.subplots_adjust(hspace=1, wspace=0.5, top=0.85)
    fig.show()

### Transcripts

In [None]:
layer = "counts"
threshold = 0

in_all_cond = [i for i in adata.obs[col_cell_type].unique() if all(
    (any(adata[adata.obs[col_cell_type] == i].obs[col_condition] == x)
     for x in hue_order))]  # clusters present in all condition-samples
percs_txs = {}
for c in snc_genes:
    stx = {}
    for g in snc_genes[c]:
        stx[g] = pd.concat([pd.Series([100 * np.sum(adatas[a][adatas[a].obs[
            col_cell_type] == x][:, g].X) / adatas[a][adatas[a].obs[
                col_cell_type] == x].obs.shape[0] for x in in_all_cond],
                                      index=all_cond) for a in adatas],
                           keys=adatas.keys(), names=["Sample"])
    percs_txs[c] = pd.concat(stx, names=["Gene"])
percs_txs = pd.concat(percs_txs, names=["Category"])

## Saved Quantifications

### Load

In [None]:
# Transcript Ratios
txs = pd.read_csv(f"{fff_tx}.csv")
tx_cl = pd.read_csv(fff_tx + "_by_cluster.csv", index_col=np.arange(
    len(txs.index.names))).reset_index().astype(
        {c_l: str}).set_index(tx_cl.index.names)
tx_cl = tx_cl.reset_index().astype({c_l: str}).set_index(
    tx_cl.index.names).drop(set(c_ann).intersection(
        tx_cl.columns), axis=1).join(fmrs[c_ann], on=["Sample", c_l])
tx_cl = tx_cl.reset_index().drop_duplicates().astype(
    {c_l: int}).set_index(tx_cl.index.names).sort_index().reset_index(
        ).astype({c_l: "string"}).set_index(tx_cl.index.names)

# Gene+ Cell Ratios
cts = pd.read_csv(fff_ct + ".csv", index_col=0)
cts_cl = pd.read_csv(fff_ct + "_by_cluster.csv", header=[0, 1],
                     index_col=[0, 1, 2])
cts_cl_s = pd.concat([fmrs[c_ann].reset_index().astype("string").set_index(
    fmrs.index.names)], axis=1, keys=["Annotation"])
cts_cl_s = cts_cl_s.reset_index().drop_duplicates().astype(str).set_index(
    cts_cl_s.index.names)
cts_cl = cts_cl[cts_cl.columns.difference(cts_cl_s.columns)].reset_index(
    ).drop_duplicates().astype(str).set_index(
        cts_cl.index.names).join(cts_cl_s)

### Plot

#### Transcripts

In [None]:
y_label = "Gene Transcript Counts: Percent of Total Counts"
dff = tx_cl.set_index(c_ann, append=True).apply(lambda x: tx_cl[
    x.name] / tx_cl["total_counts"]).drop(
        "total_counts", axis=1).rename_axis("Gene", axis=1).stack()
dff = dff * 100
dff = dff.to_frame(y_label).join(
    metadata_o[[col_sample_id, col_subject, col_condition]].set_index(
        col_sample_id), how="left").join(fmrs[c_ann]).reset_index(
            ).drop_duplicates()
if dff[y_label].max() > 1:
    raise ValueError(f"Percentages > 1: {dff[dff[y_label] > 1]}")
for x in [None, col_subject]:  # aggregate across subjects, or facet
    fig = sb.catplot(dff.dropna(subset=[c_m]), x=c_m, y=y_label,
                     kind="box" if x is None else "bar",
                     sharex=False, sharey=False, row=x, col="Gene",
                     margin_titles=True, **hue)
    fig.set_xticklabels(rotation=90)
    fig.set_titles(col_template="{col_name}", row_template="{row_name}")
    fig.fig.set_size_inches(60, 25)
    plt.subplots_adjust(hspace=0.8)
    fig.fig.tight_layout()
    plt.show()

#### Cells

In [None]:
vvr = cts_cl.copy()["Percent"].rename_axis("Label", axis=1).dropna(
    how="all").stack().to_frame("Percent").astype(float).join(fmrs[[c_m]])
vvr = vvr.join(vvr.groupby("Sample").apply(lambda x: x.name.split("-")[
    0]).to_frame(col_condition)).reset_index().drop_duplicates()
vvr = vvr[vvr.apply(lambda x: "|" not in x[c_m], axis=1)]  # no mixes
for x in vvr.Threshold.unique():
    fig = sb.catplot(vvr[vvr.Threshold == x], x=c_m, y="Percent",
                     sharey=False, kind="box", margin_titles=True,
                     col="Label", col_wrap=2, sharex=False, **hue)
    fig.set_titles(col_template="{col_name}", row_template="{row_name}")
    fig.set_xticklabels(rotation=90)
    fig.fig.set_size_inches(35, 25)
    fig.fig.suptitle(f"Threshold = {x}")
    fig.fig.set_dpi(100)
    fig.fig.tight_layout()

# More Analysis

In [None]:
for s in selves:
    s.run_composition_analysis(col_condition=col_condition)