# Reference Query Mapping

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 18.05.2023
- **Date of Last Modification:** 01.07.2023

Before proceeding with this notebook, two steps need to be successfully conducted: <br>
- A model has been trained with a reference dataset with ```<root>/scripts/train_nichecompass_reference_model.py```. <br>
- A query has been mapped onto the reference model with ```<root>/scripts/map_query_on_nichecompass_reference_model.py```.

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../../utils")

In [None]:
import os
import warnings

from nichecompass.models import NicheCompass
from nichecompass.utils import create_new_color_dict
from analysis_utils import *

### 1.2 Define Parameters

In [None]:
### Dataset ###
dataset = "seqfish_mouse_organogenesis_imputed"

### Model ###
node_label_method = "one-hop-norm"
gp_names_key = "nichecompass_gp_names"
mapping_entity_key = "mapping_entity"
sample_key = "sample"
condition_key = "batch"
cell_type_key = "celltype_mapped_refined"
latent_key = "nichecompass_latent"
reference_query_model_label = f"{node_label_method}_reference_query_query_mapping"
load_timestamp = "01072023_165203_1"

### Analysis ###
latent_leiden_resolution = 0.2
latent_cluster_key = f"latent_leiden_{str(latent_leiden_resolution)}"
spot_size = 0.03

### 1.3 Run Notebook Setup

In [None]:
warnings.filterwarnings("ignore")

### 1.4 Configure Paths

In [None]:
# Define paths
artifacts_folder_path = f"../../artifacts"
model_folder_path = f"{artifacts_folder_path}/{dataset}/models/{reference_query_model_label}/{load_timestamp}"
figure_folder_path = f"{artifacts_folder_path}/{dataset}/figures/{load_timestamp}"
os.makedirs(figure_folder_path, exist_ok=True)

## 2. Model

### 2.1 Load Model

In [None]:
# Load trained model
model = NicheCompass.load(dir_path=model_folder_path,
                          adata=None,
                          adata_file_name=f"{dataset}_{reference_query_model_label}.h5ad",
                          gp_names_key=gp_names_key)

## 3. Analysis

In [None]:
# Check number of active gene programs
active_gps = model.get_active_gps()
print(f"Number of total gene programs: {len(model.adata.uns[gp_names_key])}.")
print(f"Number of active gene programs: {len(active_gps)}.")

In [None]:
gp_summary_df = model.get_gp_summary()
gp_summary_df[gp_summary_df["gp_active"] == True][0:20]

### 3.1 Analyze Latent Manifold

In [None]:
samples = model.adata.obs[sample_key].unique().tolist()

In [None]:
# Plot mapping entities in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "mapping_entities_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Mapping Entities",
    cat_key=mapping_entity_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=None,
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
# Plot batches in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "batches_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Batches",
    cat_key=condition_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=None,
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
cell_type_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=cell_type_key)

In [None]:
# Plot cell types in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "cell_types_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Cell Types",
    cat_key=cell_type_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=cell_type_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

In [None]:
# Compute latent Leiden clustering
sc.tl.leiden(adata=model.adata,
             resolution=latent_leiden_resolution,
             key_added=latent_cluster_key,
             neighbors_key=latent_key)

In [None]:
latent_cluster_colors = create_new_color_dict(
    adata=model.adata,
    cat_key=latent_cluster_key)

In [None]:
# Latent clusters in latent and physical space
save_fig = True
file_path = f"{figure_folder_path}/res_{latent_leiden_resolution}_" \
            "latent_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Latent Clusters",
    cat_key=latent_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors=latent_cluster_colors,
    size=None,
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

### 3.2 Analyze Query Enrichments (vs Reference)

#### 3.2.1 Visualize Query-enriched Cell Types in Physical and Latent Space

In [None]:
adata_reference = model.adata[model.adata.obs[mapping_entity_key] == "reference"]
adata_query = model.adata[model.adata.obs[mapping_entity_key] == "query"]

In [None]:
enriched_query_cell_type_prop_thresh = 5.
query_enriched_cell_type_key = "query_enriched_cell_types"

# Get query-enriched cell types
cell_type_reference_proportions = adata_reference.obs[cell_type_key].value_counts().sort_index() / len(adata_reference)
cell_type_query_proportions = adata_query.obs[cell_type_key].value_counts().sort_index() / len(adata_query)
relative_cell_type_query_proportions = cell_type_query_proportions / cell_type_reference_proportions
relative_cell_type_query_proportions.sort_values(ascending=False, inplace=True)
display(relative_cell_type_query_proportions)

query_enriched_cell_types = relative_cell_type_query_proportions[relative_cell_type_query_proportions > enriched_query_cell_type_prop_thresh].index.to_list()
model.adata.obs[query_enriched_cell_type_key] = "Cell types not enriched in query"
for cell_type in query_enriched_cell_types:
    model.adata.obs.loc[model.adata.obs[cell_type_key] == cell_type, query_enriched_cell_type_key] = cell_type

In [None]:
# Plot of query-enriched cell-type annotations in physical and latent space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "query_enriched_cell_types_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Query Enriched Cell Types",
    cat_key=query_enriched_cell_type_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors="coolwarm",
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

#### 3.1.2 Visualize Query-enriched Niches / Latent Clusters in Physical and Latent Space

In [None]:
query_enriched_latent_cluster_key = "query_enriched_latent_cluster"
enriched_query_latent_cluster_prop_thresh = 5.

In [None]:
# Get query-enriched latent leiden clusters
latent_cluster_reference_proportions = adata_reference.obs[f"latent_leiden_{str(latent_leiden_resolution)}"].value_counts().sort_index() / len(adata_reference)
latent_cluster_query_proportions = adata_query.obs[f"latent_leiden_{str(latent_leiden_resolution)}"].value_counts().sort_index() / len(adata_query)
relative_latent_cluster_query_proportions = latent_cluster_query_proportions / latent_cluster_reference_proportions
relative_latent_cluster_query_proportions.sort_values(ascending=False, inplace=True)
display(relative_latent_cluster_query_proportions)

query_enriched_latent_clusters = relative_latent_cluster_query_proportions[relative_latent_cluster_query_proportions > enriched_query_latent_cluster_prop_thresh].index.to_list()
model.adata.obs[query_enriched_latent_cluster_key] = "Latent clusters not enriched in query"
for latent_cluster in query_enriched_latent_clusters:
    model.adata.obs.loc[model.adata.obs[f"latent_leiden_{str(latent_leiden_resolution)}"] == latent_cluster, query_enriched_latent_cluster_key] = f"Cluster {latent_cluster}"

In [None]:
# Plot of query-enriched cell-type annotations in physical and latent space
save_fig = True
file_path = f"{figure_folder_path}/" \
            "query_enriched_latent_clusters_latent_physical_space.svg"

plot_category_in_latent_and_physical_space(
    adata=model.adata,
    plot_label="Query Enriched Latent Clusters",
    cat_key=query_enriched_latent_cluster_key,
    groups=None,
    sample_key=sample_key,
    samples=samples,
    cat_colors="coolwarm",
    size=360000/len(model.adata),
    spot_size=spot_size,
    save_fig=save_fig,
    file_path=file_path)

## 4. Benchmarking

### 3.1 Batch Integration Baselines

In [None]:
### TO DO ###

#### 3.1.1 scVI

In [None]:
scvi.settings.seed = random_seed
    
# Setup adata
scvi.model.SCVI.setup_anndata(adata_one_shot,
                              layer=counts_key,
                              batch_key=condition_key)

# Initialize model
# Use hyperparams that provenly work well on integration tasks (https://docs.scvi-tools.org/en/stable/tutorials/notebooks/harmonization.html)
vae = scvi.model.SCVI(adata_one_shot,
                      n_layers=2,
                      n_latent=30,
                      gene_likelihood="nb")

# Train model
vae.train()

adata_one_shot.obsm["scvi_latent"] = vae.get_latent_representation()

In [None]:
# Use scVI latent space for UMAP generation
sc.pp.neighbors(adata_one_shot, use_rep="scvi_latent")
sc.tl.umap(adata_one_shot)

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(adata_one_shot,
                 color=[cell_type_key],
                 legend_fontsize=12,
                 return_fig=True)
plt.title("scVI Integration: Latent Space Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_scvi.png",
            bbox_inches="tight")

#### 3.1.2 SageNet

In [None]:
#################################################################
This cell throws an error as the solver cannot solve this problem
#################################################################


# Construct gene interaction network for spatial references
for i in range(len(adata_batch_list[:-2])):
    adata_batch_list[i].X = adata_batch_list[i].X.toarray() # convert to dense matrix as required by glasso
    print("Computing gene interaction network...")
    glasso(adata_batch_list[i], [0.25, 0.5])
    adata_batch_list[i].X = sp.csc_matrix(adata_batch_list[i].X) # convert back to sparse matrix
    print("Computing Leiden clusters...")
    sc.tl.leiden(adata_batch_list[i],
                 resolution=.05,
                 random_state=random_seed,
                 key_added="leiden_0.05",
                 adjacency=adata_batch_list[i].obsp["spatial_connectivities"])
    sc.tl.leiden(adata_batch_list[i],
                 resolution=.1,
                 random_state=random_seed,
                 key_added="leiden_0.1",
                 adjacency=adata_batch_list[i].obsp["spatial_connectivities"])
    sc.tl.leiden(adata_batch_list[i],
                 resolution=.5,
                 random_state=random_seed,
                 key_added="leiden_0.5",
                 adjacency=adata_batch_list[i].obsp["spatial_connectivities"])

In [None]:
if torch.cuda.is_available():
  dev = "cuda:0"
else:
  dev = "cpu"
device = torch.device(dev)
print(device)

In [None]:
# Define model object
sg_obj = sca.models.sagenet(device=device)

In [None]:
# Train model on spatial references
for i in range(len(adata_batch_list[:-2])):
    sg_obj.train(adata_batch_list[i],
                 comm_columns=['leiden_0.05', 'leiden_0.1', 'leiden_0.5'],
                 tag=f'batch{i}',
                 epochs=15,
                 verbose = False,
                 importance=True)

In [None]:
# Save model
os.makedirs(model_artifacts_folder_path + "/sagenet")
sg_obj.save_as_folder(model_artifacts_folder_path + "/sagenet")

In [None]:
# Load model
sg_obj_load = sca.models.sagenet(device=device)
sg_obj_load.load_from_folder(model_artifacts_folder_path + "/sagenet")

In [None]:
# Load query
sg_obj_load.load_query_data(adata)

In [None]:
# Use SageNet cell-cell-distances for UMAP generation
sc.pp.neighbors(adata, use_rep="dist_map")
sc.tl.umap(adata)

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(adata,
                 color=[condition_key],
                 legend_fontsize=12,
                 return_fig=True)
plt.title("SageNet Integration: Latent Space Batch Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_batches_sagenet.png",
            bbox_inches="tight")

In [None]:
# Plot UMAP with cell type annotations
fig = sc.pl.umap(adata,
                 color=[cell_type_key],
                 return_fig=True)
plt.title("SageNet: Latent Space Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_sagenet.png",
            bbox_inches="tight")

#### 3.1.3 BBKNN

In [None]:
latent_bbknn_list = []
for i in range(len(adata_batch_list)):
    # Initialize model
    model = NicheCompass(adata_batch_list[i],
                       counts_key=counts_key,
                       adj_key=adj_key,
                       condition_key=condition_key,
                       cond_embed_injection=["encoder",
                                             "gene_expr_decoder",
                                             "graph_decoder"],
                       n_cond_embed=n_cond_embed,
                       gp_names_key=gp_names_key,
                       active_gp_names_key=active_gp_names_key,
                       gp_targets_mask_key=gp_targets_mask_key,
                       gp_sources_mask_key=gp_sources_mask_key,
                       latent_key=latent_key,
                       active_gp_thresh_ratio=0., # all gps will be active for concatenation across batches
                       gene_expr_recon_dist=gene_expr_recon_dist,
                       n_hidden_encoder=n_hidden_encoder,
                       log_variational=True)
    print("")
    
    # Train model
    model.train(n_epochs=n_epochs,
                n_epochs_all_gps=n_epochs, # all gps will be active for concatenation across batches
                lr=lr,
                lambda_edge_recon=lambda_edge_recon,
                lambda_gene_expr_recon=lambda_gene_expr_recon,
                verbose=True)
    print("")
    
    # Save trained model
    model.save(dir_path=model_artifacts_folder_path + f"/bbknn_batch{i+1}",
               overwrite=True,
               save_adata=True,
               adata_file_name=f"{dataset}.h5ad")
    
    latent_bbknn_current_batch = model.get_latent_representation(
        adata=adata_batch_list[i],
        counts_key=counts_key,
        condition_key=condition_key,
        only_active_gps=False)
    
    latent_bbknn_list.append(latent_bbknn_current_batch)
    
adata_bbknn.obsm[latent_key] = np.vstack(latent_bbknn_list)

# Store adata to disk
adata_bbknn.write(f"{model_artifacts_folder_path}/adata_bbknn.h5ad")

In [None]:
if load_timestamp is not None:
    model_artifacts_load_folder_path = f"../artifacts/{dataset}/batch_integration/{load_timestamp}"
else:
    model_artifacts_load_folder_path = model_artifacts_folder_path

# Read adata from disk
adata_bbknn = sc.read_h5ad(f"{model_artifacts_load_folder_path}/adata_bbknn.h5ad")

In [None]:
# Compute batch-corrected latent nearest neighbor graph
bbknn.bbknn(adata=adata_bbknn,
            batch_key=condition_key,
            use_rep=latent_key)

adata_bbknn.obsp[f"{latent_knng_key}_connectivities"] = (
    adata_bbknn.obsp["connectivities"])

adata_bbknn.obsp[f"{latent_knng_key}_distances"] = (
    adata_bbknn.obsp["distances"])

In [None]:
# Use batch-corrected latent space for UMAP generation
sc.tl.umap(adata_bbknn)

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(adata_bbknn,
                 color=[condition_key],
                 legend_fontsize=12,
                 return_fig=True)
plt.title("BBKNN Integration: Latent Space Batch Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_batches_bbknn.png",
            bbox_inches="tight")

In [None]:
# Plot UMAP with cell type annotations
fig = sc.pl.umap(adata_bbknn,
                 color=[cell_type_key],
                 return_fig=True)
plt.title("BBKNN Integration: Latent Space Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_bbknn.png",
            bbox_inches="tight")

In [None]:
# Compute spatial nearest neighbor graph
sc.pp.neighbors(adata_bbknn, use_rep=spatial_key, key_added=spatial_knng_key)

In [None]:
metrics_dict_bbknn = {}

metrics_dict_bbknn["cad"] = compute_cad(
    adata=adata_bbknn,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)

metrics_dict_bbknn["rclisi"] = compute_rclisi(
    adata=adata_bbknn,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)
    
metrics_dict_bbknn["batch_asw"] = scib.me.silhouette_batch(
    adata=adata_bbknn,
    batch_key=condition_key,
    label_key=cell_type_key,
    embed="X_umap")

# knn output
metrics_dict_bbknn["ilisi"] = scib.me.ilisi_graph(
    adata=adata_bbknn,
    batch_key=condition_key,
    type_="knn")

"""
metrics_dict_bbknn["kbet"] = scib.me.kBET(
    adata=adata_bbknn,
    batch_key=condition_key,
    label_key=cell_type_key,
    type_="knn")
"""

print(metrics_dict_bbknn)

# Store to disk
with open(f"{model_artifacts_folder_path}/metrics_bbknn.pickle", "wb") as f:
    pickle.dump(metrics_dict_bbknn, f)

#### 3.2.4 Compute Metrics

In [None]:
# Store computed latent nearest neighbor graph in connectivities
# as required by scib metrics
model.adata.obsp["connectivities"] = (
    model.adata.obsp[f"{latent_knng_key}_connectivities"])
model.adata.obsp["distances"] = (
    model.adata.obsp[f"{latent_knng_key}_distances"])
model.adata.uns["neighbors"] = (
    model.adata.uns[f"{latent_knng_key}"])

# Compute spatial nearest neighbor graph
sc.pp.neighbors(model.adata,
                use_rep=spatial_key,
                key_added=spatial_knng_key)

In [None]:
# Compute metrics
metrics_dict_oneshot = {}

# Spatial conservation metrics
metrics_dict_oneshot["cas"] = compute_cas(
    adata=model.adata,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)
metrics_dict_oneshot["clisis"] = compute_clisis(
    adata=model.adata,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)
    
# Batch correction metrics
metrics_dict_oneshot["batch_asw"] = scib.me.silhouette_batch(
    adata=model.adata,
    batch_key=condition_key,
    label_key=cell_type_key,
    embed="X_umap")
metrics_dict_oneshot["ilisi"] = scib.me.ilisi_graph(
    adata=model.adata,
    batch_key=condition_key,
    type_="knn")

print(metrics_dict_oneshot)

# Store metrics to disk
with open(f"{model_artifacts_folder_path}/metrics_oneshot.pickle", "wb") as f:
    pickle.dump(metrics_dict_oneshot, f)

#### 3.2.5 Visualize Conditional Embedding

In [None]:
# Get conditional embeddings
cond_embed = model.get_cond_embeddings()
cond = model.adata.obs["batch"].unique()

# Get top 2 principal components and plot them
pca = KernelPCA(n_components=2, kernel="linear")
cond_embed_pca = pca.fit_transform(cond_embed)
sns.scatterplot(x=cond_embed_pca[:, 0], 
                y=cond_embed_pca[:, 1], 
                hue=cond)
plt.title("One-Shot Integration Conditional Embeddings", pad=15)
plt.xlabel("Principal Component 1")
plt.xticks(fontsize=12)
plt.ylabel ("Principal Component 2")
plt.yticks(fontsize=12)
plt.legend(bbox_to_anchor=(1.02, 0.75),
           loc=2,
           borderaxespad=0.,
           fontsize=12,
           frameon=False)
plt.savefig(f"{figure_folder_path}/cond_embed_oneshot.png",
            bbox_inches="tight")