In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import matplotlib.pyplot as plt
import warnings
import functools
import stlearn as st
import scanpy as sc
from dask_image.imread import imread
import imagecodecs
from PIL import Image
import dask.array as da
import squidpy as sq
import pandas as pd
import numpy as np

warnings.filterwarnings("ignore")

pd.options.display.max_colwidth = 1000
pd.options.display.max_columns = 100
pd.options.display.max_rows = 500

col_cell_type = "leiden_res1pt5_dist0_npc30"  # high resolution
l_r = {"CSF1": "CSF1R", "CSF2": ["CSF2RA", "CSF2RB"]}
plot_lr = ["CSF1_CSF1R", "CSF2_CSF2RA", "CSF2_CSF2RB"]

n_spots = 125
organism = "human"
resource = "connectomeDB2020_lit"
distance = 100
min_spots = 20
n_pairs = 10000  # CHANGE DEFAULT TO 10000
n_top = 50
n_jobs = 20
stats = "all"
layer = "counts"
adj_method = None
pval_adj_cutoff = None
stats = ["lr_scores", "p_vals", "p_adjs", "-log10(p_adjs)"]

kwargs = {}

# Data

In [None]:
# File Paths & Options
SPATIAL_KEY = "spatial"
library_id = "Inflamed-50452B"
dir_data = "/mnt/cho_lab/bbdata2/outputs/TUQ97N"
out_dir = str("/mnt/cho_lab/disk2/elizabeth/data/shared-xenium-library/"
              "outputs/TUQ97N/nebraska")
path_dir = os.path.join(out_dir, "pathology")
file_align = os.path.join(
    path_dir, f"alignment/{library_id}_alignment_files/matrix.csv")
file_image = (os.path.join(path_dir, f"{library_id.split('-')[1]}.ndpi"),
              os.path.join(path_dir, f"ome-tiff/{library_id}.ome.tif"))
image_kws = {}

# Find Files
files = functools.reduce(lambda i, j: i + j, [[os.path.join(
    run, i) for i in os.listdir(os.path.join(
        dir_data, run))] for run in os.listdir(dir_data)])
file_path = np.array(files)[np.where(["-".join(library_id.split(
    "-")[1:]) == os.path.basename(x).split("__")[2].split(
        "-")[0] for x in files])[0][0]]
img_dir = os.path.join(dir_data, file_path)
if os.path.exists(os.path.join(img_dir, "morphology_focus.ome.tif")) is False:
    img_dir = os.path.join(img_dir, "morphology_focus")

# Data
adata = sc.read(os.path.join(out_dir, f"{library_id}.h5ad"))
adata.uns[SPATIAL_KEY] = {library_id: {"images": {}}}

# Images

In [None]:
# Full Resolution? (Not Compatible Yet with stlearn)
# img_files = {f for f in os.listdir(img_dir) if f.endswith("_focus.ome.tif")}
# channel_names = {0: "DAPI"} if len(img_files) == 1 else {
#     0: "DAPI", 1: "ATP1A1/CD45/E-Cadherin", 2: "18S",
#     3: "AlphaSMA/Vimentin", 4: "dummy"}
# image_kws["c_coords"] = list(channel_names.values())
# image_path = os.path.join(img_dir, "morphology_focus_{:04}.ome.tif".format(
#     0) if len(img_files) > 1 else "morphology_focus.ome.tif")
# image = imread(image_path)

# if "c_coords" in image_kws and "dummy" in image_kws["c_coords"]:
#     image = da.concatenate([image, da.zeros_like(image[0:1])], axis=0)
# adata.uns["spatial"][library_id]["images"] = {"hires": sq.im.ImageContainer(
#     image, library_id=library_id)}

# st.add.image(adata, library_id=library_id, quality=quality,
#              imgpath=image_path, scale=scale)

# Stlearn Way (Not Full Resolution)
max_size = np.max([adata.obs["imagecol"].max(), adata.obs["imagerow"].max()])
max_size = int(max_size + 0.1 * max_size)
image = Image.new("RGBA", (max_size, max_size), (255, 255, 255, 255))
adata.uns["spatial"][library_id]["images"] = {"hires": np.array(image)}

# Make Compatible with Hard-Coded Column in stlearn Code
# max_coor = np.max(adata.obsm["spatial"])
# scale = 2000 / max_coor
scale = kwargs.pop("scale", 1)
quality = kwargs.pop("key_image", "hires")
spot_diameter_fullres = kwargs.pop("spot_diameter_fullres", 15)
if "spatial" in adata.obsm:
    adata.obs.loc[:, "imagerow"] = adata.obsm["spatial"][:, 0] * scale
    adata.obs.loc[:, "imagecol"] = adata.obsm["spatial"][:, 1] * scale
if "scalefactors" not in adata.uns["spatial"]:
    adata.uns["spatial"][library_id]["scalefactors"] = {}
    adata.uns["spatial"][library_id]["scalefactors"][
        "tissue_" + quality + "_scalef"] = scale
    adata.uns["spatial"][library_id]["scalefactors"][
        "spot_diameter_fullres"] = spot_diameter_fullres
    adata.uns["spatial"][library_id]["use_quality"] = "hires"

# Processing

In [None]:
# Pre-Process Data
adata.X = adata.layers[layer].copy()
st.pp.normalize_total(adata)

# Create Spot Grid
grid = st.tl.cci.grid(adata, n_row=n_spots, n_col=n_spots,
                      use_label=col_cell_type)

# Plot: Compare Clusters to Created Spots
fig, axes = plt.subplots(ncols=2, figsize=(20, 8))
st.pl.cluster_plot(grid, use_label=col_cell_type, size=10,
                   ax=axes[0], show_plot=False)
st.pl.cluster_plot(adata, use_label=col_cell_type,
                   ax=axes[1], show_plot=False)
axes[0].set_title(f"Grid: Dominant Spots")
axes[1].set_title(f"Cell {col_cell_type} Labels")
plt.show()git com

## Plot Cell Types

In [None]:
groups = list(grid.obs[col_cell_type].cat.categories)
for g in groups[0:2]:
    fig, axes = plt.subplots(ncols=3, figsize=(20,8))
    group_props = grid.uns[col_cell_type][g].values
    grid.obs["Group"] = group_props
    st.pl.feat_plot(grid, feature="Group", ax=axes[0], show_plot=False,
                    vmax=1, show_color_bar=False)
    st.pl.cluster_plot(grid, use_label=col_cell_type, list_clusters=[g],
                       ax=axes[1], show_plot=False)
    st.pl.cluster_plot(adata, use_label=col_cell_type, list_clusters=[g],
                       ax=axes[2], show_plot=False)
    axes[0].set_title(f"Grid {g} Proportions (Maximum = 1)")
    axes[1].set_title(f"Grid {g} Maximum Spots")
    axes[2].set_title(f"Individual Cell {g}")
    plt.show()

# Analysis

In [None]:
lrs = st.tl.cci.load_lrs([resource], species=organism)
st.tl.cci.run(
    grid, lrs, min_spots=min_spots, distance=distance,
    n_pairs=n_pairs, n_cpus=n_jobs)
if pval_adj_cutoff is not None or adj_method is not None:  # adjust p?
    st.tl.cci.adj_pvals(
        grid, correct_axis="spot", pval_adj_cutoff=pval_adj_cutoff,
        adj_method=adj_method)  # optionally, adjust p-values
print(grid.uns["lr_summary"])

# Plots

## QC & Results

In [None]:
# QC Plots
try:
    fig, axes = st.pl.cci_check(grid, col_cell_type, figsize=(16, 5))
    fig.suptitle("CCI Check: Interactions Shouldn't Correlate Much "
                "with Cell Type Frequency if Well-Controlled for")
    st.pl.lr_diagnostics(grid, figsize=(10, 2.5))
except Exception as err:
    print(err, "\n\nQC Plots failed")

# Results Plots
st.pl.lr_summary(grid, n_top=n_top, figsize=(10, 3))  # summary plot
if plot_lr is True or isinstance(
        plot_lr, (int, float)):  # if pairs unspecified, or just want top N
    plot_lr = 3 if plot_lr is None else int(plot_lr)  # top 3 = default
    plot_lr = grid.uns["lr_summary"].index.values[:plot_lr]  # best pairs
if plot_lr not in [None, False]:  # if wanted these plots...
    fig, axes = plt.subplots(ncols=len(stats), nrows=len(plot_lr),
                             figsize=(12, 6))
    for r, x in enumerate(plot_lr):  # iterate ligand-receptors
        for c, stat in enumerate(stats):  # iterate statistics
            st.pl.lr_result_plot(grid, use_result=stat, use_lr=x,
                                 show_color_bar=False, ax=axes[r, c])
            axes[r, c].set_title(f"{x} {stat}")

## Gene Expression

In [None]:
if plot_lr is not None:
    genes = functools.reduce(lambda i, j: list(i) + list(j),
                             [i.split("_") for i in plot_lr])
    for g in genes:
        fig, axes = plt.subplots(ncols=2, figsize=(20, 5))
        st.pl.gene_plot(grid, gene_symbols=g, ax=axes[0],
                        show_color_bar=False, show_plot=False)
        st.pl.gene_plot(adata, gene_symbols=g, ax=axes[1],
                        show_color_bar=False, show_plot=False, vmax=80)
        axes[0].set_title(f"Grid {g} Expression")
        axes[1].set_title(f"Cell {g} Expression")
        plt.show()