Environment: This script should be run with the `python_scvi_environment` environment using the devcontainer `docker_python_scvi`

In [None]:
import anndata
import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
from scipy.stats import spearmanr
from scvi.data import cortex, smfish
from scvi.external import GIMVI
import os
import pandas as pd
import numpy as np
from sklearn.neighbors import KDTree
from tqdm.notebook import tqdm

In [None]:
train_size = 1

%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

In [None]:
data_path = "../../data"

Loading in the preprocessed spatial data

In [None]:
file = os.path.join(
    data_path, "segmentations", "kt56", "adatas", "preprocessed_and_filtered_02.h5ad"
)
spatial_data = sc.read(file)

Loading in the preprocessed snRNA data

In [None]:
seq_file = os.path.join(data_path, "DC3000_alone.h5ad")
seq_data = sc.read(seq_file)

Finding intersecting genes between the modalities

In [None]:
intersection = []
non_intersecting = []
for k in spatial_data.var_names:
    if k in seq_data.var_names:
        intersection.append(k)
    else:
        non_intersecting.append(k)

In [None]:
# only use genes in both datasets
seq_data = seq_data[:, intersection].copy()
spatial_data = spatial_data[:, intersection].copy()

Preparing for gimVI

In [None]:
seq_gene_names = seq_data.var_names
n_genes = seq_data.n_vars
n_train_genes = int(n_genes * train_size)

# randomly select training_genes

rand_train_genes = spatial_data.var.index.values[:498]
rand_test_genes = spatial_data.var.index.values[498:]

spatial_data_partial = spatial_data.copy()

spatial_data_partial.obs["batch"] = "spatial"
# remove cells with no counts
sc.pp.filter_cells(spatial_data, min_counts=1)
sc.pp.filter_cells(seq_data, min_counts=1)

# setup_anndata for spatial and sequencing data
GIMVI.setup_anndata(spatial_data_partial, labels_key="batch")
GIMVI.setup_anndata(seq_data, labels_key="SCT_snn_res.1")

Making count matices compatible with gimVI

In [None]:
seq_data.X = seq_data.layers["counts"]

Training GIMVI

In [None]:
# create our model
model = GIMVI(seq_data, spatial_data_partial, n_latent=10)

model.train(50)

Extracting gimVI latent representation

In [None]:
# get the latent representations for the sequencing and spatial data
latent_seq, latent_spatial = model.get_latent_representation()

# concatenate to one latent representation
latent_representation = np.concatenate([latent_seq, latent_spatial])
latent_adata = anndata.AnnData(latent_representation)

# labels which cells were from the sequencing dataset and which were from the spatial dataset
latent_labels = (["seq"] * latent_seq.shape[0]) + (
    ["spatial"] * latent_spatial.shape[0]
)
latent_adata.obs["labels"] = latent_labels

# compute umap
sc.pp.neighbors(latent_adata, use_rep="X")
sc.tl.umap(latent_adata)

Assigning umap coordinates to the sc objects

In [None]:
# save umap representations to original seq and spatial_datasets
seq_data.obsm["X_umap_gimvi"] = latent_adata.obsm["X_umap"][: seq_data.shape[0]]
spatial_data.obsm["X_umap_gimvi"] = latent_adata.obsm["X_umap"][seq_data.shape[0] :]

In [None]:
seq_data.obs["modality"] = "seq"
spatial_data.obs["modality"] = "spatial"

Plotting GIMVI results

In [None]:
# utility function for plotting spatial genes
def plot_gene_spatial(model, data_spatial, gene):
    """
    Plot the spatial gene expression and the imputed gene expression from the model

    :param model: GIMVI model
    :param data_spatial: spatial anndata object
    :param gene: gene name or gene index
    """
    data_seq = model.adatas[0]
    data_fish = data_spatial

    fig, (ax_gt, ax) = plt.subplots(1, 2)

    if type(gene) == str:
        gene_id = list(data_seq.var_names).index(gene)
    else:
        gene_id = gene

    x_coord = data_fish.obs["x"]
    y_coord = data_fish.obs["y"]

    def order_by_strenght(x, y, z):
        ind = np.argsort(z)
        return x[ind], y[ind], z[ind]

    s = 20

    def transform(data):
        return np.log(1 + 100 * data)

    # Plot groundtruth
    x, y, z = order_by_strenght(
        x_coord, y_coord, data_fish.X[:, gene_id] / (data_fish.X.sum(axis=1) + 1)
    )
    ax_gt.scatter(x, y, c=transform(z), s=s, edgecolors="none", marker="s", cmap="Reds")
    ax_gt.set_title("Groundtruth")
    ax_gt.axis("off")

    _, imputed = model.get_imputed_values(normalized=True)
    x, y, z = order_by_strenght(x_coord, y_coord, imputed[:, gene_id])
    ax.scatter(x, y, c=transform(z), s=s, edgecolors="none", marker="s", cmap="Reds")
    ax.set_title("Imputed")
    ax.axis("off")
    plt.tight_layout()
    plt.show()


plot_gene_spatial(model, spatial_data, "ALD1")

Getting gimVI latent representation and save to original objects

In [None]:
seq_latent = model.get_latent_representation()[0]
seq_data.obsm["X_gimvi"] = seq_latent
spatial_data.obsm["X_gimvi"] = model.get_latent_representation()[1]

In [None]:
sc.pp.neighbors(latent_adata, use_rep="X", n_neighbors=30)
sc.tl.umap(latent_adata, min_dist=0.1)

In [None]:
# save umap representations to original seq and spatial_datasets
seq_data.obsm["X_umap_gimvi"] = latent_adata.obsm["X_umap"][: seq_data.shape[0]]
spatial_data.obsm["X_umap_gimvi"] = latent_adata.obsm["X_umap"][seq_data.shape[0] :]

Transferring pseudotime values from seq to spatial

In [None]:
adata = latent_adata.copy()

# Filter spatial and seq cells
spatial_cells = adata[adata.obs["labels"] == "spatial"].copy()
seq_cells = adata[adata.obs["labels"] == "seq"].copy()
seq_data_copy = seq_data.copy()

# Create KD tree for seq cells
seq_gimvi = seq_data.obsm["X_gimvi"]
kdtree = KDTree(seq_gimvi)

# Find nearest seq cells for each seq
n_neighbors = 10  # Specify the number of nearest neighbors
distances, indices = kdtree.query(spatial_data.obsm["X_gimvi"], k=n_neighbors)

# Average gene expression for spatial cells
averaged_expression = np.zeros((spatial_cells.n_obs, 1))

for i in tqdm(range(spatial_cells.n_obs)):
    seq_neighbors_indices = indices[i]
    seq_neighbors_expression = seq_data_copy.obs["pseudotime"].values[
        seq_neighbors_indices
    ]

    averaged_expression[i] = np.mean(seq_neighbors_expression, axis=0)

In [None]:
spatial_data.obs["pseudotime"] = averaged_expression.flatten()

Plotting results

In [None]:
sc.pl.embedding(
    spatial_data,
    basis="umap_gimvi",
    color=["pseudotime", "ALD1"],
    vmax=0.5,
    cmap="Blues",
)

In [None]:
spatial_data.obsm["X_spatial"] = spatial_data.obs[["x", "y"]].values

In [None]:
sc.pl.embedding(
    spatial_data,
    basis="spatial",
    color=["pseudotime", "ALD1"],
    vmin=0.3,
    vmax=0.2,
    cmap="jet",
)

Saving out

In [None]:
try:
    os.mkdir(os.path.join(data_path, "integration", "dc3000"))
except:
    None
latent_adata.write(
    os.path.join(data_path, "integration", "dc3000", "latent_adata.h5ad")
)
spatial_data.write(
    os.path.join(data_path, "integration", "dc3000", "spatial_data.h5ad")
)
seq_data.write(os.path.join(data_path, "integration", "dc3000", "seq_data.h5ad"))

Loading back in snRNA data

In [None]:
spatial_data = sc.read(
    os.path.join(data_path, "integration", "dc3000", "spatial_data.h5ad")
)
seq_data = sc.read(os.path.join(data_path, "integration", "dc3000", "seq_data.h5ad"))
latent_adata = sc.read(
    os.path.join(data_path, "integration", "dc3000", "latent_adata.h5ad")
)

In [None]:
seq_data.obs["celltype"] = (
    seq_data.obs["celltype"]
    .replace("epidermis", "Epidermis")
    .replace("mesophyll", "Mesophyll")
    .replace("undifferentiated", "Unknown")
    .replace("vasculature", "Vasculature")
    .replace("", "Unknown")
    .to_list()
)

Transferring celltype labels from seq to spatial

In [None]:
from collections import Counter

adata = latent_adata.copy()

# Filter spatial and seq cells
spatial_cells = adata[adata.obs["labels"] == "spatial"].copy()
seq_cells = adata[adata.obs["labels"] == "seq"].copy()
seq_data_copy = seq_data.copy()

# Create KD tree for seq cells
seq_gimvi = seq_data.obsm["X_gimvi"]
kdtree = KDTree(seq_gimvi)

# Find nearest seq cells for each seq
n_neighbors = 20  # Specify the number of nearest neighbors
distances, indices = kdtree.query(spatial_data.obsm["X_gimvi"], k=n_neighbors)

# Average gene expression for spatial cells
averaged_expression = np.zeros((spatial_cells.n_obs, 1))
average_celltype = []

for i in tqdm(range(spatial_cells.n_obs)):
    seq_neighbors_indices = indices[i]

    seq_neighbors_categories = seq_data_copy.obs["SCT_snn_res.1"].values[
        seq_neighbors_indices
    ]

    # Use Counter to count occurrences
    counter = Counter(seq_neighbors_categories)

    # Find the most common element
    most_common = counter.most_common(1)

    averaged_expression[i] = most_common[0][0]

    seq_neighbors_categories = seq_data_copy.obs["celltype"].values[
        seq_neighbors_indices
    ]

    # Use Counter to count occurrences
    counter = Counter(seq_neighbors_categories)

    # Find the most common element
    most_common = counter.most_common(1)

    average_celltype.append(most_common[0][0])

spatial_data.obs["DC3000_Cluster_Transfer"] = averaged_expression
spatial_data.obs["celltype"] = average_celltype

UMAP and saving out spatial

In [None]:
sc.pp.neighbors(spatial_data, use_rep="X_gimvi")
sc.tl.leiden(spatial_data)

In [None]:
sc.tl.umap(spatial_data, n_components=15)

In [None]:
spatial_data.write(
    os.path.join(data_path, "integration", "dc3000", "spatial_data.h5ad")
)

In [None]:
sc.set_figure_params(dpi=300)
fig = sc.pl.embedding(
    spatial_data,
    basis="X_umap_gimvi",
    color="celltype",
    frameon=False,
    title="24hr DC3000 Multiome",
    return_fig=True,
    show=False,
)
plt.axis("equal")

plt.show()