In [2]:
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0" # Change to -1 if you want to use CPU!

import warnings
warnings.filterwarnings('ignore')

In [1]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd
import scanpy as sc
import colorcet
import sklearn.neighbors
import scipy.sparse
import umap.umap_ as umap
from fa2 import ForceAtlas2

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
sc_data = sc.read_h5ad('/Users/anushka/Desktop/MERFISH data/sc_data.h5ad')
st_data= sc.read_h5ad('/Users/anushka/Desktop/MERFISH data/st_data.h5ad')

In [12]:
import scanpy as sc
import tensorflow as tf
import numpy as np
from sklearn.neighbors import NearestNeighbors
from tensorflow.keras import layers, Model

In [9]:
expression_matrix = sc_data
spatial_coordinates = st_data

In [16]:
import tensorflow as tf
import numpy as np
from sklearn.neighbors import NearestNeighbors

# Define the VAE model
class VAE(tf.keras.Model):
    def __init__(self, input_dim, latent_dim=32):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(input_dim,)),
            tf.keras.layers.Dense(512, activation='relu'),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(latent_dim * 2)
        ])
        
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(512, activation='relu'),
            tf.keras.layers.Dense(input_dim, activation='sigmoid')
        ])

    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * 0.5) + mean

    def decode(self, z):
        return self.decoder(z)

    def call(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        x_recon = self.decode(z)
        return x_recon, mean, logvar

# Define the loss function
def vae_loss(x, x_recon, mean, logvar):
    reconstruction_loss = tf.reduce_sum(tf.keras.losses.binary_crossentropy(x, x_recon), axis=1)
    kl_loss = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar), axis=1)
    return tf.reduce_mean(reconstruction_loss + kl_loss)

# Training function
@tf.function
def train_step(model, x, optimizer):
    with tf.GradientTape() as tape:
        x_recon, mean, logvar = model(x)
        loss = vae_loss(x, x_recon, mean, logvar)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

# Main training loop
def train_vae(expression_matrix, latent_dim=32, epochs=100, batch_size=128):
    input_dim = expression_matrix.shape[1]
    vae = VAE(input_dim, latent_dim)
    optimizer = tf.keras.optimizers.Adam(1e-3)
    
    dataset = tf.data.Dataset.from_tensor_slices(expression_matrix).shuffle(1000).batch(batch_size)
    
    for epoch in range(epochs):
        total_loss = 0
        for batch in dataset:
            loss = train_step(vae, batch, optimizer)
            total_loss += loss
        
        avg_loss = total_loss / len(dataset)
        print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}')
    
    return vae

# Get latent representations
def get_latent_representations(vae, expression_matrix):
    mean, _ = vae.encode(expression_matrix)
    return mean.numpy()

def compute_knn_matrix(latent_representations, n_neighbors=15):
    nn = NearestNeighbors(n_neighbors=n_neighbors, metric='cosine')
    nn.fit(latent_representations)
    return nn.kneighbors_graph(mode='connectivity')





In [None]:
# Main execution
vae = train_vae(expression_matrix)
latent_representations = get_latent_representations(vae, expression_matrix)
knn_matrix = compute_knn_matrix(latent_representations)

print("KNN matrix shape:", knn_matrix.shape)