In [None]:
import scanpy as sc
import numpy as np
from tqdm.notebook import tqdm
import scipy.stats as stats
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
import shapely
import glob
from sklearn.neighbors import NearestNeighbors
from PIL import Image, ImageDraw
import numpy as np
from scipy.spatial import cKDTree
import json
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Embedding, Flatten, Dense, Concatenate
from tensorflow.keras.models import Model
from sklearn.preprocessing import OneHotEncoder
from core_functions.unrolling import *

## Training model on reference

In [None]:
reference_path = "D:/amonell/merscope_final/SI-Ctrl-L-RAR-R-dist-1-VS120-NP_Beta8"

In [None]:
base_topics = np.load(os.path.join(reference_path, "unrolling", "base_topicas.npy"))

In [None]:
data_dir = "D:/amonell/merscope_final"
path_adata = [
    i
    for i in glob.glob(os.path.join(data_dir, "SI-*"))
    if os.path.basename(i) == os.path.basename(reference_path)
][0]

## Prepare Reference Image

In [None]:
adata = sc.read(os.path.join(path_adata, "adatas", "05_reference_unrolled.h5ad"))
all_spatial = adata.obsm["X_spatial"]

try:
    adata.X = adata.X.A
except:
    print("Adata already in array format")

topics_contain = ["1", "2"]
spatial_points = np.array(
    adata[adata.obs["topic"].isin(topics_contain), :].obsm["X_spatial"]
)
other_spatial = np.array(
    adata[~adata.obs["topic"].isin(topics_contain), :].obsm["X_spatial"]
)

spatial_points = remove_outliers(spatial_points, 99)

downsize = 10
base_image = create_base_image(spatial_points, other_spatial, downsize=downsize)

file_path = os.path.join(path_adata, "unrolling", "roll_image_before_model.png")
base_image.save(file_path)
adata.uns["unrolling_downsize"] = downsize
adata.write(os.path.join(path_adata, "adatas", "05_reference_unrolled.h5ad"))

## Parse Reference Image Annotations

In [None]:
base_num_points = 100000
json_file_path = os.path.join(path_adata, "unrolling", "roll_image_before_model.json")
removals, points, top_points, mid_points = extract_json_info(json_file_path)
dont_remove, index_set = get_removal_indices(adata, removals, all_spatial)
(
    x_points_bottom,
    y_points_bottom,
    x_points_mid,
    y_points_mid,
    x_points_top,
    y_points_top,
) = identify_spiral(adata, points, top_points, mid_points, base_num_points)
(
    all_points,
    distances_top,
    indices_top,
    distances_bottom,
    indices_bottom,
    distances_mid,
    indices_mid,
) = get_distances_and_indices(
    adata,
    dont_remove,
    x_points_bottom,
    y_points_bottom,
    x_points_mid,
    y_points_mid,
    x_points_top,
    y_points_top,
    base_num_points,
)

## Get Ready to train neural network

##### Prepare Training Inputs

In [None]:
topic_presorted = (
    sc.read(os.path.join(path_adata, "adatas", "04_tissue_cleared.h5ad"))
    .obs["topic"]
    .values[dont_remove]
)
model_input = np.array(
    [
        stats.zscore(distances_top),
        stats.zscore(distances_mid),
        stats.zscore(distances_bottom),
        indices_top,
        indices_mid,
        indices_bottom,
    ]
)

##### Prepare Training Outputs

In [None]:
axis = adata.obs["longitudinal"].values[dont_remove]

##### Define and fit model

In [None]:
num_continuous_features = np.shape(model_input)[0]  # Number of continuous features
unique_categories = np.unique(topic_presorted)
embedding_dim = 3  # Dimensionality of the embedding space
batch_size = 32
num_epochs = 10

# Categorical data as a 1D array
categorical_data = topic_presorted

# One-hot encode the categorical data
encoder = OneHotEncoder(sparse=False)
categorical_data_encoded = encoder.fit_transform(categorical_data.reshape(-1, 1))

# Define input for continuous data
continuous_input = Input(shape=(num_continuous_features,), name="continuous_input")

# Define input for one-hot encoded categorical data
categorical_input = Input(
    shape=(categorical_data_encoded.shape[1],), name="categorical_input"
)

# Concatenate the continuous and categorical inputs
concatenated_inputs = Concatenate()([continuous_input, categorical_input])

# Build the neural network architecture
x = Dense(64, activation="relu")(concatenated_inputs)
x = Dense(32, activation="relu")(x)
output = Dense(1, activation="linear", name="output")(x)

# Create the model
model = Model(inputs=[continuous_input, categorical_input], outputs=output)

# Compile the model
model.compile(
    optimizer="adam", loss="mean_squared_error"
)  # Use an appropriate loss function for your task

continuous_data = model_input.T
target_values = axis

# Train the model
model.fit(
    {
        "continuous_input": continuous_data,
        "categorical_input": categorical_data_encoded,
    },
    y=target_values,
    batch_size=batch_size,
    epochs=num_epochs,
)

##### Use model for prediction

In [None]:
# Use the model for prediction on new data
new_continuous_data = continuous_data
new_categorical_data_encoded = categorical_data_encoded
# predictions = model.predict({'continuous_input': new_continuous_data, 'categorical_input': new_categorical_data_encoded})
predictions = model.predict(
    {
        "continuous_input": new_continuous_data,
        "categorical_input": new_categorical_data_encoded,
    }
)
new_predicts = np.zeros(len(adata.obs.index))

In [None]:
new_predicts[dont_remove] = predictions.flatten()
adata.obs["predicted_longitudinal"] = new_predicts

##### Visualize predictions

In [None]:
sc.pl.embedding(adata, basis="spatial", color="predicted_longitudinal")

# Unrolling the rest of the swiss rolls

Put in the path to the experiment folder

In [None]:
data_dir = "D:/amonell/merscope_final"
input_folders = glob.glob(os.path.join(data_dir, "SI-*"))

In [None]:
for path_adata in input_folders:
    adata = sc.read(os.path.join(path_adata, "adatas", "04_tissue_cleared.h5ad"))
    all_spatial = adata.obsm["X_spatial"]

    try:
        adata.X = adata.X.A
    except:
        print("Adata already in array format")

    topics_contain = base_topics
    spatial_points = np.array(
        adata[adata.obs["topic"].isin(topics_contain), :].obsm["X_spatial"]
    )
    other_spatial = np.array(
        adata[~adata.obs["topic"].isin(topics_contain), :].obsm["X_spatial"]
    )

    spatial_points = remove_outliers(spatial_points, 99)

    downsize = 10
    base_image = create_base_image(spatial_points, other_spatial, downsize=downsize)

    try:
        os.mkdir(os.path.join(path_adata, "unrolling"))
    except:
        print("unrolling directory already exists")
    file_path = os.path.join(path_adata, "unrolling", "roll_image_for_prediction.png")
    base_image.save(file_path)
    adata.uns["unrolling_downsize"] = downsize
    adata.write(os.path.join(path_adata, "adatas", "04_tissue_cleared.h5ad"))

Go to labelme (pip install labelme > labelme) > open > open roll_image.png > create polygons > make polygons and save. Our labeled rolls are in the ./labels/unrolling folder

In [None]:
def extract_rolls(json_file_path, adata):
    # Load the JSON data from the file
    with open(json_file_path, "r") as json_file:
        data = json.load(json_file)

    # Extract relevant information from the JSON data
    image_height = data["imageHeight"]
    image_width = data["imageWidth"]
    image_path = data["imagePath"]
    shapes = data["shapes"]

    unique_rolls = []
    for shape in shapes:
        label = shape["label"]
        if ("roll" in label) & (len(label) == 5):
            unique_rolls.append(label)

    rollers = []
    # Process the shapes (annotations)
    for roll_name in unique_rolls:
        removals = []
        bottom_points = []
        top_points = []
        mid_points = []
        mid_points = []
        roll_shapes = []
        for shape in shapes:
            label = shape["label"]
            if label == f"bottom_{roll_name}":
                bottom_points.append(shape["points"])
            elif label == f"top_{roll_name}":
                top_points.append(shape["points"])
            elif label == f"mid_{roll_name}":
                mid_points.append(shape["points"])
            elif label == f"removals_{roll_name}":
                removals.append(shape["points"])
            elif label == roll_name:
                roll_shapes.append(shape["points"])
        rollers.append(
            [removals, bottom_points, top_points, mid_points, roll_shapes, roll_name]
        )
    return rollers

In [None]:
for path_adata in input_folders[3:4]:
    whole_adata = sc.read(os.path.join(path_adata, "adatas", "04_tissue_cleared.h5ad"))

    json_file_path = os.path.join(
        path_adata, "unrolling", "roll_image_for_prediction.json"
    )

    roll_counter = 0
    rolls = extract_rolls(json_file_path, whole_adata)
    roll_names = [i[5] for i in rolls]

    for roll in rolls:

        base_num_points = 100000
        removals = roll[0]
        points = roll[1]
        top_points = roll[2]
        mid_points = roll[3]
        roll_shape = np.array(roll[4][0]) * whole_adata.uns["unrolling_downsize"]
        roll_shape = shapely.Polygon(roll_shape)

        keepers = []
        for pt in whole_adata.obsm["X_spatial"]:
            if roll_shape.contains(shapely.Point(pt)):
                keepers.append(True)
            else:
                keepers.append(False)

        adata = whole_adata[keepers, :]
        all_spatial = adata.obsm["X_spatial"]

        dont_remove, index_set = get_removal_indices(adata, removals, all_spatial)

        (
            x_points_bottom,
            y_points_bottom,
            x_points_mid,
            y_points_mid,
            x_points_top,
            y_points_top,
        ) = identify_spiral(adata, points, top_points, mid_points, base_num_points)
        (
            all_points,
            distances_top,
            indices_top,
            distances_bottom,
            indices_bottom,
            distances_mid,
            indices_mid,
        ) = get_distances_and_indices(
            adata,
            dont_remove,
            x_points_bottom,
            y_points_bottom,
            x_points_mid,
            y_points_mid,
            x_points_top,
            y_points_top,
            base_num_points,
        )

        topic_presorted = adata.obs["topic"].values[dont_remove]
        model_input = np.array(
            [
                stats.zscore(distances_top),
                stats.zscore(distances_mid),
                stats.zscore(distances_bottom),
                indices_top,
                indices_mid,
                indices_bottom,
            ]
        )

        # Categorical data as a 1D array
        categorical_data = topic_presorted

        categorical_data_encoded = encoder.transform(categorical_data.reshape(-1, 1))

        continuous_data = model_input.T

        predictions = model.predict(
            {
                "continuous_input": continuous_data,
                "categorical_input": categorical_data_encoded,
            }
        )

        new_longitidunal = np.zeros(len(adata.obs.index))
        new_longitidunal[dont_remove] = predictions.flatten()
        new_longitidunal[list(index_set)] = -1
        adata.obs["predicted_longitudinal"] = new_longitidunal
        sc.pl.embedding(adata, basis="spatial", color="predicted_longitudinal")

        adata.obs["not_removed_from_longitudinal"] = new_longitidunal != -1
        adata.obs["roll"] = f"roll_{roll_names[roll_counter]}"
        adata.write(
            os.path.join(
                path_adata,
                "adatas",
                f"05_unrolled_roll_{roll_names[roll_counter]}.h5ad",
            )
        )
        roll_counter += 1

## Plot Predictions

In [None]:
sc.set_figure_params(dpi=300)
for path_adata in input_folders:
    rolls = glob.glob(os.path.join(path_adata, "adatas", "05_unrolled*.h5ad"))
    for roll in rolls:
        adata = sc.read(roll)
        fig = sc.pl.embedding(
            adata[adata.obs["not_removed_from_longitudinal"], :],
            basis="spatial",
            color="predicted_longitudinal",
            return_fig=True,
            show=False,
        )
        try:
            os.mkdir(os.path.join(path_adata, "figures", "axes"))
        except:
            print("axes directory already exists")
        fig.tight_layout()
        plt.axis("equal")
        fig.savefig(
            os.path.join(
                path_adata,
                "figures",
                "axes",
                f"spatial_longitudinal_{adata.obs.roll.values[0]}.png",
            )
        )
        fig.savefig(
            os.path.join(
                r"C:\Users\amonell\Downloads\merscope_longitudinal",
                f"spatial_longitudinal_"
                + os.path.basename(path_adata)
                + f"_{adata.obs.roll.values[0]}.png",
            )
        )