In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import os
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from scipy.spatial import KDTree
from sklearn.decomposition import NMF
import glob
import sys
from core_functions.initial_neighborhoods import *

##### We need to downsample the Xenium data to integrate with Visium since Xenium is single cell and Visium is not.

In [None]:
path_to_xenium_final_object = '/mnt/sata1/Analysis_Alex/timecourse_replicates/analysis/cleaned/full_xenium_replicates_and_reference_no_peyers.h5ad'
xenium_object= sc.read(path_to_xenium_final_object)

##### Choose the batch in the Xenium object that is closest to the Visium conditions. In our case Day 30 post infection Xenium is the most related sample. We are proceeding with the day 30 rep2 Xenium experiment for integration.


In [None]:
batch_to_use = 'day30_SI_r2'

In [None]:
path_to_visium_object = 'visium_combined.h5ad'
visium_object = sc.read(path_to_visium_object)

In [None]:
def create_binned_data_xenium(adata, bins, centers, unique_bins):
    '''
    Function to bin a Xenium object so its not single cell resolution anymore.

    Parameters
    ----------
    adata : AnnData object
        The Xenium object to be binned.
    bins : numpy array
        The bin number for each cell in the Xenium object.
    centers : numpy array
        The center of each bin.
    unique_bins : numpy array
        The unique bin numbers.
    
    Returns
    -------
    adata_filtered : AnnData object
        The binned Xenium object.
    '''
    expression_matrix = []
    arr = np.array(adata.X)
    for b in range(len(unique_bins)):
        where_bin = np.where(bins==b)[0]
        try:
            bin_expression = np.array(np.sum(arr[where_bin, :], axis=0).flatten()).squeeze()
        except:
            bin_expression = np.array([float(0) for i in range(len(adata.var.index))])
        expression_matrix.append(bin_expression)

    crypt_villi = []
    for b in range(len(unique_bins)):
        where_bin = np.where(bins==b)[0]
        
        try:
            bin_expression = np.mean(adata.obs['crypt_villi_axis'].values[where_bin])
        except:
            bin_expression = 0
        crypt_villi.append(bin_expression)

    expression_matrix = np.array(expression_matrix)
    crypt_villi = np.array(crypt_villi)
    ad = sc.AnnData(X=expression_matrix, obs=pd.DataFrame(crypt_villi, index=unique_bins, columns = ['crypt_villi']), var=pd.DataFrame(index=adata.var.index.tolist()))
    ad.obsm['spatial'] = np.array(centers)
    nan_obs_indices = np.where(np.isnan(ad.X.sum(axis=1)))[0]

    # Filter out observations with NaN values
    adata_filtered = ad[~np.isin(ad.obs_names, ad.obs_names[nan_obs_indices])].copy()
    return adata_filtered

In [None]:
#Specify the number of bins you want for your Xenium data. We opted for 3600 bins, which is 60x60. This is because it gives us a similar spot resolution as the Visium data.

n_bins_square = 60

##### Perform the downsampling

In [None]:
unique_batches = np.unique(xenium_object.obs.batch)

downsampled_adatas = []
for batch in unique_batches:
    if batch == batch_to_use:
        adata = xenium_object[xenium_object.obs['batch'] == batch]

        low_res_binning = n_bins_square # int(np.sqrt((len(visium_object[visium_object.obs['batch'] == 'distal'].obs)+len(visium_object[visium_object.obs['batch'] == 'proximal'].obs))/2))
        
        spatial_points = np.array([adata.obsm['X_spatial'][:,0], adata.obsm['X_spatial'][:,1]]).T
        binned_points_low, binned_centers_low = create_grid_bins(spatial_points, low_res_binning)
        
        zeros_low = np.zeros(len(spatial_points))
        low_counts = []
        ct = 0
        for binn in range(len(binned_points_low)):
            for k in range(len(binned_points_low[binn])):
                zeros_low[binned_points_low[binn][k]] = ct
                low_counts.append(ct)
                ct += 1
                
        adata_bin30 = create_binned_data_xenium(adata, zeros_low, binned_centers_low, low_counts)
        downsampled_adatas.append(adata_bin30)

downsampled_adatas = downsampled_adatas[0]

In [None]:
sc.pl.embedding(downsampled_adatas, basis ='spatial', color = 'crypt_villi')

In [None]:
downsampled_adatas.write('downsampled_mouse.h5ad')