# Reference Query Mapping

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 18.05.2023
- **Date of Last Modification:** 05.07.2023

Before proceeding with this notebook, two steps need to be successfully conducted: <br>
- A model has been trained with a reference dataset with ```<root>/scripts/train_nichecompass_reference_model.py```. <br>
- A query has been mapped onto the reference model with ```<root>/scripts/map_query_on_nichecompass_reference_model.py```.

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../../utils")

In [None]:
import os
import warnings

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scarches as sca
import seaborn as sns
from scipy.spatial import distance
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment

from nichecompass.models import NicheCompass
from nichecompass.utils import create_new_color_dict
from analysis_utils import *
from reference_query_mapping_utils import Sankey

### 1.2 Define Parameters

In [None]:
### Dataset ###
dataset = "seqfish_mouse_organogenesis_imputed"

### Model ###
node_label_method = "one-hop-norm"
gp_names_key = "nichecompass_gp_names"
mapping_entity_key = "mapping_entity"
sample_key = "sample"
condition_key = "batch"
cell_type_key = "celltype_mapped_refined"
latent_key = "nichecompass_latent"
spatial_key = "spatial"
reference_query_model_label = f"{node_label_method}_reference_query_query_mapping"
load_timestamp = "01072023_165203_1"

### Analysis ###
latent_leiden_resolution = 0.2
latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"
spot_size = 0.03

latent_reference_leiden_resolution = 0.15
latent_reference_cluster_key = f"latent_leiden_reference_{latent_reference_leiden_resolution}"

latent_query_leiden_resolution = 0.15
latent_query_cluster_key = f"latent_leiden_query_{latent_query_leiden_resolution}"

label_key = latent_reference_cluster_key
ground_truth_key = latent_query_cluster_key
transfer_label_key = f"transferred_{label_key}"
transfer_label_uncertainty_key = f"transferred_{label_key}_uncertainty"
transfer_label_evaluation_key = f"{transfer_label_key}_evaluation"

### 1.3 Run Notebook Setup

In [None]:
warnings.filterwarnings("ignore")

### 1.4 Configure Paths

In [None]:
# Define paths
artifacts_folder_path = f"../../artifacts"
model_folder_path = f"{artifacts_folder_path}/{dataset}/models/{reference_query_model_label}/{load_timestamp}"
figure_folder_path = f"{artifacts_folder_path}/{dataset}/figures/{load_timestamp}"
os.makedirs(figure_folder_path, exist_ok=True)

## 2. Model

### 2.1 Load Model

In [None]:
# Load trained model
model = NicheCompass.load(dir_path=model_folder_path,
                          adata=None,
                          adata_file_name=f"{dataset}_{reference_query_model_label}.h5ad",
                          gp_names_key=gp_names_key)

## 3. Analysis

In [None]:
# Check number of active gene programs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

In [None]:
gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"] == True][0:20]

### 3.1 Analyze Latent Manifold

In [None]:
samples = model.adata.obs[sample_key].unique().tolist()

In [None]:
# Plot mapping entities in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "mapping_entities_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Mapping Entities",
    cat_key=mapping_entity_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=None,
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
# Plot batches in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "batches_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Batches",
    cat_key=condition_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=None,
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
cell_type_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=cell_type_key)

In [None]:
# Plot cell types in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "cell_types_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Cell Types",
    cat_key=cell_type_key,
    groups=None, # "ExE endoderm",
    sample_key=sample_key,
    samples=samples,
    cat_colors=cell_type_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
# Compute latent Leiden clustering 
sc.tl.leiden(adata=model.adata,
             resolution=latent_leiden_resolution,
             key_added=latent_cluster_key,
             neighbors_key=latent_key)

In [None]:
latent_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=latent_cluster_key)

In [None]:
# Latent clusters in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "latent_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Clusters",
    cat_key=latent_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=latent_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

### 3.2 Analyze Query Enrichments (vs Reference)

#### 3.2.1 Visualize Query-enriched Cell Types in Physical and Latent Space

In [None]:
adata_reference = model.adata[model.adata.obs[mapping_entity_key] == "reference"]
adata_query = model.adata[model.adata.obs[mapping_entity_key] == "query"]

reference_samples = adata_reference.obs[sample_key].unique().tolist()

In [None]:
enriched_query_cell_type_prop_thresh = 5.
query_enriched_cell_type_key = "query_enriched_cell_types"

# Get query-enriched cell types
cell_type_reference_proportions = adata_reference.obs[cell_type_key].value_counts().sort_index() / len(adata_reference)
cell_type_query_proportions = adata_query.obs[cell_type_key].value_counts().sort_index() / len(adata_query)
relative_cell_type_query_proportions = cell_type_query_proportions / cell_type_reference_proportions
relative_cell_type_query_proportions.sort_values(ascending=False, inplace=True)
display(relative_cell_type_query_proportions)

query_enriched_cell_types = relative_cell_type_query_proportions[relative_cell_type_query_proportions > enriched_query_cell_type_prop_thresh].index.to_list()
model.adata.obs[query_enriched_cell_type_key] = "Cell types not enriched in query"
for cell_type in query_enriched_cell_types:
    model.adata.obs.loc[model.adata.obs[cell_type_key] == cell_type, query_enriched_cell_type_key] = cell_type

In [None]:
# Plot of query-enriched cell-type annotations in physical and latent space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "query_enriched_cell_types_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Query Enriched Cell Types",
    cat_key=query_enriched_cell_type_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors="coolwarm",
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

#### 3.1.2 Visualize Query-enriched Niches / Latent Clusters in Physical and Latent Space

In [None]:
query_enriched_latent_cluster_key = "query_enriched_latent_cluster"
enriched_query_latent_cluster_prop_thresh = 5.

In [None]:
# Get query-enriched latent leiden clusters
latent_cluster_reference_proportions = adata_reference.obs[f"latent_leiden_{str(latent_leiden_resolution)}"].value_counts().sort_index() / len(adata_reference)
latent_cluster_query_proportions = adata_query.obs[f"latent_leiden_{str(latent_leiden_resolution)}"].value_counts().sort_index() / len(adata_query)
relative_latent_cluster_query_proportions = latent_cluster_query_proportions / latent_cluster_reference_proportions
relative_latent_cluster_query_proportions.sort_values(ascending=False, inplace=True)
display(relative_latent_cluster_query_proportions)

query_enriched_latent_clusters = relative_latent_cluster_query_proportions[relative_latent_cluster_query_proportions > enriched_query_latent_cluster_prop_thresh].index.to_list()
model.adata.obs[query_enriched_latent_cluster_key] = "Latent clusters not enriched in query"
for latent_cluster in query_enriched_latent_clusters:
    model.adata.obs.loc[model.adata.obs[f"latent_leiden_{str(latent_leiden_resolution)}"] == latent_cluster, query_enriched_latent_cluster_key] = f"Cluster {latent_cluster}"

In [None]:
# Plot of query-enriched cell-type annotations in physical and latent space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "query_enriched_latent_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Query Enriched Latent Clusters",
    cat_key=query_enriched_latent_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors="coolwarm",
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

### 3.3 Transfer Niche Labels from Reference to Query

#### 3.3.1 Compute Reference Niches

In [None]:
# Compute latent neighbor graph just for reference
sc.pp.neighbors(adata_reference,
                use_rep=latent_key,
                key_added=latent_key)

# Compute latent Leiden clustering just for reference
sc.tl.leiden(adata=adata_reference,
             resolution=latent_reference_leiden_resolution,
             key_added=latent_reference_cluster_key,
             neighbors_key=latent_key)

model.adata.obs[latent_reference_cluster_key] = adata_reference.obs[latent_reference_cluster_key]

In [None]:
latent_reference_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=latent_reference_cluster_key)

In [None]:
# Latent reference clusters in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "latent_reference_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Reference Clusters",
    cat_key=latent_reference_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=latent_reference_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
"""
# Compute spatial neighbor graph just for reference
sc.pp.neighbors(adata_reference,
                use_rep=spatial_key,
                key_added=spatial_key)

# Compute spatial Leiden clustering just for reference
spatial_reference_leiden_resolution = 0.15
spatial_reference_cluster_key = f"spatial_leiden_reference_{spatial_reference_leiden_resolution}"
 
sc.tl.leiden(adata=adata_reference,
             resolution=spatial_reference_leiden_resolution,
             key_added=spatial_reference_cluster_key,
             neighbors_key=spatial_key)

model.adata.obs[spatial_reference_cluster_key] = adata_reference.obs[spatial_reference_cluster_key]
"""

In [None]:
"""
spatial_reference_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=spatial_reference_cluster_key)
"""

In [None]:
"""
# Spatial reference clusters in latent and physical space
save_fig = False
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "spatial_reference_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Reference Clusters",
    cat_key=spatial_reference_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=spatial_reference_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)
"""

#### 3.3.2 Compute Query Niches

In [None]:
# Compute latent neighbor graph just for query
sc.pp.neighbors(adata_query,
                use_rep=latent_key,
                key_added=latent_key)

# Compute latent Leiden clustering just for query
sc.tl.leiden(adata=adata_query,
             resolution=latent_query_leiden_resolution,
             key_added=latent_query_cluster_key,
             neighbors_key=latent_key)

model.adata.obs[latent_query_cluster_key] = adata_query.obs[latent_query_cluster_key]

# Map query niches to reference niches using cell type proportions
# Find best matching niches using pairwise euclidean distance between niche cell type proportion
# vectors

# Compute cell type proportions
reference_counts_df = model.adata[model.adata.obs[mapping_entity_key] == "reference"].obs.groupby(
    [latent_reference_cluster_key, cell_type_key]).size().unstack()
query_counts_df = model.adata[model.adata.obs[mapping_entity_key] == "query"].obs.groupby(
    [latent_query_cluster_key, cell_type_key]).size().unstack()
reference_proportions_df = reference_counts_df.div(reference_counts_df.sum(axis=1), axis=0)
query_proportions_df = query_counts_df.div(query_counts_df.sum(axis=1), axis=0)

# Calculate the Euclidean distance between each pair of niches in the reference and query
distances = cdist(query_proportions_df.values,
                  reference_proportions_df.values,
                  metric="euclidean")

# Apply the Hungarian algorithm to minimize the total Euclidean distance
#query_latent_cluster_indices, reference_latent_cluster_indices = linear_sum_assignment(distances)
#query_latent_cluster_indices = query_latent_cluster_indices.astype(str)
#reference_latent_cluster_indices = reference_latent_cluster_indices.astype(str)

# For each query niche find the reference niche with the minimum Euclidean distance
query_latent_cluster_indices = np.arange(adata_query.obs[latent_query_cluster_key].nunique()).astype(str)
reference_latent_cluster_indices = np.argmin(distances, axis=1).astype(str)

mapping_dict = {k: v for k, v in zip(query_latent_cluster_indices, reference_latent_cluster_indices)}

# Assign indices of most similar reference clusters to query
query_proportions_df.index = reference_latent_cluster_indices
query_proportions_df.sort_index(inplace=True)
query_proportions_df.index.name = "Niche"

model.adata.obs[latent_query_cluster_key] = model.adata.obs[latent_query_cluster_key].map(mapping_dict)

In [None]:
# Latent query clusters in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "latent_query_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Query Clusters",
    cat_key=latent_query_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=latent_reference_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
"""
# Compute spatial neighbor graph just for query
sc.pp.neighbors(adata_query,
                use_rep=spatial_key,
                key_added=spatial_key)

# Compute spatial Leiden clustering just for query
spatial_query_leiden_resolution = 0.11
spatial_query_cluster_key = f"latent_leiden_query_{spatial_query_leiden_resolution}"
 
sc.tl.leiden(adata=adata_query,
             resolution=spatial_query_leiden_resolution,
             key_added=spatial_query_cluster_key,
             neighbors_key=spatial_key)

model.adata.obs[spatial_query_cluster_key] = adata_query.obs[spatial_query_cluster_key]
"""

In [None]:
"""
spatial_query_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=spatial_query_cluster_key)
"""

In [None]:
"""
# Spatial query clusters in latent and physical space
save_fig = False
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "spatial_query_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Query Clusters",
    cat_key=spatial_query_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=spatial_query_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)
"""

#### 3.3.3 Transfer Niche Labels

In [None]:
# Prepare label transfer via scarches
knn_transformer = sca.utils.knn.weighted_knn_trainer(
    train_adata=adata_reference,
    train_adata_emb=latent_key,
    n_neighbors=15)

In [None]:
# Compute label transfer via scarches
labels, uncert = sca.utils.knn.weighted_knn_transfer(
    query_adata=adata_query,
    query_adata_emb=latent_key,
    label_keys=label_key,
    knn_model=knn_transformer,
    ref_adata_obs=adata_reference.obs)

In [None]:
labels.rename(columns={label_key: transfer_label_key}, inplace=True)
uncert.rename(columns={label_key: transfer_label_uncertainty_key}, inplace=True)

# Join results of label transfer to adata
model.adata.obs = model.adata.obs.join(labels)
model.adata.obs = model.adata.obs.join(uncert)

In [None]:
# Add evaluations
model.adata.obs[transfer_label_evaluation_key] = model.adata.obs.apply(
    lambda row: "Correct" if row[transfer_label_key] == row[ground_truth_key] else (
        row[transfer_label_key] if pd.isnull(row[transfer_label_key]) else (
            "Not in Query" if row[transfer_label_key] not in model.adata.obs[ground_truth_key].unique().tolist() else "Incorrect")), axis=1)

In [None]:
# Plot uncertainties
sns.distplot(model.adata.obs[transfer_label_uncertainty_key])

In [None]:
# Set high uncertainty labels evaluation to 'Unknown'
uncertainty_threshold = 0.001

model.adata.obs[transfer_label_evaluation_key] = model.adata.obs[transfer_label_evaluation_key].mask(
    model.adata.obs[transfer_label_uncertainty_key] > uncertainty_threshold,
    "Unknown")

In [None]:
adata_query_label_transfer = model.adata[model.adata.obs[mapping_entity_key] == "query"]

In [None]:
print(f"Percentage of 'Unknown', with uncertainty_threshold = {uncertainty_threshold}:")
print(f"{np.round(sum(adata_query_label_transfer.obs[transfer_label_evaluation_key] =='Unknown')/adata_query_label_transfer.n_obs*100,2)}%")

In [None]:
label_cats = adata_query_label_transfer.obs[transfer_label_key].unique().tolist()

In [None]:
perc_correct = pd.crosstab(
    adata_query_label_transfer.obs[transfer_label_key],
    adata_query_label_transfer.obs[transfer_label_evaluation_key],
).loc[label_cats, :]
total_n_per_ct = adata_query_label_transfer.obs[transfer_label_key].value_counts()
perc_correct = perc_correct.div(perc_correct.sum(axis=1), axis="rows") * 100
# add a bar (=row) for the entire query dataset:
perc_correct.index = perc_correct.index.tolist()
total_n_per_ct.index = total_n_per_ct.index.tolist()
total_n_per_ct["Overall"] = adata_query_label_transfer.shape[0]
perc_correct.loc["Overall", :] = (
    adata_query_label_transfer.obs[transfer_label_evaluation_key].value_counts()
    / total_n_per_ct["Overall"]
    * 100
)

incl_celln_in_label = True
# set celltype order:
# follow bio order, except that new/unseen cell types will come first:
plot_label_cats = label_cats + ["Overall"]
perc_correct = perc_correct.loc[plot_label_cats, :]
with plt.rc_context(
    {
        "figure.figsize": (0.4 * len(label_cats), 3),
        "axes.spines.right": False,
        "axes.spines.top": False,
    }
):
    fig, ax = plt.subplots()
    perc_correct.plot(kind="bar", stacked=True, ax=ax)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[::-1], labels[::-1], loc=(1.01, 0.60), frameon=False)

    cts_no_underscore = [ct.replace("_", " ") for ct in label_cats]
    if incl_celln_in_label:
        plt.xticks(
            ticks=range(len(label_cats) + 1),
            labels=[
                f"{ct_no_und} ({total_n_per_ct[ct]})"
                for ct_no_und, ct in zip(
                    cts_no_underscore + ["Overall"],
                    plot_label_cats,  # ct_df_q + ["Overall"]
                )
            ],
        )
        plt.xlabel("Original label (n cells)")
    else:
        plt.xticks(
            ticks=range(len(plot_label_cats)),
            labels=[f"{ct_no_und}" for ct_no_und in cts_no_underscore],
        )
        plt.xlabel("Original label")
    ax.set_ylabel("% of cells")
    plt.grid(False)

In [None]:
model.adata.obs

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
sankey(
    x=adata_query_label_transfer.obs[ground_truth_key],
    y=adata_query_label_transfer.obs[transfer_label_key],
    title="Original label vs. predicted annotation",
    title_left="Original label",
    title_right="Predicted annotation",
    ax=ax,
    fontsize="5",  # "xx-small",
    left_order=label_cats,
    colorside="left",
    alpha=0.5,
)
plt.show()

In [None]:
# Prepare reference data for cell type composition comparison plot by unrolling
# Use non-transferred labels for reference
reference_counts_df = model.adata[model.adata.obs[mapping_entity_key] == "reference"].obs.groupby(
    [label_key, cell_type_key]).size().unstack()

reference_counts_df[transfer_label_key] = reference_counts_df.index

unrolled_reference_counts_df = pd.melt(
    reference_counts_df,
    id_vars=transfer_label_key,
    var_name=cell_type_key,
    value_name="counts")

unrolled_reference_counts_df[mapping_entity_key] = "reference"

In [None]:
# Prepare query data for cell type composition comparison plot by unrolling
# Use transferred labels for query
query_counts_df = adata_query_label_transfer.obs.groupby(
    [transfer_label_key, cell_type_key]).size().unstack()

query_counts_df[transfer_label_key] = query_counts_df.index

unrolled_query_counts_df = pd.melt(
    query_counts_df,
    id_vars=transfer_label_key,
    var_name=cell_type_key,
    value_name="counts")

unrolled_query_counts_df[mapping_entity_key] = "query"

In [None]:
# Combine the unrolled dfs
unrolled_combined_counts_df = pd.concat([unrolled_reference_counts_df,
                                         unrolled_query_counts_df]).reset_index(drop=True)

In [None]:
unrolled_combined_counts_df

In [None]:
# Adapted from https://stackoverflow.com/questions/22787209/how-to-have-clusters-of-stacked-bars


def plot_clustered_stacked(dfall, labels=None, title="multiple stacked bar plot",  H="//", **kwargs):
    """Given a list of dataframes, with identical columns and index, create a clustered stacked bar plot. 
labels is a list of the names of the dataframe, used for the legend
title is a string for the title of the plot
H is the hatch used for identification of the different dataframe"""

    n_df = len(dfall)
    n_col = len(dfall[0].columns) 
    n_ind = len(dfall[0].index)
    plt.figure(figsize=(10, 5))
    axe = plt.subplot(111)

    for df in dfall: # for each data frame
        axe = df.plot(kind="bar",
                      linewidth=0,
                      stacked=True,
                      ax=axe,
                      legend=False,
                      grid=False,
                      **kwargs)  # make bar plots

    h,l = axe.get_legend_handles_labels() # get the handles we want to modify
    for i in range(0, n_df * n_col, n_col): # len(h) = n_col * n_df
        for j, pa in enumerate(h[i:i+n_col]):
            for rect in pa.patches: # for each index
                rect.set_x(rect.get_x() + 1 / float(n_df + 1) * i / float(n_col))
                rect.set_hatch(H * int(i / n_col)) #edited part     
                rect.set_width(1 / float(n_df + 1))

    axe.set_xticks((np.arange(0, 2 * n_ind, 2) + 1 / float(n_df + 1)) / 2.)
    axe.set_xticklabels(df.index, rotation = 0)
    axe.set_title(title)
    
    # Add invisible data to add another legend
    n=[]        
    for i in range(n_df):
        n.append(axe.bar(0, 0, color="gray", hatch=H * i))
        
    l1 = axe.legend(h[:n_col], l[:n_col], loc="best", bbox_to_anchor=(0.8, 0.55, 0.5, 0.5), fontsize=8)
    if labels is not None:
        l2 = plt.legend(n, labels, loc="best", bbox_to_anchor=(0.8, -0.5, 0.5, 0.5), fontsize=8) 
    axe.add_artist(l1)
    #plt.subplots_adjust()
    return axe

In [None]:
reference_counts_df = model.adata[model.adata.obs[mapping_entity_key] == "reference"].obs.groupby(
    [label_key, cell_type_key]).size().unstack()

query_counts_df = adata_query_label_transfer.obs.groupby(
    [transfer_label_key, cell_type_key]).size().unstack()

reference_proportions_df = reference_counts_df.div(reference_counts_df.sum(axis=1), axis=0)
query_proportions_df = query_counts_df.div(query_counts_df.sum(axis=1), axis=0)

reference_proportions_df.index = reference_proportions_df.index.astype(int)
reference_proportions_df.sort_index(inplace=True)
reference_proportions_df.index.name = "Niche"

query_proportions_df.index = query_proportions_df.index.astype(int)
query_proportions_df.sort_index(inplace=True)
query_proportions_df.index.name = "Niche"

In [None]:
plot_clustered_stacked([reference_proportions_df, query_proportions_df],["reference", "query"], title="Cell Type Proportions in Reference Niches and Transferred Query Niches")

In [None]:
# A JSD value close to 0 indicates that the two probability distributions being compared are very similar
# A JSD value close to 1 indicates that the two probability distributions being compared are very dissimilar

jsd_scores = []

for subgroup in range(len(query_proportions_df.values)):
    p = reference_proportions_df.values[subgroup]
    q = query_proportions_df.values[subgroup]
    jsd = distance.jensenshannon(p, q)
    jsd_scores.append(jsd)

average_jsd = np.mean(jsd_scores)

In [None]:
df.columns.tolist()

In [None]:
spatial_reference_counts_df = model.adata[model.adata.obs[mapping_entity_key] == "reference"].obs.groupby(
    [spatial_reference_cluster_key, cell_type_key]).size().unstack()

In [None]:
spatial_query_counts_df = model.adata[model.adata.obs[mapping_entity_key] == "query"].obs.groupby(
    [spatial_query_cluster_key, cell_type_key]).size().unstack()

In [None]:
spatial_query_counts_df

In [None]:
spatial_reference_counts_df = model.adata[model.adata.obs[mapping_entity_key] == "reference"].obs.groupby(
    [spatial_reference_cluster_key, cell_type_key]).size().unstack()

spatial_query_counts_df = model.adata[model.adata.obs[mapping_entity_key] == "query"].obs.groupby(
    [spatial_query_cluster_key, cell_type_key]).size().unstack()

spatial_reference_proportions_df = spatial_reference_counts_df.div(spatial_reference_counts_df.sum(axis=1), axis=0)
spatial_query_proportions_df = spatial_query_counts_df.div(spatial_query_counts_df.sum(axis=1), axis=0)

spatial_reference_proportions_df.index = spatial_reference_proportions_df.index.astype(int)
spatial_reference_proportions_df.sort_index(inplace=True)
spatial_reference_proportions_df.index.name = "Niche"

spatial_query_proportions_df.index = spatial_query_proportions_df.index.astype(int)

In [None]:
# Calculate the Euclidean distance between each pair of niches in the reference and query
distances = cdist(spatial_query_proportions_df.values,
                  spatial_reference_proportions_df.values,
                  metric="euclidean")

In [None]:
# Initialize an empty list to store the selected indices
selected_indices = []

# Iterate over each row of the distance array
for i in range(distances.shape[0]):
    # Exclude indices that have already been taken
    available_indices = np.setdiff1d(np.arange(distances.shape[1]), selected_indices)

    # Get the indices of the minimum distances for the available indices
    min_indices = np.argsort(distances[i, available_indices])[:1]

    # Convert the min_indices back to the original indices
    original_indices = available_indices[min_indices]

    # Add the selected index to the list
    selected_indices.append(original_indices[0])

# Print the selected indices
print(selected_indices)

In [None]:
# Apply the Hungarian algorithm to minimize the total distance
from scipy.optimize import linear_sum_assignment
query_latent_cluster_indices, reference_latent_cluster_indices = linear_sum_assignment(distances)

# Print the selected indices
selected_indices = reference_latent_cluster_indices.tolist()
print(selected_indices)
spatial_query_proportions_df.sort_index(inplace=True)
spatial_query_proportions_df.index.name = "Niche"

In [None]:
# Assign indices of most similar reference clusters to query
spatial_query_proportions_df.index = selected_indices

In [None]:
spatial_reference_sorted_indices = spatial_reference_proportions_df.values.argsort(axis=1)[:, ::-1]
spatial_reference_sorted_cell_types = spatial_reference_proportions_df.columns.values[spatial_reference_sorted_indices]

spatial_query_sorted_indices = spatial_query_proportions_df.values.argsort(axis=1)[:, ::-1]
spatial_query_sorted_cell_types = spatial_query_proportions_df.columns.values[spatial_query_sorted_indices]

In [None]:
spatial_top_n_accuracies = []
for n in range(1, len(query_proportions_df.values) + 1):
    spatial_top_n_accuracies.append(calculate_top_accuracy(spatial_reference_sorted_cell_types, spatial_query_sorted_cell_types, n=n))

In [None]:
spatial_top_n_accuracies

In [None]:
min_distance_indices = np.argmin(distances, axis=1)

In [None]:
min_distance_indices

In [None]:
min_indices = np.unravel_index(np.argmin(distances, axis=1), distances.shape)

In [None]:
min_indices

In [None]:
np.argmin(distances, axis=0)

In [None]:
reference_sorted_indices = reference_proportions_df.values.argsort(axis=1)[:, ::-1]
reference_sorted_cell_types = reference_proportions_df.columns.values[reference_sorted_indices]

query_sorted_indices = query_proportions_df.values.argsort(axis=1)[:, ::-1]
query_sorted_cell_types = query_proportions_df.columns.values[query_sorted_indices]

In [None]:
spatial_query_proportions_df

In [None]:
def calculate_top_accuracy(list1, list2, n):
    num_groups = len(list1)
    top_n_matches = 0
    top_n_total = 0

    for i in range(num_groups):
        group1_categories = list1[i][:n]
        group2_categories = list2[i][:n]
        common_categories = set(group1_categories).intersection(group2_categories)
        
        top_n_matches += len(common_categories)
        top_n_total += n

    accuracy = top_n_matches / top_n_total
    return accuracy

# Example lists of top categories per group
list1 = [['A', 'B', 'C'], ['X', 'Y', 'Z'], ['P', 'Q', 'R']]
list2 = [['A', 'B', 'C'], ['X', 'Z', 'Y'], ['R', 'P', 'Q']]

# Calculate top 1 accuracy
top1_accuracy = calculate_top_accuracy(list1, list2, n=1)
print("Top 1 Accuracy:", top1_accuracy)

# Calculate top 2 accuracy
top2_accuracy = calculate_top_accuracy(list1, list2, n=2)
print("Top 2 Accuracy:", top2_accuracy)

# Calculate top 3 accuracy
top3_accuracy = calculate_top_accuracy(list1, list2, n=3)
print("Top 3 Accuracy:", top3_accuracy)

In [None]:
top_accuracies = []
for n in range(1, len(query_proportions_df.values) + 1):
    top_accuracies.append(calculate_top_accuracy(reference_sorted_cell_types, query_sorted_cell_types, n=n))

In [None]:
sns.lineplot(top_accuracies, marker="o", label="NicheCompass Niches")
sns.lineplot(spatial_top_n_accuracies, marker="o", label="Best Matching Spatial Niches")
plt.xlabel("n")
plt.ylabel("Top n Accuracy")
plt.ylim(0, 1)
plt.legend()
plt.show()

In [None]:
top2_accuracy = calculate_top_accuracy(reference_sorted_cell_types, query_sorted_cell_types, n=2)

In [None]:
reference_sorted_cell_types

In [None]:
# Print the ordered list of categories for each group
for group, categories in zip(query_proportions_df.index, sorted_cell_types):
    print(f"{group}: {', '.join(categories)}")

In [None]:
sorted_cell_types

In [None]:
# Compute top accuracies
query_proportions_df

In [None]:
average_jsd

In [None]:
save_fig = False
file_path = f"{figure_folder_path}/" \
            f"res_{latent_leiden_resolution}_" \
            f"niche_composition.svg"

ax = df_counts.plot(kind="bar", stacked=True, figsize=(10,10))
legend = plt.legend(bbox_to_anchor=(1, 1), loc="upper left", prop={'size': 10})
legend.set_title("Cell Type Annotations", prop={'size': 10})
plt.title("Cell Type Composition of Niches")
plt.xlabel("Niche")
plt.ylabel("Cell Type Counts")
if save_fig:
    plt.savefig(file_path,
                bbox_extra_artists=(legend,),
                bbox_inches="tight")

In [None]:
model.adata.obsm["X_pca"].shape

In [None]:
adata_reference.obsm["X_pca"]

In [None]:
adata_query.obsm["X_pca"]

In [None]:
# Prepare label transfer via scarches
knn_transformer = sca.utils.knn.weighted_knn_trainer(
    train_adata=adata_reference,
    train_adata_emb="X_pca",
    n_neighbors=50)

In [None]:
# Compute label transfer via scarches
labels, uncert = sca.utils.knn.weighted_knn_transfer(
    query_adata=adata_query,
    query_adata_emb="X_pca",
    label_keys=label_key,
    knn_model=knn_transformer,
    ref_adata_obs=adata_reference.obs)

In [None]:
labels.rename(columns={label_key: f"{transfer_label_key}_pca"}, inplace=True)
uncert.rename(columns={label_key: f"{transfer_label_uncertainty_key}_pca"}, inplace=True)

# Join results of label transfer to adata
model.adata.obs = model.adata.obs.join(labels)
model.adata.obs = model.adata.obs.join(uncert)

In [None]:
labels

In [None]:
model.adata.obs[f"{transfer_label_evaluation_key}_pca"] = model.adata.obs.apply(
    lambda row: "Correct" if row[f"{transfer_label_key}_pca"] == row[label_key] else (row[f"{transfer_label_key}_pca"] if pd.isnull(row[f"{transfer_label_key}_pca"]) else "Incorrect"), axis=1)

In [None]:
uncertainty_threshold = 0.2

model.adata.obs[f"{transfer_label_evaluation_key}_pca"] = model.adata.obs[f"{transfer_label_evaluation_key}_pca"].mask(
    model.adata.obs[f"{transfer_label_uncertainty_key}_pca"] > uncertainty_threshold,
    "Unknown")

In [None]:
adata_query_label_transfer = model.adata[model.adata.obs[mapping_entity_key] == "query"]

In [None]:
print(f"Percentage of 'Unknown', with uncertainty_threshold = {uncertainty_threshold}:")
print(f"{np.round(sum(adata_query_label_transfer.obs[transfer_label_evaluation_key] =='Unknown')/adata_query_label_transfer.n_obs*100,2)}%")

In [None]:
query_enriched_labels = (query_enriched_cell_types if label_key == "celltype_mapped_refined" else query_enriched_latent_clusters)
not_query_enriched_labels = list(set(model.adata.obs[label_key].unique().tolist()) - set(query_enriched_labels))
label_cats = not_query_enriched_labels + query_enriched_labels

In [None]:
adata_query_label_transfer.obs

In [None]:
perc_correct = pd.crosstab(
    adata_query_label_transfer.obs[f"{transfer_label_key}_pca"],
    adata_query_label_transfer.obs[f"{transfer_label_evaluation_key}_pca"],
).loc[label_cats, :]
total_n_per_ct = adata_query_label_transfer.obs[transfer_label_key].value_counts()
perc_correct = perc_correct.div(perc_correct.sum(axis=1), axis="rows") * 100
# add a bar (=row) for the entire query dataset:
perc_correct.index = perc_correct.index.tolist()
total_n_per_ct.index = total_n_per_ct.index.tolist()
total_n_per_ct["Overall"] = adata_query_label_transfer.shape[0]
perc_correct.loc["Overall", :] = (
    adata_query_label_transfer.obs[transfer_label_evaluation_key].value_counts()
    / total_n_per_ct["Overall"]
    * 100
)

incl_celln_in_label = True
# set celltype order:
# follow bio order, except that new/unseen cell types will come first:
plot_label_cats = label_cats + ["Overall"]
perc_correct = perc_correct.loc[plot_label_cats, :]
with plt.rc_context(
    {
        "figure.figsize": (0.4 * len(label_cats), 3),
        "axes.spines.right": False,
        "axes.spines.top": False,
    }
):
    fig, ax = plt.subplots()
    perc_correct.plot(kind="bar", stacked=True, ax=ax)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[::-1], labels[::-1], loc=(1.01, 0.60), frameon=False)

    cts_no_underscore = [ct.replace("_", " ") for ct in label_cats]
    if incl_celln_in_label:
        plt.xticks(
            ticks=range(len(label_cats) + 1),
            labels=[
                f"{ct_no_und} ({total_n_per_ct[ct]})"
                for ct_no_und, ct in zip(
                    cts_no_underscore + ["Overall"],
                    plot_label_cats,  # ct_df_q + ["Overall"]
                )
            ],
        )
        plt.xlabel("Original label (n cells)")
    else:
        plt.xticks(
            ticks=range(len(plot_label_cats)),
            labels=[f"{ct_no_und}" for ct_no_und in cts_no_underscore],
        )
        plt.xlabel("Original label")
    ax.set_ylabel("% of cells")
    plt.grid(False)

In [None]:
# Plot of query-enriched cell-type annotations in physical and latent space
save_fig = False
file_path = f"{figure_folder_path}/" \
            "query_enriched_latent_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Label Transfer Uncertainty",
    cat_key=transfer_label_key,
    groups="Unknown",
    sample_key=sample_key,
    samples=samples,
    cat_colors="coolwarm",
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

## 4. Benchmarking

### 3.1 Batch Integration Baselines

In [None]:
### TO DO ###

#### 3.1.1 scVI

In [None]:
scvi.settings.seed = random_seed
    
# Setup adata
scvi.model.SCVI.setup_anndata(adata_one_shot,
                              layer=counts_key,
                              batch_key=condition_key)

# Initialize model
# Use hyperparams that provenly work well on integration tasks (https://docs.scvi-tools.org/en/stable/tutorials/notebooks/harmonization.html)
vae = scvi.model.SCVI(adata_one_shot,
                      n_layers=2,
                      n_latent=30,
                      gene_likelihood="nb")

# Train model
vae.train()

adata_one_shot.obsm["scvi_latent"] = vae.get_latent_representation()

In [None]:
# Use scVI latent space for UMAP generation
sc.pp.neighbors(adata_one_shot, use_rep="scvi_latent")
sc.tl.umap(adata_one_shot)

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(adata_one_shot,
                 color=[cell_type_key],
                 legend_fontsize=12,
                 return_fig=True)
plt.title("scVI Integration: Latent Space Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_scvi.png",
            bbox_inches="tight")

#### 3.1.2 SageNet

In [None]:
#################################################################
This cell throws an error as the solver cannot solve this problem
#################################################################


# Construct gene interaction network for spatial references
for i in range(len(adata_batch_list[:-2])):
    adata_batch_list[i].X = adata_batch_list[i].X.toarray() # convert to dense matrix as required by glasso
    print("Computing gene interaction network...")
    glasso(adata_batch_list[i], [0.25, 0.5])
    adata_batch_list[i].X = sp.csc_matrix(adata_batch_list[i].X) # convert back to sparse matrix
    print("Computing Leiden clusters...")
    sc.tl.leiden(adata_batch_list[i],
                 resolution=.05,
                 random_state=random_seed,
                 key_added="leiden_0.05",
                 adjacency=adata_batch_list[i].obsp["spatial_connectivities"])
    sc.tl.leiden(adata_batch_list[i],
                 resolution=.1,
                 random_state=random_seed,
                 key_added="leiden_0.1",
                 adjacency=adata_batch_list[i].obsp["spatial_connectivities"])
    sc.tl.leiden(adata_batch_list[i],
                 resolution=.5,
                 random_state=random_seed,
                 key_added="leiden_0.5",
                 adjacency=adata_batch_list[i].obsp["spatial_connectivities"])

In [None]:
if torch.cuda.is_available():
  dev = "cuda:0"
else:
  dev = "cpu"
device = torch.device(dev)
print(device)

In [None]:
# Define model object
sg_obj = sca.models.sagenet(device=device)

In [None]:
# Train model on spatial references
for i in range(len(adata_batch_list[:-2])):
    sg_obj.train(adata_batch_list[i],
                 comm_columns=['leiden_0.05', 'leiden_0.1', 'leiden_0.5'],
                 tag=f'batch{i}',
                 epochs=15,
                 verbose = False,
                 importance=True)

In [None]:
# Save model
os.makedirs(model_artifacts_folder_path + "/sagenet")
sg_obj.save_as_folder(model_artifacts_folder_path + "/sagenet")

In [None]:
# Load model
sg_obj_load = sca.models.sagenet(device=device)
sg_obj_load.load_from_folder(model_artifacts_folder_path + "/sagenet")

In [None]:
# Load query
sg_obj_load.load_query_data(adata)

In [None]:
# Use SageNet cell-cell-distances for UMAP generation
sc.pp.neighbors(adata, use_rep="dist_map")
sc.tl.umap(adata)

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(adata,
                 color=[condition_key],
                 legend_fontsize=12,
                 return_fig=True)
plt.title("SageNet Integration: Latent Space Batch Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_batches_sagenet.png",
            bbox_inches="tight")

In [None]:
# Plot UMAP with cell type annotations
fig = sc.pl.umap(adata,
                 color=[cell_type_key],
                 return_fig=True)
plt.title("SageNet: Latent Space Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_sagenet.png",
            bbox_inches="tight")

#### 3.1.3 BBKNN

In [None]:
latent_bbknn_list = []
for i in range(len(adata_batch_list)):
    # Initialize model
    model = NicheCompass(adata_batch_list[i],
                       counts_key=counts_key,
                       adj_key=adj_key,
                       condition_key=condition_key,
                       cond_embed_injection=["encoder",
                                             "gene_expr_decoder",
                                             "graph_decoder"],
                       n_cond_embed=n_cond_embed,
                       gp_names_key=gp_names_key,
                       active_gp_names_key=active_gp_names_key,
                       gp_targets_mask_key=gp_targets_mask_key,
                       gp_sources_mask_key=gp_sources_mask_key,
                       latent_key=latent_key,
                       active_gp_thresh_ratio=0., # all gps will be active for concatenation across batches
                       gene_expr_recon_dist=gene_expr_recon_dist,
                       n_hidden_encoder=n_hidden_encoder,
                       log_variational=True)
    print("")
    
    # Train model
    model.train(n_epochs=n_epochs,
                n_epochs_all_gps=n_epochs, # all gps will be active for concatenation across batches
                lr=lr,
                lambda_edge_recon=lambda_edge_recon,
                lambda_gene_expr_recon=lambda_gene_expr_recon,
                verbose=True)
    print("")
    
    # Save trained model
    model.save(dir_path=model_artifacts_folder_path + f"/bbknn_batch{i+1}",
               overwrite=True,
               save_adata=True,
               adata_file_name=f"{dataset}.h5ad")
    
    latent_bbknn_current_batch = model.get_latent_representation(
        adata=adata_batch_list[i],
        counts_key=counts_key,
        condition_key=condition_key,
        only_active_gps=False)
    
    latent_bbknn_list.append(latent_bbknn_current_batch)
    
adata_bbknn.obsm[latent_key] = np.vstack(latent_bbknn_list)

# Store adata to disk
adata_bbknn.write(f"{model_artifacts_folder_path}/adata_bbknn.h5ad")

In [None]:
if load_timestamp is not None:
    model_artifacts_load_folder_path = f"../artifacts/{dataset}/batch_integration/{load_timestamp}"
else:
    model_artifacts_load_folder_path = model_artifacts_folder_path

# Read adata from disk
adata_bbknn = sc.read_h5ad(f"{model_artifacts_load_folder_path}/adata_bbknn.h5ad")

In [None]:
# Compute batch-corrected latent nearest neighbor graph
bbknn.bbknn(adata=adata_bbknn,
            batch_key=condition_key,
            use_rep=latent_key)

adata_bbknn.obsp[f"{latent_knng_key}_connectivities"] = (
    adata_bbknn.obsp["connectivities"])

adata_bbknn.obsp[f"{latent_knng_key}_distances"] = (
    adata_bbknn.obsp["distances"])

In [None]:
# Use batch-corrected latent space for UMAP generation
sc.tl.umap(adata_bbknn)

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(adata_bbknn,
                 color=[condition_key],
                 legend_fontsize=12,
                 return_fig=True)
plt.title("BBKNN Integration: Latent Space Batch Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_batches_bbknn.png",
            bbox_inches="tight")

In [None]:
# Plot UMAP with cell type annotations
fig = sc.pl.umap(adata_bbknn,
                 color=[cell_type_key],
                 return_fig=True)
plt.title("BBKNN Integration: Latent Space Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_bbknn.png",
            bbox_inches="tight")

In [None]:
# Compute spatial nearest neighbor graph
sc.pp.neighbors(adata_bbknn, use_rep=spatial_key, key_added=spatial_knng_key)

In [None]:
metrics_dict_bbknn = {}

metrics_dict_bbknn["cad"] = compute_cad(
    adata=adata_bbknn,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)

metrics_dict_bbknn["rclisi"] = compute_rclisi(
    adata=adata_bbknn,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)
    
metrics_dict_bbknn["batch_asw"] = scib.me.silhouette_batch(
    adata=adata_bbknn,
    batch_key=condition_key,
    label_key=cell_type_key,
    embed="X_umap")

# knn output
metrics_dict_bbknn["ilisi"] = scib.me.ilisi_graph(
    adata=adata_bbknn,
    batch_key=condition_key,
    type_="knn")

"""
metrics_dict_bbknn["kbet"] = scib.me.kBET(
    adata=adata_bbknn,
    batch_key=condition_key,
    label_key=cell_type_key,
    type_="knn")
"""

print(metrics_dict_bbknn)

# Store to disk
with open(f"{model_artifacts_folder_path}/metrics_bbknn.pickle", "wb") as f:
    pickle.dump(metrics_dict_bbknn, f)

#### 3.2.4 Compute Metrics

In [None]:
# Store computed latent nearest neighbor graph in connectivities
# as required by scib metrics
model.adata.obsp["connectivities"] = (
    model.adata.obsp[f"{latent_knng_key}_connectivities"])
model.adata.obsp["distances"] = (
    model.adata.obsp[f"{latent_knng_key}_distances"])
model.adata.uns["neighbors"] = (
    model.adata.uns[f"{latent_knng_key}"])

# Compute spatial nearest neighbor graph
sc.pp.neighbors(model.adata,
                use_rep=spatial_key,
                key_added=spatial_knng_key)

In [None]:
# Compute metrics
metrics_dict_oneshot = {}

# Spatial conservation metrics
metrics_dict_oneshot["cas"] = compute_cas(
    adata=model.adata,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)
metrics_dict_oneshot["clisis"] = compute_clisis(
    adata=model.adata,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)
    
# Batch correction metrics
metrics_dict_oneshot["batch_asw"] = scib.me.silhouette_batch(
    adata=model.adata,
    batch_key=condition_key,
    label_key=cell_type_key,
    embed="X_umap")
metrics_dict_oneshot["ilisi"] = scib.me.ilisi_graph(
    adata=model.adata,
    batch_key=condition_key,
    type_="knn")

print(metrics_dict_oneshot)

# Store metrics to disk
with open(f"{model_artifacts_folder_path}/metrics_oneshot.pickle", "wb") as f:
    pickle.dump(metrics_dict_oneshot, f)

#### 3.2.5 Visualize Conditional Embedding

In [None]:
# Get conditional embeddings
cond_embed = model.get_cond_embeddings()
cond = model.adata.obs["batch"].unique()

# Get top 2 principal components and plot them
pca = KernelPCA(n_components=2, kernel="linear")
cond_embed_pca = pca.fit_transform(cond_embed)
sns.scatterplot(x=cond_embed_pca[:, 0], 
                y=cond_embed_pca[:, 1], 
                hue=cond)
plt.title("One-Shot Integration Conditional Embeddings", pad=15)
plt.xlabel("Principal Component 1")
plt.xticks(fontsize=12)
plt.ylabel ("Principal Component 2")
plt.yticks(fontsize=12)
plt.legend(bbox_to_anchor=(1.02, 0.75),
           loc=2,
           borderaxespad=0.,
           fontsize=12,
           frameon=False)
plt.savefig(f"{figure_folder_path}/cond_embed_oneshot.png",
            bbox_inches="tight")