## Tahoe-100M Embeddings

In [53]:
%load_ext autoreload
%autoreload 2

import sys
import os
import warnings
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import torch
import scanpy as sc
from sklearn.metrics import (
    silhouette_score, 
    adjusted_rand_score, 
    normalized_mutual_info_score,
    homogeneity_score,
)
import anndata as AnnData
import scvi
from scvi.model import SCVI
from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection

sys.path.append(os.path.abspath("../src"))
from model import InformedSCVI
from pathway import get_pathway_masks, get_random_masks, filter_genes
from train import plot_loss, split_train_val

warnings.simplefilter("ignore")
sc.logging.print_header()
sc.settings.figsize = (10, 10)
sc.settings.figdir = "../results/figures/"
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)
sns.set_theme()
torch.set_float32_matmul_precision("high")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


[rank: 0] Seed set to 0


Last run with scvi-tools version: 1.4.0.post1


In [33]:
CELL_TYPE_KEY = 'cell_name'
BATCH_KEY = 'plate'
# KEY_OF_INTEREST = 'condition'

# Check GPU availability
print(f"Torch cuda available: {torch.cuda.is_available()}", flush=True)

Torch cuda available: True


In [34]:
def run_model(adata, 
              activation="relu",
              masks=None,
              likelihood="normal",
              layer="RNA", 
              key="scVI"):
    
    # Define model
    InformedSCVI.setup_anndata(
        adata,
        # layer=layer,
        batch_key=BATCH_KEY,
    )
    print(type(masks))
    model = InformedSCVI(
        adata,
        gene_likelihood=likelihood,
        activation=activation,
        masks=masks,
    )

    # Train model
    train_indices, val_indices = split_train_val(adata, CELL_TYPE_KEY)
    early_stopping_kwargs = {
        "early_stopping": True,
        "early_stopping_patience": 15,
        "early_stopping_monitor": "validation_loss",
    }
    datasplitter_kwargs = {
        "external_indexing": [train_indices, val_indices],
    }
    model.train(datasplitter_kwargs=datasplitter_kwargs, **early_stopping_kwargs) 

    # Save model
    model.save(f"../results/models/{key}", overwrite=True, save_anndata=False)


    # Save latent representation
    latent = model.get_latent_representation()
    adata.obsm[f"X_{key}"] = latent
    
    return adata

In [35]:
def compute_metrics(adata, key):
    model = InformedSCVI.load(f"../results/models/{key}", adata=adata)

    plot_loss(model, save_path=f"../results/figures/{key}/loss.png")
    
    # Clustering
    sc.pp.neighbors(adata, use_rep=f"X_{key}", key_added=f"neighbors_{key}")
    sc.tl.umap(adata, min_dist=0.3, key_added=f"umap_{key}", neighbors_key=f"neighbors_{key}")
    sc.tl.leiden(adata, key_added=f"leiden_{key}", resolution=0.1, neighbors_key=f"neighbors_{key}")
    sc.settings.figdir = f"../results/figures/{key}/"
    for feature in [f"leiden_{key}", CELL_TYPE_KEY, BATCH_KEY]:
        sc.pl.embedding(adata, basis=f"umap_{key}", color=feature, title=feature, frameon=False, show=False, save=f"_{feature}.png")

    # Calculate metrics

    elbo = model.get_elbo()
    reconstruction_error = model.get_reconstruction_error()['reconstruction_loss']
    #transform tensors to floats
    elbo = elbo.item()
    reconstruction_error = reconstruction_error.item()

    silhouette = silhouette_score(adata.obsm[f"X_{key}"], adata.obs[f"leiden_{key}"], metric="euclidean")
    ARI = adjusted_rand_score(adata.obs[CELL_TYPE_KEY], adata.obs[f"leiden_{key}"])
    NMI = normalized_mutual_info_score(adata.obs[CELL_TYPE_KEY], adata.obs[f"leiden_{key}"])
    homogeneity = homogeneity_score(adata.obs[CELL_TYPE_KEY], adata.obs[f"leiden_{key}"])
    metrics = pd.DataFrame(
        {
            "elbo": [elbo],
            "reconstruction_error": [reconstruction_error],
            "silhouette": [silhouette],
            "ARI": [ARI],
            "NMI": [NMI],
            "homogeneity": [homogeneity],
        }
    )
    metrics.to_csv(f"../results/models/{key}/metrics.csv", index=False)

    return

In [36]:
adata = sc.read_h5ad("/cluster/work/bewi/members/rquiles/experiments/datasets/3_cells_2_drugs_balanced_counts.h5ad")

In [37]:
genes_per_pathway, genes_per_circuit, circuits_per_pathway = get_pathway_masks()
adata, genes_per_pathway, genes_per_circuit, circuits_per_pathway = filter_genes(adata, genes_per_pathway, genes_per_circuit, circuits_per_pathway)
frac = genes_per_circuit.sum(axis=1).mean() / genes_per_circuit.shape[1]
rnd_genes_per_pathway, rnd_genes_per_circuit, rnd_circuits_per_pathway = get_random_masks(adata.var_names, genes_per_circuit.shape[0], genes_per_pathway.shape[0], frac=frac, seed=42)

masks_keggNB = [genes_per_circuit, circuits_per_pathway]
print(f"Masks dimensions: {[mask.shape for mask in masks_keggNB]}")
masks_keggNB_rnd = [rnd_genes_per_circuit, rnd_circuits_per_pathway]
masks_keggNB_pathways = [genes_per_pathway]
print(f"Masks dimensions: {[mask.shape for mask in masks_keggNB_pathways]}")
masks_keggNB_pathways_rnd = [rnd_genes_per_pathway]

Current directory: /cluster/work/bewi/members/rquiles/piscvi/notebooks
Filtering genes based on pathways and circuits with minimum 4 genes per circuit
Circuits in genes_per_circuit: 1721
Circuits in circuits_per_pathway: 1721
Masks dimensions: [(1721, 1388), (367, 1721)]
Masks dimensions: [(367, 1388)]


In [38]:
adata = adata.copy()

In [39]:
adata = run_model(
    adata,
    key="piscVI",
    masks=masks_keggNB,
)

setup_anndata: InformedSCVI
<class 'list'>
Hello
Activation function encoder: ReLU
Activation function encoder: ReLU
Activation function decoder: ReLU
2 masks found for module encoder.
Mask dimensions for module encoder: [torch.Size([1721, 1388]), torch.Size([367, 1721])]
Trainer kwargs:  {'early_stopping_patience': 15, 'early_stopping_monitor': 'validation_loss', 'early_stopping': True}


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Training:   0%|          | 0/390 [00:00<?, ?it/s]

Monitored metric validation_loss did not improve in the last 15 records. Best score: -101.445. Signaling Trainer to stop.


In [10]:
compute_metrics(adata, key="piscVI")

[34mINFO    [0m File ..[35m/results/models/piscVI/[0m[95mmodel.pt[0m already downloaded                                                 
setup_anndata: InformedSCVI
Hello
Activation function encoder: ReLU
Activation function encoder: ReLU
Activation function decoder: ReLU


<Figure size 1000x600 with 0 Axes>

## Visualize Gene Activations

In [None]:
adata = sc.read_h5ad("/cluster/work/bewi/members/rquiles/experiments/datasets/3_cells_2_drugs_balanced_counts.h5ad")

In [42]:
key = "piscVI"
model = InformedSCVI.load(f"../results/models/{key}", adata=adata)

[34mINFO    [0m File ..[35m/results/models/piscVI/[0m[95mmodel.pt[0m already downloaded                                                 
setup_anndata: InformedSCVI
Hello
Activation function encoder: ReLU
Activation function encoder: ReLU
Activation function decoder: ReLU


In [43]:
activations = model.get_hidden_activations()
layer_0 = activations["layer_0"]
layer_1 = activations["layer_1"]

This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has b

In [44]:
activations

{'layer_0': array([[ 0.00233036, -0.0039171 ,  0.01279042, ...,  0.02626238,
         -0.02367824,  0.05844731],
        [ 0.00233036, -0.0039171 ,  0.01279042, ...,  0.02626238,
         -0.02367824,  0.07871861],
        [ 0.00233036, -0.0039171 ,  0.01279042, ...,  0.00346575,
         -0.02367824, -0.01221614],
        ...,
        [ 0.00233036, -0.0039171 ,  0.01279042, ...,  0.02626238,
         -0.02367824,  0.01648564],
        [ 0.00233036, -0.0039171 ,  0.01279042, ..., -0.01955977,
         -0.08492702, -0.01570278],
        [-0.11961788, -0.09138048,  0.01279042, ...,  0.02626238,
         -0.02367824,  0.01648564]], shape=(20488, 1721), dtype=float32),
 'layer_1': array([[-0.00286387,  0.00526624, -0.00824141, ...,  0.07292737,
          0.02240492, -0.04293919],
        [-0.00286387,  0.00526624, -0.00824141, ..., -0.02267509,
          0.02240492, -0.04625656],
        [-0.00286387,  0.00526624, -0.00824141, ...,  0.06110856,
          0.02240492, -0.03573788],
        .

In [50]:
layer_0.min()

np.float32(-29.211077)

In [62]:
sys.path.append(os.path.abspath("../experiments/"))
from plot_latent import visualize_piscvi_results

In [63]:
    adata, activations = visualize_piscvi_results(
        model_path=f"../results/models/{key}",
        adata=adata,
        drug_key="Agg_Treatment",
        dose_key="dose",
        cell_type_key="cell_name",
        n_top_pathways=10,
        layers_to_plot=[0, 1],  # 0=circuits, 1=pathways
        output_dir=f"../results/figures/{key}/latent_figures/"
    )

Loading model...
[34mINFO    [0m File ..[35m/results/models/piscVI/[0m[95mmodel.pt[0m already downloaded                                                 
setup_anndata: InformedSCVI
Hello
Activation function encoder: ReLU
Activation function encoder: ReLU
Activation function decoder: ReLU
Computing latent representation...
Extracting pathway activations...
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been changed
This has been chang

In [65]:
adata.obs

Unnamed: 0_level_0,sample,gene_count,tscp_count,mread_count,drugname_drugconc,Agg_Treatment,covariates,sublibrary,BARCODE,pcnt_mito,...,pathway_404,pathway_382,pathway_221,pathway_387,pathway_159,pathway_129,pathway_169,pathway_153,pathway_156,pathway_218
BARCODE_SUB_LIB_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
69_037_113-lib_2241,smp_2427,862,1097,1314,"[('Trametinib', 0.05, 'uM')]",Trametinib,CVCL_0028,lib_2241,69_037_113,0.051960,...,0.162846,0.187026,0.021679,0.563346,0.207700,-0.019487,0.049504,0.028986,1.596608,-0.043192
43_011_103-lib_1042,smp_1729,1766,2561,3023,"[('Sapanisertib', 5.0, 'uM')]",Sapanisertib,CVCL_0069,lib_1042,43_011_103,0.042952,...,0.273617,0.220565,0.348583,-0.002847,0.051011,0.007353,0.424217,0.228631,0.500833,-0.055742
95_185_127-lib_1699,smp_2453,1223,1709,2055,"[('DMSO_TF', 0.0, 'uM')]",DMSO_TF,CVCL_0069,lib_1699,95_185_127,0.100644,...,0.246975,0.205794,0.185131,-0.062440,0.012405,-0.038922,0.087066,0.575540,-0.087847,-0.051648
35_007_087-lib_870,smp_1529,1453,2096,2511,"[('Trametinib', 0.05, 'uM')]",Trametinib,CVCL_0028,lib_870,35_007_087,0.051527,...,0.539087,0.535994,0.021679,0.388404,0.079010,0.009390,-0.008491,0.876937,4.916120,0.139070
69_186_107-lib_2451,smp_2619,711,863,992,"[('Trametinib', 5.0, 'uM')]",Trametinib,CVCL_0023,lib_2451,69_186_107,0.040556,...,0.043453,0.096934,0.185131,0.194140,0.037377,0.044422,0.131159,0.102234,0.492144,0.019341
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
35_043_115-lib_1007,smp_1625,1522,2253,2703,"[('Trametinib', 0.5, 'uM')]",Trametinib,CVCL_0069,lib_1007,35_043_115,0.134043,...,0.279842,0.207411,-0.063709,-0.112196,0.219924,0.302270,-0.005152,0.039609,-0.117815,0.019701
69_115_188-lib_1914,smp_2619,1785,2835,3232,"[('Trametinib', 5.0, 'uM')]",Trametinib,CVCL_0023,lib_1914,69_115_188,0.068078,...,0.046078,0.405959,0.710400,0.570463,0.253615,-0.010305,1.205122,1.201695,0.468897,0.145599
95_039_098-lib_1723,smp_2453,507,645,773,"[('DMSO_TF', 0.0, 'uM')]",DMSO_TF,CVCL_0069,lib_1723,95_039_098,0.085271,...,0.290203,0.171996,0.039135,-0.173859,2.491680,0.296231,0.017332,0.028986,0.492144,-0.043192
69_020_022-lib_1899,smp_2619,1733,2375,2801,"[('Trametinib', 5.0, 'uM')]",Trametinib,CVCL_0069,lib_1899,69_020_022,0.038316,...,0.292339,0.206404,-0.028797,-0.249792,0.611815,0.286655,0.018918,0.039130,0.588203,1.408273
