# SageNet Batch Integration Mouse Organogenesis Imputed

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 30.01.2023
- **Date of Last Modification:** 01.02.2023

## 1. Setup

### 1.1 Import Libraries

In [13]:
import sys
sys.path.append("../../../autotalker")

In [14]:
import copy
import os
import random
import warnings
from datetime import datetime

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scarches as sca
import scanpy as sc
import scipy.sparse as sp
import squidpy as sq
import torch
import torch_geometric.data as geo_dt
from scarches.models.sagenet.utils import glasso

from autotalker.utils import (extract_gp_dict_from_mebocost_es_interactions,
                              extract_gp_dict_from_nichenet_ligand_target_mx,
                              extract_gp_dict_from_omnipath_lr_interactions,
                              get_unique_genes_from_gp_dict)

ModuleNotFoundError: No module named 'pyreadr'

### 1.2 Define Parameters

In [3]:
## Dataset
dataset = "seqfish_mouse_organogenesis"
batch1 = "embryo1_z2"
batch2 = "embryo1_z5"
batch3 = "embryo2_z2"
batch4 = "embryo2_z5"
batch5 = "embryo3_z2"
batch6 = "embryo3_z5"
n_neighbors = 4
reference_removed_cell_type = "Presomitic mesoderm"

spatial_key = "spatial"

## Others
random_seed = 42
load_timestamp = None

### 1.3 Run Notebook Setup

In [4]:
sc.set_figure_params(figsize=(6, 6))

In [5]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

In [6]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Create Directories

In [7]:
# Define paths
data_folder_path = "../../datasets/srt_data/gold/"
figure_folder_path = f"../../figures/{dataset}/batch_integration/{current_timestamp}"
model_artifacts_folder_path = f"../../artifacts/{dataset}/batch_integration/{current_timestamp}"

# Create required directories
os.makedirs(figure_folder_path, exist_ok=True)
os.makedirs(model_artifacts_folder_path, exist_ok=True)

## 2. Data

### 2.1 Load Data

In [8]:
adata = ad.read_h5ad(f"{data_folder_path}/{dataset}_imputed.h5ad")

### 2.2 Prepare Data for Model Training

In [9]:
nichenet_gp_dict = extract_gp_dict_from_nichenet_ligand_target_mx(
    keep_target_ratio=0.01,
    load_from_disk=True,
    save_to_disk=False,
    file_path=nichenet_ligand_target_mx_file_path)

NameError: name 'extract_gp_dict_from_nichenet_ligand_target_mx' is not defined

In [78]:
omnipath_gp_dict = extract_gp_dict_from_omnipath_lr_interactions(
    min_curation_effort=0,
    load_from_disk=False,
    save_to_disk=True,
    file_path=omnipath_lr_interactions_file_path)

In [79]:
mebocost_gp_dict = extract_gp_dict_from_mebocost_es_interactions(
    dir_path=f"{gp_data_folder_path}/metabolite_enzyme_sensor_gps/",
    species="mouse",
    genes_uppercase=True)

In [80]:
# Combine gene programs into one dictionary
combined_gp_dict = dict(nichenet_gp_dict)
combined_gp_dict.update(omnipath_gp_dict)
combined_gp_dict.update(mebocost_gp_dict)

In [106]:
# Filter genes and only keep ligand, receptor, metabolitye enzyme, metabolite sensor and 
# highly variable genes (potential target genes of nichenet) 
omnipath_genes = get_unique_genes_from_gp_dict(gp_dict=omnipath_gp_dict,
                                               retrieved_gene_entities=["sources", "targets"])

mebocost_genes = get_unique_genes_from_gp_dict(gp_dict=mebocost_gp_dict,
                                               retrieved_gene_entities=["sources", "targets"])

nichenet_source_genes = get_unique_genes_from_gp_dict(gp_dict=nichenet_gp_dict,
                                                      retrieved_gene_entities=["sources"])

gp_relevant_genes = list(set(omnipath_genes + mebocost_genes + nichenet_source_genes))

# Keep only highly variable genes
sc.pp.highly_variable_genes(
    adata,
    n_top_genes=1000,
    batch_key="batch",
    subset=False)

adata.var["gp_relevant"] = adata.var.index.str.upper().isin(gp_relevant_genes)
adata.var["gp_relevant"] = adata.var["gp_relevant"] | adata.var["highly_variable"]
print(f"Keeping {adata.var['gp_relevant'].sum()} relevant gene program genes.")

Keeping 3475 relevant gene program genes.


In [10]:
# Split adata into different batches
adata_batch1 = adata[adata.obs.index.str.startswith("embryo1") &
                     adata.obs.index.str.endswith("z2")] # reference
adata_batch2 = adata[adata.obs.index.str.startswith("embryo1") &
                     adata.obs.index.str.endswith("z5")] # reference
adata_batch3 = adata[adata.obs.index.str.startswith("embryo2") &
                     adata.obs.index.str.endswith("z2")] # reference
adata_batch4 = adata[adata.obs.index.str.startswith("embryo2") &
                     adata.obs.index.str.endswith("z5")] # reference
adata_batch5 = adata[adata.obs.index.str.startswith("embryo3") &
                     adata.obs.index.str.endswith("z2")] # query
adata_batch6 = adata[adata.obs.index.str.startswith("embryo3") &
                     adata.obs.index.str.endswith("z5")] # query

adata_batch_list = [adata_batch1,
                    adata_batch2,
                    adata_batch3,
                    adata_batch4,
                    adata_batch5,
                    adata_batch6]

### 2.2 Remove Selected Cell Type from Reference for Recovery by Query

In [11]:
# Artificially remove cell type from reference for recovery by query
for i in range(len(adata_batch_list[:-2])):
    adata_batch_list[i] = adata_batch_list[i][adata_batch_list[i].obs["celltype_mapped_refined"] != reference_removed_cell_type]

### 2.3 Compute Spatial Neighbor Graphs

In [12]:
for i in range(len(adata_batch_list)):
    # Compute (separate) spatial neighborhood
    sq.gr.spatial_neighbors(adata_batch_list[i],
                            coord_type="generic",
                            spatial_key=spatial_key,
                            n_neighs=n_neighbors)
    # Make adjacency matrix symmetric
    adata_batch_list[i].obsp["spatial_connectivities"] = adata_batch_list[i].obsp["spatial_connectivities"].maximum(
        adata_batch_list[i].obsp["spatial_connectivities"].T)

### 2.4 Combine Data for SageNet Integration

In [13]:
adata_sagenet = ad.concat(adata_batch_list, join="inner")

# Combine spatial neighborhood graphs as disconnected components
connectivities_extension_batch1 = sp.csr_matrix((adata_batch_list[0].shape[0],
                                                 (adata_batch_list[1].shape[0] +
                                                  adata_batch_list[2].shape[0] +
                                                  adata_batch_list[3].shape[0] +
                                                  adata_batch_list[4].shape[0] +
                                                  adata_batch_list[5].shape[0])))
connectivities_extension_batch2_before = sp.csr_matrix((adata_batch_list[1].shape[0],
                                                        adata_batch_list[0].shape[0]))
connectivities_extension_batch2_after = sp.csr_matrix((adata_batch_list[1].shape[0],
                                                       (adata_batch_list[2].shape[0] +
                                                        adata_batch_list[3].shape[0] +
                                                        adata_batch_list[4].shape[0] +
                                                        adata_batch_list[5].shape[0])))
connectivities_extension_batch3_before = sp.csr_matrix((adata_batch_list[2].shape[0],
                                                        (adata_batch_list[0].shape[0] +
                                                         adata_batch_list[1].shape[0])))
connectivities_extension_batch3_after = sp.csr_matrix((adata_batch_list[2].shape[0],
                                                       (adata_batch_list[3].shape[0] +
                                                        adata_batch_list[4].shape[0] +
                                                        adata_batch_list[5].shape[0])))
connectivities_extension_batch4_before = sp.csr_matrix((adata_batch_list[3].shape[0],
                                                        (adata_batch_list[0].shape[0] +
                                                         adata_batch_list[1].shape[0] +
                                                         adata_batch_list[2].shape[0])))
connectivities_extension_batch4_after = sp.csr_matrix((adata_batch_list[3].shape[0],
                                                       (adata_batch_list[4].shape[0] +
                                                        adata_batch_list[5].shape[0])))
connectivities_extension_batch5_before = sp.csr_matrix((adata_batch_list[4].shape[0],
                                                        (adata_batch_list[0].shape[0] +
                                                         adata_batch_list[1].shape[0] +
                                                         adata_batch_list[2].shape[0] +
                                                         adata_batch_list[3].shape[0])))
connectivities_extension_batch5_after = sp.csr_matrix((adata_batch_list[4].shape[0],
                                                       adata_batch_list[5].shape[0]))
connectivities_extension_batch6 = sp.csr_matrix((adata_batch_list[5].shape[0],
                                                 (adata_batch_list[0].shape[0] +
                                                  adata_batch_list[1].shape[0] +
                                                  adata_batch_list[2].shape[0] +
                                                  adata_batch_list[3].shape[0] +
                                                  adata_batch_list[4].shape[0])))

connectivities_batch1 = sp.hstack((adata_batch_list[0].obsp["spatial_connectivities"],
                                   connectivities_extension_batch1))
connectivities_batch2 = sp.hstack((connectivities_extension_batch2_before,
                                   adata_batch_list[1].obsp["spatial_connectivities"],
                                   connectivities_extension_batch2_after))
connectivities_batch3 = sp.hstack((connectivities_extension_batch3_before,
                                   adata_batch_list[2].obsp["spatial_connectivities"],
                                   connectivities_extension_batch3_after))
connectivities_batch4 = sp.hstack((connectivities_extension_batch4_before,
                                   adata_batch_list[3].obsp["spatial_connectivities"],
                                   connectivities_extension_batch4_after))
connectivities_batch5 = sp.hstack((connectivities_extension_batch5_before,
                                   adata_batch_list[4].obsp["spatial_connectivities"],
                                   connectivities_extension_batch5_after))
connectivities_batch6 = sp.hstack((connectivities_extension_batch6,
                                   adata_batch_list[5].obsp["spatial_connectivities"]))

connectivities = sp.vstack((connectivities_batch1,
                            connectivities_batch2,
                            connectivities_batch3,
                            connectivities_batch4,
                            connectivities_batch5,
                            connectivities_batch6))

adata_sagenet.obsp["spatial_connectivities"] = connectivities

In [14]:
# Construct gene interaction network for spatial references
for i in range(len(adata_batch_list[:-2])):
    print("Computing gene interaction network...")
    glasso(adata_batch_list[i], [0.25, 0.5])
    print("Computing Leiden clusters...")
    sc.tl.leiden(adata_batch_list[i],
                 resolution=.05,
                 random_state=0,
                 key_added='leiden_0.05',
                 adjacency=adata_batch_list[i].obsp["spatial_connectivities"])
    sc.tl.leiden(adata_batch_list[i],
                 resolution=.1,
                 random_state=0,
                 key_added='leiden_0.1',
                 adjacency=adata_batch_list[i].obsp["spatial_connectivities"])
    sc.tl.leiden(adata_batch_list[i],
                 resolution=.5, random_state=0,
                 key_added='leiden_0.5',
                 adjacency=adata_batch_list[i].obsp["spatial_connectivities"])

Computing gene interaction network...



KeyboardInterrupt



In [None]:
# Define model object
sg_obj = sca.models.sagenet(device=device)

In [None]:
for i in range(len(adata_batch_list[:-2])):
    sg_obj.add_ref(adata_batch_list[i],
                   comm_columns=['leiden_0.05', 'leiden_0.1', 'leiden_0.5'],
                   tag=f'batch{i}',
                   epochs=15,
                   verbose = False)

In [None]:
# Use GPU if available
if torch.cuda.is_available():
  dev = "cuda:0"
else:
  dev = "cpu"
device = torch.device(dev)
print(device)

# Estimate gene interaction network
glasso(adata_sagenet, [0.25, 0.5]) # the adjacency matrix of the built graph is added under 'adata.varm["adj"]'
    
# Compute spatial partitoning with 3 different resolutions to capture different granularities
sc.tl.leiden(adata_sagenet,
             resolution=.05,
             random_state=random_seed,
             key_added="leiden_0.05",
             adjacency=adata_sagenet.obsp["spatial_connectivities"]) # the partitioning is added under 'adata.obs["leiden_0.05"]'
sc.tl.leiden(adata_sagenet,
             resolution=.1,
             random_state=random_seed,
             key_added="leiden_0.1",
             adjacency=adata_sagenet.obsp["spatial_connectivities"]) # the partitioning is added under 'adata.obs["leiden_0.1"]'
sc.tl.leiden(adata_sagenet,
             resolution=.5,
             random_state=random_seed,
             key_added="leiden_0.5",
             adjacency=adata_sagenet.obsp["spatial_connectivities"]) # the partitioning is added under 'adata.obs["leiden_0.5"]'
sc.pl.spatial(adata_sagenet,
              color=["leiden_0.05", "leiden_0.1", "leiden_0.5"],
              frameon=False,
              ncols=3,
              spot_size=.1,
              title=["leiden_0.05", "leiden_0.1", "leiden_0.5"],
              legend_loc=None)

print("Training SageNet model...")

# Define model object
sg_obj = sca.models.sagenet(device=device)

sg_obj.train(adata_sagenet,
             comm_columns=["leiden_0.05", "leiden_0.1", "leiden_0.5"],
             tag="integrated",
             epochs=15,
             verbose = False,
             importance=True)

In [None]:
    sg_obj.load_query_data(adata)
    
    # Use SageNet cell-cell-distances for UMAP generation
    sc.pp.neighbors(adata_one_shot, use_rep="dist_map")
    sc.tl.umap(adata_one_shot)
    fig = sc.pl.umap(adata_one_shot,
                     color=[cell_type_key],
                     title="Latent Space with Cell Types: SageNet",
                     return_fig=True)
    fig.savefig(f"{figure_folder_path}/latent_sagenet_cell_types_run_{run_number}_{current_timestamp}.png",
                bbox_inches="tight")
    
    # Compute latent Leiden clustering
    sc.tl.leiden(adata=adata,
                 resolution=leiden_resolution,
                 random_state=random_seed,
                 key_added=f"latent_sagenet_leiden_{str(leiden_resolution)}")
    
    # Create subplot of latent Leiden cluster annotations in physical and latent space
    fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(6, 12))
    title = fig.suptitle(t="Latent and Physical Space with Leiden Clusters: SageNet")
    sc.pl.umap(adata=adata,
               color=[f"latent_sagenet_leiden_{str(leiden_resolution)}"],
               title=f"Latent Space with Leiden Clusters",
               ax=axs[0],
               show=False)
    sc.pl.spatial(adata=adata,
                  color=[f"latent_sagenet_leiden_{str(leiden_resolution)}"],
                  spot_size=0.03,
                  title=f"Physical Space with Leiden Clusters",
                  ax=axs[1],
                  show=False)

    # Create and position shared legend
    handles, labels = axs[0].get_legend_handles_labels()
    lgd = fig.legend(handles, labels, bbox_to_anchor=(1.25, 0.9185))
    axs[0].get_legend().remove()
    axs[1].get_legend().remove()

    # Adjust, save and display plot
    plt.subplots_adjust(wspace=0, hspace=0.2)
    fig.savefig(f"{figure_folder_path}/latent_physical_comparison_sagenet_leiden_run_{run_number}_{current_timestamp}.png",
                bbox_extra_artists=(lgd, title),
                bbox_inches="tight")
    plt.show()
    
    # Use UMAP embedding of cell-cell distances as latent features
    adata_original.obsm[latent_key + f"_run{run_number}"] = adata.obsm["X_umap"]