# Autotalker Batch Integration Mouse Organogenesis Imputed

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 20.01.2023
- **Date of Last Modification:** 16.02.2023

## 1. Setup

### 1.1 Import Libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%reload_ext autoreload

In [3]:
import sys
sys.path.append("../../autotalker")

In [4]:
import argparse
import os
import pickle
import random
import warnings
from copy import deepcopy
from datetime import datetime

import anndata as ad
import matplotlib
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import scanpy as sc
import scib
import scipy.sparse as sp
import seaborn as sns
import squidpy as sq
import torch
from matplotlib.pyplot import rc_context
from sklearn.decomposition import KernelPCA

from autotalker.benchmarking import compute_rclisi, compute_cad
from autotalker.models import Autotalker
from autotalker.utils import (add_gps_from_gp_dict_to_adata,
                              extract_gp_dict_from_mebocost_es_interactions,
                              extract_gp_dict_from_nichenet_ligand_target_mx,
                              extract_gp_dict_from_omnipath_lr_interactions,
                              filter_and_combine_gp_dict_gps,
                              get_unique_genes_from_gp_dict)

  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  return UNKNOWN_SERVER_VERSION
  warn(


### 1.2 Define Parameters

In [5]:
## Dataset
dataset = "seqfish_mouse_organogenesis"
batch1 = "embryo1_z2"
batch2 = "embryo1_z5"
batch3 = "embryo2_z2"
batch4 = "embryo2_z5"
batch5 = "embryo3_z2"
batch6 = "embryo3_z5"
n_neighbors = 4 # 12, 20
n_hvg_genes = 1 # 4000

## Model
# AnnData Keys
counts_key = "log_normalized_counts" # raw counts not available
cell_type_key = "celltype_mapped_refined"
adj_key = "spatial_connectivities"
spatial_key = "spatial"
gp_names_key = "autotalker_gp_names"
active_gp_names_key = "autotalker_active_gp_names"
gp_targets_mask_key = "autotalker_gp_targets"
gp_sources_mask_key = "autotalker_gp_sources"
latent_key = "autotalker_latent"
condition_key = "batch"
mapping_entity_key = "mapping_entity"

# Architecture
active_gp_thresh_ratio = 0.03 # 0.01
gene_expr_recon_dist = "nb" # zinb
n_cond_embed = 3
log_variational = False # log normalized counts as input

# Trainer
n_epochs = 1
n_epochs_all_gps = 20
lr = 0.001
query_cond_embed_lr = 0.01
lambda_edge_recon = 0.01
lambda_gene_expr_recon = 0.0033
edge_batch_size = 512 # 128
node_batch_size = 32 # 16

# Benchmarking
spatial_knng_key = "autotalker_spatial_knng"
latent_knng_key = "autotalker_latent_knng"

## Others
random_seed = 42
load_timestamp = None

### 1.3 Run Notebook Setup

In [6]:
sc.set_figure_params(figsize=(6, 6))
sns.set_style("whitegrid", {'axes.grid' : False})

  IPython.display.set_matplotlib_formats(*ipython_format)


In [7]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

In [8]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Create Directories

In [9]:
# Define paths
figure_folder_path = f"../figures/{dataset}/batch_integration/{current_timestamp}"
model_artifacts_folder_path = f"../artifacts/{dataset}/batch_integration/{current_timestamp}"
gp_data_folder_path = "../datasets/gp_data" # gene program data
srt_data_folder_path = "../datasets/srt_data" # spatially resolved transcriptomics data
srt_data_gold_folder_path = f"{srt_data_folder_path}/gold"
nichenet_ligand_target_mx_file_path = gp_data_folder_path + "/nichenet_ligand_target_matrix.csv"
omnipath_lr_interactions_file_path = gp_data_folder_path + "/omnipath_lr_interactions.csv"

# Create required directories
os.makedirs(figure_folder_path, exist_ok=True)
os.makedirs(model_artifacts_folder_path, exist_ok=True)

## 2. Data

### 2.1 Load Data

In [None]:
adata_original = ad.read_h5ad(f"{srt_data_gold_folder_path}/{dataset}.h5ad")

In [None]:
# Use log normalized counts as raw counts are not available
adata_original.layers[counts_key] = adata_original.X

In [None]:
adata_original.obs

### 2.2 Prepare Data & GP Mask

#### 2.2.1 Filter Genes Based on GP Mask & HVG

In [None]:
nichenet_gp_dict = extract_gp_dict_from_nichenet_ligand_target_mx(
    keep_target_ratio=0.01,
    load_from_disk=False,
    save_to_disk=False,
    file_path=nichenet_ligand_target_mx_file_path)

In [None]:
omnipath_gp_dict = extract_gp_dict_from_omnipath_lr_interactions(
    min_curation_effort=0,
    load_from_disk=False,
    save_to_disk=False,
    file_path=omnipath_lr_interactions_file_path)

In [None]:
mebocost_gp_dict = extract_gp_dict_from_mebocost_es_interactions(
    dir_path=f"{gp_data_folder_path}/metabolite_enzyme_sensor_gps/",
    species="mouse",
    genes_uppercase=True)

In [None]:
# Combine gene programs into one dictionary
combined_gp_dict = dict(nichenet_gp_dict)
combined_gp_dict.update(omnipath_gp_dict)
combined_gp_dict.update(mebocost_gp_dict)

In [None]:
# Filter and combine gene programs
combined_new_gp_dict = filter_and_combine_gp_dict_gps(
    gp_dict=combined_gp_dict,
    gp_filter_mode="subset",
    combine_overlap_gps=True,
    overlap_thresh_source_genes=0.9,
    overlap_thresh_target_genes=0.9,
    overlap_thresh_genes=0.9,
    verbose=True)

print(f"Number of gene programs before filtering and combining: {len(combined_gp_dict)}.")
print(f"Number of gene programs after filtering and combining: {len(combined_new_gp_dict)}.")

In [None]:
adata = adata_original.copy()

In [None]:
# Split adata into different batches
adata_batch1 = adata[adata.obs["batch"] == batch1].copy() # reference
adata_batch2 = adata[adata.obs["batch"] == batch2].copy() # reference
adata_batch3 = adata[adata.obs["batch"] == batch3].copy() # reference
adata_batch4 = adata[adata.obs["batch"] == batch4].copy() # reference
adata_batch5 = adata[adata.obs["batch"] == batch5].copy() # query
adata_batch6 = adata[adata.obs["batch"] == batch6].copy() # query

adata_batch_list = [adata_batch1,
                    adata_batch2,
                    adata_batch3,
                    adata_batch4,
                    adata_batch5,
                    adata_batch6]

### 2.3 Compute Spatial Neighbor Graphs

In [None]:
for i in range(len(adata_batch_list)):
    # Compute (separate) spatial neighborhood
    sq.gr.spatial_neighbors(adata_batch_list[i],
                            coord_type="generic",
                            spatial_key=spatial_key,
                            n_neighs=n_neighbors)
    # Make adjacency matrix symmetric
    adata_batch_list[i].obsp[adj_key] = adata_batch_list[i].obsp[adj_key].maximum(
        adata_batch_list[i].obsp[adj_key].T)

### 2.4 Combine Data for One-Shot Batch Integration

In [None]:
adata_one_shot = ad.concat(adata_batch_list, join="inner")

# Combine spatial neighborhood graphs as disconnected components
connectivities_extension_batch1 = sp.csr_matrix((adata_batch_list[0].shape[0],
                                                 (adata_batch_list[1].shape[0] +
                                                  adata_batch_list[2].shape[0] +
                                                  adata_batch_list[3].shape[0] +
                                                  adata_batch_list[4].shape[0] +
                                                  adata_batch_list[5].shape[0])))
connectivities_extension_batch2_before = sp.csr_matrix((adata_batch_list[1].shape[0],
                                                        adata_batch_list[0].shape[0]))
connectivities_extension_batch2_after = sp.csr_matrix((adata_batch_list[1].shape[0],
                                                       (adata_batch_list[2].shape[0] +
                                                        adata_batch_list[3].shape[0] +
                                                        adata_batch_list[4].shape[0] +
                                                        adata_batch_list[5].shape[0])))
connectivities_extension_batch3_before = sp.csr_matrix((adata_batch_list[2].shape[0],
                                                        (adata_batch_list[0].shape[0] +
                                                         adata_batch_list[1].shape[0])))
connectivities_extension_batch3_after = sp.csr_matrix((adata_batch_list[2].shape[0],
                                                       (adata_batch_list[3].shape[0] +
                                                        adata_batch_list[4].shape[0] +
                                                        adata_batch_list[5].shape[0])))
connectivities_extension_batch4_before = sp.csr_matrix((adata_batch_list[3].shape[0],
                                                        (adata_batch_list[0].shape[0] +
                                                         adata_batch_list[1].shape[0] +
                                                         adata_batch_list[2].shape[0])))
connectivities_extension_batch4_after = sp.csr_matrix((adata_batch_list[3].shape[0],
                                                       (adata_batch_list[4].shape[0] +
                                                        adata_batch_list[5].shape[0])))
connectivities_extension_batch5_before = sp.csr_matrix((adata_batch_list[4].shape[0],
                                                        (adata_batch_list[0].shape[0] +
                                                         adata_batch_list[1].shape[0] +
                                                         adata_batch_list[2].shape[0] +
                                                         adata_batch_list[3].shape[0])))
connectivities_extension_batch5_after = sp.csr_matrix((adata_batch_list[4].shape[0],
                                                       adata_batch_list[5].shape[0]))
connectivities_extension_batch6 = sp.csr_matrix((adata_batch_list[5].shape[0],
                                                 (adata_batch_list[0].shape[0] +
                                                  adata_batch_list[1].shape[0] +
                                                  adata_batch_list[2].shape[0] +
                                                  adata_batch_list[3].shape[0] +
                                                  adata_batch_list[4].shape[0])))

connectivities_batch1 = sp.hstack((adata_batch_list[0].obsp[adj_key],
                                   connectivities_extension_batch1))
connectivities_batch2 = sp.hstack((connectivities_extension_batch2_before,
                                   adata_batch_list[1].obsp[adj_key],
                                   connectivities_extension_batch2_after))
connectivities_batch3 = sp.hstack((connectivities_extension_batch3_before,
                                   adata_batch_list[2].obsp[adj_key],
                                   connectivities_extension_batch3_after))
connectivities_batch4 = sp.hstack((connectivities_extension_batch4_before,
                                   adata_batch_list[3].obsp[adj_key],
                                   connectivities_extension_batch4_after))
connectivities_batch5 = sp.hstack((connectivities_extension_batch5_before,
                                   adata_batch_list[4].obsp[adj_key],
                                   connectivities_extension_batch5_after))
connectivities_batch6 = sp.hstack((connectivities_extension_batch6,
                                   adata_batch_list[5].obsp[adj_key]))

connectivities = sp.vstack((connectivities_batch1,
                            connectivities_batch2,
                            connectivities_batch3,
                            connectivities_batch4,
                            connectivities_batch5,
                            connectivities_batch6))

adata_one_shot.obsp[adj_key] = connectivities

### 2.5 Combine Data for Query-to-Reference Mapping

In [None]:
# Reference
adata_reference = ad.concat(adata_batch_list[:-2], join="inner")

# Combine spatial neighborhood graphs as disconnected components
connectivities_extension_batch1 = sp.csr_matrix((adata_batch_list[0].shape[0],
                                                 (adata_batch_list[1].shape[0] +
                                                  adata_batch_list[2].shape[0] +
                                                  adata_batch_list[3].shape[0])))
connectivities_extension_batch2_before = sp.csr_matrix((adata_batch_list[1].shape[0],
                                                        adata_batch_list[0].shape[0]))
connectivities_extension_batch2_after = sp.csr_matrix((adata_batch_list[1].shape[0],
                                                       (adata_batch_list[2].shape[0] +
                                                        adata_batch_list[3].shape[0])))
connectivities_extension_batch3_before = sp.csr_matrix((adata_batch_list[2].shape[0],
                                                        (adata_batch_list[0].shape[0] +
                                                         adata_batch_list[1].shape[0])))
connectivities_extension_batch3_after = sp.csr_matrix((adata_batch_list[2].shape[0],
                                                       (adata_batch_list[3].shape[0])))
connectivities_extension_batch4 = sp.csr_matrix((adata_batch_list[3].shape[0],
                                                 (adata_batch_list[0].shape[0] +
                                                  adata_batch_list[1].shape[0] +
                                                  adata_batch_list[2].shape[0])))

connectivities_batch1 = sp.hstack((adata_batch_list[0].obsp[adj_key],
                                   connectivities_extension_batch1))
connectivities_batch2 = sp.hstack((connectivities_extension_batch2_before,
                                   adata_batch_list[1].obsp[adj_key],
                                   connectivities_extension_batch2_after))
connectivities_batch3 = sp.hstack((connectivities_extension_batch3_before,
                                   adata_batch_list[2].obsp[adj_key],
                                   connectivities_extension_batch3_after))
connectivities_batch4 = sp.hstack((connectivities_extension_batch4,
                                   adata_batch_list[3].obsp[adj_key]))

connectivities = sp.vstack((connectivities_batch1,
                            connectivities_batch2,
                            connectivities_batch3,
                            connectivities_batch4))

adata_reference.obsp[adj_key] = connectivities

In [None]:
# Query
adata_query = ad.concat(adata_batch_list[-2:], join="inner")

# Combine spatial neighborhood graphs as disconnected components
connectivities_extension_batch5 = sp.csr_matrix((adata_batch_list[4].shape[0],
                                                 (adata_batch_list[5].shape[0])))
connectivities_extension_batch6 = sp.csr_matrix((adata_batch_list[5].shape[0],
                                                 (adata_batch_list[4].shape[0])))

connectivities_batch5 = sp.hstack((adata_batch_list[4].obsp[adj_key],
                                   connectivities_extension_batch5))
connectivities_batch6 = sp.hstack((connectivities_extension_batch6,
                                   adata_batch_list[5].obsp[adj_key]))

connectivities = sp.vstack((connectivities_batch5,
                            connectivities_batch6))

adata_query.obsp[adj_key] = connectivities

In [None]:
adata_reference_query = adata_one_shot.copy()

# Add mapping entity 
reference_batches = adata_reference.obs[condition_key].unique().tolist()
query_batches = adata_query.obs[condition_key].unique().tolist()
adata_reference_query.obs[mapping_entity_key] = None
adata_reference_query.obs.loc[adata_reference_query.obs[condition_key].isin(reference_batches), mapping_entity_key] = "reference"
adata_reference_query.obs.loc[adata_reference_query.obs[condition_key].isin(query_batches), mapping_entity_key] = "query"

### 2.6 Add GP Mask to Data

In [None]:
for adata in adata_batch_list + [adata_one_shot, adata_reference, adata_query, adata_reference_query]:
    # Add the gene program dictionary as binary masks to the adata for model training
    add_gps_from_gp_dict_to_adata(
        gp_dict=combined_new_gp_dict,
        adata=adata,
        genes_uppercase=True,
        gp_targets_mask_key=gp_targets_mask_key,
        gp_sources_mask_key=gp_sources_mask_key,
        gp_names_key=gp_names_key,
        min_genes_per_gp=1,
        min_source_genes_per_gp=0,
        min_target_genes_per_gp=0,
        max_genes_per_gp=None,
        max_source_genes_per_gp=None,
        max_target_genes_per_gp=None,
        filter_genes_not_in_masks=False)

# Determine dimensionality of hidden encoder
n_hidden_encoder = len(adata_one_shot.uns[gp_names_key])

# Summarize gene programs
print(f"Number of gene programs with probed genes: {len(adata_one_shot.uns[gp_names_key])}.")
print(f"Example gene programs: {random.sample(list(adata_one_shot.uns[gp_names_key]), 5)}.")
print(f"Number of gene program target genes: {adata_one_shot.varm[gp_targets_mask_key].sum()}.")
print(f"Number of gene program source genes: {adata_one_shot.varm[gp_sources_mask_key].sum()}.")

## 3. Model Training

### 3.2 One-Shot Batch Integration

#### 3.2.1 Initialize, Train & Save Model

In [None]:
model.adata.obsp["autotalker_recon_adj"] = model.get_recon_adj()

In [None]:
model.adata.uns["autotalker_recon_adj_best_acc_threshold"]

In [None]:
(model.adata.obsp["autotalker_recon_adj"] > model.adata.uns["autotalker_recon_adj_best_acc_threshold"]).sum(1)

In [None]:
model.adata.obsp["autotalker_recon_adj"][1, :500]

In [None]:
adata.layers["counts"] = adata.X

In [None]:
counts_key = "counts"

In [None]:
adata_test = adata_one_shot[17500:18000]

In [None]:
adata_test.obs["sample"].unique()

In [None]:
adata = adata[adata.obs["sample"] == "embryo1"]

In [None]:
torch.cuda.empty_cache()
import gc
gc.collect()

In [None]:
# Initialize model
model = Autotalker(adata,
                   counts_key=counts_key,
                   adj_key=adj_key,
                   condition_key=condition_key,
                   cond_embed_injection=["encoder",
                                         "gene_expr_decoder",
                                         "graph_decoder"],
                   n_cond_embed=n_cond_embed,
                   gp_names_key=gp_names_key,
                   active_gp_names_key=active_gp_names_key,
                   gp_targets_mask_key=gp_targets_mask_key,
                   gp_sources_mask_key=gp_sources_mask_key,
                   latent_key=latent_key,
                   active_gp_thresh_ratio=active_gp_thresh_ratio,
                   gene_expr_recon_dist=gene_expr_recon_dist,
                   n_layers_encoder=1,
                   n_hidden_encoder=n_hidden_encoder,
                   log_variational=log_variational)

In [None]:
# Train model
model.train(10,
            n_epochs_all_gps=n_epochs_all_gps,
            lr=lr,
            lambda_edge_recon=lambda_edge_recon,
            lambda_gene_expr_recon=lambda_gene_expr_recon,
            edge_batch_size=1024,
            node_batch_size=64,
            verbose=True)

In [None]:
# Save trained model
model.save(dir_path=model_artifacts_folder_path + "/test",
           overwrite=True,
           save_adata=True,
           adata_file_name=f"{dataset}.h5ad")

In [10]:
model_artifacts_folder_path = '../artifacts/seqfish_mouse_organogenesis/batch_integration/20022023_121527'
# Load trained model

model = Autotalker.load(dir_path=model_artifacts_folder_path + "/test",
                        adata=None,
                        adata_file_name=f"{dataset}.h5ad",
                        gp_names_key="autotalker_gp_names")

--- INITIALIZING NEW NETWORK MODULE: VARIATIONAL GENE PROGRAM GRAPH AUTOENCODER ---
LOSS -> include_edge_recon_loss: True, include_gene_expr_recon_loss: True, gene_expr_recon_dist: nb
NODE LABEL METHOD -> one-hop-attention
ACTIVE GP THRESHOLD RATIO -> 0.03
LOG VARIATIONAL -> False
CONDITIONAL EMBEDDING INJECTION -> ['encoder', 'gene_expr_decoder', 'graph_decoder']
GRAPH ENCODER -> n_input: 351, n_cond_embed_input: 3, n_layers: 1, n_hidden: 489, n_latent: 489, n_addon_latent: 0, conv_layer: gcnconv, n_attention_heads: 0, dropout_rate: 0.0
COSINE SIM GRAPH DECODER -> n_cond_embed_input: 3, n_output: 489, dropout_rate: 0.0
MASKED GENE EXPRESSION DECODER -> n_input: 489, n_cond_embed_input: 3, n_addon_input: 0, n_output: 702


In [11]:
n_neighs_adj = (model.adata.obsp[adj_key][1: 4].sum(axis=1))

In [12]:
np.array(n_neighs_adj.astype(int))

array([[5],
       [4],
       [5]])

In [15]:
test = torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]])

In [19]:
test

tensor([[1, 2, 3, 4, 5, 6],
        [1, 2, 3, 4, 5, 6],
        [1, 2, 3, 4, 5, 6]])

In [20]:
np.arange(3)

array([0, 1, 2])

In [21]:
np.array(n_neighs_adj.astype(int))

array([[5],
       [4],
       [5]])

In [23]:
test[np.arange(3), np.array(n_neighs_adj.astype(int)).flatten()]

tensor([6, 5, 6])

In [39]:
test = model.get_recon_adj(edge_thresh=None) #device="cpu")

ok
torch.Size([2048, 52568])
[5 5 4 ... 5 4 4]
(2048,)
9467.0
9467.0


In [None]:
model.adata.obsp["spatial_connectivities"][i: i+node_batch_size].sum(axis=1).item()

In [None]:
test.sum()

In [None]:
model.adata.uns["autotalker_recon_adj_best_acc_threshold"]

In [None]:
model.adata.uns["autotalker_recon_adj_best_acc_threshold"]

In [None]:
from autotalker.utils import get_recon_adj_from_edge_probs

In [None]:
sim_matrix(test, test).shape

In [None]:
model.adata.obsp["autotalker_agg_alpha"] = model.get_gene_expr_node_label_agg_att_weights()

In [None]:
model.adata.obsp["autotalker_agg_alpha"].nonzero()

In [None]:
from autotalker.utils import aggregate_node_label_agg_att_weights_per_cell_type

In [None]:
model.adata.obs[cell_type_key]

In [None]:
model.adata.obsp["autotalker_agg_alpha"]

In [None]:
from chord import Chord

In [None]:
import pandas as pd
import holoviews as hv
from holoviews import opts, dim
from bokeh.sampledata.les_mis import data

hv.extension('bokeh')
hv.output(size=200)

In [None]:
df = aggregate_node_label_agg_att_weights_per_cell_type(model.adata,
                                                   cell_type_key=cell_type_key).groupby(cell_type_key).sum()

In [None]:
df

In [None]:
from autotalker.utils import aggregate_node_label_agg_att_weights_per_cell_type, create_cell_type_chord_plot_from_df

In [None]:
df = aggregate_node_label_agg_att_weights_per_cell_type(model.adata,
                                                        cell_type_key=cell_type_key).groupby(cell_type_key).sum()

In [None]:
df

In [None]:
create_cell_type_chord_plot_from_df(adata=model.adata,
                                    df=df,
                                    link_threshold=0.2,
                                    cell_type_key=cell_type_key)

In [None]:
sorted_cell_types = sorted(model.adata.obs[cell_type_key].unique().tolist())

links_list = []
for i in range(len(sorted_cell_types)):
    for j in range(len(sorted_cell_types)):
        if df.iloc[i, j] != 0:
            link_dict = {}
            link_dict["source"] = i
            link_dict["target"] = j
            link_dict["value"] = df.iloc[i, j]
            links_list.append(link_dict)
links = pd.DataFrame(links_list)

nodes_list = []
for cell_type in sorted_cell_types:
    nodes_dict = {}
    nodes_dict["name"] = cell_type
    nodes_dict["group"] = 1
    nodes_list.append(nodes_dict)
nodes = hv.Dataset(pd.DataFrame(nodes_list), "index")

chord = hv.Chord((links, nodes)).select(value=(5, None))
chord.opts(
    opts.Chord(cmap='Category20', edge_cmap='Category20', edge_color=dim('source').str(), 
               labels='name', node_color=dim('index').str()))

In [None]:
model.adata.obs['celltype_mapped_refined'] = model.adata.obs['celltype_mapped_refined'].astype("category")

In [None]:
import squidpy as sq
sq.gr.nhood_enrichment(model.adata, cluster_key=cell_type_key, connectivity_key="autotalker_agg_alpha")

In [None]:
model.adata.obsp['autotalker_agg_alpha_connectivities'] = model.adata.obsp['autotalker_agg_alpha']

In [None]:
n_obs = len(adata)
n_cell_types = model.adata.obs[cell_type_key].nunique()

cell_type_label_encoder = {k: v for k, v in zip(
    sorted(model.adata.obs[cell_type_key].unique().tolist()),
    range(n_cell_types))}

nz_alpha_idx = model.adata.obsp["autotalker_agg_alpha"].nonzero()
neighbor_cell_type_index = model.adata.obs[cell_type_key][nz_alpha_idx[1]].map(cell_type_label_encoder).values
nz_alpha = model.adata.obsp["autotalker_agg_alpha"].data

cell_type_agg_alpha = np.zeros((n_obs, n_cell_types))
np.add.at(cell_type_agg_alpha, (nz_alpha_idx[0], neighbor_cell_type_index), nz_alpha)
cell_type_agg_alpha

In [None]:
cell_type_agg_alpha

In [None]:
neighbor_cell_type_index.shape

In [None]:
nz_alpha

In [None]:
import numpy as np

# Example input arrays
rows = np.array([0, 1, 1, 2])
cols = np.array([0, 1, 1, 2])
values = np.array([1, 2, 3, 4])

# Get unique column indices and their corresponding indices in the original array
unique_cols, col_indices = np.unique(cols, return_inverse=True)

# Initialize result array with zeros
result = np.zeros((np.max(rows)+1, np.max(cols)+1), dtype=values.dtype)

# Accumulate values for each unique column index
np.add.at(result, (rows, cols), values)

print(result)  # Output: [[1 0 0]
               #          [0 5 0]
               #          [0 0 4]]

In [None]:
values[col_indices]

In [None]:
result

In [None]:
neighbor_cell_type_index

In [None]:
cell_type_agg_alpha

In [None]:
model.adata.obsp["autotalker_agg_alpha"][6,7]

In [None]:
model.adata.obsp["autotalker_agg_alpha"][6,20]

In [None]:
model.adata.obsp["autotalker_agg_alpha"][6,21]

In [None]:
model.adata.obs[cell_type_key][model.adata.obsp["autotalker_agg_alpha"].nonzero()[1]]

#### 3.2.2 Load & Save Model with UMAP Features

In [None]:
load_timestamp = "16022023_173451"

In [None]:
model_artifacts_load_folder_path

In [None]:
if load_timestamp is not None:
    model_artifacts_load_folder_path = f"../artifacts/{dataset}/batch_integration/{load_timestamp}"
else:
    model_artifacts_load_folder_path = model_artifacts_folder_path

# Load trained model
model = Autotalker.load(dir_path=model_artifacts_load_folder_path + "/oneshot",
                        adata=None,
                        adata_file_name=f"{dataset}.h5ad",
                        gp_names_key="autotalker_gp_names")

In [None]:
model.adata.obsp["autotalker_agg_alpha"] = model.get_gene_expr_node_label_agg_att_weights()

In [None]:
# Store active gene programs
model.adata.uns[active_gp_names_key] = model.get_active_gps()

In [None]:
# Compute latent nearest neighbor graph
sc.pp.neighbors(model.adata,
                use_rep=latent_key,
                key_added=latent_knng_key)

# Use Autotalker latent space for UMAP generation
sc.tl.umap(model.adata,
           neighbors_key=latent_knng_key)

# Compute spatial nearest neighbor graph
sc.pp.neighbors(model.adata,
                use_rep=spatial_key,
                key_added=spatial_knng_key)

In [None]:
# Save trained model
model.save(dir_path=model_artifacts_folder_path + "/oneshot",
           overwrite=True,
           save_adata=True,
           adata_file_name=f"{dataset}.h5ad")

In [None]:
# Save adata
model.adata.write(f"{model_artifacts_folder_path}/{dataset}_oneshot.h5ad")

#### 3.2.3 Visualize Latent Space

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(model.adata,
                 color=[condition_key],
                 legend_fontsize=12,
                 return_fig=True)
plt.title("One-Shot Integration: Latent Space Batch Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_batches_oneshot.png",
            bbox_inches="tight")

In [None]:
# Plot UMAP with cell type annotations
fig = sc.pl.umap(model.adata,
                 color=[cell_type_key],
                 return_fig=True)
plt.title("One-Shot Integration: Latent Space Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_oneshot.png",
            bbox_inches="tight")

#### 3.2.4 Compute Metrics

In [None]:
# Store computed latent nearest neighbor graph in latent connectivities
model.adata.obsp[f"{latent_knng_key}_connectivities"] = (
    model.adata.obsp["connectivities"])
model.adata.obsp[f"{latent_knng_key}_distances"] = (
    model.adata.obsp["distances"])

# Compute spatial nearest neighbor graph
sc.pp.neighbors(model.adata, use_rep=spatial_key, key_added=spatial_knng_key)

In [None]:
# Compute metrics
metrics_dict_oneshot = {}

# Spatial conservation metrics
metrics_dict_oneshot["cad"] = compute_cad(
    adata=model.adata,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)
metrics_dict_oneshot["rclisi"] = compute_rclisi(
    adata=model.adata,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)
    
# Batch correction metrics
metrics_dict_oneshot["batch_asw"] = scib.me.silhouette_batch(
    adata=model.adata,
    batch_key=condition_key,
    label_key=cell_type_key,
    embed="X_umap")
metrics_dict_oneshot["ilisi"] = scib.me.ilisi_graph(
    adata=model.adata,
    batch_key=condition_key,
    type_="knn")

print(metrics_dict_oneshot)

# Store metrics to disk
with open(f"{model_artifacts_folder_path}/metrics_oneshot.pickle", "wb") as f:
    pickle.dump(metrics_dict_oneshot, f)

#### 3.2.5 Visualize Conditional Embedding

In [None]:
# Get conditional embeddings
cond_embed = model.get_conditional_embeddings()
cond = model.adata.obs["batch"].unique()

# Get top 2 principal components and plot them
pca = KernelPCA(n_components=2, kernel="linear")
cond_embed_pca = pca.fit_transform(cond_embed)
sns.scatterplot(x=cond_embed_pca[:, 0], 
                y=cond_embed_pca[:, 1], 
                hue=cond)
plt.title("One-Shot Integration Conditional Embeddings", pad=15)
plt.xlabel("Principal Component 1")
plt.xticks(fontsize=12)
plt.ylabel ("Principal Component 2")
plt.yticks(fontsize=12)
plt.legend(bbox_to_anchor=(1.02, 0.75),
           loc=2,
           borderaxespad=0.,
           fontsize=12,
           frameon=False)
plt.savefig(f"{figure_folder_path}/cond_embed_oneshot.png",
            bbox_inches="tight")

### 3.3 Query-to-Reference Mapping

#### 3.3.1 Building the Reference

##### 3.3.1.1 Initialize, Train & Save Model

In [None]:
# Initialize model
model = Autotalker(adata_reference,
                   counts_key=counts_key,
                   adj_key=adj_key,
                   condition_key=condition_key,
                   cond_embed_injection=["encoder",
                                         "gene_expr_decoder",
                                         "graph_decoder"],
                   n_cond_embed=n_cond_embed,
                   gp_names_key=gp_names_key,
                   active_gp_names_key=active_gp_names_key,
                   gp_targets_mask_key=gp_targets_mask_key,
                   gp_sources_mask_key=gp_sources_mask_key,
                   latent_key=latent_key,
                   active_gp_thresh_ratio=active_gp_thresh_ratio,
                   gene_expr_recon_dist=gene_expr_recon_dist,
                   n_hidden_encoder=n_hidden_encoder,
                   log_variational=log_variational)

In [None]:
# Train model
model.train(n_epochs,
            n_epochs_all_gps=n_epochs_all_gps,
            lr=lr,
            lambda_edge_recon=lambda_edge_recon,
            lambda_gene_expr_recon=lambda_gene_expr_recon,
            edge_batch_size=edge_batch_size,
            node_batch_size=node_batch_size,
            verbose=True)

In [None]:
# Save trained model
model.save(dir_path=model_artifacts_folder_path + "/reference",
           overwrite=True,
           save_adata=True,
           adata_file_name=f"{dataset}.h5ad")

##### 3.3.1.2 Load Model

In [None]:
if load_timestamp is not None:
    model_artifacts_load_folder_path = f"../artifacts/{dataset}/batch_integration/{load_timestamp}"
else:
    model_artifacts_load_folder_path = model_artifacts_folder_path

# Load trained model
model = Autotalker.load(dir_path=model_artifacts_load_folder_path + "/reference",
                        adata=None,
                        adata_file_name=f"{dataset}.h5ad",
                        gp_names_key="autotalker_gp_names")

##### 3.3.1.3 Visualize Latent Space

In [None]:
# Use Autotalker latent space for UMAP generation
sc.pp.neighbors(model.adata, use_rep=latent_key)
sc.tl.umap(model.adata)

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(model.adata,
                 color=[condition_key],
                 legend_fontsize=12,
                 return_fig=True)
plt.title("Reference: Latent Space Batch Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_batches_reference.png",
            bbox_inches="tight")

In [None]:
# Plot UMAP with cell type annotations
fig = sc.pl.umap(model.adata,
                 color=[cell_type_key],
                 return_fig=True)
plt.title("Reference: Latent Space Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_reference.png",
            bbox_inches="tight")

#### 3.3.2 Mapping the Query

##### 3.3.2.1 Initialize, Train & Save Model

In [None]:
if load_timestamp is not None:
    model_artifacts_load_folder_path = f"../artifacts/{dataset}/batch_integration/{load_timestamp}"
else:
    model_artifacts_load_folder_path = model_artifacts_folder_path

# Load model trained on reference data for transfer learning with query data    
# Freeze all weights except for conditional weights
model = Autotalker.load(dir_path=model_artifacts_load_folder_path + "/reference",
                        adata=adata_query,
                        adata_file_name=f"{dataset}.h5ad",
                        gp_names_key="autotalker_gp_names",
                        unfreeze_all_weights=False,
                        unfreeze_cond_embed_weights=True)

In [None]:
# Train model
model.train(n_epochs,
            n_epochs_all_gps=n_epochs_all_gps,
            lr=query_cond_embed_lr,
            lambda_edge_recon=lambda_edge_recon,
            lambda_gene_expr_recon=lambda_gene_expr_recon,
            verbose=True)

In [None]:
# Save trained model
model.save(dir_path=model_artifacts_folder_path + "/query",
           overwrite=True,
           save_adata=True,
           adata_file_name=f"{dataset}.h5ad")

##### 3.3.2.2 Load & Save Model with Full Dataset & UMAP Features

In [None]:
if load_timestamp is not None:
    model_artifacts_load_folder_path = f"../artifacts/{dataset}/batch_integration/{load_timestamp}"
else:
    model_artifacts_load_folder_path = model_artifacts_folder_path

# Load trained model
model = Autotalker.load(dir_path=model_artifacts_load_folder_path + "/query",
                        adata=adata_reference_query,
                        adata_file_name=f"{dataset}.h5ad",
                        gp_names_key="autotalker_gp_names")

In [None]:
# Store latent representation
model.adata.obsm[latent_key] = model.get_latent_representation(
    counts_key=counts_key,
    condition_key=condition_key)

In [None]:
# Store active gene programs
model.adata.uns[active_gp_names_key] = model.get_active_gps()

In [None]:
# Compute latent nearest neighbor graph
sc.pp.neighbors(model.adata,
                use_rep=latent_key,
                key_added=latent_knng_key)

# Use Autotalker latent space for UMAP generation
sc.tl.umap(model.adata,
           neighbors_key=latent_knng_key)

# Compute spatial nearest neighbor graph
sc.pp.neighbors(model.adata,
                use_rep=spatial_key,
                key_added=spatial_knng_key)

In [None]:
# Save trained model
model.save(dir_path=model_artifacts_folder_path + "/reference_query",
           overwrite=True,
           save_adata=True,
           adata_file_name=f"{dataset}.h5ad")

In [None]:
# Save adata
model.adata.write(f"{model_artifacts_folder_path}/{dataset}_integrated.h5ad")

##### 3.3.2.3 Visualize Latent Space

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(model.adata,
                 color=[condition_key],
                 legend_fontsize=12,
                 return_fig=True)
plt.title("Reference + Query: Latent Space Batch Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_batches_reference_query.png",
            bbox_inches="tight")

In [None]:
# Plot UMAP with mapping entity annotations
fig = sc.pl.umap(model.adata,
                 color=[mapping_entity_key],
                 legend_fontsize=12,
                 return_fig=True)
plt.title("Reference + Query: Latent Space Mapping Entity Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_mapping_entities_reference_query.png",
            bbox_inches="tight") 

In [None]:
# Plot UMAP with cell type annotations
fig = sc.pl.umap(model.adata,
                 color=[cell_type_key],
                 return_fig=True)
plt.title("Reference + Query: Latent Space Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_reference_query.png",
            bbox_inches="tight")

##### 3.3.2.4 Compute Metrics

In [None]:
# Compute metrics
metrics_dict_reference_query = {}

# Spatial conservation metrics
metrics_dict_reference_query["cad"] = compute_cad(
    adata=model.adata,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)
metrics_dict_reference_query["rclisi"] = compute_rclisi(
    adata=model.adata,
    cell_type_key=cell_type_key,
    spatial_knng_key=spatial_knng_key,
    latent_knng_key=latent_knng_key)
    
# Batch correction metrics
metrics_dict_reference_query["batch_asw"] = scib.me.silhouette_batch(
    adata=model.adata,
    batch_key=condition_key,
    label_key=cell_type_key,
    embed=latent_key)
metrics_dict_reference_query["ilisi"] = scib.me.ilisi_graph(
    adata=model.adata,
    batch_key=condition_key,
    type_="embed",
    use_rep=latent_key)

print(metrics_dict_reference_query)

# Store metrics to disk
with open(f"{model_artifacts_folder_path}/metrics_reference_query.pickle", "wb") as f:
    pickle.dump(metrics_dict_reference_query, f)

##### 3.3.2.5 Visualize Conditional Embedding

In [None]:
# Get conditional embeddings
cond_embed = model.get_conditional_embeddings()
cond = model.adata.obs["batch"].unique()

# Get top 2 principal components and plot them
pca = KernelPCA(n_components=2, kernel="linear")
cond_embed_pca = pca.fit_transform(cond_embed)
sns.scatterplot(x=cond_embed_pca[:, 0], 
                y=cond_embed_pca[:, 1], 
                hue=cond)
plt.title("Reference + Query Conditional Embeddings", pad=15)
plt.xlabel("Principal Component 1")
plt.xticks(fontsize=12)
plt.ylabel ("Principal Component 2")
plt.yticks(fontsize=12)
plt.legend(bbox_to_anchor=(1.02, 0.75),
           loc=2,
           borderaxespad=0.,
           fontsize=12,
           frameon=False)
plt.savefig(f"{figure_folder_path}/cond_embed_reference_query.png",
            bbox_inches="tight")