# Setup

## Imports & Settings

In [1]:
%load_ext autoreload
%autoreload 2

import os
import re
import itertools
import scipy
import scanpy as sc
import seaborn as sb
import pandas as pd
import numpy as np
import corescpy as cr
from cr.pp import (
    COL_SAMPLE_ID_O, COL_SAMPLE_ID, COL_SUBJECT, COL_INFLAMED, COL_STRICTURE,
    COL_CONDITION, COL_FFF, COL_TANGRAM, COL_SEGMENT, COL_SLIDE,
    COL_OBJECT, KEY_INFLAMED, KEY_UNINFLAMED, KEY_STRICTURE)

# Computing Resources
gpu = False
sc.settings.n_jobs = 8
sc.settings.max_memory = 150

# Display
pd.options.display.max_colwidth = 1000
pd.options.display.max_columns = 100
pd.options.display.max_rows = 500
sc.settings.set_figure_params(dpi=100, frameon=False, figsize=(20, 20))

# Panel & Column Names (from Metadata & To Be Created)
panel = "TUQ97N"
col_sample_id_o, col_sample_id, col_condition, col_subject = (
    cr.pp.COL_SAMPLE_ID_O, cr.pp.COL_SAMPLE_ID,
    cr.pp.COL_CONDITION, cr.pp.COL_SUBJECT)
col_inflamed, col_stricture = (cr.pp.COL_INFLAMED, cr.pp.COL_STRICTURE)
col_fff = cr.pp.COL_FFF  # column in which to store data file path
col_tangram = cr.pp.COL_TANGRAM  # for future Tangram imputation
col_segment = cr.pp.COL_SEGMENT  # column to signal old or new segmentation

Downloading data from `https://omnipathdb.org/queries/enzsub?format=json`
Downloading data from `https://omnipathdb.org/queries/interactions?format=json`
Downloading data from `https://omnipathdb.org/queries/complexes?format=json`
Downloading data from `https://omnipathdb.org/queries/annotations?format=json`
Downloading data from `https://omnipathdb.org/queries/intercell?format=json`
Downloading data from `https://omnipathdb.org/about?format=text`


## Options & Data

In [2]:
# Directories & Metadata
load, reannotate = True, True
run = None  # just look for samples in all runs
samples = "all"

# Main Directories
# Replace manually or mirror my file/directory tree in your home (`ddu`)
ddu = os.path.expanduser("~")
ddm = "/mnt/cho_lab" if os.path.exists("/mnt/cho_lab") else "/mnt"  # Spark?
ddl = f"{ddm}/disk2/{os.getlogin()}/data/shared-xenium-library" if (
    "cho" in ddm) else os.path.join(ddu, "shared-xenium-library")
ddx = f"{ddm}/bbdata2"  # mounted drive Xenium folder
out_dir = os.path.join(ddl, "outputs", panel, "nebraska")  # None = no save
d_path = os.path.join(ddm, "disk2" if "cho" in ddm else "",
                      os.getlogin(), "data")  # other, e.g., Tangram data
file_mdf = os.path.join(ddl, "samples.csv")  # metadata
anf = pd.read_csv(os.path.join(ddu, "corescpy/examples/markers_lineages.csv"))
genes_subset = None
# genes_subset = list(anf.iloc[:, 0])

# Processing & Clustering Options
kws_pp = dict(cell_filter_pmt=None, cell_filter_ncounts=[15, None],
              cell_filter_ngene=[3, None], gene_filter_ncell=[3, None],
              gene_filter_ncounts=[3, None], custom_thresholds=None,
              kws_scale=dict(max_value=10, zero_center=True),
              method_norm="log")  # preprocessing keyword arguments
kws_cluster = dict(kws_umap=dict(method="rapids" if gpu else "umap"),
                   genes_subset=genes_subset,  # use only markers
                   use_gpu=gpu, use_highly_variable=False)
kws_clustering, col_assignment = {}, []
res_list = [1.5, 0.75, 0.5]
min_dist_list = [0, 0.3, 0.5]
n_comps_list = [30, 30, 30]

# After this point, no more options to specify
# Just code to infer the data file path from your specifications
# and construct argument dictionaries and manipulate metadata and such.

# Construct Clustering Keyword Dictionary
for i in zip(res_list, min_dist_list, n_comps_list):
    kws = {**kws_cluster}
    kws.update({"resolution": i[0], "n_comps": i[2],
                "kws_umap": {**kws_cluster["kws_umap"], "min_dist": i[1]}})
    suff = str(f"res{re.sub('[.]', 'pt', str(kws['resolution']))}_dist"
               f"{re.sub('[.]', 'pt', str(kws['kws_umap']['min_dist']))}"
               f"_npc{kws['n_comps']}")  # file path suffix
    kws_clustering.update({suff: kws})
    col_assignment += ["group" if kws["resolution"] >= 0.7 else "Bucket"]
if man_anns is True:
    man_anns = list(kws_clustering.keys())
col_cell_type = list(kws_clustering.keys())[0] if (
    man_anns is None) else f"annotation_{man_anns[-1]}"  # default cell labels

# Read Metadata
metadata = cr.pp.get_metadata_cho(
    ddx, file_mdf, panel_id=panel, samples=samples)  # get metadata

# Annotation File
assign = anf.dropna(subset=col_assignment).set_index(
    "gene").rename_axis("Gene")  # markers
# assign = assign[~assign.Quality.isin([-1])]  # drop low-quality markers

# Print Metadata & Make Output Directory (If Not Present)
print(metadata[[col_sample_id_o, col_subject, col_condition,
                col_inflamed, col_stricture, col_segment]])
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# Load Data
kws_init = dict(col_sample_id=col_sample_id, col_subject=col_subject,
                col_cell_type=col_cell_type)  # object creation arguments
selves = [None] * metadata.shape[0]  # to hold different samples
for i, x in enumerate(metadata.index.values):
    selves[i] = cr.Spatial(metadata.loc[x][col_fff], library_id=x, **kws_init)
    out = os.path.join(out_dir, selves[i]._library_id)  # where to save object
    if load and os.path.exists(os.path.join(out_dir, self._library_id)):
        selves[i].update_from_h5ad(out)  # update from processed object file
    for j in metadata:  # iterate metadata columns & add to .obs
        selves[i].rna.obs.loc[:, j] = str(metadata.loc[x][j])
    selves[i].rna.obs.loc[:, COL_OBJECT] = out  # output path (to save object)

                Sample ID  Patient Condition  Inflamed Stricture segmentation
Sample                                                                       
Inflamed-50006A    50006A    50006  Inflamed  inflamed        no          new


<<< INITIALIZING SPATIAL CLASS OBJECT >>>

[34mINFO    [0m reading                                                                                                   
         [35m/mnt/cho_lab/bbdata2/outputs/TUQ97N/CHO-007/output-XETG00189__0022407__50006A-TUQ97N-EA__20240411__205514/[0m
         [95mcell_feature_matrix.h5[0m                                                                                    


Counts: Initial


	Observations: 297656

	Genes: 469







 AnnData object with n_obs × n_vars = 297656 × 469
    obs: 'cell_id', 'transcript_counts', 'control_probe_counts', 'control_codeword_counts', 'unassigned_codeword_counts', 'deprecated_codeword_counts', 'total_counts', 'cell_area', 'nucleus_area', 'region', 'z_level', 'nucleus_count'

# Clustering

## Processing, Leiden, Annotation

In [None]:
%%time

for i, s in enumerate(selves):
    f_o = str(selves[i].rna.obs.out_file.iloc[0])

    # Preprocessing
    print("\n\n", kws_pp, "\n\n")
    _ = s.preprocess(**kws_pp, figsize=(15, 15))  # preprocess

    # Clustering at Different Resolutions & Minimum Distances & # of PCs
    for j, x in enumerate(kws_clustering):

        # Variables & Output Files
        print(f"\n\n{'=' * 80}\n{x}\n{'=' * 80}\n\n")
        cct, cca = f"leiden_{x}", f"label_{x}"  # Leiden & annotation columns
        annot = assign[[col_assignment[j]]]  # gene-annotation dictionary

        # Clustering, Markers, Annotation, & Writing Output
        if load is False or cct not in s.rna.obs:
            _ = s.cluster(**kws_clustering[x], key_added=cct, out_file=f_o)
        if load is False or f"rank_genes_groups_{cct}" not in s.rna.uns:
            _ = s.find_markers(col_cell_type=cct, kws_plot=False)  # DEGs
        if reannotate is True or load is False:  # annotate; Explorer files
            _ = s.annotate_clusters(annot, col_cell_type=cct,
                                    col_annotation=cca)  # annotate
            for c in [k for k in [cct, cca] if k in s.rna.obs]:  # Explorer
                s.write_clusters(out_dir, col_cell_type=c, overwrite=True,
                                 file_prefix=f"{s._library_id}__",
                                 n_top="find_markers")

        # Write Final Object
        if load is False or reannotate is True and f_o is not None:
            s.write(f_o)

## Plot Clusters Individually (in Same PDF)

In [None]:
for i, s in enumerate(selves):
    for x in kws_clustering:
        print(f"\n\n{'=' * 80}\n{x}\n{'=' * 80}\n\n")
        cct, cca =   # Leiden & annotation columns
        for c in [f"leiden_{x}", f"label_{x}", f"manual_{x}"]:
            if c not in s.rna.obs:
                print(f"\n\n{c} not in {s.rna.obs.columns}.\n\n")
            pfp = os.path.join(out_dir, "plots", s._library_id, f"{c}.pdf")
            if out_dir is not None and (
                    load is False or not os.path.exists(pfp)):
                s.plot_clusters(col_cell_type=c, out_dir=pfp, multi_pdf=True)

## Plot Clusters

In [None]:
for s in selves:
    s.plot_spatial(color=col_tangram)
    for j, x in enumerate(kws_clustering):
        _ = s.plot_spatial(color=[f"leiden_{x}", f"label_{x}"])

# Analyze

## Centrality Scores

In [None]:
%%time

for s in selves:
    s.calculate_centrality(n_jobs=sc.settings.n_jobs)

## Neighborhood Enrichment Analysis

In [None]:
%%time

for s in selves:
    _ = s.calculate_neighborhood(figsize=(60, 30))

## Cell Type Co-Occurrence

In [None]:
%%time

for s in selves:
    _ = s.find_cooccurrence(figsize=(60, 20), kws_plot=dict(wspace=3))

## Spatial Clustering

In [None]:
for s in selves:
    cct = f"leiden_spatial_{list(kws_clustering.keys())[-1]}"
    _ = s.cluster_spatial(key_added=cct,
                          **kws_clustering[list(kws_clustering.keys())[-1]])
    _ = s.find_markers(col_cell_type=cct, kws_plot=False)
    _ = s.annotate_clusters(assign[[col_assignment[-1]]], col_cell_type=cct,
                            col_annotation=f"annotation_{cct}")
    for c in [cct, f"annotation_{cct}"]:
        s.plot_spatial(c)
        if out_dir is not None:
            s.write_clusters(out_dir, col_cell_type=c,
                             n_top=True, overwrite=True,
                             file_prefix=f"{s._library_id}___")
    if out_dir is not None:
        s.write(str(s.rna.obs.out_file.iloc[0]))

## Spatially-Variable Genes

In [None]:
%%time

kws = dict(kws_plot=dict(legend_fontsize="large"), figsize=(15, 15))
for s in selves:
    _ = s.find_svgs(genes=15, method="moran", n_perms=10, **kws)

## GEX

In [None]:
# for s in selves:
#     s.plot_spatial(color=["TNF", "IL23", col_cell_type])