In [26]:
import scanpy as sc
import os
import numpy as np
from tqdm.notebook import tqdm
import glob
from scipy.spatial import KDTree
import matplotlib.pyplot as plt
from tensorflow import keras

### Put in the path to the reference dataset where the crypt-villi axis was calculated

In [27]:
reference = sc.read(os.path.join('/mnt/sata1/Analysis_Alex/timecourse_replicates/unrolling_meta', 'reference_prep_decomposition_model.h5ad'))



In [28]:
input_folders = glob.glob('/mnt/sata1/Analysis_Alex/timecourse_replicates/day*')

### Calculate the epithelial axis for all datasets

In [29]:

sc.set_figure_params(dpi=1000, dpi_save=1000)
for input_file in input_folders:
    ad = sc.read(os.path.join(input_file, 'adatas', '09_before_decomposition_model.h5ad')) 
    points_epi = ad[ad.obs.Class.isin(['Epithelial'])].obsm['X_spatial']

    all_tree = KDTree(ad.obsm['X_spatial'])
    epi_tree = KDTree(points_epi)
    distances_all, neighbors_all = all_tree.query(ad.obsm['X_spatial'], k=5)
    distances, neighbors = epi_tree.query(ad.obsm['X_spatial'], k=5)
    distance_medians = (np.mean(distances, axis=1)/np.mean(distances_all, axis=1))
    ad.obs['epithelial_distance'] = distance_medians
    ad.obs['epithelial_distance'] = ad.obs['epithelial_distance']/np.percentile(ad.obs['epithelial_distance'], 99)
    fig = sc.pl.embedding(ad, basis = 'spatial', color='epithelial_distance', return_fig=True, show=False, vmax=1, cmap='viridis', size=4)
    fig.tight_layout()
    plt.axis('equal')
    fig.savefig(os.path.join(input_file, 'figures', 'axes', f'spatial_epithelial.png'))
    plt.close()
    ad.write(os.path.join(input_file, 'adatas', '09_before_decomposition_model.h5ad'))



### Calculate the crypt-villi axis for all datasets

In [30]:
reference = reference[reference.obs['in_villi']]


In [31]:
training_data = reference.obs[reference.obs.columns[reference.obs.columns.str.contains('Topic')]].values

In [32]:
training_labels = reference.obs['normalized_crypt_villi_scaled'].values

In [33]:
# Define model
model = keras.Sequential([
    keras.layers.Dense(64, activation='relu', input_shape=(training_data.shape[1],)),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

# Compile the model
model.compile(optimizer='adam', loss='mean_squared_error')

# Train the model
model.fit(training_data, training_labels, epochs=10, batch_size=32, verbose= 2)

Epoch 1/10


6527/6527 - 4s - loss: 0.0224 - 4s/epoch - 547us/step
Epoch 2/10
6527/6527 - 3s - loss: 0.0216 - 3s/epoch - 513us/step
Epoch 3/10
6527/6527 - 3s - loss: 0.0213 - 3s/epoch - 515us/step
Epoch 4/10
6527/6527 - 3s - loss: 0.0211 - 3s/epoch - 515us/step
Epoch 5/10
6527/6527 - 3s - loss: 0.0210 - 3s/epoch - 515us/step
Epoch 6/10
6527/6527 - 3s - loss: 0.0209 - 3s/epoch - 516us/step
Epoch 7/10
6527/6527 - 3s - loss: 0.0208 - 3s/epoch - 516us/step
Epoch 8/10
6527/6527 - 3s - loss: 0.0207 - 3s/epoch - 510us/step
Epoch 9/10
6527/6527 - 3s - loss: 0.0206 - 3s/epoch - 508us/step
Epoch 10/10
6527/6527 - 3s - loss: 0.0205 - 3s/epoch - 514us/step


<keras.src.callbacks.History at 0x7f5b045fda50>

In [34]:
for input_file in input_folders:
    adata = sc.read(os.path.join(input_file, 'adatas', '09_before_decomposition_model.h5ad'))
    adata.obs = adata.obs.drop(adata.obs.columns[adata.obs.columns.str.contains('_x')], axis=1)
    testing_data = adata.obs[adata.obs.columns[adata.obs.columns.str.contains('Topic')]].values
    predictions = model.predict(testing_data)
    adata.obs['crypt_villi_axis'] = predictions
    try:
        adata.obs['predicted_longitudinal'] = adata.obs['predicted_longitudinal'].replace(-1.0, np.nan)
        adata.obs['predicted_longitudinal'] = adata.obs['predicted_longitudinal']/100000
    except:
        adata.obs['predicted_longitudinal'] = adata.obs['longitudinal'].replace(-1.0, np.nan)
        adata.obs['predicted_longitudinal'] = adata.obs['predicted_longitudinal']/100000        
    
    adata.write(os.path.join(input_file, 'adatas', '10_axes_defined.h5ad'))
    fig = sc.pl.embedding(adata, basis = 'spatial', color='crypt_villi_axis', return_fig=True, show=False, vmax=1, cmap='viridis', size=4)
    fig.tight_layout()
    plt.axis('equal')
    fig.savefig(os.path.join(input_file, 'figures', 'axes', f'spatial_crypt_villi.png'))
    plt.close()















