In [None]:
import scanpy as sc
import os
import numpy as np
import pandas as pd
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import pylab as pl
from scipy.spatial import cKDTree
import glob
import pickle

In [None]:
output_folder = r"./"
data_folder = "../../data"

spatial_datas = glob.glob(
    os.path.join(data_folder, "integration", "*", "spatial_data.h5ad")
)
seq_datas = glob.glob(os.path.join(data_folder, "integration", "*", "seq_data.h5ad"))

spatial_datas = [i for i in spatial_datas if "dc3000" not in i]
seq_datas = [i for i in seq_datas if "dc3000" not in i]

spatial_adatas = []
for fl in spatial_datas:
    spatial_adatas.append(sc.read(fl))
ad_sp_only = sc.concat(spatial_adatas)

seq_adatas = []
for fl in seq_datas:
    seq_adatas.append(sc.read(fl))
ad_sc_only = sc.concat(seq_adatas)

In [None]:
figure_output_folder = "6ac"

In [None]:
sc.set_figure_params(dpi=300, dpi_save=400)

In [None]:
ad_sp = ad_sp_only.copy()
ad_sc = ad_sc_only.copy()

In [None]:
ad_sc.obs["batch"] = "rep1"

In [None]:
ad_sc_only.obs["modality"] = "seq"
ad_sp_only.obs["modality"] = "spatial"

In [None]:
ad_sc_sp = sc.concat([ad_sc_only, ad_sp_only])

In [None]:
color_list = [
    "#bcbd22",
    "#9467bd",
    "#2ca02c",
    "#10492b",
    "#ff9896",
    "#e377c2",
    "#17becf",
    "#c7c7c7",
    "#6a0ed1",
    "#8c564b",
    "#ff7f0e",
    "#f7b6d2",
    "#cdd5c3",
    "#1f77b4",
    "#c5b0d5",
    "#98df8a",
    "#c49c94",
    "#aec7e8",
    "#dbdb8d",
    "#7f7f7f",
    "#9edae5",
    "#d62728",
    "#ffbb78",
    "#816cfd",
]

In [None]:
ad_sc_sp.uns["SCT_snn_res.1_colors"] = color_list

In [None]:
ad_sc_sp.obs["SCT_snn_res.1"] = pd.Categorical(ad_sc_sp.obs["SCT_snn_res.1"])

In [None]:
unique_clusters = ad_sc_sp.obs["SCT_snn_res.1"].cat.categories

In [None]:
dic = {}
for categ in range(len(ad_sc_sp.obs["SCT_snn_res.1"].cat.categories)):
    dic[ad_sc_sp.obs["SCT_snn_res.1"].cat.categories[categ]] = ad_sc_sp.uns[
        "SCT_snn_res.1_colors"
    ][categ]

In [None]:
try:
    os.mkdir(os.path.join(output_folder, figure_output_folder))
except FileExistsError:
    pass

# Specify the file path
file_path = os.path.join(figure_output_folder, "colors.pkl")

# Save the dictionary to a binary file using pickle
with open(file_path, "wb") as pickle_file:
    pickle.dump(dic, pickle_file)

### Figure d part 1

In [None]:
ad_sp = sc.read(
    os.path.join(data_folder, "integration", "9hr_avr", "spatial_data.h5ad")
)

In [None]:
ad_sp.obsm["X_spatial"] = ad_sp.obs[["x", "y"]].values

In [None]:
ad_sp.uns["predicted.celltype_colors"] = [
    "#FF007F",  # Bright Pink
    "#008000",  # Green
    "#FFD700",
]  # Gold

In [None]:
sc.set_figure_params(facecolor="white", figsize=(10, 10), dpi=300)
fig = sc.pl.embedding(
    ad_sp[ad_sp.obs.batch == "9hr_avr"],
    basis="spatial",
    color=["predicted.celltype"],
    vmax=0.1,
    frameon=False,
    size=40,
    return_fig=True,
)

l = plt.legend()
for text in l.get_texts():
    text.set_color("black")
plt.title("celltype", {"color": "white"})
fig.savefig(os.path.join(figure_output_folder, "Figure6d_celltype_sp.pdf"))
plt.show()

In [None]:
fig = sc.pl.embedding(
    ad_sp,
    basis="umap",
    color=["SCT_snn_res.1"],
    vmax=0.1,
    frameon=False,
    size=10,
    return_fig=True,
)

l = plt.legend()
for text in l.get_texts():
    text.set_color("black")
plt.title("celltype", {"color": "white"})
fig.savefig(os.path.join(figure_output_folder, "Figure6d_celltype_sp.pdf"))
plt.show()

### Figure 6e

In [None]:
for cluster_to_map in np.unique(ad_sp.obs["SCT_snn_res.1"]):
    cluster_map = [1 if i == cluster_to_map else 0 for i in ad_sp.obs["SCT_snn_res.1"]]
    ad_sp.obs["cluster_map"] = cluster_map
    sc.set_figure_params(facecolor="white", figsize=(10, 10), dpi=300)
    fig = sc.pl.embedding(
        ad_sp[ad_sp.obs.batch == "9hr_avr"],
        basis="spatial",
        color=["cluster_map"],
        vmax=1.3,
        vmin=-0.4,
        frameon=False,
        size=10,
        return_fig=True,
        cmap="Purples",
    )

    l = plt.legend()
    for text in l.get_texts():
        text.set_color("black")
    plt.title(f"Mapped cluster {cluster_to_map}", {"color": "black"})
    try:
        os.mkdir(os.path.join(figure_output_folder, "cluster_projections"))
    except:
        None
    fig.savefig(
        os.path.join(
            figure_output_folder,
            "cluster_projections",
            f"Figure6e_cluster_{cluster_to_map}.pdf",
        )
    )
    plt.close()

In [None]:
ad_list = []
for adata_name in ["mock", "4hr_avr", "6hr_avr", "9hr_avr", "avrrpt24"]:
    ad_sp_only = sc.read(
        os.path.join(data_folder, "integration", adata_name, "spatial_data.h5ad")
    )
    ad_list.append(ad_sp_only)
ad_sp = sc.concat(ad_list)

In [None]:
batches = np.unique(ad_sp.obs["batch"])
adata_list = []
for batch in batches:
    experiment = ad_sp[ad_sp.obs["batch"] == batch, :]
    experiment.obs.index = batch + experiment.obs.index
    experiment.obsm["X_spatial"] = experiment.obs[["x", "y"]].values
    adata_list.append(experiment.copy())

In [None]:
for k in range(len(adata_list)):
    adata = adata_list[k]
    adata.obs.index = adata.obs.index.str.replace("-", "_")
    if adata.obs.index.values[0][0].isnumeric():
        adata.obs.index = "X" + np.array(adata.obs.index.values)

    # Sample data - Replace this with your data
    points = adata.obsm["X_spatial"]
    values = adata.obs["pseudotime"]
    categories = adata.obs["celltype"]

    # Convert categories to a NumPy array for indexing
    categories = np.array(categories)

    # Create a KD-tree for efficient nearest neighbor search
    tree = cKDTree(points)

    # Define the number of nearest neighbors to consider
    n_neighbors = 50  # Change this value as needed

    # Function to calculate the smoothed value for each point
    def smooth_point(point_index):
        # Find the indices of the N nearest neighbors
        nearest_indices = tree.query(points[point_index], k=n_neighbors)[1]

        # Filter neighbors that belong to the 'Mesophyll' category
        mesophyll_neighbors = [
            i for i in nearest_indices if categories[i] == "Mesophyll"
        ]

        # Calculate the weighted average of 'Mesophyll' values
        if len(mesophyll_neighbors) > 0:
            smoothed_value = np.mean(values[mesophyll_neighbors])
        else:
            # If no 'Mesophyll' neighbors are found, keep the original value
            smoothed_value = values[point_index]

        return smoothed_value

    # Apply the smoothing function to each point
    smoothed_values = [smooth_point(i) for i in range(len(points))]
    adata.obs["smoothed_pseudo"] = smoothed_values

    if adata.obs.batch[0] in ["9hr_avr", "avrrpt24"]:
        fig = sc.pl.embedding(
            adata,
            basis="spatial",
            color="smoothed_pseudo",
            cmap="jet",
            size=100,
            vmax=0.65,
            vmin=0.1,
            return_fig=True,
        )
    else:
        fig = sc.pl.embedding(
            adata,
            basis="spatial",
            color="smoothed_pseudo",
            cmap="jet",
            size=100,
            vmax=0.65,
            vmin=0.1,
            return_fig=True,
        )

    try:
        os.mkdir(os.path.join(figure_output_folder, "smoothed_spatial_pseudo"))
    except:
        None
    fig.savefig(
        os.path.join(
            figure_output_folder,
            "smoothed_spatial_pseudo",
            f"smoothed_spatial_pseudo_{adata.obs.batch[0]}.pdf",
        )
    )
    plt.show()
    plt.close()