In [None]:
import anndata
import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
from scipy.stats import spearmanr
from scvi.data import cortex, smfish
from scvi.external import GIMVI

##### Load in the downsampled Xenium object

In [None]:
spatial_data = sc.read("downsampled_mouse.h5ad")
spatial_data.obs["batch"] = "xen"
sc.pp.filter_cells(spatial_data, min_counts=10)

##### Load in the Visium object

In [None]:
visium_object = sc.read("visium_combined.h5ad")
seq_data = visium_object[visium_object.obs["batch"] == "proximal"]

##### Let's use almost all the genes to train the GIMVI model

In [None]:
train_size = 0.99

%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

##### Subset adatas to intersecting genes

In [None]:
intersection = []
non_intersecting = []
for k in spatial_data.var_names:
    if k in seq_data.var_names:
        intersection.append(k)
    else:
        non_intersecting.append(k)

In [None]:
# only use genes in both datasets
seq_data = seq_data[:, intersection].copy()
spatial_data = spatial_data[:, intersection].copy()

##### Prepare for GIMVI and run GIMVI

In [None]:
seq_gene_names = seq_data.var_names
n_genes = seq_data.n_vars
n_train_genes = int(n_genes * train_size)

# randomly select training_genes
rand_train_gene_idx = np.random.choice(range(n_genes), n_train_genes, replace=False)
rand_test_gene_idx = sorted(set(range(n_genes)) - set(rand_train_gene_idx))
rand_train_genes = seq_gene_names[rand_train_gene_idx]
rand_test_genes = seq_gene_names[rand_test_gene_idx]

# spatial_data_partial has a subset of the genes to train on
# spatial_data_partial = spatial_data[:, rand_train_genes].copy()

# remove cells with no counts
sc.pp.filter_cells(spatial_data, min_counts=1)
sc.pp.filter_cells(seq_data, min_counts=1)

# setup_anndata for spatial and sequencing data
GIMVI.setup_anndata(spatial_data, labels_key="batch")
GIMVI.setup_anndata(seq_data, labels_key="batch")

# spatial_data should use the same cells as our training data
# cells may have been removed by scanpy.pp.filter_cells()
# spatial_data = spatial_data[spatial_data_partial.obs_names]

In [None]:
# create our model
model = GIMVI(seq_data, spatial_data, n_latent=10)

# train for 200 epochs
model.train(200, kappa=10)

##### Transfer the crypt villus axis using nearest neighbors in the latent space

In [None]:
# get the latent representations for the sequencing and spatial data
latent_seq, latent_spatial = model.get_latent_representation()

# concatenate to one latent representation
latent_representation = np.concatenate([latent_seq, latent_spatial])
latent_adata = anndata.AnnData(latent_representation)

# labels which cells were from the sequencing dataset and which were from the spatial dataset
latent_labels = (["seq"] * latent_seq.shape[0]) + (
    ["spatial"] * latent_spatial.shape[0]
)
latent_adata.obs["labels"] = latent_labels

# compute umap
sc.pp.neighbors(latent_adata, use_rep="X")
sc.tl.umap(latent_adata)

# save umap representations to original seq and spatial_datasets
seq_data.obsm["X_umap"] = latent_adata.obsm["X_umap"][: seq_data.shape[0]]
spatial_data.obsm["X_umap"] = latent_adata.obsm["X_umap"][seq_data.shape[0] :]

In [None]:
import numpy as np
from sklearn.neighbors import KDTree
from tqdm.notebook import tqdm

# Assuming your original AnnData object is named 'adata'
adata = latent_adata.copy()

# Step 1: Filter spatial and seq cells
spatial_cells = adata[adata.obs["labels"] == "spatial"].copy()
seq_cells = adata[adata.obs["labels"] == "seq"].copy()
spatial_data_copy = sc.read("downsampled_mouse.h5ad")

spatial_data_copy.obs["batch"] = "xen"
sc.pp.filter_cells(spatial_data_copy, min_counts=10)

# Step 2: Create KD tree for seq cells
spatial_gene_expression = spatial_cells.X
kdtree = KDTree(spatial_gene_expression)

# Step 3: Find nearest seq cells for each seq
n_neighbors = 20  # Specify the number of nearest neighbors
distances, indices = kdtree.query(seq_cells.X, k=n_neighbors)


# Step 4: Average gene expression for spatial cells
averaged_expression = np.zeros(seq_cells.n_obs)

for i in tqdm(range(seq_cells.n_obs)):
    spatial_neighbors_indices = indices[i]
    spatial_neighbors_expression = spatial_data_copy.obs[["crypt_villi"]].values[
        spatial_neighbors_indices
    ]

    averaged_expression[i] = np.mean(spatial_neighbors_expression)

In [None]:
seq_cells.obs["normalized_crypt_villi_scaled"] = averaged_expression

In [None]:
spatial_data.obsm["X_umap"] = spatial_cells.obsm["X_umap"]

In [None]:
sc.pl.umap(spatial_data, color="crypt_villi")

In [None]:
seq_data.obs["crypt_villi_axis"] = seq_cells.obs["normalized_crypt_villi_scaled"].values

In [None]:
sc.pl.umap(seq_data, color=["crypt_villi_axis"], vmax=4)

##### Add the imputed crypt villus axis to the distal Visium data

In [None]:
visium_object = sc.read("visium_combined.h5ad")
seq_data_full = visium_object[visium_object.obs["batch"] == "proximal"]

sc.pp.filter_cells(seq_data_full, min_counts=1)

seq_data_full = seq_data_full[seq_data_full.obs.index.isin(seq_data.obs.index)]

seq_data_full.obs = seq_data.obs
seq_data_full.obsm = seq_data.obsm

seq_data_full.write("visium_with_axis_proximal.h5ad")