In [None]:
import scanpy as sc
import os
import numpy as np
from tqdm.notebook import tqdm
import glob
from scipy.spatial import KDTree
import matplotlib.pyplot as plt
from tensorflow import keras

### Put in the path to the reference dataset where the crypt-villi axis was calculated

In [None]:
xenium_path = "D:/amonell/timecourse_final/day7_SI_DMSO"

In [None]:
reference = sc.read(
    os.path.join(xenium_path, "adatas", "07_axes_defined_reference.h5ad")
)

In [None]:
input_folders = glob.glob("D:/amonell/timecourse_final/day*")

### Calculate the epithelial axis for all datasets

In [None]:
sc.set_figure_params(dpi=1000, dpi_save=1000)
for input_file in input_folders:
    ad = sc.read(
        os.path.join(input_file, "adatas", "06_before_decomposition_model.h5ad")
    )
    points_epi = ad[ad.obs.Class.isin(["Epithelial"])].obsm["X_spatial"]

    all_tree = KDTree(ad.obsm["X_spatial"])
    epi_tree = KDTree(points_epi)
    distances_all, neighbors_all = all_tree.query(ad.obsm["X_spatial"], k=5)
    distances, neighbors = epi_tree.query(ad.obsm["X_spatial"], k=5)
    distance_medians = np.mean(distances, axis=1) / np.mean(distances_all, axis=1)
    ad.obs["epithelial_distance"] = distance_medians
    ad.obs["epithelial_distance"] = ad.obs["epithelial_distance"] / np.percentile(
        ad.obs["epithelial_distance"], 99
    )
    fig = sc.pl.embedding(
        ad,
        basis="spatial",
        color="epithelial_distance",
        return_fig=True,
        show=False,
        vmax=1,
        cmap="viridis",
        size=4,
    )
    fig.tight_layout()
    plt.axis("equal")
    fig.savefig(os.path.join(input_file, "figures", "axes", f"spatial_epithelial.png"))
    plt.close()
    ad.write(os.path.join(input_file, "adatas", "06_before_decomposition_model.h5ad"))

### Calculate the crypt-villi axis and normalize the longitudinal axis for all datasets

In [None]:
reference = reference[reference.obs["in_villi"]]

In [None]:
training_data = reference.obs[
    reference.obs.columns[reference.obs.columns.str.contains("Topic")]
].values

In [None]:
training_labels = reference.obs["normalized_crypt_villi_scaled"].values

In [None]:
# Define model
model = keras.Sequential(
    [
        keras.layers.Dense(
            64, activation="relu", input_shape=(training_data.shape[1],)
        ),
        keras.layers.Dense(32, activation="relu"),
        keras.layers.Dense(1, activation="sigmoid"),
    ]
)

# Compile the model
model.compile(optimizer="adam", loss="mean_squared_error")

# Train the model
model.fit(training_data, training_labels, epochs=10, batch_size=32, verbose=2)

In [None]:
for input_file in input_folders:
    adata = sc.read(
        os.path.join(input_file, "adatas", "06_before_decomposition_model.h5ad")
    )
    testing_data = adata.obs[
        adata.obs.columns[adata.obs.columns.str.contains("Topic")]
    ].values
    predictions = model.predict(testing_data)
    adata.obs["crypt_villi_axis"] = predictions
    try:
        adata.obs["predicted_longitudinal"] = adata.obs[
            "predicted_longitudinal"
        ].replace(-1.0, np.nan)
        adata.obs["predicted_longitudinal"] = (
            adata.obs["predicted_longitudinal"] / 100000
        )
    except:
        adata.obs["predicted_longitudinal"] = adata.obs["longitudinal"].replace(
            -1.0, np.nan
        )
        adata.obs["predicted_longitudinal"] = (
            adata.obs["predicted_longitudinal"] / 100000
        )

    adata.write(os.path.join(input_file, "adatas", "07_axes_defined.h5ad"))
    fig = sc.pl.embedding(
        adata,
        basis="spatial",
        color="crypt_villi_axis",
        return_fig=True,
        show=False,
        vmax=1,
        cmap="viridis",
        size=4,
    )
    fig.tight_layout()
    plt.axis("equal")
    fig.savefig(os.path.join(input_file, "figures", "axes", f"spatial_crypt_villi.png"))
    plt.close()