# Define Crypt-Villus Axis

Applies a neural network model to infer the crypt-villus axis from topics learned from spatial decomposition.

**Pinned Environment:** [`envs/sc-cv_axis.yaml`](../../envs/sc-cv_axis.yaml)  

In [None]:
import scanpy as sc
from pathlib import Path
import os
import anndata as ad
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras

In [None]:
import tensorflow as tf

print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices("GPU")))
print("-" * 50)
print(tf.config.list_physical_devices("GPU"))

In [None]:
sc.set_figure_params(dpi=100, dpi_save=200)

In [None]:
sys.path.append(str(Path.cwd().resolve().parents[1]))

from config.paths import BASE_DIR

h5ad_dir = BASE_DIR / "axes/cv/02_before_decomposition"
h5ad_out = BASE_DIR / "data/h5ad/export_09"
output_data = h5ad_out / "cv_calculated-adata.h5ad"

h5ad_out.mkdir(parents=True, exist_ok=True)

In [None]:
refdata = sc.read(
    os.path.join(h5ad_dir, "02_refdata_before_decomposition_model.h5ad")
)

## Calculate the crypt-villus axis for all datasets

### Train model

In [None]:
refdata = refdata[refdata.obs["in_villi"]]

In [None]:
training_data = refdata.obs[
    refdata.obs.columns[refdata.obs.columns.str.contains("Topic")]
].values

In [None]:
training_labels = refdata.obs["normalized_crypt_villi_scaled"].values

In [None]:
# Define model
model = keras.Sequential(
    [
        keras.layers.Dense(
            64, activation="relu", input_shape=(training_data.shape[1],)
        ),
        keras.layers.Dense(32, activation="relu"),
        keras.layers.Dense(1, activation="sigmoid"),
    ]
)

# Compile the model
model.compile(optimizer="adam", loss="mean_squared_error")

# Train the model
model.fit(training_data, training_labels, epochs=10, batch_size=32, verbose=2)

In [None]:
adata_list = []

for sample_id in os.listdir(h5ad_dir):
    sample_path = os.path.join(h5ad_dir, sample_id)
    h5ad_path = os.path.join(sample_path, "02_before_decomposition_model.h5ad")

    if not os.path.exists(h5ad_path):
        print(f"Skipping {sample_id}, file not found.")
        continue

    adata = sc.read(h5ad_path)

    testing_data = adata.obs[
        adata.obs.columns[adata.obs.columns.str.contains("Topic")]
    ].values
    predictions = model.predict(testing_data)
    adata.obs["crypt_villi_axis"] = predictions

    # Plot and show
    fig = sc.pl.embedding(
        adata,
        basis="spatial",
        color="crypt_villi_axis",
        return_fig=True,
        show=False,
        vmax=1,
        cmap="viridis",
        size=4,
    )
    fig.tight_layout()
    plt.axis("equal")
    plt.show()

    # Track sample ID
    adata.obs["sample_id"] = sample_id
    adata_list.append(adata)

# Concatenate all processed samples
combined_adata = ad.concat(adata_list)

## Export

In [None]:
combined_adata.write_h5ad(output_data, compression="gzip")
refdata.write_h5ad(os.path.join(h5ad_dir, "retrained-refdata.h5ad"), compression="gzip")