# Setup

In [None]:
!pip install -q condacolab

In [None]:
import condacolab
condacolab.install()

✨🍰✨ Everything looks OK!


In [None]:
!conda create -n py3918 python=3.9.18 -y


Channels:
 - conda-forge
Platform: linux-64
Collecting package metadata (repodata.json): - \ | / - \ | / - \ | / - \ | / - \ | / - done
Solving environment: | / done


    current version: 24.11.2
    latest version: 25.1.1

Please update conda by running

    $ conda update -n base -c conda-forge conda



## Package Plan ##

  environment location: /usr/local/envs/py3918

  added / updated specs:
    - python=3.9.18


The following NEW packages will be INSTALLED:

  _libgcc_mutex      conda-forge/linux-64::_libgcc_mutex-0.1-conda_forge 
  _openmp_mutex      conda-forge/linux-64::_openmp_mutex-4.5-2_gnu 
  bzip2              conda-forge/linux-64::bzip2-1.0.8-h4bc722e_7 
  ca-certificates    conda-forge/linux-64::ca-certificates-2025.1.31-hbcca054_0 
  ld_impl_linux-64   conda-forge/linux-64::ld_impl_linux-64-2.43-h712a8e2_2 
  libffi             conda-forge/linux-64::libffi-3.4.2-h7f98852_5 
  libgcc             conda-forge/linux-64::libgcc

In [None]:
!source activate py3918

In [None]:
!conda run -n py3918 python --version

Python 3.9.18



In [None]:
!git clone https://github.com/bhargavchippada/forceatlas2.git


Cloning into 'forceatlas2'...
remote: Enumerating objects: 232, done.[K
remote: Counting objects: 100% (90/90), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 232 (delta 78), reused 74 (delta 74), pack-reused 142 (from 1)[K
Receiving objects: 100% (232/232), 502.31 KiB | 2.15 MiB/s, done.
Resolving deltas: 100% (134/134), done.


In [None]:
!pip install cython



In [None]:
!conda install -c conda-forge fa2



Channels:
 - conda-forge
Platform: linux-64
Collecting package metadata (repodata.json): - \ | / - \ | / done
Solving environment: \ | / - done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - fa2


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    conda-24.11.3              |  py311h38be061_0         1.1 MB  conda-forge
    fa2-0.3.5                  |  py311hd4cff14_2          97 KB  conda-forge
    libblas-3.9.0              |28_h59b9bed_openblas          16 KB  conda-forge
    libcblas-3.9.0             |28_he106b2a_openblas          16 KB  conda-forge
    libgfortran-14.2.0         |       h69a702a_1          53 KB  conda-forge
    libgfortran5-14.2.0        |       hd5240d6_1         1.4 MB  conda-forge
    liblapack-3.9.0            |28_h7ac8fdf_openblas          16 KB  conda-forge
    libopenblas-0.3.28         |pthread

In [None]:
!pip install scanpy



In [None]:
import fa2
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0" # Change to -1 if you want to use CPU!

import warnings
warnings.filterwarnings('ignore')

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd
import scanpy as sc
import colorcet
import sklearn.neighbors
import scipy.sparse
import umap.umap_ as umap
from fa2 import ForceAtlas2

# ENVI


In [None]:
def flatten(arr):
    return(np.reshape(arr, [arr.shape[0], -1]))

def force_directed_layout(affinity_matrix, cell_names=None, verbose=True, iterations=500, device='cpu'):
    """" Function to compute force directed layout from the affinity_matrix
    :param affinity_matrix: Sparse matrix representing affinities between cells
    :param cell_names: pandas Series object with cell names
    :param verbose: Verbosity for force directed layout computation
    :param iterations: Number of iterations used by ForceAtlas
    :return: Pandas data frame representing the force directed layout
    """

    init_coords = np.random.random((affinity_matrix.shape[0], 2))

    if device == 'cpu':
        forceatlas2 = ForceAtlas2(
            # Behavior alternatives
            outboundAttractionDistribution=False,
            linLogMode=False,
            adjustSizes=False,
            edgeWeightInfluence=1.0,
            # Performance
            jitterTolerance=1.0,
            barnesHutOptimize=True,
            barnesHutTheta=1.2,
            multiThreaded=False,
            # Tuning
            scalingRatio=2.0,
            strongGravityMode=False,
            gravity=1.0,
            # Log
            verbose=verbose)

        positions = forceatlas2.forceatlas2(
            affinity_matrix, pos=init_coords, iterations=iterations)
        positions = np.array(positions)


    positions = pd.DataFrame(positions,
                             index=np.arange(affinity_matrix.shape[0]), columns=['x', 'y'])
    return positions

def run_diffusion_maps(data_df, n_components=10, knn=30, alpha=0):
    """Run Diffusion maps using the adaptive anisotropic kernel
    :param data_df: PCA projections of the data or adjacency matrix
    :param n_components: Number of diffusion components
    :param knn: Number of nearest neighbors for graph construction
    :param alpha: Normalization parameter for the diffusion operator
    :return: Diffusion components, corresponding eigen values and the diffusion operator
    """

    # Determine the kernel
    N = data_df.shape[0]

    if(type(data_df).__module__ == np.__name__):
        data_df = pd.DataFrame(data_df)

    if not scipy.sparse.issparse(data_df):
        print("Determing nearest neighbor graph...")
        temp = sc.AnnData(data_df.values)
        sc.pp.neighbors(temp, n_pcs=0, n_neighbors=knn)
        kNN = temp.obsp['distances']

        # Adaptive k
        adaptive_k = int(np.floor(knn / 3))
        adaptive_std = np.zeros(N)

        for i in np.arange(len(adaptive_std)):
            adaptive_std[i] = np.sort(kNN.data[kNN.indptr[i] : kNN.indptr[i + 1]])[
                adaptive_k - 1
            ]

        # Kernel
        x, y, dists = scipy.sparse.find(kNN)

        # X, y specific stds
        dists = dists / adaptive_std[x]
        W = scipy.sparse.csr_matrix((np.exp(-dists), (x, y)), shape=[N, N])

        # Diffusion components
        kernel = W + W.T
    else:
        kernel = data_df

    # Markov
    D = np.ravel(kernel.sum(axis=1))

    if alpha > 0:
        # L_alpha
        D[D != 0] = D[D != 0] ** (-alpha)
        mat = scipy.sparse.csr_matrix((D, (range(N), range(N))), shape=[N, N])
        kernel = mat.dot(kernel).dot(mat)
        D = np.ravel(kernel.sum(axis=1))

    D[D != 0] = 1 / D[D != 0]
    T = scipy.sparse.csr_matrix((D, (range(N), range(N))), shape=[N, N]).dot(kernel)
    # Eigen value dcomposition
    D, V = scipy.sparse.linalg.eigs(T, n_components, tol=1e-4, maxiter=1000)
    D = np.real(D)
    V = np.real(V)
    inds = np.argsort(D)[::-1]
    D = D[inds]
    V = V[:, inds]

    # Normalize
    for i in range(V.shape[1]):
        V[:, i] = V[:, i] / np.linalg.norm(V[:, i])

    # Create are results dictionary
    res = {"T": T, "EigenVectors": V, "EigenValues": D}
    res["EigenVectors"] = pd.DataFrame(res["EigenVectors"])
    if not scipy.sparse.issparse(data_df):
        res["EigenVectors"].index = data_df.index
    res["EigenValues"] = pd.Series(res["EigenValues"])
    res["kernel"] = kernel

    return res


def FDL(data, k = 30):


    nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=int(k), metric='euclidean',
                               n_jobs=5).fit(data)
    kNN = nbrs.kneighbors_graph(data, mode='distance')
    # Adaptive k

    adaptive_k = int(np.floor(k / 3))
    nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=int(adaptive_k),
                           metric='euclidean', n_jobs=5).fit(data)
    adaptive_std = nbrs.kneighbors_graph(data, mode='distance').max(axis=1)
    adaptive_std = np.ravel(adaptive_std.todense())
    # Kernel
    x, y, dists = scipy.sparse.find(kNN)
    # X, y specific stds
    dists = dists / adaptive_std[x]
    N = data.shape[0]
    W = scipy.sparse.csr_matrix((np.exp(-dists), (x, y)), shape=[N, N])
    # Diffusion components
    kernel = W + W.T
    layout = force_directed_layout(kernel)
    return(layout)

In [None]:
sc_data = sc.read_h5ad('/Users/anushka/Desktop/MERFISH data/sc_data.h5ad')

In [None]:
st_data= sc.read_h5ad('/Users/anushka/Desktop/MERFISH data/st_data.h5ad')

In [None]:
plt.figure(figsize=(10,10))

sns.scatterplot(x = st_data.obsm['spatial'][st_data.obs['batch'] == 'mouse1_slice10'][:, 1],
                y = -st_data.obsm['spatial'][st_data.obs['batch'] == 'mouse1_slice10'][:, 0], legend = True,
                hue = st_data.obs['cell_type'][st_data.obs['batch'] == 'mouse1_slice10'],
                s = 12, palette = cell_type_palette)
plt.axis('equal')
plt.axis('off')
plt.title("MERFISH Data")
plt.show()

In [None]:
fit = umap.UMAP(
    n_neighbors = 100,
    min_dist = 0.8,
    n_components = 2,
)

sc_data.layers['log'] = np.log(sc_data.X + 1)
sc.pp.highly_variable_genes(sc_data, layer = 'log', n_top_genes = 2048)
sc_data.obsm['UMAP_exp'] = fit.fit_transform(np.log(sc_data[:, sc_data.var['highly_variable']].X + 1))

In [None]:
fig = plt.figure(figsize = (10,10))
sns.scatterplot(x = sc_data.obsm['UMAP_exp'][:, 0], y = sc_data.obsm['UMAP_exp'][:, 1],  hue = sc_data.obs['cell_type'], s = 16,
                palette = cell_type_palette, legend = True)
plt.tight_layout()
plt.axis('off')
plt.title('scRNA-seq Data')
plt.show()

In [None]:
import scenvi
envi_model = scenvi.ENVI(spatial_data = st_data, sc_data = sc_data)

In [None]:
envi_model.train()
envi_model.impute_genes()
envi_model.infer_niche_covet()
envi_model.infer_niche_celltype()

In [None]:
st_data.obsm['envi_latent'] = envi_model.spatial_data.obsm['envi_latent']
st_data.obsm['COVET'] = envi_model.spatial_data.obsm['COVET']
st_data.obsm['COVET_SQRT'] = envi_model.spatial_data.obsm['COVET_SQRT']
st_data.uns['COVET_genes'] =  envi_model.CovGenes
st_data.obsm['imputation'] = envi_model.spatial_data.obsm['imputation']
st_data.obsm['cell_type_niche'] = envi_model.spatial_data.obsm['cell_type_niche']

sc_data.obsm['envi_latent'] = envi_model.sc_data.obsm['envi_latent']
sc_data.obsm['COVET'] = envi_model.sc_data.obsm['COVET']
sc_data.obsm['COVET_SQRT'] = envi_model.sc_data.obsm['COVET_SQRT']
sc_data.obsm['cell_type_niche'] = envi_model.sc_data.obsm['cell_type_niche']
sc_data.uns['COVET_genes'] =  envi_model.CovGenes

In [None]:
fit = umap.UMAP(
    n_neighbors = 100,
    min_dist = 0.3,
    n_components = 2,
)

latent_umap = fit.fit_transform(np.concatenate([st_data.obsm['envi_latent'], sc_data.obsm['envi_latent']], axis = 0))

st_data.obsm['latent_umap'] = latent_umap[:st_data.shape[0]]
sc_data.obsm['latent_umap'] = latent_umap[st_data.shape[0]:]

In [None]:
lim_arr = np.concatenate([st_data.obsm['latent_umap'], sc_data.obsm['latent_umap']], axis = 0)


delta = 1
pre = 0.1
xmin = np.percentile(lim_arr[:, 0], pre) - delta
xmax = np.percentile(lim_arr[:, 0], 100 - pre) + delta
ymin = np.percentile(lim_arr[:, 1], pre) - delta
ymax = np.percentile(lim_arr[:, 1], 100 - pre) + delta

In [None]:
fig = plt.figure(figsize = (13,5))
plt.subplot(121)
sns.scatterplot(x = sc_data.obsm['latent_umap'][:, 0],
                y = sc_data.obsm['latent_umap'][:, 1], hue = sc_data.obs['cell_type'], s = 8, palette = cell_type_palette,
                legend = False)
plt.title("scRNA-seq Latent")
plt.xlim([xmin, xmax])
plt.ylim([ymin, ymax])
plt.axis('off')

plt.subplot(122)
sns.scatterplot(x = st_data.obsm['latent_umap'][:, 0],
                y = st_data.obsm['latent_umap'][:, 1],  hue = st_data.obs['cell_type'], s = 8, palette = cell_type_palette, legend = True)


legend = plt.legend(title = 'Cell Type', prop={'size': 12}, fontsize = '12',  markerscale = 3, ncol = 2, bbox_to_anchor = (1, 1))#, loc = 'lower left')
plt.setp(legend.get_title(),fontsize='12')
plt.title("MERFISH Latent")
plt.axis('off')
plt.tight_layout()
plt.xlim([xmin, xmax])
plt.ylim([ymin, ymax])
plt.show()

In [None]:
st_data_sst = st_data[st_data.obs['cell_type'] == 'Sst']
sc_data_sst = sc_data[sc_data.obs['cell_type'] == 'Sst']

In [None]:
gran_sst_palette = {'Th': (0.0, 0.294118, 0.0, 1.0),
                    'Calb2': (0.560784, 0.478431, 0.0, 1.0),
                    'Chodl': (1.0, 0.447059, 0.4, 1.0),
                    'Myh8': (0.933333, 0.72549, 0.72549, 1.0),
                    'Crhr2': (0.368627, 0.494118, 0.4, 1.0),
                    'Hpse': (0.65098, 0.482353, 0.72549, 1.0),
                    'Hspe': (0.352941, 0.0, 0.643137, 1.0),
                    'Crh': (0.607843, 0.894118, 1.0, 1.0),
                    'Pvalb Etv1': (0.92549, 0.0, 0.466667, 1.0)}

In [None]:
FDL_COVET = np.asarray(FDL(np.concatenate([flatten(st_data_sst.obsm['COVET_SQRT']),
                                           flatten(sc_data_sst.obsm['COVET_SQRT'])], axis = 0), k = 30))

st_data_sst.obsm['FDL_COVET'] = FDL_COVET[:st_data_sst.shape[0]]
sc_data_sst.obsm['FDL_COVET'] = FDL_COVET[st_data_sst.shape[0]:]

In [None]:
DC_COVET = np.asarray(run_diffusion_maps(np.concatenate([flatten(st_data_sst.obsm['COVET_SQRT']),
                                                         flatten(sc_data_sst.obsm['COVET_SQRT'])], axis = 0), knn = 30)['EigenVectors'])[:, 1:]
st_data_sst.obsm['DC_COVET'] = -DC_COVET[:st_data_sst.shape[0]]
sc_data_sst.obsm['DC_COVET'] = -DC_COVET[st_data_sst.shape[0]:]

In [None]:
st_data_sst.obsm['DC_COVET'] = -DC_COVET[:st_data_sst.shape[0]]
sc_data_sst.obsm['DC_COVET'] = -DC_COVET[st_data_sst.shape[0]:]

In [None]:
lim_arr = np.concatenate([st_data_sst.obsm['FDL_COVET'], sc_data_sst.obsm['FDL_COVET']], axis = 0)


delta = 1000
pre = 0.01
xmin = np.percentile(lim_arr[:, 0], pre) - delta
xmax = np.percentile(lim_arr[:, 0], 100 - pre) + delta
ymin = np.percentile(lim_arr[:, 1], pre) - delta
ymax = np.percentile(lim_arr[:, 1], 100 - pre) + delta

In [None]:
plt.figure(figsize=(10,5))

plt.subplot(121)
sns.scatterplot(x = sc_data_sst.obsm['FDL_COVET'][:, 0],
                y = sc_data_sst.obsm['FDL_COVET'][:, 1],
                hue = sc_data_sst.obs['cluster_label'], s = 16,  palette= gran_sst_palette, legend = True)
plt.tight_layout()
plt.xlim([xmin, xmax])
plt.ylim([ymin, ymax])
plt.title('scRNA-seq Sst, COVET FDL')
legend = plt.legend(title = 'Sst subtype', prop={'size': 8}, fontsize = '8',  markerscale = 1, ncol = 2)
plt.axis('off')

plt.subplot(122)
ax = sns.scatterplot(x = st_data_sst.obsm['FDL_COVET'][:, 0],
                y = st_data_sst.obsm['FDL_COVET'][:, 1],
                c = st_data_sst.obsm['DC_COVET'][:,0], s = 16,  cmap= 'cet_CET_D13', legend = False)
plt.tight_layout()
plt.xlim([xmin, xmax])
plt.ylim([ymin, ymax])
plt.axis('off')
plt.title('MERFISH Sst, COVET FDL')
plt.show()

In [None]:
fig = plt.figure(figsize=(25,5))

for ind, batch in enumerate(['mouse1_slice212', 'mouse1_slice162', 'mouse1_slice71', 'mouse2_slice270', 'mouse1_slice40']):
    st_dataBatch = st_data[st_data.obs['batch'] == batch]
    st_dataPlotBatch = st_data_sst[st_data_sst.obs['batch'] == batch]

    plt.subplot(1,5, 1+ ind)
    sns.scatterplot(x = st_dataBatch.obsm['spatial'][:, 0], y = st_dataBatch.obsm['spatial'][:, 1],  color = (207/255,185/255,151/255, 1))
    sns.scatterplot(x = st_dataPlotBatch.obsm['spatial'][:, 0], y = st_dataPlotBatch.obsm['spatial'][:, 1], marker = '^',
                        c = st_dataPlotBatch.obsm['DC_COVET'][:, 0], s = 256,  cmap= 'cet_CET_D13', legend = False)
    plt.title(batch)
    plt.axis('off')
    plt.tight_layout()

plt.show()

In [None]:
depth_df = pd.DataFrame()
depth_df['Subtype'] = sc_data_sst.obs['cluster_label']
depth_df['Depth'] = -sc_data_sst.obsm['DC_COVET'][:,0]

In [None]:
subtype_depth_order = depth_df.groupby(['Subtype']).mean().sort_values(by = 'Depth', ascending=False).index


In [None]:
plt.figure(figsize=(12,5))
sns.set(font_scale=1.7)
sns.set_style("whitegrid")
sns.boxenplot(depth_df, x = 'Subtype', y = 'Depth',# bw = 1, width = 0.9,
          order = subtype_depth_order,
          palette = gran_sst_palette)
plt.tight_layout()
plt.show()

In [None]:
subtype_canonical = pd.DataFrame([sc_data_sst[sc_data_sst.obs['cluster_label']==subtype].obsm['cell_type_niche'].mean(axis = 0) for subtype in subtype_depth_order],
                                     index = subtype_depth_order, columns = sc_data.obsm['cell_type_niche'].columns)

In [None]:
subtype_canonical[subtype_canonical<0.2] = 0
subtype_canonical.drop(labels=subtype_canonical.columns[(subtype_canonical == 0).all()], axis=1, inplace=True)
subtype_canonical = subtype_canonical.div(subtype_canonical.sum(axis=1), axis=0)
subtype_canonical.plot(kind = 'bar', stacked = 'True',
                       color = {col:cell_type_palette[col] for col in subtype_canonical.columns})
plt.legend(bbox_to_anchor = (1,1), ncols = 1, fontsize = 'x-small')
plt.title("Predicted Niche Composition")
plt.ylabel("Proportion")
plt.xlabel("Sst Subtype")
plt.show()


In [None]:
tick_genes = np.asarray(['Adamts18','Pamr1', 'Dkkl1', 'Hs6st2', 'Slit1', 'Ighm'])

plt.figure(figsize=(15,10))

for ind, gene in enumerate(tick_genes):
    plt.subplot(2,3,1+ind)

    cvec = np.log(st_data[st_data.obs['batch'] == 'mouse1_slice10'].obsm['imputation'][gene] + 0.1)
    sns.scatterplot(x = st_data.obsm['spatial'][st_data.obs['batch'] == 'mouse1_slice10'][:, 1],
                    y = -st_data.obsm['spatial'][st_data.obs['batch'] == 'mouse1_slice10'][:, 0], legend = False,
                    c = cvec, cmap = 'Reds',
                    vmax = np.percentile(cvec, 95), vmin = np.percentile(cvec, 30),
                    s = 24, edgecolor = 'k')#, palette = cell_type_palette)
    plt.title(gene)
    plt.axis('equal')
    plt.axis('off')
    plt.tight_layout()
plt.show()

# COVET

In [None]:
import numpy as np
from sklearn.neighbors import NearestNeighbors
import numpy as np
from scipy.linalg import sqrtm

def calculate_covet(expression_matrix, spatial_coordinates, k=8):
    # Find spatial nearest neighbors
    nn = NearestNeighbors(n_neighbors=k, metric='euclidean')
    nn.fit(spatial_coordinates)
    _, indices = nn.kneighbors(spatial_coordinates)

    # Calculate global mean
    global_mean = np.mean(expression_matrix, axis=0)

    # Calculate COVET matrices
    covet_matrices = []
    for idx in indices:
        niche_matrix = expression_matrix[idx]
        shifted_matrix = niche_matrix - global_mean
        covet = np.dot(shifted_matrix.T, shifted_matrix) / k
        covet_matrices.append(covet)

    return np.array(covet_matrices)


def aot_distance(covet1, covet2, epsilon=0.1, max_iter=100):
    """
    Calculate the Approximate Optimal Transport distance between two COVET matrices.

    Args:
    covet1, covet2: Input COVET matrices
    epsilon: Regularization parameter
    max_iter: Maximum number of iterations for Sinkhorn algorithm

    Returns:
    float: AOT distance between covet1 and covet2
    """
    # Ensure matrices are positive semi-definite
    covet1 = np.maximum(covet1, 0)
    covet2 = np.maximum(covet2, 0)

    # Calculate matrix square roots
    sqrt_covet1 = sqrtm(covet1)
    sqrt_covet2 = sqrtm(covet2)

    # Calculate the product of square roots
    product = np.dot(sqrt_covet1, sqrt_covet2)

    # Compute the trace
    trace_term = np.trace(covet1 + covet2 - 2 * sqrtm(product))

    # Sinkhorn iteration for entropic regularization
    n = covet1.shape[0]
    K = np.exp(-trace_term / epsilon)
    u = np.ones(n) / n
    v = np.ones(n) / n

    for _ in range(max_iter):
        u = 1 / np.dot(K, v)
        v = 1 / np.dot(K.T, u)

    # Compute final distance
    pi = np.diag(u) @ K @ np.diag(v)
    distance = np.sum(pi * trace_term)

    return np.sqrt(distance)



def get_covet_knn_matrix(covet_matrices, k=8):
    n_cells = len(covet_matrices)
    distance_matrix = np.zeros((n_cells, n_cells))

    # Calculate pairwise distances
    for i in range(n_cells):
        for j in range(i+1, n_cells):
            dist = aot_distance(covet_matrices[i], covet_matrices[j])
            distance_matrix[i, j] = distance_matrix[j, i] = dist

    # Find k nearest neighbors
    nn = NearestNeighbors(n_neighbors=k, metric='precomputed')
    nn.fit(distance_matrix)

    return nn.kneighbors(return_distance=False)





In [None]:
expression_matrix = sc_data
spatial_coordinates = st_data




In [None]:
covet_matrices = calculate_covet(expression_matrix, spatial_coordinates)

In [None]:
knn_matrix = get_covet_knn_matrix(covet_matrices)

# VAE

In [None]:
import scanpy as sc
import tensorflow as tf
import numpy as np
from sklearn.neighbors import NearestNeighbors
from tensorflow.keras import layers, Model

In [None]:
expression_matrix = sc_data
spatial_coordinates = st_data

In [None]:
import tensorflow as tf
import numpy as np
from sklearn.neighbors import NearestNeighbors

# Define the VAE model
class VAE(tf.keras.Model):
    def __init__(self, input_dim, latent_dim=32):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(input_dim,)),
            tf.keras.layers.Dense(512, activation='relu'),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(latent_dim * 2)
        ])

        self.decoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(512, activation='relu'),
            tf.keras.layers.Dense(input_dim, activation='sigmoid')
        ])

    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * 0.5) + mean

    def decode(self, z):
        return self.decoder(z)

    def call(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        x_recon = self.decode(z)
        return x_recon, mean, logvar

# Define the loss function
def vae_loss(x, x_recon, mean, logvar):
    reconstruction_loss = tf.reduce_sum(tf.keras.losses.binary_crossentropy(x, x_recon), axis=1)
    kl_loss = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar), axis=1)
    return tf.reduce_mean(reconstruction_loss + kl_loss)

# Training function
@tf.function
def train_step(model, x, optimizer):
    with tf.GradientTape() as tape:
        x_recon, mean, logvar = model(x)
        loss = vae_loss(x, x_recon, mean, logvar)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

# Main training loop
def train_vae(expression_matrix, latent_dim=32, epochs=100, batch_size=128):
    input_dim = expression_matrix.shape[1]
    vae = VAE(input_dim, latent_dim)
    optimizer = tf.keras.optimizers.Adam(1e-3)

    dataset = tf.data.Dataset.from_tensor_slices(expression_matrix).shuffle(1000).batch(batch_size)

    for epoch in range(epochs):
        total_loss = 0
        for batch in dataset:
            loss = train_step(vae, batch, optimizer)
            total_loss += loss

        avg_loss = total_loss / len(dataset)
        print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}')

    return vae

# Get latent representations
def get_latent_representations(vae, expression_matrix):
    mean, _ = vae.encode(expression_matrix)
    return mean.numpy()

def compute_knn_matrix(latent_representations, n_neighbors=15):
    nn = NearestNeighbors(n_neighbors=n_neighbors, metric='cosine')
    nn.fit(latent_representations)
    return nn.kneighbors_graph(mode='connectivity')





In [None]:
# Main execution
vae = train_vae(expression_matrix)
latent_representations = get_latent_representations(vae, expression_matrix)
knn_matrix = compute_knn_matrix(latent_representations)

print("KNN matrix shape:", knn_matrix.shape)