# 3.a SpatialDE2 analysis

In [None]:
import pickle
import sys
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import SpatialDE
from helper_functions import select_slide
from scipy.stats import mannwhitneyu

warnings.filterwarnings("ignore")
sys.path.insert(1, "../../helper_functions")

# SpatialDE

In [None]:
low_grade = ["B22", "B24", "B60", "B154", "B156"]
high_grade = ["B175", "B178", "B4", "B42", "B123"]
samples = low_grade + high_grade

In [None]:
vargenes = {}
adata = sc.read_h5ad(f"../../data/merged_samples.h5ad")
for sample in low_grade + high_grade:
    print(f"Working on sample {sample}")
    slide = select_slide(adata, sample)
    vargenes[sample], _ = SpatialDE.test(slide, layer="counts", omnibus=True)

pickle.dump(vargenes, open("vargenes.pkl", "wb"))

In [None]:
adatas = {}
for s in samples:
    adatas[s] = sc.read_h5ad(f"../../data/{s}.h5ad")

In [None]:
vargenes = pickle.load(open("vargenes.pkl", "rb"))

In [None]:
segmentations = {}
for sample, adata in adatas.items():
    sc.pp.calculate_qc_metrics(adata, inplace=True)
    genes = (
        vargenes[sample]
        .query("padj < 0.05 and not gene.str.startswith('MT-').values")
        .set_index("gene")
        .assign(totalcounts=adata.var.total_counts)
        .sort_values("totalcounts", ascending=False)
        .iloc[:2000,]
        .index.values
    )
    segmentations[sample] = genes
    segmentations[sample] = SpatialDE.tissue_segmentation(
        adata, layer="counts", genes=genes, rng=np.random.default_rng(42)
    )

In [None]:
for sample, adata in adatas.items():
    adata.write_h5ad(f"h5ad/{sample}.h5ad", compression="gzip", compression_opts=9)

In [None]:
adatas = {}
for sample in samples:
    adatas[sample] = sc.read_h5ad(f"h5ad/{sample}.h5ad")

In [None]:
for name, adata in adatas.items():
    adata.obs["segmentation_labels"] = "Cluster " + adata.obs[
        "segmentation_labels"
    ].astype(str)
    lab = []
    labels, counts = np.unique(adata.obs["segmentation_labels"], return_counts=True)
    for label, count in zip(labels, counts):
        if count > 3:
            lab.append(label)

    adata = adata[adata.obs["segmentation_labels"].isin(lab)].copy()
    adatas[name] = adata

In [None]:
for sample in sorted(adatas.keys()):
    sc.pl.spatial(
        adatas[sample],
        color="segmentation_labels",
        title=f"SpatialDE {sample}",
        img_key=None,
        size=1.5,
        show=False,
    )
    plt.tight_layout()
    plt.savefig(f"figures/SpatialDE_segments_{sample}.svg", dpi=300)

In [None]:
df = []
for sample, adata in adatas.items():
    regions, counts = np.unique(adata.obs.segmentation_labels, return_counts=True)
    df.append(
        pd.DataFrame(
            {
                "sample": [sample],
                "nregions": (counts > 3).sum(),
                "nspots": adata.n_obs,
                "type": "ecDNA" if sample in ["B4", "B123", "B42"] else "non-ecDNA",
            }
        )
    )
df = pd.concat(df).sort_values("sample").reset_index(drop=True)
df["nregions/nspots"] = df["nregions"] / df["nspots"]

In [None]:
plt.figure(figsize=(6, 4))
sns.stripplot(
    df,
    x="type",
    y="nregions/nspots",
    alpha=1,
    hue="type",
    size=8,
    legend=False,
    palette="colorblind",
)
sns.pointplot(
    df,
    x="type",
    y="nregions/nspots",
    errorbar=None,
    markers="_",
    scale=2.5,
    color="black",
    linestyles="",
)

plt.xlabel("ecDNA status", size=14)
plt.ylabel("regions / spot", size=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.ylim(0, 0.008)
plt.title("Number of segmented regions per spot")
plt.tight_layout()
plt.savefig("figures/spatialDE_pointplot_ecDNA.svg", dpi=300)

In [None]:
mannwhitneyu(
    (df.nregions / df.nspots)[df.type == "ecDNA"],
    (df.nregions / df.nspots)[df.type != "ecDNA"],
    alternative="greater",
)

In [None]:
for name, adata in adatas.items():

    adata.uns["log1p"]["base"] = None
    sc.tl.rank_genes_groups(adata, "segmentation_labels", method="wilcoxon")
    sc.pl.rank_genes_groups_dotplot(
        adata,
        var_group_rotation=0,
        values_to_plot="logfoldchanges",
        cmap="bwr",
        vmin=-2,
        vmax=2,
        n_genes=5,
        show=False,
        figsize=(17, 5),
    )
    plt.suptitle(name)
    plt.tight_layout()
    plt.savefig(f"figures/{name}_spatialDE_dotplot.svg", dpi=300, bbox_inches="tight")