# Reference Query Mapping

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 18.05.2023
- **Date of Last Modification:** 04.09.2023

Before proceeding with this notebook, two steps need to be successfully conducted: <br>
- A model has been trained with a reference dataset with ```<root>/scripts/train_nichecompass_reference_model.py```. <br>
- A query has been mapped onto the reference model with ```<root>/scripts/map_query_on_nichecompass_reference_model.py```.

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../../utils")

In [None]:
import os
import warnings

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scarches as sca
import seaborn as sns
from scipy.spatial import distance
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment

from nichecompass.models import NicheCompass
from nichecompass.utils import create_new_color_dict

from analysis_utils import *
from reference_query_mapping_utils import *

### 1.2 Define Parameters

In [None]:
### Dataset ###
# dataset = "seqfish_mouse_organogenesis_imputed"
dataset = "nanostring_cosmx_human_nsclc_modified"

In [None]:
reference_model_label = "reference"
reference_query_model_label = "reference_query_mapping"

if dataset == "nanostring_cosmx_human_nsclc_modified":
    ### Model ###
    gp_names_key = "nichecompass_gp_names"
    mapping_entity_key = "mapping_entity"
    sample_key = "batch"
    condition_key = "batch"
    cell_type_key = "cell_type"
    latent_key = "nichecompass_latent"
    spatial_key = "spatial"

    load_timestamp = "01092023_182943_6" # fov batch effect
    load_timestamp = "01092023_182316_2" # much better
    load_timestamp = "01092023_182316_1" # fov batch effect
    load_timestamp = "02092023_122614_1" # smaller batch effect
    load_timestmap = "02092023_124452_2"

    ### Analysis ###
    latent_leiden_resolution =  0.7
    latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"
    spot_size = 200

    latent_reference_leiden_resolution = 0.7
    latent_reference_cluster_key = f"latent_leiden_reference_{latent_reference_leiden_resolution}"

    #spatial_reference_leiden_resolution = 0.15
    #spatial_reference_cluster_key = f"spatial_leiden_reference_{spatial_reference_leiden_resolution}"

    latent_query_leiden_resolution = 0.7
    latent_query_cluster_key = f"latent_leiden_query_{latent_query_leiden_resolution}"

    label_key = latent_reference_cluster_key
    ground_truth_key = f"matched_ground_truth_latent_leiden_reference_{latent_reference_leiden_resolution}"
    transfer_label_key = f"transferred_{label_key}"
    transfer_label_uncertainty_key = f"transferred_{label_key}_uncertainty"
    transfer_label_evaluation_key = f"{transfer_label_key}_evaluation"

### 1.3 Run Notebook Setup

In [None]:
warnings.filterwarnings("ignore")

### 1.4 Configure Paths

In [None]:
# Define paths
artifacts_folder_path = f"../../artifacts"
figure_folder_path = f"{artifacts_folder_path}/{dataset}/figures/{load_timestamp}"
os.makedirs(figure_folder_path, exist_ok=True)

## 2. Reference Model

In [None]:
# Load trained model
model_folder_path = f"{artifacts_folder_path}/{dataset}/models/{reference_model_label}/{load_timestamp}"
model = NicheCompass.load(dir_path=model_folder_path,
                          adata=None,
                          adata_file_name=f"{dataset}_{reference_model_label}.h5ad",
                          gp_names_key=gp_names_key)

In [None]:
# Check number of active gene programs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

In [None]:
gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"] == True][0:20]

In [None]:
gp_summary_df

In [None]:
samples = model.adata.obs[sample_key].unique().tolist()

In [None]:
# Plot batches in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "batches_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Batches",
    cat_key=condition_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=None,
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
cell_type_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=cell_type_key)

# Plot cell types in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "cell_types_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Cell Types",
    cat_key=cell_type_key,
    groups=None, # "ExE endoderm",
    sample_key=sample_key,
    samples=samples,
    cat_colors=cell_type_colors,
    size=(360000 / len(model.adata)),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
model.adata.obs

In [None]:
niche_colors = create_new_color_dict(
    adata=model.adata,
    cat_key="niche")

# Plot cell types in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "cell_types_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Niches",
    cat_key="niche",
    groups=None, # "ExE endoderm",
    sample_key=sample_key,
    samples=samples,
    cat_colors=niche_colors,
    size=(360000 / len(model.adata)),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
# Compute latent Leiden clustering 
sc.tl.leiden(adata=model.adata,
             resolution=latent_leiden_resolution,
             key_added=latent_cluster_key,
             neighbors_key=latent_key)

In [None]:
latent_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=latent_cluster_key)

# Latent clusters in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "latent_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Clusters",
    cat_key=latent_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=latent_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
dataset = "nanostring_cosmx_human_nsclc"

## 3. Query Model

In [None]:
load_timestamp = "02092023_233857_1"
model_folder_path = f"{artifacts_folder_path}/{dataset}/models/{reference_query_model_label}/{load_timestamp}"

# Load trained model
model = NicheCompass.load(dir_path=model_folder_path,
                          adata=None,
                          adata_file_name=f"{dataset}_{reference_query_model_label}.h5ad",
                          gp_names_key=gp_names_key)

In [None]:
# Check number of active gene programs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

In [None]:
gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"] == True][0:20]

In [None]:
gp_summary_df

In [None]:
samples = model.adata.obs[sample_key].unique().tolist()

In [None]:
# Plot batches in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "batches_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Batches",
    cat_key=condition_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=None,
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
cell_type_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=cell_type_key)

# Plot cell types in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "cell_types_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Cell Types",
    cat_key=cell_type_key,
    groups=None, # "ExE endoderm",
    sample_key=sample_key,
    samples=samples,
    cat_colors=cell_type_colors,
    size=(360000 / len(model.adata)),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
model.adata.obs

In [None]:
niche_colors = create_new_color_dict(
    adata=model.adata,
    cat_key="niche")

# Plot cell types in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "cell_types_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Niches",
    cat_key="niche",
    groups=None, # "ExE endoderm",
    sample_key=sample_key,
    samples=samples,
    cat_colors=niche_colors,
    size=(360000 / len(model.adata)),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
latent_leiden_resolution = 0.4

In [None]:
# Compute latent Leiden clustering 
sc.tl.leiden(adata=model.adata,
             resolution=latent_leiden_resolution,
             key_added=latent_cluster_key,
             neighbors_key=latent_key)

In [None]:
latent_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=latent_cluster_key)

# Latent clusters in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "latent_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Clusters",
    cat_key=latent_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=latent_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
benchmarking_folder_path

In [None]:
model.adata.write(f"{artifacts_folder_path}/reference_query_mapping/{dataset}_nichecompass_{load_timestamp}_adata.h5ad")

In [None]:
timestamp

In [None]:
artifacts_folder_path

## 2. Model

### 2.1 Load Model

In [None]:
model_folder_path

In [None]:
# Load trained model
model = NicheCompass.load(dir_path=model_folder_path,
                          adata=None,
                          adata_file_name=f"{dataset}_{reference_query_model_label}.h5ad",
                          gp_names_key=gp_names_key)

In [None]:
"""
model.add_active_gp_expr_to_obs()
model.adata.write(f"{model_folder_path}/adata_active_gps.h5ad")
"""

## 3. Analysis

In [None]:
# Check number of active gene programs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

In [None]:
gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"] == True][0:20]

### 3.1 Analyze Latent Manifold

In [None]:
samples = model.adata.obs[sample_key].unique().tolist()

In [None]:
# Plot mapping entities in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "mapping_entities_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Mapping Entities",
    cat_key=mapping_entity_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=None,
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
# Plot batches in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "batches_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Batches",
    cat_key=condition_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=None,
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
cell_type_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=cell_type_key)

In [None]:
# Plot cell types in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "cell_types_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Cell Types",
    cat_key=cell_type_key,
    groups=None, # "ExE endoderm",
    sample_key=sample_key,
    samples=samples,
    cat_colors=cell_type_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
# Compute latent Leiden clustering 
sc.tl.leiden(adata=model.adata,
             resolution=latent_leiden_resolution,
             key_added=latent_cluster_key,
             neighbors_key=latent_key)

In [None]:
latent_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=latent_cluster_key)

In [None]:
# Latent clusters in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "latent_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Clusters",
    cat_key=latent_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=latent_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

### 3.2 Analyze Query Enrichments (vs Reference)

#### 3.2.1 Visualize Query-enriched Cell Types in Physical and Latent Space

In [None]:
adata_reference = model.adata[model.adata.obs[mapping_entity_key] == "reference"]
adata_query = model.adata[model.adata.obs[mapping_entity_key] == "query"]

reference_samples = adata_reference.obs[sample_key].unique().tolist()

In [None]:
enriched_query_cell_type_prop_thresh = 5.
query_enriched_cell_type_key = "query_enriched_cell_types"

# Get query-enriched cell types
cell_type_reference_proportions = adata_reference.obs[cell_type_key].value_counts().sort_index() / len(adata_reference)
cell_type_query_proportions = adata_query.obs[cell_type_key].value_counts().sort_index() / len(adata_query)
relative_cell_type_query_proportions = cell_type_query_proportions / cell_type_reference_proportions
relative_cell_type_query_proportions.sort_values(ascending=False, inplace=True)
display(relative_cell_type_query_proportions)

query_enriched_cell_types = relative_cell_type_query_proportions[relative_cell_type_query_proportions > enriched_query_cell_type_prop_thresh].index.to_list()
model.adata.obs[query_enriched_cell_type_key] = "Cell types not enriched in query"
for cell_type in query_enriched_cell_types:
    model.adata.obs.loc[model.adata.obs[cell_type_key] == cell_type, query_enriched_cell_type_key] = cell_type

In [None]:
# Plot of query-enriched cell-type annotations in physical and latent space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "query_enriched_cell_types_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Query Enriched Cell Types",
    cat_key=query_enriched_cell_type_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors="coolwarm",
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

#### 3.1.2 Visualize Query-enriched Niches / Latent Clusters in Physical and Latent Space

In [None]:
query_enriched_latent_cluster_key = "query_enriched_latent_cluster"
enriched_query_latent_cluster_prop_thresh = 5.

In [None]:
# Get query-enriched latent leiden clusters
latent_cluster_reference_proportions = adata_reference.obs[f"latent_leiden_{str(latent_leiden_resolution)}"].value_counts().sort_index() / len(adata_reference)
latent_cluster_query_proportions = adata_query.obs[f"latent_leiden_{str(latent_leiden_resolution)}"].value_counts().sort_index() / len(adata_query)
relative_latent_cluster_query_proportions = latent_cluster_query_proportions / latent_cluster_reference_proportions
relative_latent_cluster_query_proportions.sort_values(ascending=False, inplace=True)
display(relative_latent_cluster_query_proportions)

query_enriched_latent_clusters = relative_latent_cluster_query_proportions[relative_latent_cluster_query_proportions > enriched_query_latent_cluster_prop_thresh].index.to_list()
model.adata.obs[query_enriched_latent_cluster_key] = "Latent clusters not enriched in query"
for latent_cluster in query_enriched_latent_clusters:
    model.adata.obs.loc[model.adata.obs[f"latent_leiden_{str(latent_leiden_resolution)}"] == latent_cluster, query_enriched_latent_cluster_key] = f"Cluster {latent_cluster}"

In [None]:
# Plot of query-enriched cell-type annotations in physical and latent space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "query_enriched_latent_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Query Enriched Latent Clusters",
    cat_key=query_enriched_latent_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors="coolwarm",
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

### 3.3 Transfer Niche Labels from Reference to Query

#### 3.3.1 Define Reference Niches

In [None]:
# Compute latent neighbor graph just for reference
sc.pp.neighbors(adata_reference,
                use_rep=latent_key,
                key_added=latent_key)

# Compute latent Leiden clustering just for reference
sc.tl.leiden(adata=adata_reference,
             resolution=latent_reference_leiden_resolution,
             key_added=latent_reference_cluster_key,
             neighbors_key=latent_key)

# Add latent reference clusters to integrated adata object
model.adata.obs[latent_reference_cluster_key] = adata_reference.obs[latent_reference_cluster_key]

In [None]:
# Create color dict for latent reference clusters for plotting
latent_reference_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=latent_reference_cluster_key)

In [None]:
# Plot latent reference clusters in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/res_{latent_reference_leiden_resolution}_" \
            "latent_reference_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Reference Clusters",
    cat_key=latent_reference_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=latent_reference_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
"""
# Compute spatial neighbor graph just for reference
sc.pp.neighbors(adata_reference,
                use_rep=spatial_key,
                key_added=spatial_key)

# Compute spatial Leiden clustering just for reference
sc.tl.leiden(adata=adata_reference,
             resolution=spatial_reference_leiden_resolution,
             key_added=spatial_reference_cluster_key,
             neighbors_key=spatial_key)

# Add spatial reference clusters to integrated adata object
model.adata.obs[spatial_reference_cluster_key] = adata_reference.obs[spatial_reference_cluster_key]
"""

In [None]:
"""
# Create color dict for spatial reference clusters for plotting
spatial_reference_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=spatial_reference_cluster_key)
"""

In [None]:
"""
# Plot spatial reference clusters in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/res_{spatial_reference_leiden_resolution}_" \
            "spatial_reference_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Spatial Reference Clusters",
    cat_key=spatial_reference_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=spatial_reference_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)
"""

#### 3.3.2 Define (Ground Truth) Query Niches

In [None]:
# Compute latent neighbor graph just for query
sc.pp.neighbors(adata_query,
                use_rep=latent_key,
                key_added=latent_key)

# Compute latent Leiden clustering just for query
sc.tl.leiden(adata=adata_query,
             resolution=latent_query_leiden_resolution,
             key_added=latent_query_cluster_key,
             neighbors_key=latent_key)

# Add latent query clusters to integrated adata object
model.adata.obs[latent_query_cluster_key] = adata_query.obs[latent_query_cluster_key]

In [None]:
# Create color dict for latent query clusters for plotting
latent_query_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=latent_query_cluster_key)

In [None]:
# Plot latent query clusters in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "latent_query_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Query Clusters",
    cat_key=latent_query_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=latent_query_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
# Map query niches to reference niches using cell type proportions
# Find best matching niches using pairwise euclidean distance between niche cell type proportion
# vectors

# Compute cell type proportions
reference_counts_df = model.adata[model.adata.obs[mapping_entity_key] == "reference"].obs.groupby(
    [latent_reference_cluster_key, cell_type_key]).size().unstack()
query_counts_df = model.adata[model.adata.obs[mapping_entity_key] == "query"].obs.groupby(
    [latent_query_cluster_key, cell_type_key]).size().unstack()
reference_proportions_df = reference_counts_df.div(reference_counts_df.sum(axis=1), axis=0)
query_proportions_df = query_counts_df.div(query_counts_df.sum(axis=1), axis=0)
reference_proportions_df.index = reference_proportions_df.index.astype(int)
query_proportions_df.index = query_proportions_df.index.astype(int)
query_proportions_df.sort_index(inplace=True)
query_proportions_df.index.name = "Query Niche"

# Calculate the Euclidean distance between each pair of niches in the reference and query
distances = cdist(query_proportions_df.values,
                  reference_proportions_df.values,
                  metric="euclidean")

# Apply the Hungarian algorithm to match clusters by minimizing the total Euclidean distance
#query_latent_cluster_indices, reference_latent_cluster_indices = linear_sum_assignment(distances)
#query_latent_cluster_indices = query_latent_cluster_indices.astype(str)
#reference_latent_cluster_indices = reference_latent_cluster_indices.astype(str)

# For each query niche find the reference niche with the minimum Euclidean distance
query_latent_cluster_indices = np.arange(adata_query.obs[latent_query_cluster_key].nunique())
reference_latent_cluster_indices = np.argmin(distances, axis=1)

# Determine new query niches, assign them label above 1000, and add artifical row of 0s
# to reference proportions
min_distances = np.min(distances, axis=1)
thresh = np.mean(min_distances) + 2 * np.std(min_distances) # more than 2 std above the mean
new_query_niche_indices = np.where(min_distances > thresh)[0]
for i, query_niche_idx in enumerate(new_query_niche_indices):
    reference_latent_cluster_indices[query_niche_idx] = i + 1000
    reference_proportions_df.loc[i + 1000] = np.zeros(len(reference_proportions_df.columns))
    latent_reference_cluster_colors[f"{i + 1000}"] = "#000000" # Add black color for new query niches

mapping_dict = {str(k): str(v) for k, v in zip(query_latent_cluster_indices, reference_latent_cluster_indices)}

# Get matched reference proportions
matched_reference_proportions_df = reference_proportions_df.loc[reference_latent_cluster_indices, :]

# Replace query niche numbers with matched reference niche numbers
model.adata.obs[ground_truth_key] = model.adata.obs[latent_query_cluster_key].map(mapping_dict)

In [None]:
plot_clustered_stacked(df_list=[matched_reference_proportions_df, query_proportions_df],
                       labels=["Reference", "Query"],
                       title="Cell Type Proportions in Query Niches and Matched Reference Niches",
                       H="//")

In [None]:
# Plot latent query clusters in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "latent_query_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Query Clusters",
    cat_key=ground_truth_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=latent_reference_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
"""
# Compute spatial neighbor graph just for query
sc.pp.neighbors(adata_query,
                use_rep=spatial_key,
                key_added=spatial_key)

# Compute spatial Leiden clustering just for query
spatial_query_leiden_resolution = 0.11
spatial_query_cluster_key = f"latent_leiden_query_{spatial_query_leiden_resolution}"
 
sc.tl.leiden(adata=adata_query,
             resolution=spatial_query_leiden_resolution,
             key_added=spatial_query_cluster_key,
             neighbors_key=spatial_key)

model.adata.obs[spatial_query_cluster_key] = adata_query.obs[spatial_query_cluster_key]
"""

In [None]:
"""
spatial_query_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=spatial_query_cluster_key)
"""

In [None]:
"""
# Spatial query clusters in latent and physical space
save_fig = False
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "spatial_query_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Query Clusters",
    cat_key=spatial_query_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=spatial_query_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)
"""

#### 3.3.3 Transfer Niche Labels from Reference to Query

In [None]:
# Prepare label transfer via scarches
knn_transformer = sca.utils.knn.weighted_knn_trainer(
    train_adata=adata_reference,
    train_adata_emb=latent_key,
    n_neighbors=15)

In [None]:
# Compute label transfer via scarches
labels, uncert = sca.utils.knn.weighted_knn_transfer(
    query_adata=adata_query,
    query_adata_emb=latent_key,
    label_keys=label_key,
    knn_model=knn_transformer,
    ref_adata_obs=adata_reference.obs)

In [None]:
# Add labels and uncertainties to adata using transfer keys
labels.rename(columns={label_key: transfer_label_key}, inplace=True)
uncert.rename(columns={label_key: transfer_label_uncertainty_key}, inplace=True)
model.adata.obs = model.adata.obs.join(labels)
model.adata.obs = model.adata.obs.join(uncert)

In [None]:
# Drop artifical reference proportion rows again
for i, _ in enumerate(new_query_niche_indices):
    reference_proportions_df.drop(i + 1000, inplace=True)

transferred_query_counts_df = model.adata[model.adata.obs[mapping_entity_key] == "query"].obs.groupby(
    [transfer_label_key, cell_type_key]).size().unstack()
transferred_query_proportions_df = transferred_query_counts_df.div(transferred_query_counts_df.sum(axis=1), axis=0)

transferred_query_proportions_df.index = transferred_query_proportions_df.index.astype(int)
transferred_query_proportions_df.sort_index(inplace=True)
transferred_query_proportions_df.index.name = "Transferred Query Niche"

In [None]:
plot_clustered_stacked(df_list=[reference_proportions_df, transferred_query_proportions_df],
                       labels=["Reference", "Query"],
                       title="Cell Type Proportions in Transferred Query Niches and Reference Niches",
                       H="//")

In [None]:
# Add evaluations
model.adata.obs[transfer_label_evaluation_key] = model.adata.obs.apply(
    lambda row: "Correct" if row[transfer_label_key] == row[ground_truth_key] else (
        row[transfer_label_key] if pd.isnull(row[transfer_label_key]) else "Incorrect"), axis=1)

In [None]:
# Plot uncertainties
sns.distplot(model.adata.obs[transfer_label_uncertainty_key])

In [None]:
# Set high uncertainty labels evaluation to 'Unknown'
uncertainty_threshold = 0.01

model.adata.obs[transfer_label_evaluation_key] = model.adata.obs[transfer_label_evaluation_key].mask(
    model.adata.obs[transfer_label_uncertainty_key] > uncertainty_threshold,
    "Uncertain")

In [None]:
adata_query_label_transfer = model.adata[model.adata.obs[mapping_entity_key] == "query"]

In [None]:
print(f"Percentage of 'Unknown', with uncertainty_threshold = {uncertainty_threshold}:")
print(f"{np.round(sum(adata_query_label_transfer.obs[transfer_label_evaluation_key] =='Uncertain')/adata_query_label_transfer.n_obs*100,2)}%")

In [None]:
label_cats = adata_query_label_transfer.obs[ground_truth_key].unique().tolist()

perc_correct = pd.crosstab(
    adata_query_label_transfer.obs[ground_truth_key],
    adata_query_label_transfer.obs[transfer_label_evaluation_key],
).loc[label_cats, :]
total_n_per_ct = adata_query_label_transfer.obs[ground_truth_key].value_counts()
perc_correct = perc_correct.div(perc_correct.sum(axis=1), axis="rows") * 100
# add a bar (=row) for the entire query dataset:
perc_correct.index = perc_correct.index.tolist()
total_n_per_ct.index = total_n_per_ct.index.tolist()
total_n_per_ct["Overall"] = adata_query_label_transfer.shape[0]
perc_correct.loc["Overall", :] = (
    adata_query_label_transfer.obs[transfer_label_evaluation_key].value_counts()
    / total_n_per_ct["Overall"]
    * 100
)

incl_celln_in_label = True
# set celltype order:
# follow bio order, except that new/unseen cell types will come first:
plot_label_cats = label_cats + ["Overall"]
perc_correct = perc_correct.loc[plot_label_cats, :]
with plt.rc_context(
    {
        "figure.figsize": (0.4 * len(label_cats), 3),
        "axes.spines.right": False,
        "axes.spines.top": False,
    }
):
    fig, ax = plt.subplots()
    perc_correct.plot(kind="bar", stacked=True, ax=ax)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[::-1], labels[::-1], loc=(1.01, 0.60), frameon=False)

    cts_no_underscore = [ct.replace("_", " ") for ct in label_cats]
    if incl_celln_in_label:
        plt.xticks(
            ticks=range(len(label_cats) + 1),
            labels=[
                f"{ct_no_und} ({total_n_per_ct[ct]})"
                for ct_no_und, ct in zip(
                    cts_no_underscore + ["Overall"],
                    plot_label_cats,  # ct_df_q + ["Overall"]
                )
            ],
        )
        plt.xlabel("Original Label (Number of Cells)")
    else:
        plt.xticks(
            ticks=range(len(plot_label_cats)),
            labels=[f"{ct_no_und}" for ct_no_und in cts_no_underscore],
        )
        plt.xlabel("Original Label")
    ax.set_ylabel("% of Cells")
    plt.grid(False)

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
sankey(
    x=adata_query_label_transfer.obs[ground_truth_key],
    y=adata_query_label_transfer.obs[transfer_label_key],
    title="Original label vs. predicted annotation",
    title_left="Original label",
    title_right="Predicted annotation",
    ax=ax,
    fontsize="5",  # "xx-small",
    left_order=adata_query_label_transfer.obs[ground_truth_key].unique().tolist(),
    colorside="left",
    alpha=0.5,
)
plt.show()

#### 3.3.5 Compute Top n Accuracies

In [None]:
reference_sorted_indices = reference_proportions_df.values.argsort(axis=1)[:, ::-1]
reference_sorted_cell_types = reference_proportions_df.columns.values[reference_sorted_indices]

transferred_query_sorted_indices = transferred_query_proportions_df.values.argsort(axis=1)[:, ::-1]
transferred_query_sorted_cell_types = transferred_query_proportions_df.columns.values[transferred_query_sorted_indices]

In [None]:
top_n_transfer_accuracies = []
for n in range(1, len(transferred_query_proportions_df.values) + 1):
    top_n_transfer_accuracies.append(calculate_top_accuracy(reference_sorted_cell_types, transferred_query_sorted_cell_types, n=n))

In [None]:
matched_reference_sorted_indices = matched_reference_proportions_df.values.argsort(axis=1)[:, ::-1]
matched_reference_sorted_cell_types = matched_reference_proportions_df.columns.values[matched_reference_sorted_indices]

query_sorted_indices = query_proportions_df.values.argsort(axis=1)[:, ::-1]
query_sorted_cell_types = query_proportions_df.columns.values[query_sorted_indices]

In [None]:
top_n_matching_accuracies = []
for n in range(1, len(matched_reference_proportions_df.values) + 1):
    top_n_matching_accuracies.append(calculate_top_accuracy(matched_reference_sorted_cell_types, query_sorted_cell_types, n=n))

In [None]:
sns.lineplot(x=np.arange(1, len(top_n_transfer_accuracies) + 1),
             y=top_n_transfer_accuracies,
             marker="o",
             label="NicheCompass Integrated Niches")
sns.lineplot(x=np.arange(1, len(top_n_matching_accuracies) + 1),
             y=top_n_matching_accuracies,
             marker="o",
             label="NicheCompass Matched Niches")
plt.xlabel("n")
plt.ylabel("Top n Accuracy")
plt.ylim(0, 1)
plt.legend()
plt.show()

#### 3.3.6 Compute Jensen Shannon Divergence

In [None]:
matched_reference_sorted_indices = matched_reference_proportions_df.values.argsort(axis=1)[:, ::-1]
matched_reference_sorted_cell_types = matched_reference_proportions_df.columns.values[matched_reference_sorted_indices]

query_sorted_indices = query_proportions_df.values.argsort(axis=1)[:, ::-1]
query_sorted_cell_types = query_proportions_df.columns.values[query_sorted_indices]

In [None]:
# A JSD value close to 0 indicates that the two probability distributions being compared are very similar
# A JSD value close to 1 indicates that the two probability distributions being compared are very dissimilar
transfer_jsd_scores = []

for niche in range(len(reference_proportions_df.values)):
    p = reference_proportions_df.values[niche]
    q = transferred_query_proportions_df.values[niche]
    jsd = distance.jensenshannon(p, q)
    transfer_jsd_scores.append(jsd)

average_transfer_jsd = np.mean(transfer_jsd_scores)

In [None]:
# A JSD value close to 0 indicates that the two probability distributions being compared are very similar
# A JSD value close to 1 indicates that the two probability distributions being compared are very dissimilar
matching_jsd_scores = []

for niche in range(len(matched_reference_proportions_df.values)):
    p = matched_reference_proportions_df.values[niche]
    q = query_proportions_df.values[niche]
    jsd = distance.jensenshannon(p, q)
    matching_jsd_scores.append(jsd)

average_matching_jsd = np.mean(matching_jsd_scores)

In [None]:
# Create subplots with shared y-axis
fig, axes = plt.subplots(1, 2, figsize=(5, 5), sharey=True)
plt.suptitle("JSD Scores between Query and Reference Niches",
             fontsize=16)

# Plot violinplot for list1
sns.violinplot(data=[transfer_jsd_scores], ax=axes[0])
axes[0].set_title("Niche Transfer",
                  fontsize=12)

# Plot violinplot for list2
sns.violinplot(data=[matching_jsd_scores], ax=axes[1])
axes[1].set_title("Niche Matching",
                  fontsize=12)

# Set y-label
axes[0].set_ylabel("Niche Cell Type Proportion JSD Scores")

# Adjust spacing between subplots
plt.subplots_adjust(top=0.85)

# Show the plot
plt.show()