# Setup

## Imports & Settings

In [1]:
%load_ext autoreload
%autoreload 2

import os
import re
import itertools
import matplotlib.pyplot as plt
import seaborn as sb
import scanpy as sc
import pandas as pd
import numpy as np
import corescpy as cr

# Computing Resources
gpu = False
sc.settings.n_jobs = 4
sc.settings.max_memory = 200

# Display
pd.options.display.max_colwidth = 1000
pd.options.display.max_columns = 100
pd.options.display.max_rows = 500
sc.settings.set_figure_params(dpi=100, frameon=False, figsize=(30, 30))
plt.rcParams["figure.figsize"] = [15, 15]

# Panel Information
old_seg = ["50452A", "50452B", "50452C",  # CHO-001
           "50618B5", "50564A4",  # CHO-002
           "49377A2",  # CHO-003
           "49464A4",  # ?
           "49696A4", "49559A5",  # CHO-004
           "50115A2", "50007B2",  # CHO-005
           "49471A4", "50445A3",  # CHO-006
           ]  # old segmentation = old processing arguments

# Panel & Column Names (from Metadata & To Be Created)
panel = "TUQ97N"
col_sample_id_o, col_sample_id = "Sample ID", "Sample"  # in metadata, new
col_subject = "Patient"  # in metadata file
col_inflamed, col_stricture = "Inflamed", "Stricture"  # in metadata file
col_condition = "Condition"  # constructed from col_inflamed & col_stricture
col_fff = "file_path"  # column in metadata in which to store data file path
col_tangram = "tangram_prediction"  # for future Tangram imputation annotation
col_segment = "segmentation"
key_uninfl, key_infl, key_stric = "Uninflamed", "Inflamed", "Stricture"

Downloading data from `https://omnipathdb.org/queries/enzsub?format=json`
Downloading data from `https://omnipathdb.org/queries/interactions?format=json`
Downloading data from `https://omnipathdb.org/queries/complexes?format=json`
Downloading data from `https://omnipathdb.org/queries/annotations?format=json`
Downloading data from `https://omnipathdb.org/queries/intercell?format=json`
Downloading data from `https://omnipathdb.org/about?format=text`


## Samples

In [None]:
ddm = "/mnt/cho_lab" if os.path.exists("/mnt/cho_lab") else "/mnt"  # Spark?
ddl = f"{ddm}/disk2/{os.getlogin()}/data/shared-xenium-library" if (
    "cho" in ddm) else os.path.join(ddu, "shared-xenium-library")
file_mdf = os.path.join(ddl, "samples.csv")  # metadata
m_d = (pd.read_excel if file_mdf[-4:] == "xlsx" else pd.read_csv)(
    file_mdf, dtype={"Slide ID": str}).rename({
        "Name": col_subject, "Inflammation": col_inflamed}, axis=1)
m_d.loc[:, col_segment] = "new"
m_d.loc[m_d[col_sample_id_o].isin(old_seg), col_segment] = "old"
m_d.loc[:, col_condition] = m_d.apply(lambda x: "Stricture" if x[
    col_stricture].lower() in ["stricture", "yes"] else x[
        col_inflamed].capitalize(), axis=1)  # inflamation/stricture condition
m_d.loc[:, col_sample_id] = m_d[[col_condition, col_sample_id_o]].apply(
    "-".join, axis=1)
m_d = m_d.set_index(col_sample_id)
print(m_d[[col_subject, col_condition]].reset_index(0)[
    col_condition].value_counts())
samps_paired = m_d.groupby("Patient").apply(
    lambda x: list(x.reset_index()[col_sample_id].sort_values()) if all((
        i in list(x.reset_index()[col_condition]) for i in [
            key_uninfl, key_infl, key_stric])) else np.nan).dropna(
                ).explode()
# print(list(samps_paired.sort_index()))
# m_d.reset_index().set_index(col_sample_id).loc[samps_paired.to_list()]
m_d.reset_index().set_index([col_subject, col_condition]).sort_index()

## Options & Data

In [None]:
# Directories & Metadata
load, reannotate = True, True
# run = "CHO-011"
# samples = "all"
run = None  # just look for samples in all runs
samples = ["50452A", "50452B", "50006A", "50006B",
           "50217A", "50217B", "50336B", "50336C"]  # paired (un)inflamed
# samples = ["50006B", "50006A",  "50006C",
#            "50217B", "50217A", "50217C",
#            "50564A4",
#            "50452A", "50452B", "50452C",
#            "50336C", "50336B",  "50336A"]  # all
# samples = ["50006C", "50217C", "50452C", "50336A"]  # paired strictures
dir_ax = "/home/elizabeth/elizabeth/projects/senescence/analysis"


# Optionally, Define Manual Annotation Versions
# should be stored in ("<out_dir>/annotations_dictionaries")
# in format <selves[i]._library_id>___leiden_<man_anns[i]>_dictionary.xlsx
# with first column = leiden cluster and second column = annotation
man_anns = True  # load manual annotations according to clustering kws
# man_anns = ["res0pt5_dist0pt5_npc30", "res0pt75_dist0pt3_npc30",
#             "res1pt5_dist0_npc30"]  # choose manual annotations to load
# man_anns = None  # do not load manual annotations

# Main Directories
# Replace manually or mirror my file/directory tree in your home (`ddu`)
ddu = os.path.expanduser("~")
ddm = "/mnt/cho_lab" if os.path.exists("/mnt/cho_lab") else "/mnt"  # Spark?
ddl = f"{ddm}/disk2/{os.getlogin()}/data/shared-xenium-library" if (
    "cho" in ddm) else os.path.join(ddu, "shared-xenium-library")
ddx = f"{ddm}/bbdata2"  # mounted drive Xenium folder
out_dir = os.path.join(ddl, "outputs", "TUQ97N", "nebraska")  # None = no save
d_path = os.path.join(ddm, "disk2" if "cho" in ddm else "",
                      os.getlogin(), "data")  # other, e.g., Tangram data
anf = pd.read_csv(os.path.join(ddu, "corescpy/examples/markers_lineages.csv"))
file_mdf = os.path.join(ddl, "samples.csv")  # metadata

# Annotation & Tangram Imputation
col_assignment = "Bin"  # which column from annotation file to use
# col_cell_type_sc, file_sc = "ClusterAnnotation", str(
#     f"{d_path}/2023-05-12_CombinedCD-v2_ileal_new.h5ad")
col_cell_type_sc, file_sc = "cell_type", f"{d_path}/elmentaite_ileal.h5ad"
# file_sc = None  # to skip Tangram imputation/label transfer

# Processing & Clustering Options
kws_cluster = dict(kws_umap=dict(method="rapids" if gpu else "umap"),
                   genes_subset=list(anf.iloc[:, 0]),  # use only markers
                   use_gpu=gpu, use_highly_variable=False)
kws_clustering, col_assignment = {}, []
for i in zip([0.5, 0.75, 1.5], [0.5, 0.3, 0], [30, 30, 30]):
    kws = {**kws_cluster, "resolution": i[0], "kws_umap": {
        **kws_cluster["kws_umap"], "min_dist": i[1]}, "n_comps": i[2]}
    suff = str(f"res{re.sub('[.]', 'pt', str(kws['resolution']))}_dist"
               f"{re.sub('[.]', 'pt', str(kws['kws_umap']['min_dist']))}"
               f"_npc{kws['n_comps']}")  # file path suffix
    kws_clustering.update({suff: kws})
    col_assignment += ["group" if kws["resolution"] >= 0.7 else "Bucket"]
if man_anns is True:
    man_anns = list(kws_clustering.keys())
col_cell_type = list(kws_clustering.keys())[-1] if (
    man_anns is None) else f"manual_{man_anns[-1]}"  # default cell labels

# After this point, no more options to specify
# Just code to infer the data file path from your specifications
# and construct argument dictionaries and manipulate metadata and such.

# Read Metadata & Other Information
metadata = (pd.read_excel if file_mdf[-4:] == "xlsx" else pd.read_csv)(
    file_mdf, dtype={"Slide ID": str}).rename({
        "Name": col_subject, "Inflammation": col_inflamed}, axis=1)
metadata.loc[:, col_segment] = "new"
metadata.loc[metadata[col_sample_id_o].isin(old_seg), col_segment] = "old"

# Revise Metadata & Construct Variables from Options
metadata.loc[:, col_condition] = metadata.apply(lambda x: key_stric if x[
    col_stricture].lower() in ["stricture", "yes"] else x[
        col_inflamed].capitalize(), axis=1)  # inflamation/stricture condition
metadata.loc[:, col_sample_id] = metadata[[col_condition, col_sample_id_o]
                                          ].apply("-".join, axis=1)
metadata_o = metadata.copy()
if samples not in ["all", None]:  # subset by sample ID?
    metadata = metadata.set_index(col_sample_id_o).loc[samples].reset_index()
metadata = metadata.set_index(col_sample_id)
fff = np.array(cr.pp.construct_file(run=run, directory=ddx, panel_id=panel))
bff = np.array([os.path.basename(i) for i in fff])  # base path names
samps = np.array([i.split("__")[2].split("-")[0] for i in fff])
for x in metadata[col_sample_id_o]:
    m_f = metadata[metadata[col_sample_id_o] == x][
        "out_file"].iloc[0]  # ...use to find unconventionally-named files
    locx = np.where(samps == x)[0] if pd.isnull(
        m_f) else np.where(bff == m_f)[0]
    metadata.loc[metadata[col_sample_id_o] == x, col_fff] = fff[locx[0]] if (
        len(locx) > 0) else np.nan  # assign output file to metadata row
metadata = metadata.dropna(subset=[col_fff]).drop_duplicates()

# Annotation File
assign = anf.dropna(subset=col_assignment).set_index(
    "gene").rename_axis("Gene")  # markers
# assign = assign[~assign.Quality.isin([-1])]  # drop low-quality markers

# Print Metadata & Make Output Directory (If Not Present)
print(metadata)
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# Load Data
kws_init = dict(col_sample_id=col_sample_id, col_subject=col_subject,
                col_cell_type=col_cell_type)  # object creation arguments
selves = [None] * metadata.shape[0]  # to hold different samples
for i, x in enumerate(metadata.index.values):
    selves[i] = cr.Spatial(metadata.loc[x][col_fff], library_id=x, **kws_init)
    for j in metadata:  # iterate metadata columns
        selves[i].rna.obs.loc[:, j] = str(metadata.loc[x][j])  # add to object
    selves[i].rna.obs.loc[:, "out_file"] = os.path.join(
        out_dir, selves[i]._library_id)  # output path (to save object)
    if load is True:
        if os.path.exists(str(selves[i].rna.obs.out_file.iloc[0]) + ".h5ad"):
            selves[i].update_from_h5ad(selves[i].rna.obs.out_file.iloc[0])
        print(selves[i].rna)

# Marker Gene Dictionary (for Scanpy Plotting)
marker_genes_dict = dict(assign["Bucket"].reset_index().groupby(
    "Bucket").apply(lambda x: list(pd.unique(list(set(
        x.Gene).intersection(selves[0].rna.var_names))))))  # to dictionary

## Manual Annotations

### All Annotations

In [None]:
# ann = {}
# for x in [f for f in os.listdir(dir_ann) if os.path.isdir(
#         os.path.join(dir_ann, f)) is False]:
#     ann[x] = pd.read_excel(os.path.join(
#         out_dir, "annotation_dictionaries", x), index_col=0)
#     # ann[x] = ann[x].assign(Sample=x.split("___")[0])
# ann = pd.concat(ann, names=["File"])
# ann

### Load Manual Annotations

In [None]:
if man_anns is not None and man_anns is not False:
    for i, s in enumerate(selves):
        for r in man_anns:  # iterate Leiden clusterings
            fmr = os.path.join(out_dir, "annotation_dictionaries", str(
                f"{s._library_id}___leiden_{r}_dictionary.xlsx"))  # file
            if os.path.exists(fmr) is False:
                print(f"{fmr} file NOT found.")
                continue
            else:
                print(f"{fmr} file found.")
            if f"leiden_{r}" not in s.rna.obs:
                print(f"leiden_{r} not found in adata for {s._library_id}")
                continue
            fmr = pd.read_excel(fmr).astype(str)
            for x in ["annotation", "bin", "bucket"]:
                s.rna.obs.loc[:, f"{x}_{r}"] = s.rna.obs[
                    f"leiden_{r}"].astype(int).astype(str).replace(
                        fmr.set_index(fmr.columns[0])[x])  # Leiden -> label
                s.rna.obs.loc[s.rna.obs[f"{x}_{r}"].isnull(
                    ), f"{x}_{r}"] = s.rna.obs.loc[s.rna.obs[
                        f"{x}_{r}"].isnull(), f"leiden_{r}"].astype(
                            str)  # missing annotations replaced with Leiden
                s.rna.obs.loc[:, f"{x}_{r}"] = s.rna.obs[
                    f"{x}_{r}"].astype("category")  # as categorical
                # s.plot_spatial(f"{x}_{r}")

### Cluster Cell Counts

#### Plot

In [None]:
c_m = f"bin_{man_anns[-1]}"
p_clusts = pd.concat([100 * s.rna.obs[c_m].value_counts() / s.rna.n_obs
                      for s in selves], keys=[s._library_id for s in selves],
                     names=["Sample"])
p_clusts = p_clusts.unstack(1).replace(np.nan, 0).stack().to_frame("Percent")
all_samps = p_clusts.Percent.unstack().T.apply(
    lambda x: x if all(x > 0) else np.nan, axis=1).dropna().index
for x in [col_subject, col_condition]:
    p_clusts = p_clusts.join(pd.Series([
        s.rna.obs[x].iloc[0] for s in selves], index=pd.Index([
            s._library_id for s in selves], name=p_clusts.index.names[
                0])).to_frame(x))
print(f"Clusters Present in All Samples: {', '.join(list(all_samps))}")
fig = sb.catplot(p_clusts, x=c_m, y="Percent", sharex=False,
                 kind="bar", margin_titles=True,
                 col=col_subject, hue=col_condition,
                 hue_order=[key_uninfl, key_infl, key_stric],
                 palette=["blue", "red", "yellow"])
fig.set_xticklabels(rotation=90)
fig.set_yticklabels(rotation=30)
fig.fig.tight_layout()
fig.fig.suptitle(x)
fig.fig.set_size_inches(30, 30)
fig.fig.show()
n_clusts.unstack().astype(int).T.replace(np.nan, 0)

#### Write Cluster Ns to Annotation Dictionary

In [None]:
if out_dir is not None:
    for i, s in enumerate(selves):
        for r in man_anns:  # iterate Leiden clusterings
            fff = os.path.join(out_dir, "annotation_dictionaries", str(
                f"{s._library_id}___leiden_{r}_dictionary.xlsx"))  # file
            if os.path.exists(fff) is False:
                print(f"{fff} file NOT found.")
                continue
            else:
                print(f"{fff} file found.")
            fmr = pd.read_excel(fff).astype(str)
            if "n_cells" in fmr:
                fmr = fmr.drop("n_cells", axis=1)
            i_x = fmr.columns[0]
            fmr = fmr.set_index(i_x).join(s.rna.obs[f"leiden_{r}"].astype(
                str).value_counts().to_frame("n_cells").rename_axis(i_x))
            fmr.to_excel(fff)

### Write Manual Annotations (Xenium Explorer Files)

In [None]:
if man_anns not in [None, False] and out_dir is not None and load is True:
    for i, s in enumerate(selves):
        for r in man_anns:
            for x in ["annotation", "bin", "bucket"]:
                if f"{x}_{r}" in s.rna.obs:
                    s.write_clusters(out_dir, col_cell_type=f"{x}_{r}",
                                     overwrite=True,
                                     file_prefix=f"{s._library_id}__")
        s.write(s.rna.obs.out_file.iloc[0])

## Tangram Imputation

In [None]:
%%time

if file_sc is not None:
    adata_sc = sc.read(file_sc)  # read whole tx'ome data for imputation
    if load is False:
        for i, s in enumerate(selves):
            out = s.impute(
                adata_sc.copy(), col_cell_type=col_cell_type_sc,
                mode="clusters", markers=None, plot=False, plot_density=False,
                plot_genes=None, col_annotation=col_tangram, out_file=None)
            out[0].write_h5ad(os.path.splitext(selves[
                0].rna.obs.out_file.iloc[0])[0] + "___tangram.h5ad")  # write
            s.write(s.rna.obs.out_file.iloc[0])
            s.write_clusters(out_dir, file_prefix=f"{s._library_id}___",
                             col_cell_type=col_tangram,
                             overwrite=True, n_top=True)
    s.plot_spatial(color=col_tangram)

## Plot Clusters

In [None]:
for s in selves:
    s.plot_spatial(color=col_tangram)
    for j, x in enumerate(kws_clustering):
        _ = s.plot_spatial(color=[f"leiden_{x}",, f"label_{x}"])

# General Analysis

## Centrality Scores

In [None]:
%%time

for s in selves:
    _, fig = s.calculate_centrality(n_jobs=sc.settings.n_jobs)
    fig.savefig(os.path.join(dir_ax, f"{s._library_id}_centrality.pdf"))

## Neighborhood Enrichment Analysis

In [None]:
%%time

for s in selves:
    _, fig = s.calculate_neighborhood(figsize=(60, 30))
    fig.savefig(os.path.join(dir_ax, f"{s.library_id}_neighborhood.pdf"))

## Cell Type Co-Occurrence

In [None]:
%%time

for s in selves:
    _ = s.find_cooccurrence(figsize=(60, 20), kws_plot=dict(wspace=3))
    fig.savefig(os.path.join(dir_ax, f"{s.library_id}_cooccurrence.pdf"))

## Spatially-Variable Genes

In [None]:
%%time

kws = dict(kws_plot=dict(legend_fontsize="large"), figsize=(15, 15))
for s in selves:
    _ = s.find_svgs(genes=15, method="moran", n_perms=10, **kws)

## Receptor-Ligand Interactions

In [None]:
%%time

for s in selves:
    kss, ktt = None, None
    _ = s.calculate_receptor_ligand(
        col_condition=False, p_threshold=0.01, remove_ns=True,
        figsize=(30, 20), top_n=25, key_sources=kss, key_targets=ktt)
    # s.calculate_receptor_ligand_spatial()

## GEX

In [None]:
for s in selves:
    s.plot_spatial(color=["CSF1", "CSF2", "CSF3", col_cell_type])

# Spatial Clustering

In [None]:
for s in selves:
    cct = f"leiden_spatial_{list(kws_clustering.keys())[-1]}"
    _ = s.cluster_spatial(key_added=cct,
                          **kws_clustering[list(kws_clustering.keys())[-1]])
    _ = s.find_markers(col_cell_type=cct, kws_plot=False)
    _ = s.annotate_clusters(assign[[col_assignment[-1]]], col_cell_type=cct,
                            col_annotation=f"annotation_{cct}")
    for c in [cct, f"annotation_{cct}"]:
        s.plot_spatial(c)
        if out_dir is not None:
            s.write_clusters(out_dir, col_cell_type=c, overwrite=True,
                             n_top=True, file_prefix=f"{s._library_id}___")
    if out_dir is not None:
        s.write(str(s.rna.obs.out_file.iloc[0]))

# Counts

In [None]:
# Write Results Files? Read & Join Prior Results if Present?
write = True
join_old = True

# Cell Type Columns
c_l = "leiden_res1pt5_dist0_npc30"
c_ann = ["bucket", "bin", "annotation"]
c_m = "bucket"  # focus for certain plots

# For Plots
hue_order = ["Uninflamed", "Inflamed", "Stricture"]
palette = ["blue", "red", "yellow"]
hue = dict(hue=col_condition, hue_order=hue_order, palette=palette)

# For Markers
p_threshold = 1e-15
lfc_threshold = 1.5
n_top = 10

# Senescence
# label = "snc"
# c_t = "Senescence_Proxy"
# g_t = ["CDKN1A", "TP53", "PLAUR"]
# g_c = ["CDKN1A", "TP53", "PLAUR"]
# directory = "/home/elizabeth/elizabeth/projects/senescence/analysis"

# CSF2RB
label = "csf"
c_t = "ILC_Proxy"
g_t = ["CSF1", "CSF2", "CSF3"]
g_c = ["IL7R", "KLRB1", "RORC"]
c_l = "leiden_res1pt5_dist0_npc30"
directory = "/home/elizabeth/elizabeth/projects/csf2rb/analysis"

fmrs = {}
for x in metadata_o[col_sample_id].unique():
    fann = os.path.join(out_dir, "annotation_dictionaries", str(
        f"{x}___{c_l}_dictionary.xlsx"))
    if os.path.exists(fann):
        fmr = pd.read_excel(fann).astype(str).replace("nan", np.nan)
        fmrs[x] = fmr.set_index(fmr.columns[0]).rename_axis(c_l)
fmrs = pd.concat(fmrs, names=["Sample", c_l])

# Markers
mks = pd.concat([cr.ax.make_marker_genes_df(
    s.rna, c_l, key_added=f"rank_genes_groups_{c_l}", p_threshold=p_threshold,
    lfc_threshold=lfc_threshold) for s in selves], keys=[
        s._library_id for s in selves], names=["Sample"])
mks_strings = mks.groupby(mks.index.names[:-1]).apply(lambda x: ", ".join([
    str(f"{i} (lfc={round(x.loc[x.name].loc[i]['logfoldchanges'], 2)}; "
        f"p_adj={x.loc[x.name].loc[i]['pvals_adj']})")
    for i in x.reset_index().sort_values("pvals_adj").iloc[
        :min(n_top, x.shape[0])].names.unique()])).to_frame("Markers")

## Transcripts

### Quantify

In [None]:
tx_cts, tx_cts_cl = {}, {}
for s in selves:
    tx_cts[s._library_id], tx_cts_cl[s._library_id] = s.quantify_transcripts(
        g_t, col_cell_type=c_l if c_l in s.rna.obs else None, layer="counts")
tx_cts = pd.concat(tx_cts, keys=tx_cts, names=["Sample", "Gene"])
tx_cts_cl = pd.concat(tx_cts_cl, names=["Sample", c_l])
fff = os.path.join(directory, f"quantification_{label}_tx_cts")
if join_old is True and os.path.exists(fff + ".xlsx"):
    tx_os = pd.read_excel(fff + ".xlsx", index_col=np.arange(
        len(tx_cts.index.names)))
    tx_cts = pd.concat([tx_cts, tx_os.loc[
        tx_os.index.difference(tx_cts.index)]])
if join_old is True and os.path.exists(fff + "_by_cluster.xlsx"):
    tx_os_cl = pd.read_excel(fff + "_by_cluster.xlsx", index_col=np.arange(
        len(tx_cts.index.names))).reset_index().astype(
            {c_l: str}).set_index(tx_cts_cl.index.names)
    tx_cts_cl = pd.concat([tx_cts_cl, tx_os_cl.loc[
        tx_os_cl.index.difference(tx_cts_cl.index)]])
tx_cts_cl = tx_cts_cl.reset_index().astype({c_l: str}).set_index(
    tx_cts_cl.index.names).drop(set(c_ann).intersection(
        tx_cts_cl.columns), axis=1).join(
            fmrs[c_ann], on=["Sample", c_l])
tx_cts_cl = tx_cts_cl.reset_index().drop_duplicates().astype(
    {c_l: int}).set_index(tx_cts_cl.index.names).sort_index().reset_index(
        ).astype({c_l: "string"}).set_index(tx_cts_cl.index.names)
tx_cts = tx_cts.reset_index().drop_duplicates().set_index(tx_cts.index.names)
if write is True:
    tx_cts.to_excel(fff + ".xlsx")
    tx_cts_cl.join(mks_strings).to_excel(fff + "_by_cluster.xlsx")
tx_cts_cl

### Validate

In [None]:
for s in selves:
    val = tx_cts.loc[s._library_id].stack().to_frame("Mine").join(
        s.rna.var.loc[g_t].rename_axis("Gene")["n_counts"].to_frame(
            "n_transcripts").assign(total_counts=s.rna.obs[
                "n_counts"].sum()).stack().to_frame("Scanpy"))
    print("Comparison\n\n", val["Mine"].compare(val["Scanpy"]))
val

### Plot

#### Overall (Multi-Gene)

In [None]:
dff = tx_cts_cl.join(
    metadata_o[[col_sample_id, col_subject, col_condition]].set_index(
        col_sample_id), how="left").set_index([
            col_subject, col_condition] + c_ann, append=True).rename_axis(
                "Gene", axis=1).stack()
dff = [dff.groupby(dff.index.names.difference(set(c_ann).difference(
    [y]))).sum() for y in c_ann]
dff = pd.concat([x.to_frame("n_transcripts").reset_index().rename({
    c_ann[i]: "Cluster"}, axis=1).set_index(x.index.names.difference([c_ann[
        i]])).set_index("Cluster", append=True) for i, x in enumerate(dff)],
                keys=c_ann, names=["Annotation"])
for x in dff.reset_index()[col_subject].unique():
    fig = sb.catplot(dff.reset_index()[dff.reset_index()[col_subject] == x],
                     x="Cluster", y="n_transcripts", kind="bar", sharex=False,
                     sharey=False, row="Annotation",
                     margin_titles=True, **hue)
    fig.set_xticklabels(rotation=90)
    fig.set_titles(col_template="{col_name}", row_template="{row_name}")
    fig.fig.set_size_inches(25, 25)
    plt.subplots_adjust(hspace=0.8)
    fig.fig.suptitle(x)
    plt.show()
# fig.fig.tight_layout()
fig = sb.catplot(dff.reset_index()[dff.reset_index()[
    "Annotation"] == c_ann[0]], x="Cluster", y="n_transcripts", **hue,
                 kind="box", sharex=False, sharey=False)
fig.set_xticklabels(rotation=90)
fig.set_titles(col_template="{col_name}", row_template="{row_name}")
fig.fig.set_size_inches(25, 25)
plt.subplots_adjust(hspace=0.8)
fig.fig.suptitle(x)
plt.show()

#### By Cluster

In [None]:
y_label = "Gene Transcript Counts: Percent of Total Counts"
dff = tx_cts_cl.set_index(c_ann, append=True).apply(lambda x: tx_cts_cl[
    x.name] / tx_cts_cl["total_counts"]).drop(
        "total_counts", axis=1).rename_axis("Gene", axis=1).stack()
dff = dff * 100
dff = dff.to_frame(y_label).join(
    metadata_o[[col_sample_id, col_subject, col_condition]].set_index(
        col_sample_id), how="left").join(fmrs[c_ann]).reset_index(
            ).drop_duplicates()
if dff[y_label].max() > 1:
    raise ValueError(f"Percentages > 1: {dff[dff[y_label] > 1]}")
fig = sb.catplot(dff, x=c_m, y=y_label, kind="bar",
                 sharex=False, sharey=False, col="Gene", hue=c_m)
fig.set_xticklabels(rotation=90)
fig.fig.set_size_inches(60, 25)
fig.fig.set_dpi(200)
plt.subplots_adjust(hspace=0.8)
fig.fig.tight_layout()
plt.show()

#### By Gene & Condition

In [None]:
y_label = "Gene Transcript Counts: Percent of Total Counts"
dff = tx_cts_cl.set_index(c_ann, append=True).apply(lambda x: tx_cts_cl[
    x.name] / tx_cts_cl["total_counts"]).drop(
        "total_counts", axis=1).rename_axis("Gene", axis=1).stack()
dff = dff * 100
dff = dff.to_frame(y_label).join(
    metadata_o[[col_sample_id, col_subject, col_condition]].set_index(
        col_sample_id), how="left").join(fmrs[c_ann]).reset_index(
            ).drop_duplicates()
if dff[y_label].max() > 1:
    raise ValueError(f"Percentages > 1: {dff[dff[y_label] > 1]}")
for x in [None, col_subject]:  # aggregate across subjects, or facet
    fig = sb.catplot(dff.dropna(subset=[c_m]), x=c_m, y=y_label,
                     kind="box" if x is None else "bar",
                     sharex=False, sharey=False, row=x, col="Gene",
                     margin_titles=True, **hue)
    fig.set_xticklabels(rotation=90)
    fig.set_titles(col_template="{col_name}", row_template="{row_name}")
    fig.fig.set_size_inches(60, 25)
    plt.subplots_adjust(hspace=0.8)
    fig.fig.tight_layout()
    plt.show()

## Cells

### Quantify

In [None]:
threshold = [1, 2, 3]
lfc_threshold, p_threshold = 1.5, 1e-15

cts, cts_cl = {}, {}
for s in selves:
    if label = "csf":
        _, cts_cl[s._library_id], _ = s.quantify_cells(
            g_c, threshold=threshold, n_combos="all", col_cell_type=c_l,
            layer="counts", inplace=True)
        cts[s._library_id], _, _ = s.quantify_cells(
            g_t, threshold=threshold, n_combos="all", col_cell_type=None,
            layer="counts", inplace=True)
    else:
        cts[s._library_id], cts_cl[s._library_id], _ = s.quantify_cells(
            g_c, threshold=threshold, n_combos="all", col_cell_type=c_l,
            layer="counts", inplace=True)
cts, cts_cl = [pd.concat(x, names=["Sample"]) for x in [cts, cts_cl]]
fff = os.path.join(directory, f"quantification_ncells_{label}")
# for s in selves:
#     if c_l in s.rna.obs and f"rank_genes_groups_{c_l}" in s.rna.uns:
#         mks =
#         cts_cl.loc[s._library_id] = cts_cl.loc[s._library_id].join(
#             mks.groupby(c_l).apply(lambda x: ", ".join(mks.loc[x.name].head(
#                 10).index) if (x.name in mks.index) else "").to_frame(
#                     "Meta").rename_axis("Metric", axis=1).stack().to_frame(
#                         "Markers").unstack())
if join_old is True and os.path.exists(fff + ".xlsx"):
    cts_o = pd.read_excel(fff + ".xlsx").set_index(list(cts.index.names))
    for x in cts_o:
        if x in cts_o:
            cts_o = cts_o.drop(x, axis=1)
    cts = pd.concat([cts, cts_o.loc[cts_o.index.difference(cts.index)]])
if join_old is True and os.path.exists(fff + "_by_cluster.xlsx"):
    cts_cl_o = pd.read_excel(fff + "_by_cluster.xlsx", header=[0, 1],
                             index_col=[0, 1, 2])
    for x in c_ann:
        if x in cts_cl_o:
            cts_cl_o = cts_cl_o.drop(x, axis=1)
    cts_cl = pd.concat([cts_cl, cts_cl_o.loc[
        cts_cl_o.index.difference(cts_cl.index)]]).dropna(how="all")
cts_cl_s = pd.concat([fmrs[c_ann].reset_index().astype("string").set_index(
    fmrs.index.names)], axis=1, keys=["Annotation"])
cts_cl_s = cts_cl_s.reset_index().drop_duplicates().astype(str).set_index(
    cts_cl_s.index.names)
cts_cl = cts_cl[cts_cl.columns.difference(cts_cl_s.columns)].reset_index(
    ).drop_duplicates().astype(str).set_index(
        cts_cl.index.names).join(cts_cl_s)
if any(cts_cl["Percent"].astype(float).stack() > 100):
    raise ValueError("Percentages > 1:\n")
    print(cts_cl["Percent"].astype(float).stack()[cts_cl[
        "Percent"].astype(float).stack() > 100])
if write is True:
    cts.to_excel(fff + ".xlsx")
    cts_cl.join(pd.concat([mks_strings], keys=["Markers"], axis=1)).to_excel(
        fff + "_by_cluster.xlsx")
cts_cl

### Validate

In [None]:
for s in selves:
    val = cts.loc[s._library_id].loc[1].stack().to_frame("Mine").join(
        s.rna.var.loc[g_c].rename_axis("Gene")["n_cells"].to_frame(
            "Count").assign(Total=s.rna.n_obs).stack(
                ).to_frame("Scanpy"))
    print("Comparison\n\n", val["Mine"].compare(val["Scanpy"]))
val

### Plot

In [None]:
vvr = cts_cl.copy()["Percent"].rename_axis("Label", axis=1).dropna(
    how="all").stack().to_frame("Percent").astype(float).join(fmrs[[c_m]])
vvr = vvr.join(vvr.groupby("Sample").apply(lambda x: x.name.split("-")[
    0]).to_frame(col_condition)).reset_index().drop_duplicates()
vvr = vvr[vvr.apply(lambda x: "|" not in x[c_m], axis=1)]  # no mixes
for x in vvr.Threshold.unique():
    fig = sb.catplot(vvr[vvr.Threshold == x], x=c_m, y="Percent",
                     sharey=False, kind="box", margin_titles=True,
                     row="Label", sharex=False, **hue)
    fig.set_titles(col_template="{col_name}", row_template="{row_name}")
    fig.set_xticklabels(rotation=90)
    fig.fig.set_size_inches(35, 25)
    fig.fig.suptitle(f"Threshold = {x}")
    fig.fig.set_dpi(100)
    fig.fig.tight_layout()

# Spatial Distance

In [None]:
_ = s.calculate_spatial_distance(f"{'/'.join(g_c)}+", col_cell_type=c_t,
                                 genes="CSF2RB")

# Coordinates

In [None]:
# self.adata.labels["cell_labels"]["scale0"]
# selves[1].adata.shapes["cell_boundaries"]

# Spatial Plots

In [None]:
for s in selves:
    m_d = metadata.loc[s._library_id]
    fig = s.plot_spatial(color="bucket_res1pt5_dist0_npc30")
    fig.set_title(f"{s._library_id} (Age {m_d['Age']})")
    fig.fig.set_dpi(200)

# GEX Plots

In [None]:
for k in ["heat", "matrix", "dot"]:
    for s in selves:
        s.plot(g_c, kind=k, col_cell_type=c_m, title=s._library_id)

# Workspace

In [None]:
s=selves[0]

coex = [["CSF2RA", "CSF2RB"], ["IL3RA", "CSF2RB"], ["IL5RA", "CSF2RB"]]
col_cell_type = "leiden_res1pt5_dist0_npc30"
threshold = 1
layer = "counts"

percs = {}
for s in selves:
    for i in coex:
        s.get_layer(layer, inplace=True)
        for g in i:
            s.rna.obs.loc[:, g] = s.rna[:, g].X.toarray()
        s.rna.obs.loc[:, "-".join(i)] = s.rna.obs.apply(
            lambda x: all((x[g] >= threshold for g in i)), axis=1)
    coex_names = ["-".join(i) for i in coex]
    percs[s._library_id] = pd.concat([s.rna.obs.groupby(col_cell_type).apply(
        lambda x: 100 * x.loc[:, c].mean()) for c in coex_names],
                                     keys=coex_names, names=["Coexpression"])
percs = pd.concat(percs, names=["Sample"]).to_frame("Percent of Cells")
percs = percs.groupby(["Sample", "Coexpression"]).apply(
    lambda x: x.iloc[:, 0].sort_values(ascending=False)).reset_index([
        2, 3], drop=True).to_frame(percs.columns[0]).join(fmrs[[
            c_ann[-1]]]).join(mks_strings).join(metadata[[
                col_subject, col_condition]]).reset_index().set_index([
                    col_condition, "Coexpression",
                    col_subject, col_cell_type])

In [None]:
percs.to_excel(str("/home/elizabeth/elizabeth/projects"
                   "/csf2rb/analysis/percs_coex.xlsx"))

In [None]:
percs.groupby(["Condition", "Coexpression", "Patient"]).apply(
    lambda x: x[x["Percent of Cells"] >= 3]).reset_index([
        -2, -3, -4], drop=True).groupby([
            col_condition, "Coexpression"]).apply(
                lambda x: x.annotation.unique())

In [None]:
thresh_ce = 3
cct = "bin"
percs_p = percs.join(fmrs[[cct]], lsuffix="_original")
in_all_cond = [i for i in percs_p[cct].unique() if all(
    (any(percs_p[percs_p[cct] == i].reset_index(col_condition)[
        col_condition] == x) for x in percs_p.reset_index(col_condition)[
            col_condition].unique()))]  # cell types present in all conditions
percs_p = percs_p[percs_p[cct].isin(in_all_cond)]  # only if cell type in all

In [None]:
thresh_ce = 3
cct = "bin"

percs_p = percs.join(fmrs[[cct]], lsuffix="_original")
in_all_cond = [i for i in percs_p[cct].unique() if all(
    (any(percs_p[percs_p[cct] == i].reset_index(col_condition)[
        col_condition] == x) for x in percs_p.reset_index(col_condition)[
            col_condition].unique()))]  # cell types present in all conditions
percs_p = percs_p[percs_p[cct].isin(in_all_cond)]  # only if cell type in all

percs_p = percs_p[percs_p[cct].isin(percs_p[percs_p[
    "Percent of Cells"] >= thresh_ce][cct].unique())]
percs_p = percs_p[percs_p.annotation.apply(
    lambda x: "B Cell" in x or "T Cell" in x or "Endothelial" in x or (
        "Glia" in x) or "Mono" in x or "DC" in x or "Myeloid" in x or (
            "Macrophage" in x) or "Mast" in x or "Plasma" in x)]
fig = sb.catplot(percs_p, y="Percent of Cells", row="Coexpression", x=cct,
                 hue=col_condition, kind="box", sharex=False)
fig.set_titles(col_template="{col_name}", row_template="{row_name}")
fig.set_xticklabels(rotation=90)
fig.fig.set_size_inches(25, 20)
fig.fig.suptitle(f"Coexpression Levels" + str(
    f" (Cell Types with Any >= {thresh_ce}%)" if thresh_ce > 0 else ""))
fig.fig.tight_layout()

In [None]:
pbt = percs[percs.annotation.apply(lambda x: "B Cell" in x or "T Cell" in x)]
# pbt.loc["Uninflamed"]
pbt[pbt["Percent of Cells"] >= 3].sort_index()