# Setup

## Imports and Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import scanpy as sc
import os
import re
import json
from datetime import datetime
from anndata import AnnData
import seaborn as sb
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import corescpy as cr

pd.options.display.max_columns = 100
pd.options.display.max_rows = 500
sc.settings.set_figure_params(dpi=100, frameon=False, figsize=(20, 20))

palette = "tab20"

# Column Names (from Metadata & To Be Created)
col_cell_type = "Annotation"  # for eventual cluster annotation column

col_sample_id_o, col_sample_id, col_condition, col_subject,  = (
    cr.tl.COL_SAMPLE_ID_O, cr.tl.COL_SAMPLE_ID,
    cr.tl.COL_CONDITION, cr.tl.COL_SUBJECT)
col_inflamed, col_stricture = (cr.tl.COL_INFLAMED, cr.tl.COL_STRICTURE)
col_fff = cr.tl.COL_FFF  # column in which to store data file path
col_tangram = cr.tl.COL_TANGRAM  # for future Tangram imputation
col_segment = cr.tl.COL_SEGMENT
key_uninfl, key_infl, key_stric = (
    cr.tl.KEY_UNINFLAMED, cr.tl.KEY_INFLAMED,
    cr.tl.KEY_STRICTURE)


def construct_file(sample=None, run="CHO-001", slide=None,
                   date=None, timestamp=None,
                   panel_id="TUQ97N", prefix="output-XETG00189",
                   project_owner="EA", directory=None):
    """Construct file path from information."""
    if "outputs" not in directory and os.path.exists(
            os.path.join(directory, "outputs")):
        directory = os.path.join(directory, "outputs")
    if sample is None:
        run = [run] if isinstance(run, str) else run
        panel_id = [panel_id] * len(run) if isinstance(
            panel_id, str) else panel_id
        fff = []
        for i, x in enumerate(run):
            d_x = os.path.join(directory, panel_id[i], x)
            fff += [os.path.join(d_x, y) for y in os.listdir(d_x)]
        return fff
    if isinstance(sample, str):
        sample = [sample]
    print(directory)
    panel_id, prefix, project_owner, slide, date, timestamp = [
        [x] * len(sample) if isinstance(x, str) else list(x) if x else x
        for x in [panel_id, prefix, project_owner, slide, date, timestamp]]
    run = [run] * len(sample) if isinstance(run, (str, int, float)) else run
    block = ["-".join(i) for i in zip(sample, panel_id, project_owner)]
    fff = [f"{prefix[i]}__{slide[i]}__{block[i]}" for i in range(len(sample))]
    if date is None or timestamp is None:
        for i, x in enumerate(fff):  # iterate current file stems
            ddd = os.path.join(directory, panel_id[i], run[i])
            print(ddd)
            matches = sum([x in d for d in os.listdir(ddd)])
            if  matches != 1:
                raise ValueError(f"{x} found in 0 or multiple file paths",
                                 f"\n\n{os.listdir(ddd)}")
            fff[i] = os.path.join(ddd, np.array(os.listdir(ddd))[np.where([
                x in d for d in os.listdir(ddd)])[0][0]])  # find match
    else:
        fff = [os.path.join(directory, panel_id[i], run[i],
                            f"{x}__{date[i]}__{timestamp[i]}")
               for i, x in enumerate(fff)]
    return fff


def perform_qc_concatenated(selves):
    hue = selves[0]._columns["col_sample_id"]
    ids = [str(s.rna.obs[hue].iloc[0]) for s in selves]
    patterns = [("MT-", "mt-"), ("RPS", "RPL", "rps", "rpl"), (
        "^HB[^(P)]", "^hb[^(p)]")]  # pattern matching for gene symbols
    patterns = dict(zip(["mt", "ribo", "hb"], patterns))  # dictionary
    names = dict(zip(["mt", "ribo", "hb"],
                        ["Mitochondrial", "Ribosomal", "Hemoglobin"]))
    p_names = [names[k] if k in names else k for k in patterns]  # "pretty"
    patterns_names = dict(zip(patterns, p_names))  # map abbreviated to pretty
    adata = AnnData.concatenate(
        *[x.rna for x in selves], join="outer", batch_key=hue,
        batch_categories=ids, index_unique=None, uns_merge="unique")
    qc_vars = []  # to hold mt, rb, hb, etc. if present in data
    for k in patterns:  # calculate MT, RB, HB counts
        gvars = adata.var_names.str.startswith(patterns[k])
        if any(gvars):
            qc_vars += [k]
        adata.var[k] = gvars
    pct_n = [f"pct_counts_{k}" for k in qc_vars]  # "% counts" variables
    cgs = selves[0]._columns["col_gene_symbols"] if selves[
        0].rna.var.index.values[0] not in selves[0].rna.var_names else None
    adata.obs = adata.obs.astype({col_condition: "category"})
    ggg = list(set(genes).intersection(adata.var_names))
    ctm = list(set(["n_genes_by_counts", "total_counts", "log1p_total_counts",
                    "cell_area", "nucleus_area"]
                ).intersection(adata.obs.columns))
    vam = pct_n + ctm + [hue]  # QC variable names
    mets_df = adata.obs[vam].rename_axis("Metric", axis=1).rename(
        {"total_counts": "Total Counts in Cell", "cell_area": "Cell Area",
        "nucleus_area": "Nucleus Area",
        "n_genes_by_counts": "Genes Detected in Cell",
        "log1p_total_counts": "Log-Normalized Total Counts",
        **patterns_names}, axis=1)  # rename
    fff = sb.pairplot(
        mets_df, hue=hue, height=3, diag_kind="hist",
        plot_kws=dict(marker=".", linewidth=0.05))  # pair
    return adata

## Options

In [None]:
# Main Directories
# Replace manually or mirror my file/directory tree in your home (`ddu`)
ddu = os.path.expanduser("~")
ddm = "/mnt/cho_lab" if os.path.exists("/mnt/cho_lab") else "/mnt"  # Spark?
ddl = f"{ddm}/disk2/elizabeth/data/shared-xenium-library" if (
    "cho" in ddm) else os.path.join(ddu, shared-xenium-library)
ddx = f"{ddm}/bbdata0/xenium"  # mounted drive Xenium folder
out_dir = os.path.join(ddl, "outputs", "TUQ97N", "nebraska")  # None = no save
d_path = "/mnt/cho_lab/disk2/elizabeth/data"  # for other data
file_ann = os.path.join(ddu, "corescpy/examples/annotation_guide.xlsx")  # AG
col_assignment = "Bin"  # which column from annotation file to use
file_mdf = os.path.join(ddl, "Xenium_Samples_03152024.xlsx")  # metadata

# Tangram
col_tangram = "tangram_prediction"
# col_cell_type_sc, file_sc = "ClusterAnnotation", str(
#     f"{d_path}/2023-05-12_CombinedCD-v2_ileal_new.h5ad")
col_cell_type_sc, file_sc = "cell_type", f"{d_path}/elmentaite_ileal.h5ad"

# Directories & Metadata
run = "CHO-005"
samples = "all"
# samples = ["50452A", "50452B", "50452C"]
# samples = ["50564A4", "50618B5"]

# Input/Output Options
plot_all_qc = False
reload = False

# Computing Resources
gpu = False
sc.settings.n_jobs = 4
sc.settings.max_memory = 150

# Processing & Clustering Options
resolution = 0.5
resolution_sub = 0.5
min_dist = 0.3
n_comps = 50
# custom_thresholds = {col_qscore: [, None]}
genes_subset, use_highly_variable = True, False  # genes to use in clustering
kws_pp = dict(cell_filter_pmt=None, cell_filter_ncounts=[50, None],
              cell_filter_ngene=[30, None], gene_filter_ncell=[3, None],
              gene_filter_ncounts=[3, None], custom_thresholds=None,
              kws_scale=dict(max_value=10, zero_center=True))

# After this point, no more options to specify
# Just code to infer the data file path from your specifications
# and construct argument dictionaries and manipulate metadata and such.

# Read Metadata & Other Information
annot_df = pd.read_excel(file_ann)
metadata = pd.read_excel(file_mdf, dtype={"Slide ID": str})
if samples not in ["all", None]:  # subset by sample ID?
    metadata = metadata.set_index(col_sample_id_o).loc[samples].reset_index()

# Construct Clustering Argument Keyword Dictionaries
if genes_subset is True:
    genes_subset = list(annot_df.iloc[:, 0])
kws_umap = dict(min_dist=min_dist, method="rapids" if gpu else "umap")
kws_cluster = dict(use_gpu=gpu, kws_umap=kws_umap, kws_neighbors=None,
                   use_highly_variable=use_highly_variable, n_comps=n_comps,
                   genes_subset=genes_subset, resolution=resolution)
kws_subcluster = dict(method_cluster="leiden", resolution=resolution_sub)

# Revise Metadata & Construct Variables from Options
metadata.loc[:, col_stricture] = metadata[col_stricture].replace(
    {"yes": "Stricture", "no": "None"})  # stricture column
metadata.loc[:, col_condition] = metadata.apply(lambda x: "Stricture" if x[
    col_stricture] == "Stricture" else x[col_inflamed].capitalize(), axis=1)
metadata.loc[:, col_sample_id] = metadata[[col_condition, col_sample_id_o]
                                          ].apply("-".join, axis=1)
metadata = metadata.set_index(col_sample_id)
fff = np.array(construct_file(run=run, directory=ddx))
samps = np.array([i.split("__")[2].split("-")[0] for i in fff])
for x in metadata[col_sample_id_o]:
    metadata.loc[metadata[col_sample_id_o] == x, col_fff] = fff[np.where(
        samps == x)[0][0]] if len(np.where(samps == x)[0]) > 0 else np.nan
metadata = metadata.dropna(subset=[col_fff])
file_path_dict = dict(zip(metadata.index.values, metadata["file_path"]))
kws_init = dict(col_batch=col_batch, col_subject=col_subject,
                col_sample_id=col_sample_id, col_cell_type=col_cell_type)
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# Annotation File
assign = pd.read_excel(file_ann, index_col=0).dropna(
    subset=col_assignment).rename_axis("Gene")
marker_genes_dict = dict(assign.reset_index().groupby(col_assignment).apply(
    lambda x: list(pd.unique(x.Gene))))  # to marker dictionary

# Print Metadata
metadata

# Genes
# genes = ["CDKN1A", "CDKN2A", "TP53", "PLAUR", "PTGER4", "FTL", "IL6ST"]
# cell_types = ["ILC3", "LTi-like NCR+ ILC3", "LTi-like NCR- ILC3",
#               "ILCP", "Macrophages", "Stem cells"]
# palette = ["r", "tab:pink", "m", "b", "tab:brown", "cyan"]
# High in inf. vs. un
# OSM
# IL13
# IL1B
# IL6
# TNF
# S100A8
# S100A9
# ------------------------------
# High in stricture vs inf/un
# PDGFRA
# IL6ST
# PTPN1
# IFNG

# Data

## Loading

In [None]:
%%time

# Load Spatial Data
suff = str(f"res{re.sub('[.]', 'pt', str(resolution))}_dist"
           f"{re.sub('[.]', 'pt', str(min_dist))}_npc{n_comps}")  # file end
selves, paths_he, file_mks, out_files = [], [], [], []
for x in metadata.index.values:
    self = cr.Spatial(metadata.loc[x][col_fff], library_id=x, **kws_init)
    for i in metadata:  # add metadata for subject
        self.rna.obs.loc[:, i] = str(metadata.loc[x][i])  # add metadata
    selves += [self]
    paths_he += [os.path.join(metadata.loc[x][
        col_fff], "aux_outputs/image_he.ome.tif")]  # H&E paths
        # out_files += [os.path.join(out_dir, f"{x}.zarr")]
    if out_dir is not None:
        out_files += [os.path.join(out_dir, f"{x}__{suff}")]
    file_mks += [os.path.join(out_dir, f"{x}__{suff}_markers.csv")]

# Reload Processed & Clustered Data (Optionally)
if reload is True:
    for i, s in enumerate(selves):
        mks = file_mks[i] if os.path.exists(file_mks[i]) else None
        s.update_from_h5ad(file=out_files[i], file_path_markers=mks)
        print(s.adata)

## Plot

In [None]:
out_files = []
for x in metadata.index.values:
    if out_dir is not None:
        out_files += [os.path.join(out_dir, f"{x}__{suff}")]

In [None]:
col_cell_type

In [None]:
if reload is True:
    for i, s in enumerate(selves):
        s.plot_spatial(color="leiden")
        # s.plot_spatial(color=col_cell_type)

# Annotation & Sub-Clustering

In [None]:
for s in selves:
    # _ = s.cluster(**kws_cluster)
    _ = s.annotate_clusters(file_ann, col_cell_type="leiden",
                            col_annotation=col_cell_type,
                            col_assignment=col_assignment)  # annotate & write
    s.plot_spatial(color=col_cell_type)

In [None]:
# cat_epi = ["BEST2+ Goblet Cell", "BEST4+ Epithelial", "EC Cells (NPW+)",
#            "EC Cells (TAC1+)", "Enterocyte", "Epithelial", "Goblet Cell",
#            "I Cells (CCK+)", "Paneth", "D Cells (SST+)",
#            "Stem Cells", "N Cells (NTS+)", "K Cells (GIP+)",
#            # "Colonocyte",
#            "L Cells (PYY+)", "Tuft", "Microfold Cell"]
# cat_epi = list(np.array()[np.where(["epithelial" in x.lower()
#                                     for x in marker_genes_dict])[0]])
# file_ann_epi = pd.concat([pd.Series(marker_genes_dict[x]) for x in cat_epi],
#                          keys=cat_epi, names=["Type"]).to_frame(
#                              "Gene").reset_index(0).set_index("Gene")
file_ann_epi = file_ann
# key_cell_type, col_ann = ["7"], "Subclustering_Epi"
key_cell_type, col_ann = None, "leiden_sub"
for s in selves:
    ann = s.subcluster(restrict_to=(col_cell_type, key_cell_type), copy=False,
                       key_added="leiden_sub", **kws_subcluster)
    s.annotate_clusters(kws_annotation={"model": file_ann_epi,
                                        "col_assignment": "Bin"},
                        col_annotation=[col_ann, col_cell_type])
    s.write(out_files[i])  # write object
    s.plot_spatial(color=col_ann)

# Imputation

I used the following code to make an ileal subset of the Elmentaite data & run the find markers command (so that we don't have to repeat that step every iteration).

```
adata_sc = sc.read("/mnt/cho_lab/disk2/elizabeth/data/elmentaite.h5ad"
                   col_gene_symbols="feature_name")
adata_sc.var = adata_sc.var.reset_index().set_index("feature_name")
adata_sc = adata_sc[adata_sc.obs.tissue == "small intestine"]
adata_sc = adata_sc[adata_sc.obs.Age_group.isin(
    ["Adult", "Pediatric", "Pediatric_IBD"])]
adata_sc = adata_sc[adata_sc.obs["cell_type"].isin(adata_sc.obs[
    "cell_type"].value_counts().loc[lambda x: x >= 2].index)]  # >= 2 cells
sc.tl.rank_genes_groups(adata_sc, groupby="cell_type",  use_raw=False,
                        key_added="rank_genes_groups_cell_type")
adata_sc.write("/mnt/cho_lab/disk2/elizabeth/data/elmentaite_ileal.h5ad")
```

In [None]:
# Read scRNA-seq Data
adata_sc = sc.read(file_sc)

# Impute GEX & Predict Labels from sc-RNA-seq Integration
# for i, s in enumerate(selves):
i, s = 0, selves[0]
out = s.impute(
    adata_sc.copy(), col_cell_type=col_cell_type_sc,
    # mode="cells",
    mode="clusters",
    markers=None,  # use all overlap
    # markers=200,
    plot=False, plot_density=False, plot_genes=None,
    col_annotation=col_tangram,
    out_file=os.path.dirname(out_files[i]) if out_files else None)
    # adata_sp_new, sdata, adata_sc_n, ad_map, df_compare, fig = out

In [None]:
adata_sp_new, adata_sp, adata_sc_n, ad_map, df_compare, fig = out
cr.pl.plot_spatial(adata_sp_new, color=col_tangram)  # plot

In [None]:
figs = cr.pl.plot_integration_spatial(
    adata_sp, adata_sp_new=adata_sp_new, adata_sc=adata_sc,
    col_cell_type=col_cell_type, ad_map=ad_map, cmap="magma",
    df_compare=df_compare, plot_genes=None, perc=0.01,
    col_annotation=col_annotation)
figs

In [None]:
cts = ["Enterocyte", "Tuft", "Goblet", "Paneth", "Stem"]

cr.pl.plot_gex(adata_sp_new, col_tangram, kind=["dot", "matrix"],
               marker_genes_dict=dict(zip(cts, [
                   marker_genes_dict[x] for x in cts])))

# Write

In [None]:
for c in ["leiden", col_cell_type, col_tangram]:
    for i, s in enumerate(selves):
        s.write_clusters(
            out_dir, file_prefix=out_files[i] if c != col_tangram else None,
            col_cell_type=c, overwrite=True)

# Workspace

In [None]:
adata_sc_cp = adata_sc.copy()
adata_sp_cp = adata_sp.copy()

In [None]:
    import tangram as tg
    from corescpy.processing.spatial_pp import project_genes_m

    kwargs = {}
    adata_sc, adata_sp = adata_sc_cp, adata_sp_cp

    col_cell_type = "cell_type"
    col_annotation = "tangram_prediction"
    mode = "clusters"
    device = "gpu"
    markers = 100
    gene_to_lowercase = False
    num_epochs = 100
    learning_rate = 0.1
    density_prior = None
    perc = 0.01
    seed = 0

    key = f"rank_genes_groups_{col_cell_type}"

    if device == "gpu":
        device = "cuda:0"
    kws = {"suffix": kwargs.pop("suffix", None)
           }  # to construct .obs; suffix, density columns for each cell type

    if mode == "clusters":  # if mapping ~ clusters rather than cells...
        kwargs["cluster_label"] = col_cell_type  # ...must give label column


    if isinstance(markers, (int, float)) or markers is None:
        # if makers not a list of pre-specified genes
        mks = set(np.unique(pd.DataFrame(adata_sc.uns[key]["names"]).melt(
            ).value.values)).intersection(set(adata_sp.var_names))
        if isinstance(markers, (int, float)):  # if markers = #...
            markers = list(pd.Series(list(mks)).sample(
                int(markers)))  # ...random subset of overlapping markers
        else:  # if markers = None...
            markers = list(mks)  # ...use all overlapping genes
    tg.pp_adatas(adata_sc, adata_sp, genes=markers,
                 gene_to_lowercase=gene_to_lowercase)  # preprocess
    if "uniform_density" not in adata_sp.obs:  # issue with Tangram?
        adata_sp.obs["uniform_density"] = np.ones(adata_sp.X.shape[
            0]) / adata_sp.X.shape[0]  # uniform density calculation -> .obs
    if "rna_count_based_density" not in adata_sp.obs:  # issue with Tangram?
        ct_spot = np.array(adata_sp.X.sum(axis=1)).squeeze()  # cts per spot
        adata_sp.obs["rna_count_based_density"] = ct_spot / np.sum(ct_spot)

In [None]:
    ad_map = tg.map_cells_to_space(
        adata_sc, adata_sp, mode=mode, device=device, random_state=seed,
        learning_rate=learning_rate, num_epochs=num_epochs,
        density_prior=density_prior, **kwargs)  # map cells on spatial spots

In [None]:
    tg.project_cell_annotations(
        ad_map, adata_sp, annotation=col_cell_type)  # clusters -> space
    print(adata_sp)

In [None]:
    c_l = col_cell_type if mode == "clusters" else None
    adata_sp_new = project_genes_m(ad_map, adata_sc, cluster_label=c_l,
                                   gene_to_lowercase=gene_to_lowercase)  # GEX
    print(adata_sc)
    print(adata_sp_new)

In [None]:
    adata_sp_new.obsm["tangram_ct_pred"] = adata_sp.obsm[
        "tangram_ct_pred"].loc[adata_sp_new.obs.index]
    df_compare = tg.compare_spatial_geneexp(adata_sp_new, adata_sp, adata_sc)
    tmp, dfp, preds = cr.pp.construct_obs_spatial_integration(
        adata_sp_new.copy(), adata_sc.copy(), col_cell_type, perc=perc,
        col_annotation=col_annotation, **kws)  # normalized densities; labels
    adata_sp_new.obsm["tangram"] = tmp.obs[dfp.columns]
    adata_sp_new.obs = adata_sp_new.obs.join(preds)
    print(adata_sc)
    print(adata_sp_new)

In [None]:
figs = cr.pl.plot_integration_spatial(
    adata_sp, adata_sp_new=adata_sp_new, adata_sc=adata_sc,
    col_cell_type=col_cell_type, ad_map=ad_map, cmap="magma",
    df_compare=df_compare, plot_genes=None, perc=0.01,
    col_annotation=col_annotation)
figs

In [None]:
adf = pd.read_excel(file_ann)
adf

i = 0
self = selves[i]
self.add_image(paths_he[i], name="H_E", file_align=paths_he_align[i])