# Setup

## Imports and Configuration

In [1]:
%load_ext autoreload
%autoreload 2

import os
import re
import json
from datetime import datetime
import functools
import matplotlib.pyplot as plt
import seaborn as sb
from anndata import AnnData
import scanpy as sc
import pandas as pd
import numpy as np
import corescpy as cr

pd.options.display.max_columns = 100
pd.options.display.max_rows = 200
sc.settings.set_figure_params(dpi=100, frameon=False, figsize=(20, 20))

palette = "tab20"


def construct_file(sample, slide, date=None, timestamp=None,
                   panel_id="TUQ97N", prefix="output-XETG00189",
                   project_owner="EA", run="CHO-001", directory=None):
    """Construct file path from information."""
    if isinstance(sample, str):
        sample = [sample]
    if "outputs" not in directory and os.path.exists(
            os.path.join(directory, "outputs")):
        directory = os.path.join(directory, "outputs")
    print(directory)
    panel_id, prefix, project_owner, slide, date, timestamp = [
        [x] * len(sample) if isinstance(x, str) else list(x) if x else x
        for x in [panel_id, prefix, project_owner, slide, date, timestamp]]
    run = [run] * len(sample) if isinstance(run, (str, int, float)) else run
    block = ["-".join(i) for i in zip(sample, panel_id, project_owner)]
    fff = [f"{prefix[i]}__{slide[i]}__{block[i]}" for i in range(len(sample))]
    if date is None or timestamp is None:
        for i, x in enumerate(fff):  # iterate current file stems
            ddd = os.path.join(directory, panel_id[i], run[i])
            print(ddd)
            matches = sum([x in d for d in os.listdir(ddd)])
            if  matches != 1:
                raise ValueError(f"{x} found in 0 or multiple file paths",
                                 f"\n\n{os.listdir(ddd)}")
            fff[i] = os.path.join(ddd, np.array(os.listdir(ddd))[np.where([
                x in d for d in os.listdir(ddd)])[0][0]])  # find match
    else:
        fff = [os.path.join(directory, panel_id[i], run[i],
                            f"{x}__{date[i]}__{timestamp[i]}")
               for i, x in enumerate(fff)]
    return fff

ModuleNotFoundError: No module named 'tangram'

## Options

In [None]:
# Column Names
col_sample_id_original = "Sample ID"
col_sample_id = "Sample"
col_subject = "Patient"
col_batch = "Slide"
col_path = "file_path"
# col_date_time, col_date = "Date Sectioned", "Date"
col_date_time, col_date = None, None
col_inflamed, col_stricture = "Inflamed", "Stricture"
col_condition = "Condition"
meta_rn = {"Name": col_subject, "Slide ID": col_batch,
           "Inflammation Status": col_inflamed}

# Directories & Metadata
# Replace manually or mirror my file/directory tree in your home
include_stricture = True
run = "CHO-001"
samples = ["50452A", "50452B", "50452C"]
# samples = ["50452A", "50452B"]
# samples = ["50452A", "50452B", "50452C", "50564A4", "50618B5"]
# run = ["CHO-001"] * 3 + ["CHO-002"] * 3
panel_id = "TUQ97N"
prefix = "output-XETG00189"
project_owner = "EA"
ddu = os.path.expanduser("~")
ddl = "/mnt/cho_lab/disk2/elizabeth/data/shared-xenium-library"
ddd = os.path.join(ddl, "outputs", panel_id)
panel = os.path.join(ddu, "projects/senescence/ProposedGenePanel.xlsx")
file_ann = os.path.join(ddu, "corescpy/examples/annotation_guide.xlsx")

# Input/Output Options
reload = True
out_dir = os.path.join(ddd, "nebraska")  # set to None to avoid saving
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# Computing Resources
gpu = False
sc.settings.n_jobs = 4
sc.settings.max_memory = 150

# Read Metadata & Other Information
annot_df = pd.read_excel(file_ann)
metadata = pd.read_excel(os.path.join(ddl, "Xenium_Samples_02092024.xlsx"),
                         dtype={"Slide ID": str})
metadata = metadata.rename(meta_rn, axis=1)
if samples not in ["all", None]:  # subset by sample ID?
    metadata = metadata.set_index(col_sample_id_original).loc[
        samples].reset_index()

# Processing & Clustering Options
resolution = 0.5
min_dist = 1
n_comps = 20
# col_qscore = ?
# custom_thresholds = {col_qscore: [, None]}
custom_thresholds = None
genes_subset = list(annot_df.iloc[:, 0])
kws_pp = dict(cell_filter_pmt=None, cell_filter_ncounts=[50, None],
              cell_filter_ngene=[30, None], gene_filter_ncell=[3, None],
              gene_filter_ncounts=[3, None],
              custom_thresholds=custom_thresholds,
              kws_scale=dict(max_value=10, zero_center=True))
kws_umap = dict(min_dist=min_dist, method="rapids") if gpu is True else dict(
    min_dist=min_dist)
kws_cluster = dict(use_gpu=gpu, kws_umap=kws_umap, kws_neighbors=None,
                   use_highly_variable=False, n_comps=n_comps,
                   genes_subset=genes_subset, resolution=resolution)
# kws_subcluster = dict(use_gpu=gpu, kws_umap=dict(min_dist=0.3),
#                       kws_neighbors=None, use_highly_variable=False,
#                       n_comps=n_comps, genes_subset=genes_subset,
#                       method_cluster="leiden", resolution=0.5)
kws_subcluster = dict(method_cluster="leiden", resolution=0.5)

# Revise Metadata & Construct Variables from Options
if col_stricture not in metadata.columns:
    metadata.loc[:, col_stricture] = metadata["Sample Location"].apply(
        lambda x: "Stricture" if "stricture" in x.lower() else "None")
metadata.loc[:, col_condition] = metadata.apply(
    lambda x: "Stricture" if "stricture" in x[col_stricture].lower() else x[
        col_inflamed].capitalize() , axis=1)
if col_date_time:
    metadata.loc[:, col_date] = metadata[col_date_time].apply(
        lambda x: datetime.strftime(x, "%Y%m%d"))
    dates = list(metadata[col_date])
else:
    dates = None
metadata.loc[:, col_sample_id] = metadata.apply(
    lambda x: f"{x[col_condition]}-{x[col_sample_id_original]}" , axis=1)
metadata = metadata.set_index(col_sample_id)
metadata.loc[:, col_path] = construct_file(
    list(metadata[col_sample_id_original]), list(metadata[col_batch]),
    dates, panel_id=panel_id, prefix=prefix,
    project_owner=project_owner, run=run, directory=ddl)
col_cell_type = "Annotation"
file_path_dict = dict(zip(metadata.index.values, metadata["file_path"]))
kws_init = dict(col_batch=col_batch, col_subject=col_subject,
                col_sample_id=col_sample_id, col_cell_type=col_cell_type)

# Annotation File
col_assignment = "Bin"
assign = pd.read_excel(file_ann, index_col=0).dropna(subset=col_assignment)
sources = assign[col_assignment].unique()
rename = dict(zip(sources, [" ".join([i.capitalize() if i and i[
    0] != "(" and not i.isupper() and i not in [
        "IgG", "IgA"] else i for i in x.split(" ")]) if len(x.split(
            " ")) > 1 else x for x in [re.sub("glia", "Glia", re.sub(
                "_", " ", j)) for j in sources]]))
assign.loc[:, col_assignment] = assign[col_assignment].replace(rename)
assign = assign.rename_axis("Gene")
marker_genes_dict = dict(assign.reset_index().groupby(col_assignment).apply(
    lambda x: list(pd.unique(x.Gene))))  # to marker dictionary

# Subset if Desired
if include_stricture is False:
    metadata = metadata[metadata.Stricture != "Stricture"]
metadata

# Genes
# genes = ["CDKN1A", "CDKN2A", "TP53", "PLAUR", "PTGER4", "FTL", "IL6ST"]
# cell_types = ["ILC3", "LTi-like NCR+ ILC3", "LTi-like NCR- ILC3",
#               "ILCP", "Macrophages", "Stem cells"]
# palette = ["r", "tab:pink", "m", "b", "tab:brown", "cyan"]
# High in inf. vs. un
# OSM
# IL13
# IL1B
# IL6
# TNF
# S100A8
# S100A9
# ------------------------------
# High in stricture vs inf/un
# PDGFRA
# IL6ST
# PTPN1
# IFNG

## Re-Name Files in Standard Form Using Object

In [None]:
# fff = list(pd.Series([x if "h5ad" in x else np.nan for x in os.listdir(
#     out_dir)]).dropna())
# frn = []
# for x in fff:
#     ann = sc.read(os.path.join(out_dir, x))
#     kwu = re.sub("{", "", re.sub("}", "", str(ann.obs.iloc[0].kws_umap))
#                  ).split(", ")
#     kwu = np.array(kwu)[np.where(["min_dist" in x for x in kwu])[
#         0]][0].split(": ")[1]
#     frn += [str(f"{str(ann.obs.iloc[0][col_sample_id])}__"
#                 f"res{str(ann.obs.iloc[0].resolution)}"
#                 f"_dist{kwu}_npc{ann.varm['PCs'].shape[1]}")]
# frd = dict(zip([os.path.join(out_dir, i) for i in fff], [os.path.join(
#     out_dir, re.sub("[.]", "pt", re.sub("[.]h5ad", "", i)) + ".h5ad")
#                                                          for i in frn]))
# frd
# # for x in frd:
# #     os.system(f"mv {x} {frd[x]}")

# Data

## Loading

In [None]:
%%time

# Load Spatial Data
suff = str(f"res{re.sub('[.]', 'pt', str(resolution))}_dist"
           f"{re.sub('[.]', 'pt', str(min_dist))}_npc{n_comps}")  # file end
selves, paths_he, file_mks, out_files = [], [], [], []
for x in metadata.index.values:
    self = cr.Spatial(metadata.loc[x][col_path], library_id=x, **kws_init)
    for i in metadata:  # add metadata for subject
        self.rna.obs.loc[:, i] = str(metadata.loc[x][i])  # add metadata
    selves += [self]
    paths_he += [os.path.join(metadata.loc[x][
        col_path], "aux_outputs/image_he.ome.tif")]  # H&E paths
        # out_files += [os.path.join(out_dir, f"{x}.zarr")]
    if out_dir is not None:
        out_files += [os.path.join(out_dir, f"{x}__{suff}.zarr")]
    file_mks += [os.path.join(out_dir, f"{x}__{suff}_markers.csv")]

# Reload Processed & Clustered Data (Optionally)
if reload is True:
    for i, s in enumerate(selves):
        mks = file_mks[i] if os.path.exists(file_mks[i]) else None
        s.update_from_h5ad(file=out_files[i], file_path_markers=mks)
        print(s.adata)

## Plot

In [None]:
if reload is True:
    for i, s in enumerate(selves):
        s.plot_spatial(color="leiden")
        # s.plot_spatial(color=col_cell_type)

In [None]:
selves[0].cluster(**kws_subcluster, restrict_to=(col_cell_type, [1]))

In [None]:
selves[0].rna

# Annotation & Sub-Clustering

In [None]:
i = 0
self = selves[i]
self.add_image(paths_he[i], name="H_E", file_align=paths_he_align[i])

In [None]:
for s in selves:
    _ = s.cluster(**kws_cluster)
    _ = s.annotate_clusters(file_ann, col_cell_type="leiden",
                            col_annotation=col_cell_type,
                            col_assignment=col_assignment)  # annotate & write
    s.plot_spatial(color=col_cell_type)

In [None]:
s.plot_spatial(color="Annotation", palette=None)

In [None]:
adf = pd.read_excel(file_ann)
adf

In [None]:
# cat_epi = ["BEST2+ Goblet Cell", "BEST4+ Epithelial", "EC Cells (NPW+)",
#            "EC Cells (TAC1+)", "Enterocyte", "Epithelial", "Goblet Cell",
#            "I Cells (CCK+)", "Paneth", "D Cells (SST+)",
#            "Stem Cells", "N Cells (NTS+)", "K Cells (GIP+)",
#            # "Colonocyte",
#            "L Cells (PYY+)", "Tuft", "Microfold Cell"]
# cat_epi = list(np.array()[np.where(["epithelial" in x.lower()
#                                     for x in marker_genes_dict])[0]])
# file_ann_epi = pd.concat([pd.Series(marker_genes_dict[x]) for x in cat_epi],
#                          keys=cat_epi, names=["Type"]).to_frame(
#                              "Gene").reset_index(0).set_index("Gene")
file_ann_epi = file_ann
# key_cell_type, col_ann = ["7"], "Subclustering_Epi"
key_cell_type, col_ann = None, "leiden_sub"
for s in selves:
    ann = s.subcluster(restrict_to=(col_cell_type, key_cell_type), copy=False,
                       key_added="leiden_sub", **kws_subcluster)
    s.annotate_clusters(kws_annotation={"model": file_ann_epi,
                                        "col_assignment": "Bin"},
                        col_annotation=[col_ann, col_cell_type])
    s.plot_spatial(color=col_ann)

In [None]:
for s in selves:
    s.write(out_files[i])

In [None]:
ann

In [None]:
key_cell_type, col_ann = None, "leiden_sub"
for s in selves:
    s.plot_spatial(color=col_ann)

In [None]:
marker_genes_dict.keys()

In [None]:
s.annotate_clusters(file_ann, col_cell_type="leiden", col_annotation="Annotation")
s.annotate_clusters(file_ann, col_cell_type="leiden_subcluster", col_annotation="Annotation_Subcluster")
s.plot_spatial(color="Annotation_Subcluster")

In [None]:
pd.options.display.max_rows = 500

dff = pd.concat([pd.Series(marker_genes_dict[x]).to_frame("Marker")
                 for x in marker_genes_dict], keys=marker_genes_dict)
dff[dff.Marker.isin(self.rna.var_names)]