# Setup 

In [1]:
%load_ext autoreload
%autoreload 2

import os
import matplotlib.pyplot
import seaborn as sns
import scanpy as sc
import bento as bt
import spatialdata_io as sdio
import spatialdata as sd
import pandas as pd
import numpy as np

pd.options.display.max_colwidth = 1000
pd.options.display.max_columns = 100
pd.options.display.max_rows = 500

# Directories
libid = "50452A"  # paired (un)inflamed
ddm = "/mnt/cho_lab" if os.path.exists("/mnt/cho_lab") else "/mnt"  # Spark?
ddl = f"{ddm}/disk2/{os.getlogin()}/data/shared-xenium-library" if (
    "cho" in ddm) else os.path.join(ddu, "shared-xenium-library")
dir_data = os.path.join(ddm, "bbdata2/outputs/TUQ97N")
out_dir = os.path.join(ddl, "outputs/TUQ97N/nebraska")
file_a = os.path.join(out_dir, "annotation_dictionaries/annotations_all.xlsx")
file_mdf = os.path.join(ddl, "samples.csv")  # metadata
c_m = "annotation"  # column in `file_a` to use for cell type labels
panel_id = "TUQ97N"

# Column Names
col_sample_id_o, col_sample_id, col_condition, col_subject,  = (
    cr.tl.COL_SAMPLE_ID_O, cr.tl.COL_SAMPLE_ID,
    cr.tl.COL_CONDITION, cr.tl.COL_SUBJECT)
col_inflamed, col_stricture = (cr.tl.COL_INFLAMED,
                               cr.tl.COL_STRICTURE)
col_fff = cr.tl.COL_FFF  # column in which to store data file path
col_tangram = cr.tl.COL_TANGRAM  # for future Tangram imputation
col_segment = cr.tl.COL_SEGMENT
key_uninfl, key_infl, key_stric = (
    cr.tl.KEY_UNINFLAMED, cr.tl.KEY_INFLAMED,
    cr.tl.KEY_STRICTURE)
run = None
genes = ["CDKN1A", "CDKN2A", "TP53", "PLAUR"]

# Optionally, Define Manual Annotation Versions
# should be stored in ("<out_dir>/annotations_dictionaries")
# in format <selves[i]._library_id>___leiden_<man_anns[i]>_dictionary.xlsx
# with first column = leiden cluster and second column = annotation
man_anns = True  # load manual annotations according to clustering kws
# man_anns = ["res0pt5_dist0pt5_npc30", "res0pt75_dist0pt3_npc30",
#             "res1pt5_dist0_npc30"]  # choose manual annotations to load
# man_anns = None  # do not load manual annotations

# Main Directories
# Replace manually or mirror my file/directory tree in your home (`ddu`)
ddu = os.path.expanduser("~")
ddm = "/mnt/cho_lab" if os.path.exists("/mnt/cho_lab") else "/mnt"  # Spark?
ddl = f"{ddm}/disk2/{os.getlogin()}/data/shared-xenium-library" if (
    "cho" in ddm) else os.path.join(ddu, "shared-xenium-library")
ddx = f"{ddm}/bbdata2"  # mounted drive Xenium folder
out_dir = os.path.join(ddl, "outputs", "TUQ97N", "nebraska")  # None = no save
d_path = os.path.join(ddm, "disk2" if "cho" in ddm else "",
                      os.getlogin(), "data")  # other, e.g., Tangram data
anf = pd.read_csv(os.path.join(ddu, "corescpy/examples/markers_lineages.csv"))
file_mdf = os.path.join(ddl, "samples.csv")  # metadata

# After this point, no more options to specify
# Just code to infer the data file path from your specifications
# and construct argument dictionaries and manipulate metadata and such.

# Read Metadata & Other Information
metadata = (pd.read_excel if file_mdf[-4:] == "xlsx" else pd.read_csv)(
    file_mdf, dtype={"Slide ID": str})
assign = anf.dropna(subset=["Bin"]).set_index(
    "gene").rename_axis("Gene")  # annotation to markers map

# Revise Metadata & Construct Variables from Options
metadata.loc[:, col_condition] = metadata.apply(lambda x: "Stricture" if x[
    col_stricture].lower() in ["stricture", "yes"] else x[
        col_inflamed].capitalize(), axis=1)  # inflamation/stricture condition
metadata.loc[:, col_sample_id] = metadata[[col_condition, col_sample_id_o]
                                          ].apply("-".join, axis=1)
metadata = metadata.set_index(col_sample_id)
col_fff = "file_path"  # column in metadata in which to store data file path
if "outputs" not in ddx and os.path.exists(os.path.join(ddx, "outputs")):
    ddx = os.path.join(ddx, "outputs")
run = [j for j in os.listdir(os.path.join(
    ddx, panel_id)) if os.path.isdir(os.path.join(ddx, panel_id, j))]
panel_id = [panel_id] * len(run)
fff = []
for i, x in enumerate(run):
    d_x = os.path.join(ddx, panel_id[i], x)
    fff += [os.path.join(d_x, y) for y in os.listdir(d_x)]
bff = np.array([os.path.basename(i) for i in fff])  # base path names
samps = np.array([i.split("__")[2].split("-")[0] for i in fff])
for x in metadata[col_sample_id_o]:
    m_f = metadata[metadata[col_sample_id_o] == x][
        "out_file"].iloc[0]  # ...use to find unconventionally-named files
    locx = np.where(samps == x)[0] if pd.isnull(
        m_f) else np.where(bff == m_f)[0]
    metadata.loc[metadata[col_sample_id_o] == x, col_fff] = fff[locx[0]] if (
        len(locx) > 0) else np.nan  # assign output file to metadata row
metadata = metadata.dropna(subset=[col_fff]).drop_duplicates()
print("\n\n", metadata[[col_sample_id_o, col_subject, col_condition,
                        col_inflamed, col_stricture]], "\n\n")

# Load Data & Make Marker Dictionary
libid = metadata.reset_index().set_index(col_sample_id_o).loc[
    libid][col_sample_id]
sdata = sdio.xenium(metadata.loc[libid][col_fff])
marker_genes_dict = dict(assign["Bucket"].reset_index().groupby(
    "Bucket").apply(lambda x: list(pd.unique(list(set(
        x.Gene).intersection(sdata.table.var_names))))))  # to dictionary
del sdata.table
sdata.table = sc.read(os.path.join(out_dir, libid) + ".h5ad")
sdata


  from .autonotebook import tqdm as notebook_tqdm




                   Sample ID  Patient   Condition    Inflamed Stricture
Sample                                                                
Uninflamed-50336C    50336C    50336  Uninflamed  uninflamed        no
Inflamed-50336B      50336B    50336    Inflamed    inflamed        no
Stricture-50336A     50336A    50336   Stricture    inflamed       yes
Stricture-50403C2   50403C2    50403   Stricture    inflamed       yes
Stricture-50403C1   50403C1    50403   Stricture    inflamed       yes
Uninflamed-50403B    50403B    50403  Uninflamed  uninflamed        no
Inflamed-50403A1    50403A1    50403    Inflamed    inflamed        no
Inflamed-50403A2    50403A2    50403    Inflamed    inflamed        no
Stricture-50217C     50217C    50217   Stricture    inflamed       yes
Uninflamed-50217B    50217B    50217  Uninflamed  uninflamed        no
Inflamed-50217A      50217A    50217    Inflamed    inflamed        no
Stricture-50006C     50006C    50006   Stricture    inflamed       yes
Uni

SpatialData object with:
├── Images
│     ├── 'morphology_focus': MultiscaleSpatialImage[cyx] (1, 57808, 45580), (1, 28904, 22790), (1, 14452, 11395), (1, 7226, 5697), (1, 3613, 2848)
│     └── 'morphology_mip': MultiscaleSpatialImage[cyx] (1, 57808, 45580), (1, 28904, 22790), (1, 14452, 11395), (1, 7226, 5697), (1, 3613, 2848)
├── Labels
│     ├── 'cell_labels': MultiscaleSpatialImage[yx] (57808, 45580), (28904, 22790), (14452, 11395), (7226, 5697), (3613, 2848)
│     └── 'nucleus_labels': MultiscaleSpatialImage[yx] (57808, 45580), (28904, 22790), (14452, 11395), (7226, 5697), (3613, 2848)
├── Points
│     └── 'transcripts': DataFrame with shape: (<Delayed>, 10) (3D points)
├── Shapes
│     ├── 'cell_boundaries': GeoDataFrame shape: (333825, 1) (2D shapes)
│     ├── 'cell_circles': GeoDataFrame shape: (333825, 2) (2D shapes)
│     └── 'nucleus_boundaries': GeoDataFrame shape: (333825, 1) (2D shapes)
└── Tables
      └── 'table': AnnData (312629, 469)
with coordinate systems:
▸ 'global

# Preparation

In [None]:
# sdata.points["transcripts"].compute()
# sdata = sd.bounding_box_query(
#     sdata, axes=["x", "y"], target_coordinate_system="global",
#     min_coordinate=[0, 0], max_coordinate=[600, 600])
# # sdata.table = sc.read_h5ad(os.path.join(out_dir, libid + ".h5ad"))
# sdata

In [3]:
kwargs = dict(points_key="transcripts", feature_key="feature_name",
              instance_key="cell_boundaries",
              shape_keys=["cell_boundaries", "nucleus_boundaries"])
sdata_p = bt.io.prep(sdata, **kwargs)  # for Bento compatibility
sdata_p



                   x             y          z feature_name     cell_id  \
0          37.685432  10092.416016  11.131549      IRF2BP2  UNASSIGNED   
1         147.899933  10138.629883  11.069950      SLC26A6  UNASSIGNED   
2         189.406891  10248.221680  11.275550         TAP1  UNASSIGNED   
3          13.936633  10114.642578  11.317411        PRDM1  UNASSIGNED   
4          17.336510  10194.420898  11.345085     SERPINA1  UNASSIGNED   
...              ...           ...        ...          ...         ...   
1289261  9502.740234   6507.323242  19.898563         AQP1  UNASSIGNED   
1289262  9658.059570   6506.693848  19.695395         NET1  UNASSIGNED   
1289263  9656.251953   6501.889648  19.689312       SP140L  UNASSIGNED   
1289264  9653.966797   6508.765137  19.835377         TP53  UNASSIGNED   
1289265  9509.833984   6512.308594  19.862207        IKZF3  UNASSIGNED   

           transcript_id fov_name  nucleus_distance         qv  \
0        282170761412612       Q3        392.

ValueError: Must have equal len keys and value when setting with an iterable

In [2]:
kwargs = dict(points_key="transcripts", feature_key="feature_name",
              instance_key="cell_boundaries",
              shape_keys=["cell_boundaries", "nucleus_boundaries"])
sdata_p = bt.io.prep(sdata, **kwargs)  # for Bento compatibility
sdata_p

Mapping points:   0%|          | 0/3 [00:00<?, ?it/s]

                   x             y          z feature_name     cell_id  \
0          37.685432  10092.416016  11.131549      IRF2BP2  UNASSIGNED   
1         147.899933  10138.629883  11.069950      SLC26A6  UNASSIGNED   
2         189.406891  10248.221680  11.275550         TAP1  UNASSIGNED   
3          13.936633  10114.642578  11.317411        PRDM1  UNASSIGNED   
4          17.336510  10194.420898  11.345085     SERPINA1  UNASSIGNED   
...              ...           ...        ...          ...         ...   
1289261  9502.740234   6507.323242  19.898563         AQP1  UNASSIGNED   
1289262  9658.059570   6506.693848  19.695395         NET1  UNASSIGNED   
1289263  9656.251953   6501.889648  19.689312       SP140L  UNASSIGNED   
1289264  9653.966797   6508.765137  19.835377         TP53  UNASSIGNED   
1289265  9509.833984   6512.308594  19.862207        IKZF3  UNASSIGNED   

           transcript_id fov_name  nucleus_distance         qv  \
0        282170761412612       Q3        392.

ValueError: Must have equal len keys and value when setting with an iterable

# Plotting

In [None]:
sdata.table.uns["spatialdata_attrs"]["instance_key"] = "cell_boundaries"
for s in sdata.shapes:
    sdata.shapes[s] = sdata.shapes[s].reset_index().rename_axis("index")

In [None]:
%%time

bt.pl.shapes(sdata, hue=col_cell_type, color_style="fill")

In [None]:
%%time

bt.pl.points(sdata, hue="feature_name", hue_order=genes)

In [None]:
# sdata.points["transcripts"].attrs["spatialdata_attrs"][
#     "instance_key"] = "cell_boundaries"

# Shape Statistics

In [None]:
%%time

bt.tl.shape_stats(sdata)
bt.pl.shape_stats(sdata)

In [None]:
%%time

# bt.tl.analyze_shapes(sdata, "cell_boundaries", "area")
bt.tl.analyze_shapes(sdata, ["cell_boundaries", "nucleus_boundaries"],
                     ["radius", "span", "perimeter"])
features = ["cell_boundaries_area", "cell_boundaries_aspect_ratio",
            "cell_boundaries_density", "cell_boundaries_perimeter",
            "nucleus_boundaries_area", "nucleus_boundaries_aspect_ratio",
            "nucleus_boundaries_density", "nucleus_boundaries_perimeter"]
bt.pl.shape_stats(sdata, cols=features)

In [None]:
kws_plot = dict(z_score=0, vmin=-2, vmax=2, cmap="RdBu_r",
                figsize=(8, 3), xticklabels=False)

sns.clustermap(data=sdata["cell_boundaries"][[
    "cell_boundaries_area", "cell_boundaries_aspect_ratio",
    "cell_boundaries_density", "cell_boundaries_perimeter",
    "cell_boundaries_span", "cell_boundaries_moment"]].T, **kws_plot)

# RNA Flux

## Assign Sub-Cellular Domains

Segmentation + PCA

In [None]:
resolution = 0.1
figsize = (12, 12)

bt.tl.flux(sdata, method="radius", res=resolution, recompute=True)
fig, axis = plt.subplots(figsize=figsize)
bt.pl.flux(sdata, res=resolution, ax=axis)

## Cluster Sub-Cellular Domains

In [None]:
min_count = 100
figsize = (12, 12)

bt.tl.fluxmap(sdata, res=resolution, min_count=min_count)
fig, axis = plt.subplots(figsize=figsize)
bt.pl.fluxmap(sdata, palette=bt.colors.bento6, ax=axis)

In [None]:
n_top = 5

fluxmap_names = [s for s in sdata.shapes.keys() if s.startswith("fluxmap")]
bt.tl.comp(sdata, points_key="transcripts", shape_names=fluxmap_names)
bt.pl.comp(sdata, annotate=n_top)

## Functional Enrichment

# Predict RNA Localization

# Co-Localization

In [None]:
bt.geo.overlay(sdata, s1="cell_boundaries", s2="nucleus_boundaries",
               name="cytoplasm", how="difference")
bt.tl.coloc_quotient(sdata, shapes=["cytoplasm", "nucleus_boundaries"])
bt.tl.colocation(sdata, ranks=range(1, 6))
bt.pl.colocation(sdata, rank=2)